-
Notifications
You must be signed in to change notification settings - Fork 2
/
toolkits.py
131 lines (103 loc) · 3.9 KB
/
toolkits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import numpy as np
def initialize_GPU(args):
# Initialize GPUs
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
return session
def get_chunks(l, n):
# For item i in a range that is a length of l,
for i in range(0, len(l), n):
# Create an index range for l of n items:
yield l[i:i+n]
def debug_generator(generator):
import cv2
import pdb
G = generator.next()
for i,img in enumerate(G[0]):
path = '../sample/{}.jpg'.format(i)
img = np.asarray(img[:,:,::-1] + 128.0, dtype='uint8')
cv2.imwrite(path, img)
# set up multiprocessing
def set_mp(processes=8):
import multiprocessing as mp
def init_worker():
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
global pool
try:
pool.terminate()
except:
pass
if processes:
pool = mp.Pool(processes=processes, initializer=init_worker)
else:
pool = None
return pool
# vggface2 dataset
def get_vggface2_imglist(args):
def get_datalist(s):
file = open('{}'.format(s), 'r')
datalist = file.readlines()
imglist = []
labellist = []
for i in datalist:
linesplit = i.split(' ')
imglist.append(linesplit[0])
labellist.append(int(linesplit[1][:-1]))
return imglist, labellist
print('==> calculating image lists...')
# Prepare training data.
imgs_list_trn, lbs_list_trn = get_datalist(args.trn_meta)
imgs_list_trn = [os.path.join(args.data_path, i) for i in imgs_list_trn]
imgs_list_trn = np.array(imgs_list_trn)
lbs_list_trn = np.array(lbs_list_trn)
# Prepare validation data.
imgs_list_val, lbs_list_val = get_datalist(args.val_meta)
imgs_list_val = [os.path.join(args.data_path, i) for i in imgs_list_val]
imgs_list_val = np.array(imgs_list_val)
lbs_list_val = np.array(lbs_list_val)
return imgs_list_trn, lbs_list_trn, imgs_list_val, lbs_list_val
def get_imagenet_imglist(args, trn_meta_path='', val_meta_path=''):
with open(trn_meta_path) as f:
strings = f.readlines()
trn_list = np.array([os.path.join(args.data_path, '/'.join(string.split()[0].split(os.sep)[-4:]))
for string in strings])
trn_lb = np.array([int(string.split()[1]) for string in strings])
f.close()
with open(val_meta_path) as f:
strings = f.readlines()
val_list = np.array([os.path.join(args.data_path, '/'.join(string.split()[0].split(os.sep)[-4:]))
for string in strings])
val_lb = np.array([int(string.split()[1]) for string in strings])
f.close()
return trn_list, trn_lb, val_list, val_lb
def get_voxceleb2_datalist(args, path):
with open(path) as f:
strings = f.readlines()
audiolist = np.array([os.path.join(args.data_path, string.split()[0]) for string in strings])
labellist = np.array([int(string.split()[1]) for string in strings])
f.close()
return audiolist, labellist
def calculate_eer(y, y_score):
# y denotes groundtruth scores,
# y_score denotes the prediction scores.
from scipy.optimize import brentq
from sklearn.metrics import roc_curve
from scipy.interpolate import interp1d
fpr, tpr, thresholds = roc_curve(y, y_score, pos_label=1)
eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
thresh = interp1d(fpr, thresholds)(eer)
return eer, thresh
def sync_model(src_model, tgt_model):
print('==> synchronizing the model weights.')
params = {}
for l in src_model.layers:
params['{}'.format(l.name)] = l.get_weights()
for l in tgt_model.layers:
if len(l.get_weights()) > 0:
l.set_weights(params['{}'.format(l.name)])
return tgt_model