-
Notifications
You must be signed in to change notification settings - Fork 24
/
yolo_v3.py
292 lines (219 loc) · 10.3 KB
/
yolo_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
_BATCH_NORM_DECAY = 0.9
_BATCH_NORM_EPSILON = 1e-05
_LEAKY_RELU = 0.1
_ANCHORS = [(10, 13), (16, 30), (33, 23),
(30, 61), (62, 45), (59, 119),
(116, 90), (156, 198), (373, 326)]
def darknet53(inputs):
"""
Builds Darknet-53 model.
"""
inputs = _conv2d_fixed_padding(inputs, 32, 3)
inputs = _conv2d_fixed_padding(inputs, 64, 3, strides=2)
inputs = _darknet53_block(inputs, 32)
inputs = _conv2d_fixed_padding(inputs, 128, 3, strides=2)
for i in range(2):
inputs = _darknet53_block(inputs, 64)
inputs = _conv2d_fixed_padding(inputs, 256, 3, strides=2)
for i in range(8):
inputs = _darknet53_block(inputs, 128)
route_1 = inputs
inputs = _conv2d_fixed_padding(inputs, 512, 3, strides=2)
for i in range(8):
inputs = _darknet53_block(inputs, 256)
route_2 = inputs
inputs = _conv2d_fixed_padding(inputs, 1024, 3, strides=2)
for i in range(4):
inputs = _darknet53_block(inputs, 512)
return route_1, route_2, inputs
def _conv2d_fixed_padding(inputs, filters, kernel_size, strides=1):
if strides > 1:
inputs = _fixed_padding(inputs, kernel_size)
inputs = slim.conv2d(inputs, filters, kernel_size, stride=strides,
padding=('SAME' if strides == 1 else 'VALID'))
return inputs
def _darknet53_block(inputs, filters):
shortcut = inputs
inputs = _conv2d_fixed_padding(inputs, filters, 1)
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
inputs = inputs + shortcut
return inputs
def _spp_block(inputs, data_format='NCHW'):
return tf.concat([slim.max_pool2d(inputs, 13, 1, 'SAME'),
slim.max_pool2d(inputs, 9, 1, 'SAME'),
slim.max_pool2d(inputs, 5, 1, 'SAME'),
inputs],
axis=1 if data_format == 'NCHW' else 3)
@tf.contrib.framework.add_arg_scope
def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
"""
Pads the input along the spatial dimensions independently of input size.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
Should be a positive integer.
data_format: The input format ('NHWC' or 'NCHW').
mode: The mode for tf.pad.
Returns:
A tensor with the same format as the input with the data either intact
(if kernel_size == 1) or padded (if kernel_size > 1).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if kwargs['data_format'] == 'NCHW':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
[pad_beg, pad_end],
[pad_beg, pad_end]],
mode=mode)
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]], mode=mode)
return padded_inputs
def _yolo_block(inputs, filters, data_format='NCHW', with_spp=False):
inputs = _conv2d_fixed_padding(inputs, filters, 1)
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
inputs = _conv2d_fixed_padding(inputs, filters, 1)
if with_spp:
inputs = _spp_block(inputs, data_format)
inputs = _conv2d_fixed_padding(inputs, filters, 1)
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
inputs = _conv2d_fixed_padding(inputs, filters, 1)
route = inputs
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
return route, inputs
def _get_size(shape, data_format):
if len(shape) == 4:
shape = shape[1:]
return shape[1:3] if data_format == 'NCHW' else shape[0:2]
def _detection_layer(inputs, num_classes, anchors, img_size, data_format):
num_anchors = len(anchors)
predictions = slim.conv2d(inputs, num_anchors * (5 + num_classes), 1,
stride=1, normalizer_fn=None,
activation_fn=None,
biases_initializer=tf.zeros_initializer())
shape = predictions.get_shape().as_list()
grid_size = _get_size(shape, data_format)
dim = grid_size[0] * grid_size[1]
bbox_attrs = 5 + num_classes
if data_format == 'NCHW':
predictions = tf.reshape(
predictions, [-1, num_anchors * bbox_attrs, dim])
predictions = tf.transpose(predictions, [0, 2, 1])
predictions = tf.reshape(predictions, [-1, num_anchors * dim, bbox_attrs])
stride = (img_size[0] // grid_size[0], img_size[1] // grid_size[1])
anchors = [(a[0] / stride[0], a[1] / stride[1]) for a in anchors]
box_centers, box_sizes, confidence, classes = tf.split(
predictions, [2, 2, 1, num_classes], axis=-1)
box_centers = tf.nn.sigmoid(box_centers)
confidence = tf.nn.sigmoid(confidence)
grid_x = tf.range(grid_size[0], dtype=tf.float32)
grid_y = tf.range(grid_size[1], dtype=tf.float32)
a, b = tf.meshgrid(grid_x, grid_y)
x_offset = tf.reshape(a, (-1, 1))
y_offset = tf.reshape(b, (-1, 1))
x_y_offset = tf.concat([x_offset, y_offset], axis=-1)
x_y_offset = tf.reshape(tf.tile(x_y_offset, [1, num_anchors]), [1, -1, 2])
box_centers = box_centers + x_y_offset
box_centers = box_centers * stride
anchors = tf.tile(anchors, [dim, 1])
box_sizes = tf.exp(box_sizes) * anchors
box_sizes = box_sizes * stride
detections = tf.concat([box_centers, box_sizes, confidence], axis=-1)
classes = tf.nn.sigmoid(classes)
predictions = tf.concat([detections, classes], axis=-1)
return predictions
def _upsample(inputs, out_shape, data_format='NCHW'):
# tf.image.resize_nearest_neighbor accepts input in format NHWC
if data_format == 'NCHW':
inputs = tf.transpose(inputs, [0, 2, 3, 1])
if data_format == 'NCHW':
new_height = out_shape[3]
new_width = out_shape[2]
else:
new_height = out_shape[2]
new_width = out_shape[1]
inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width))
# back to NCHW if needed
if data_format == 'NCHW':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = tf.identity(inputs, name='upsampled')
return inputs
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False):
"""
Creates YOLO v3 model.
:param inputs: a 4-D tensor of size [batch_size, height, width, channels].
Dimension batch_size may be undefined. The channel order is RGB.
:param num_classes: number of predicted classes.
:param is_training: whether is training or not.
:param data_format: data format NCHW or NHWC.
:param reuse: whether or not the network and its variables should be reused.
:param with_spp: whether or not is using spp layer.
:return:
"""
# it will be needed later on
img_size = inputs.get_shape().as_list()[1:3]
# transpose the inputs to NCHW
if data_format == 'NCHW':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
# normalize values to range [0..1]
inputs = inputs / 255
# set batch norm params
batch_norm_params = {
'decay': _BATCH_NORM_DECAY,
'epsilon': _BATCH_NORM_EPSILON,
'scale': True,
'is_training': is_training,
'fused': None, # Use fused batch norm if possible.
}
# Set activation_fn and parameters for conv2d, batch_norm.
with slim.arg_scope([slim.conv2d, slim.batch_norm, _fixed_padding], data_format=data_format, reuse=reuse):
with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params,
biases_initializer=None,
activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=_LEAKY_RELU)):
with tf.variable_scope('darknet-53'):
route_1, route_2, inputs = darknet53(inputs)
with tf.variable_scope('yolo-v3'):
route, inputs = _yolo_block(inputs, 512, data_format, with_spp)
detect_1 = _detection_layer(
inputs, num_classes, _ANCHORS[6:9], img_size, data_format)
detect_1 = tf.identity(detect_1, name='detect_1')
inputs = _conv2d_fixed_padding(route, 256, 1)
upsample_size = route_2.get_shape().as_list()
inputs = _upsample(inputs, upsample_size, data_format)
inputs = tf.concat([inputs, route_2],
axis=1 if data_format == 'NCHW' else 3)
route, inputs = _yolo_block(inputs, 256)
detect_2 = _detection_layer(
inputs, num_classes, _ANCHORS[3:6], img_size, data_format)
detect_2 = tf.identity(detect_2, name='detect_2')
inputs = _conv2d_fixed_padding(route, 128, 1)
upsample_size = route_1.get_shape().as_list()
inputs = _upsample(inputs, upsample_size, data_format)
inputs = tf.concat([inputs, route_1],
axis=1 if data_format == 'NCHW' else 3)
_, inputs = _yolo_block(inputs, 128)
detect_3 = _detection_layer(
inputs, num_classes, _ANCHORS[0:3], img_size, data_format)
detect_3 = tf.identity(detect_3, name='detect_3')
detections = tf.concat([detect_1, detect_2, detect_3], axis=1)
detections = tf.identity(detections, name='detections')
return detections
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
"""
Creates YOLO v3 with SPP model.
:param inputs: a 4-D tensor of size [batch_size, height, width, channels].
Dimension batch_size may be undefined. The channel order is RGB.
:param num_classes: number of predicted classes.
:param is_training: whether is training or not.
:param data_format: data format NCHW or NHWC.
:param reuse: whether or not the network and its variables should be reused.
:return:
"""
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)