-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_train.py
236 lines (185 loc) · 12.1 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
""" Tool for training the models"""
import sys
import pickle
import argparse
from models.AAD import AAD
from models.StyleAAD import StyleAAD
from models.DCGAN import DCGAN
from models.stylegan import StyleGAN_G
from models.stylegan import StyleGAN_D
from utils.PreProcessing import load_test_data
#----------------------------------------------------------------------------
def train_gan(latentDim, epochs, dataDir, resultsDir):
gan = DCGAN(latentDim, resultsDir)
gan.preprocessing(dataDir)
gan.train(n_epochs=epochs)
# save the object state
pickle.dump( gan, open( resultsDir + "/gan.pkl", "wb" ) )
def train_anomaly_detector(generatorDir, discriminatorDir, latentDim, reconstructionError, dicriminatorError, epochs, dataDir, resultsDir):
modelDir = {
"generator": generatorDir,
"discriminator": discriminatorDir
}
gan = DCGAN(latentDim, resultsDir)
gan.load(modelDir)
anomaly_detector = AAD(gan.get_generator(), gan.get_discriminator(), resultsDir, latentDim, reconstructionError, dicriminatorError)
anomaly_detector.preprocessing(dataDir)
anomaly_detector.train(n_epochs=epochs)
# save the object state
#pickle.dump( anomaly_detector, open( resultsDir + "/add.pkl", "wb" ) )
def train_style_anomaly_detector(generatorDir, discriminatorDir, latentDim, reconstructionError, dicriminatorError, epochs, dataDir, resultsDir):
# load the style gan
style_gan_g = StyleGAN_G()
style_gan_g.load_weights(generatorDir)
style_gan_d = StyleGAN_D()
style_gan_d.load_weights(discriminatorDir)
anomaly_detector = StyleAAD(style_gan_g, style_gan_d, resultsDir, latentDim, reconstructionError, dicriminatorError)
anomaly_detector.preprocessing(dataDir)
anomaly_detector.train(n_epochs=epochs)
def train_anomaly_grid_search(generatorDir, discriminatorDir, anomalyDetectorDir, latentDim, normlaDataDir, anomalyDataDir, resultsDir):
modelDir = {
"generator": generatorDir,
"discriminator": discriminatorDir
}
gan = DCGAN(latentDim, resultsDir)
gan.load(modelDir)
aadModelDir = {
"aad" : anomalyDetectorDir,
"svc" : '',
"scaler" : ''
}
normal = load_test_data(normlaDataDir)
anomaly = load_test_data(anomalyDataDir)
anomaly_detector = AAD(gan.get_generator(), gan.get_discriminator(), resultsDir, latentDim)
anomaly_detector.load(aadModelDir)
anomaly_detector.train_svm_with_grid_search(normal, anomaly)
def train_style_anomaly_grid_search(generatorDir, discriminatorDir, anomalyDetectorDir, latentDim, normlaDataDir, anomalyDataDir, resultsDir):
style_gan_g = StyleGAN_G()
style_gan_g.load_weights(generatorDir)
style_gan_d = StyleGAN_D()
style_gan_d.load_weights(discriminatorDir)
aadModelDir = {
"aad" : anomalyDetectorDir,
"svc" : '',
"scaler" : ''
}
normal = load_test_data(normlaDataDir, 64, 'channels_first')
anomaly = load_test_data(anomalyDataDir, 64, 'channels_first')
anomaly_detector = StyleAAD(style_gan_g, style_gan_d, resultsDir, latentDim)
anomaly_detector.load(aadModelDir)
anomaly_detector.train_svm_with_grid_search(normal, anomaly)
def train_anomaly_classifier(generatorDir, discriminatorDir, anomalyDetectorDir, latentDim, C, gamma, kernel, degree, normlaDataDir, anomalyDataDir, resultsDir):
modelDir = {
"generator": generatorDir,
"discriminator": discriminatorDir
}
gan = DCGAN(latentDim, resultsDir)
gan.load(modelDir)
aadModelDir = {
"aad" : anomalyDetectorDir,
"svc" : '',
"scaler" : ''
}
normal = load_test_data(normlaDataDir)
anomaly = load_test_data(anomalyDataDir)
anomaly_detector = AAD(gan.get_generator(), gan.get_discriminator(), resultsDir, latentDim)
anomaly_detector.load(aadModelDir)
anomaly_detector.train_svm(C, gamma, degree, kernel, normal, anomaly)
def train_style_anomaly_classifier(generatorDir, discriminatorDir, anomalyDetectorDir, latentDim, C, gamma, kernel, degree, normlaDataDir, anomalyDataDir, resultsDir):
style_gan_g = StyleGAN_G()
style_gan_g.load_weights(generatorDir)
style_gan_d = StyleGAN_D()
style_gan_d.load_weights(discriminatorDir)
aadModelDir = {
"aad" : anomalyDetectorDir,
"svc" : '',
"scaler" : ''
}
normal = load_test_data(normlaDataDir, 64, 'channels_first')
anomaly = load_test_data(anomalyDataDir, 64, 'channels_first')
anomaly_detector = StyleAAD(style_gan_g, style_gan_d, resultsDir, latentDim)
anomaly_detector.load(aadModelDir)
anomaly_detector.train_svm(C, gamma, degree, kernel, normal, anomaly)
#----------------------------------------------------------------------------
def cmdline(argv):
prog = argv[0]
parser = argparse.ArgumentParser(
prog = prog,
description = 'Tool for training the models in the Adversarial Anomaly Detector.',
epilog = 'Type "%s <command> -h" for more information.' % prog)
subparsers = parser.add_subparsers(dest='command')
subparsers.required = True
def add_command(cmd, desc, example=None):
epilog = 'Example: %s %s' % (prog, example) if example is not None else None
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
p = add_command( 'train_gan', 'Training of the DCGAN model.')
p.add_argument( '--latentDim', help='Latent space dimension of the GAN\'s generator', type=int, default=100)
p.add_argument( '--epochs', help='Number of epochs for the training', type=int, default=20)
p.add_argument( '--dataDir', help='Path of the dataset', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_anomaly_detector','Training of the Adversarial Anomaly Detector model.')
p.add_argument( '--generatorDir', help='Path of the generator h5 file', default='')
p.add_argument( '--discriminatorDir', help='Path of the discriminator h5 file', default='')
p.add_argument( '--latentDim', help='Latent space dimension of the GAN\'s generator', type=int, default=100)
p.add_argument( '--reconstructionError',help='Reconstruction error weight', type=float, default=0.90)
p.add_argument( '--dicriminatorError', help='Discriminator error weight', type=float, default=0.10)
p.add_argument( '--epochs', help='Number of epochs for the training', type=int, default=20)
p.add_argument( '--dataDir', help='Path of the dataset', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_style_anomaly_detector','Training of the Adversarial Anomaly Detector model with a StyleGAN.')
p.add_argument( '--generatorDir', help='Path of the generator h5 file', default='')
p.add_argument( '--discriminatorDir', help='Path of the discriminator h5 file', default='')
p.add_argument( '--latentDim', help='Latent space dimension of the GAN\'s generator', type=int, default=512)
p.add_argument( '--reconstructionError',help='Reconstruction error weight', type=float, default=0.90)
p.add_argument( '--dicriminatorError', help='Discriminator error weight', type=float, default=0.10)
p.add_argument( '--epochs', help='Number of epochs for the training', type=int, default=20)
p.add_argument( '--dataDir', help='Path of the dataset', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_anomaly_grid_search','Training of the Adversarial Anomaly Detector classifier.')
p.add_argument( '--generatorDir', help='Path of the GAN\'s generator weights', default='')
p.add_argument( '--discriminatorDir', help='Path of the GAN\'s discriminator weights', default='')
p.add_argument( '--anomalyDetectorDir', help='Path of the AnomalyDetector weights', default='')
p.add_argument( '--latentDim', help='Latent dimension of the GAN', type=int, default=100)
p.add_argument( '--normlaDataDir', help='Path of the dataset with healthy samples', default='')
p.add_argument( '--anomalyDataDir', help='Path of the dataset with anomal samples', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_style_anomaly_grid_search','Training of the Styel Adversarial Anomaly Detector classifier.')
p.add_argument( '--generatorDir', help='Path of the GAN\'s generator weights', default='')
p.add_argument( '--discriminatorDir', help='Path of the GAN\'s discriminator weights', default='')
p.add_argument( '--anomalyDetectorDir', help='Path of the AnomalyDetector weights', default='')
p.add_argument( '--latentDim', help='Latent dimension of the GAN', type=int, default=100)
p.add_argument( '--normlaDataDir', help='Path of the dataset with healthy samples', default='')
p.add_argument( '--anomalyDataDir', help='Path of the dataset with anomal samples', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_anomaly_classifier','Training of the Adversarial Anomaly Detector classifier.')
p.add_argument( '--generatorDir', help='Path of the GAN\'s generator weights', default='')
p.add_argument( '--discriminatorDir', help='Path of the GAN\'s discriminator weights', default='')
p.add_argument( '--anomalyDetectorDir', help='Path of the AnomalyDetector weights', default='')
p.add_argument( '--latentDim', help='Latent dimension of the GAN', type=int, default=100)
p.add_argument( '--C', help='SVM C parameter', type=float, default=-1)
p.add_argument( '--gamma', help='SVM gamma parameter', type=float, default=0.001)
p.add_argument( '--kernel', help='SVM kernel parameter', default='rbf')
p.add_argument( '--degree', help='SVM degree parameter', type=int, default=3)
p.add_argument( '--normlaDataDir', help='Path of the dataset with healthy samples', default='')
p.add_argument( '--anomalyDataDir', help='Path of the dataset with anomal samples', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
p = add_command( 'train_style_anomaly_classifier','Training of the Adversarial Anomaly Detector classifier.')
p.add_argument( '--generatorDir', help='Path of the GAN\'s generator weights', default='')
p.add_argument( '--discriminatorDir', help='Path of the GAN\'s discriminator weights', default='')
p.add_argument( '--anomalyDetectorDir', help='Path of the AnomalyDetector weights', default='')
p.add_argument( '--latentDim', help='Latent dimension of the GAN', type=int, default=100)
p.add_argument( '--C', help='SVM C parameter', type=float, default=-1)
p.add_argument( '--gamma', help='SVM gamma parameter', type=float, default=0.001)
p.add_argument( '--kernel', help='SVM kernel parameter', default='rbf')
p.add_argument( '--degree', help='SVM degree parameter', type=int, default=3)
p.add_argument( '--normlaDataDir', help='Path of the dataset with healthy samples', default='')
p.add_argument( '--anomalyDataDir', help='Path of the dataset with anomal samples', default='')
p.add_argument( '--resultsDir', help='Path where the results will be stored', default='')
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])
func = globals()[args.command]
del args.command
func(**vars(args))
#----------------------------------------------------------------------------
if __name__ == "__main__":
cmdline(sys.argv)
#----------------------------------------------------------------------------