forked from rohban-lab/Knowledge_Distillation_AD
-
Notifications
You must be signed in to change notification settings - Fork 22
/
train.py
99 lines (79 loc) · 3.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from test import *
from utils.utils import *
from dataloader import *
from pathlib import Path
from torch.autograd import Variable
import pickle
from test_functions import detection_test
from loss_functions import *
parser = ArgumentParser()
parser.add_argument('--config', type=str, default='configs/config.yaml', help="training configuration")
def train(config):
direction_loss_only = config["direction_loss_only"]
normal_class = config["normal_class"]
learning_rate = float(config['learning_rate'])
num_epochs = config["num_epochs"]
lamda = config['lamda']
continue_train = config['continue_train']
last_checkpoint = config['last_checkpoint']
checkpoint_path = "./outputs/{}/{}/checkpoints/".format(config['experiment_name'], config['dataset_name'])
# create directory
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
train_dataloader, test_dataloader = load_data(config)
if continue_train:
vgg, model = get_networks(config, load_checkpoint=True)
else:
vgg, model = get_networks(config)
# Criteria And Optimizers
if direction_loss_only:
criterion = DirectionOnlyLoss()
else:
criterion = MseDirectionLoss(lamda)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if continue_train:
optimizer.load_state_dict(
torch.load('{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint)))
losses = []
roc_aucs = []
if continue_train:
with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, last_checkpoint), 'rb') as f:
roc_aucs = pickle.load(f)
for epoch in range(num_epochs + 1):
model.train()
epoch_loss = 0
for data in train_dataloader:
X = data[0]
if X.shape[1] == 1:
X = X.repeat(1, 3, 1, 1)
X = Variable(X).cuda()
output_pred = model.forward(X)
output_real = vgg(X)
total_loss = criterion(output_pred, output_real)
# Add loss to the list
epoch_loss += total_loss.item()
losses.append(total_loss.item())
# Clear the previous gradients
optimizer.zero_grad()
# Compute gradients
total_loss.backward()
# Adjust weights
optimizer.step()
print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, epoch_loss))
if epoch % 10 == 0:
roc_auc = detection_test(model, vgg, test_dataloader, config)
roc_aucs.append(roc_auc)
print("RocAUC at epoch {}:".format(epoch), roc_auc)
if epoch % 50 == 0:
torch.save(model.state_dict(),
'{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch))
torch.save(optimizer.state_dict(),
'{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch))
with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, epoch),
'wb') as f:
pickle.dump(roc_aucs, f)
def main():
args = parser.parse_args()
config = get_config(args.config)
train(config)
if __name__ == '__main__':
main()