-
Notifications
You must be signed in to change notification settings - Fork 0
/
layer_graph.py
492 lines (444 loc) · 18.8 KB
/
layer_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import networkx as nx
from enum import Enum
import matplotlib.pyplot as plt
import random
import numpy as np
import copy
LAYERS = Enum('layers', ('conv3', 'conv5', 'conv7', 'maxpool', 'avgpool', 'fc', 'ip', 'op', 'softmax', 'batchnorm', 'resnet'))
from distance import get_distance, clear_distance
import pickle
layers_type_num = 11
layer_graph_count = 0
layer_graph_table = []
LAYERS_UPPER_BOUND = 200
DEGREE_UPPER_BOUND = 20
class Layer_graph(object):
'''
Don't need to speicify input layer
'''
def __init__(self, input_unit=1):
self._graph = nx.DiGraph()
self.layer_count = 0
self.total_lm = 0
self.input_unit = input_unit
self.renew_id()
self._init_layers()
def renew_id(self):
global layer_graph_table, layer_graph_count
self.id = layer_graph_count
layer_graph_count += 1
layer_graph_table.append(self)
def add_node(self, type, num_of_filters=1, stride=2):
'''
Return:
node number
'''
node_idx = max(self._graph.nodes) + 1 if self.layer_count else 0
self._graph.add_node(node_idx, type=type, num_of_filters=num_of_filters, layer_mass=0, stride=stride)
self.layer_count += 1
return node_idx
def add_edge(self, f, t):
self._graph.add_edge(f, t)
def remove_edge(self, f, t):
self._graph.remove_edge(f, t)
def append(self, type, num_of_filters=1, stride=2, append_to=None):
if append_to is None:
append_to = self.layer_count - 1
new_node = self.add_node(type, num_of_filters, stride)
self.add_edge(append_to, new_node)
return new_node
def get_node_attr(self, n, attr='type'):
return self._graph.node[n][attr]
def get_graph(self):
return self._graph
def get_nodes(self):
return nx.topological_sort(self._graph)
def get_node(self, idx):
return self.get_graph().nodes[idx]
def finish(self, zeta1=0.1, zeta2=0.1):
'''
Args:
zeta1: for ip, op
zeta2: for decision layers
'''
# update layer mass
global layer_graph_table
nodes = self.get_nodes()
ipop = []
dl = []
pl_lm = 0
for node in nodes:
if self.get_node_attr(node) in self.iop_layers:
ipop.append(node)
elif self.get_node_attr(node) in self.decision_layers:
dl.append(node)
else:
total_filters = 0
for n in self._graph.predecessors(node):
total_filters += self._graph.node[n]['num_of_filters']
k = 0.1 if self._graph.node[node]['type'] == LAYERS.fc else 1
self._graph.node[node]['layer_mass'] = k * total_filters * self._graph.node[node]['num_of_filters']
pl_lm += self._graph.node[node]['layer_mass']
self.total_lm = pl_lm
for node in ipop:
self._graph.node[node]['layer_mass'] = zeta1 * pl_lm
self.total_lm += self._graph.node[node]['layer_mass']
for node in dl:
self._graph.node[node]['layer_mass'] = zeta2 * pl_lm / len(dl)
self.total_lm += self._graph.node[node]['layer_mass']
# update distance
for graph in layer_graph_table:
get_distance(graph, self, update=True)
#remove designated pool layer - helper
def remove_a_pool(self, node):
print('removing...')
for parent in self._graph.predecessors(node):
if sum(1 for _ in self._graph.successors(parent)) == 1:
#only child for this parent
self.add_edge(parent, next(self._graph.successors(node)))#connect parent u with del's child
for child in self._graph.successors(node):
if sum(1 for _ in self._graph.predecessors(child)) == 1:
#only parent for this child
self.add_edge(next(self._graph.predecessors(node)), child)#connect child u with del's parent
self._graph.remove_node(node)
self.layer_count -= 1
#remove pool recursively by going up - helper
def remove_pool(self, node):
#base case
if self._graph.node[node]['type'] == LAYERS.maxpool or self._graph.node[node]['type'] == LAYERS.avgpool:
self.remove_a_pool(node)
#reach the root
elif node == 0:
return
else:
tmp_parents = list(self._graph.predecessors(node))
for parent in tmp_parents:
self.remove_pool(parent)
def update_pool(self, after_add):
'''
Args:
after_add: being called after add_layer or not
'''
#bfs_edges = list(nx.bfs_edges(self._graph, list(self.get_nodes())[0]))
#bfs_cnt = np.zeros(len(bfs_edges))#record the num_pool from source to this node
topo_nodes = list(self.get_nodes())
topo_cnt = np.zeros(max(topo_nodes)+1)
is_updated = False
for update_node in topo_nodes[1:]:#skip the first node
#for tpl in bfs_edges:
# update_node = tpl[1]
#remove pool
if not after_add:
parents = list(self._graph.predecessors(update_node))
if len(parents)==0:
continue
pool_cnt = np.zeros(len(parents))
i = 0
for parent in parents:
if self._graph.node[parent]['type'] in self.pool_layers or (self._graph.node[parent]['type'] in self.conv_layers and self._graph.node[parent]['stride'] == 2):
pool_cnt[i] = topo_cnt[parent]+1
else:
pool_cnt[i] = topo_cnt[parent]
i+=1
#remove pool recursively
i = 0
for parent in parents:
if pool_cnt[i] > np.amin(pool_cnt):
#remove ONE pool in this path
self.remove_pool(parent)
is_updated = True
i+=1
topo_cnt[update_node] = np.amin(pool_cnt)#update bfs_cnt
if is_updated:
return self.update_pool(after_add)
#add pool
else:
parents = list(self._graph.predecessors(update_node))
if len(parents)==0:
continue
pool_cnt = np.zeros(len(parents))
i = 0
for parent in parents:
# if self._graph.node[parent]['type'] == LAYERS.maxpool or self._graph.node[parent]['type'] == LAYERS.avgpool:
if self._graph.node[parent]['type'] in self.pool_layers or (self._graph.node[parent]['type'] in self.conv_layers and self._graph.node[parent]['stride'] == 2):
pool_cnt[i] = topo_cnt[parent]+1
else:
pool_cnt[i] = topo_cnt[parent]
i+=1
#add pool recursively
i = 0
for parent in parents:
if pool_cnt[i] < np.amax(pool_cnt):
#directly add ONE pool in this path (parent, pool) (pool, update_node)
# new_pool = self.add_node(LAYERS.maxpool)
new_pool = self.add_node(random.choice(self.pool_layers))
self.add_edge(parent, new_pool)
self.add_edge(new_pool, update_node)
self.remove_edge(parent, update_node)
is_updated = True
i+=1
topo_cnt[update_node] = np.amax(pool_cnt)#update bfs_cnt
if is_updated:
return self.update_pool(after_add)
def get_num_layers(self):
return self.layer_count
def get_total_mass(self):
return self.total_lm
def mut_dup_path(self):
# may cause cycle, reason unknown
# Need more thinking to control in/out degree
# I only controled in-degree
if self.layer_count >= LAYERS_UPPER_BOUND:
return
stop_count = random.randint(1, min(self.layer_count-1, LAYERS_UPPER_BOUND - self.layer_count))
while True:
pick = random.randint(0, self.layer_count-1)#u1
nodes =list(self.get_nodes())
node = nodes[pick]
# if not (self._graph.node[node]['type'] == LAYERS.fc or self._graph.node[node]['type'] == LAYERS.ip or self._graph.node[node]['type'] == LAYERS.op or self._graph.node[node]['type'] == LAYERS.softmax):
if self.get_node_attr(node) in self.process_layers and not self._graph.out_degree(node) >= DEGREE_UPPER_BOUND:
break
head = node
end = node
new_head = node
new_end = node
for _ in range(stop_count):
#reach the end
if self._graph.node[head]['type'] == LAYERS.fc:
break
childs = list(self._graph.successors(head))
pick_child = childs[random.randint(0, len(childs)-1)]
#copy
new_end = self.add_node(self._graph.node[pick_child]['type'], self._graph.node[pick_child]['num_of_filters'], self._graph.node[pick_child]['stride'])
self.add_edge(new_head, new_end)
#update and store
new_head = new_end
head = pick_child
#converge
self.add_edge(new_head, next(self._graph.successors(head)))
def mut_remove_layer(self):
is_pool = False
nodes = list(self.get_nodes())
pick_node = False
for _ in range(10):
pick = random.randint(0, self.layer_count-1)
node = nodes[pick]
#print('pick_trial: ', self._graph.node[node])#dict
if self._graph.node[node]['type'] == LAYERS.maxpool or self._graph.node[node]['type'] == LAYERS.avgpool:
is_pool = True
pick_node = True
break
if self._graph.node[node]['type'] == LAYERS.conv3 or self._graph.node[node]['type'] == LAYERS.conv5 or self._graph.node[node]['type'] == LAYERS.conv7:
#continue
pick_node = True
break
print('removing: ', self._graph.node[node])#dict
if not pick_node:
return
for parent in self._graph.predecessors(node):
if sum(1 for _ in self._graph.successors(parent)) == 1:
#only child for this parent
self.add_edge(parent, next(self._graph.successors(node)))#connect parent u with del's child
for child in self._graph.successors(node):
if sum(1 for _ in self._graph.predecessors(child)) == 1:
#only parent for this child
self.add_edge(next(self._graph.predecessors(node)), child)#connect child u with del's parent
self._graph.remove_node(node)
self.layer_count -= 1
#remove pool requires update for other paths
if is_pool:
self.update_pool(after_add = False)
def processing_nodes(self):
'''
Return list of node idx
'''
def is_processing_node(node_idx):
return self.get_node(node_idx)['type'] in self.process_layers
return [node_idx for node_idx in self.get_graph().nodes if is_processing_node(node_idx)]
def mut_alt_single(self, portion):
while True:
random_node = self.get_node(random.choice(self.processing_nodes()))
if random_node['type'] in self.process_layers: break
num_of_filters = random_node['num_of_filters']
random_node['num_of_filters'] = int(num_of_filters*(1+portion))
def mut_dec_single(self):
self.mut_alt_single(-1/8)
def mut_inc_single(self):
self.mut_alt_single(1/8)
def mut_alt_en_masse(self, portion):
num_of_nodes = len(self.processing_nodes())
rate = 1 + portion
if num_of_nodes > 8:
num_of_mut = int(num_of_nodes/8)
elif num_of_nodes > 4:
num_of_mut = int(num_of_nodes/4)
else :
num_of_mut = int(num_of_nodes/2)
for random_node_idx in random.sample(self.processing_nodes(), int(num_of_mut)):
random_node = self.get_node(random_node_idx)
num_of_filters = random_node['num_of_filters']
random_node['num_of_filters'] = int(num_of_filters*rate)
def mut_dec_en_masse(self):
self.mut_alt_en_masse(-1/8)
def mut_inc_en_masse(self):
self.mut_alt_en_masse(1/8)
def show_graph(self, title=None, node_size=1000):
plt.figure()
if title is not None:
plt.title(title)
labels = {}
for n, t in self._graph.nodes(data=True):
labels[n] = str(n) + '*' + str(t['type']) + ' ' + str(t['stride']) + ' ' + str(t['num_of_filters'])
nx.draw_kamada_kawai(self._graph, labels=labels, node_size=node_size)
# plt.show()
def mut_skip(self):
def random_pick():
nodes = list(self.get_nodes())
i = 0
A = random.choice(nodes)
while self.get_node_attr(A) not in self.process_layers or self._graph.out_degree(A) >= DEGREE_UPPER_BOUND:
A = random.choice(nodes)
i += 1
if i == 20: return
B = random.choice(nodes)
i = 0
while self.get_node_attr(B) not in self.process_layers or B == A or self._graph.in_degree(B) >= DEGREE_UPPER_BOUND:
B = random.choice(nodes)
i += 1
if i == 20: return
for n in self.get_nodes():
if A == n:
return [A, B]
if B == n:
return [B, A]
i = 0
nodes = random_pick()
while nodes is not None and self._graph.has_edge(*nodes) and i < 20:
i += 1
nodes = random_pick()
if i == 20 or nodes is None: return
'''
path = nx.shortest_path(self._graph, source=nodes[0], target=nodes[1])
pool_counter = 0
for node in path[1:-1]:
if self.get_node_attr(node) in self.pool_layers:
pool_counter += 1
while pool_counter != 0:
nodes[0] = self.append(LAYERS.avgpool, append_to=nodes[0])
pool_counter -= 1
'''
self.add_edge(nodes[0], nodes[1])
self.update_pool(after_add=True)
def mut_swap_label(self):
'''
Can I pick softmax or change to softmax?
'''
nodes = list(self.get_nodes())
node = random.choice(nodes)
while self.get_node_attr(node) not in self.process_layers:
node = random.choice(nodes)
layers_list = list(LAYERS)
type = random.choice(layers_list)
while type not in self.process_layers:
type = random.choice(layers_list)
if type in [*self.conv_layers, LAYERS.fc]:
num_of_filters = np.random.choice([64, 128, 256, 512], 1, p=[0.4, 0.3, 0.2, 0.1])[0]
stride = random.choice([1, 2]) if type != LAYERS.fc else 2
else:
num_of_filters = 1
stride = 2
self._graph.node[node]['stride'] = stride
self._graph.node[node]['num_of_filters'] = num_of_filters
self._graph.node[node]['type'] = type
def mut_wedge_layer(self):
if self.layer_count >= LAYERS_UPPER_BOUND: return
edges = self._graph.edges()
edge = random.choice(list(edges))
while self.get_node_attr(edge[0]) not in self.process_layers and self.get_node_attr(edge[1]) not in self.process_layers:
edge = random.choice(list(edges))
layers_list = list(LAYERS)
type = random.choice(layers_list)
while type not in self.process_layers:
type = random.choice(layers_list)
if type in [*self.conv_layers, LAYERS.fc]:
num_of_filters = int((self.get_node_attr(
edge[0], 'num_of_filters') + self.get_node_attr(edge[1], 'num_of_filters')) / 2)
if num_of_filters % 2: num_of_filters += 1
if num_of_filters < 16:
num_of_filters = np.random.choice([64, 128, 256, 512], 1, p=[0.4, 0.3, 0.2, 0.1])[0]
stride = random.choice([1, 2]) if type != LAYERS.fc else 2
else:
num_of_filters = 1
stride = 2
self._graph.remove_edge(*edge)
new_node = self.append(type, num_of_filters, stride, edge[0])
self.add_edge(new_node, edge[1])
def mut_step(self):
# mut_op = random.choice([self.mut_skip])
mut_op = random.choice([self.mut_dup_path, self.mut_remove_layer,
self.mut_dec_single, self.mut_inc_single,
self.mut_swap_label, self.mut_wedge_layer,
self.mut_inc_en_masse, self.mut_dec_en_masse,
self.mut_skip])
print(mut_op.__name__)
mut_op()
def mutate(self):
num_of_steps = np.random.choice([1, 2, 3, 4, 5], 1, p=[0.5, 0.25, 0.125, 0.075, 0.05])[0]
print('num_step: ', num_of_steps)
for i in range(num_of_steps):
self.mut_step()
self.show_graph('mutate step ' + str(i))
self.finish()
'''
try:
for i in range(num_of_steps):
self.mut_step()
self.show_graph('mutate step ' + str(i))
self.finish()
except:
self.show_graph('Error')
plt.show()
'''
plt.close('all')
def copy(self):
global layer_graph_count, layer_graph_table
G = copy.deepcopy(self)
layer_graph_table.append(G)
G.id = layer_graph_count
G.finish()
layer_graph_count += 1
return G
def elim_LAYERS(self):
'''
Run before save
'''
for node in self._graph.nodes:
self._graph.nodes[node]['type'] = int(self._graph.nodes[node]['type'].value)
del self.conv_layers, self.pool_layers, self.process_layers, \
self.decision_layers, self.iop_layers
def rec_LAYERS(self):
'''
Run after read
'''
for node in self._graph.nodes:
self._graph.nodes[node]['type'] = LAYERS(self._graph.nodes[node]['type'])
self._init_layers()
self.finish()
def _init_layers(self):
self.conv_layers = [LAYERS.conv3, LAYERS.conv5, LAYERS.conv7]
self.pool_layers = [LAYERS.maxpool, LAYERS.avgpool]
self.process_layers = [*self.conv_layers, *self.pool_layers, LAYERS.fc, LAYERS.batchnorm, LAYERS.resnet]
self.decision_layers = [LAYERS.softmax]
self.iop_layers = [LAYERS.ip, LAYERS.op]
def write(lg_object, path='graph'):
lg_object.elim_LAYERS()
pickle.dump(lg_object, open(path, 'wb'))
def read(path='graph'):
lg_object = pickle.load(open(path, 'rb'))
lg_object.rec_LAYERS()
return lg_object
def clear_layers():
global layer_graph_count, layer_graph_table
layer_graph_table = []
layer_graph_count = 0
clear_distance()