-
Notifications
You must be signed in to change notification settings - Fork 482
/
neural_artistic_style.py
executable file
·138 lines (111 loc) · 4.91 KB
/
neural_artistic_style.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
import os
import argparse
import numpy as np
import scipy.misc
import deeppy as dp
from matconvnet import vgg_net
from style_network import StyleNetwork
def weight_tuple(s):
try:
conv_idx, weight = map(float, s.split(','))
return conv_idx, weight
except:
raise argparse.ArgumentTypeError('weights must by "int,float"')
def float_range(x):
x = float(x)
if x < 0.0 or x > 1.0:
raise argparse.ArgumentTypeError("%r not in range [0, 1]" % x)
return x
def weight_array(weights):
array = np.zeros(19)
for idx, weight in weights:
array[idx] = weight
norm = np.sum(array)
if norm > 0:
array /= norm
return array
def imread(path):
return scipy.misc.imread(path).astype(dp.float_)
def imsave(path, img):
img = np.clip(img, 0, 255).astype(np.uint8)
scipy.misc.imsave(path, img)
def to_bc01(img):
return np.transpose(img, (2, 0, 1))[np.newaxis, ...]
def to_rgb(img):
return np.transpose(img[0], (1, 2, 0))
def run():
parser = argparse.ArgumentParser(
description='Neural artistic style. Generates an image by combining '
'the subject from one image and the style from another.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--subject', required=True, type=str,
help='Subject image.')
parser.add_argument('--style', required=True, type=str,
help='Style image.')
parser.add_argument('--output', default='out.png', type=str,
help='Output image.')
parser.add_argument('--init', default=None, type=str,
help='Initial image. Subject is chosen as default.')
parser.add_argument('--init-noise', default=0.0, type=float_range,
help='Weight between [0, 1] to adjust the noise level '
'in the initial image.')
parser.add_argument('--random-seed', default=None, type=int,
help='Random state.')
parser.add_argument('--animation', default='animation', type=str,
help='Output animation directory.')
parser.add_argument('--iterations', default=500, type=int,
help='Number of iterations to run.')
parser.add_argument('--learn-rate', default=2.0, type=float,
help='Learning rate.')
parser.add_argument('--smoothness', type=float, default=5e-8,
help='Weight of smoothing scheme.')
parser.add_argument('--subject-weights', nargs='*', type=weight_tuple,
default=[(9, 1)],
help='List of subject weights (conv_idx,weight).')
parser.add_argument('--style-weights', nargs='*', type=weight_tuple,
default=[(0, 1), (2, 1), (4, 1), (8, 1), (12, 1)],
help='List of style weights (conv_idx,weight).')
parser.add_argument('--subject-ratio', type=float, default=2e-2,
help='Weight of subject relative to style.')
parser.add_argument('--pool-method', default='avg', type=str,
choices=['max', 'avg'], help='Subsampling scheme.')
parser.add_argument('--network', default='imagenet-vgg-verydeep-19.mat',
type=str, help='Network in MatConvNet format).')
args = parser.parse_args()
if args.random_seed is not None:
np.random.seed(args.random_seed)
layers, pixel_mean = vgg_net(args.network, pool_method=args.pool_method)
# Inputs
style_img = imread(args.style) - pixel_mean
subject_img = imread(args.subject) - pixel_mean
if args.init is None:
init_img = subject_img
else:
init_img = imread(args.init) - pixel_mean
noise = np.random.normal(size=init_img.shape, scale=np.std(init_img)*1e-1)
init_img = init_img * (1 - args.init_noise) + noise * args.init_noise
# Setup network
subject_weights = weight_array(args.subject_weights) * args.subject_ratio
style_weights = weight_array(args.style_weights)
net = StyleNetwork(layers, to_bc01(init_img), to_bc01(subject_img),
to_bc01(style_img), subject_weights, style_weights,
args.smoothness)
# Repaint image
def net_img():
return to_rgb(net.image) + pixel_mean
if not os.path.exists(args.animation):
os.mkdir(args.animation)
params = net.params
learn_rule = dp.Adam(learn_rate=args.learn_rate)
learn_rule_states = [learn_rule.init_state(p) for p in params]
for i in range(args.iterations):
imsave(os.path.join(args.animation, '%.4d.png' % i), net_img())
cost = np.mean(net.update())
for param, state in zip(params, learn_rule_states):
learn_rule.step(param, state)
print('Iteration: %i, cost: %.4f' % (i, cost))
imsave(args.output, net_img())
if __name__ == "__main__":
run()