-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
191 lines (162 loc) · 8.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import argparse
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
from BraTS import *
from networks.Unet import UNet
from utils import Loss,cal_dice,cosine_scheduler
def train_loop(model,optimizer,scheduler,criterion,train_loader,device,epoch):
model.train()
running_loss = 0
dice1_train = 0
dice2_train = 0
dice3_train = 0
pbar = tqdm(train_loader)
for it,(images,masks) in enumerate(pbar):
# update learning rate according to the schedule
it = len(train_loader) * epoch + it
param_group = optimizer.param_groups[0]
param_group['lr'] = scheduler[it]
# print(scheduler[it])
# [b,4,128,128,128] , [b,128,128,128]
images, masks = images.to(device),masks.to(device)
# [b,4,128,128,128], 4分割
outputs = model(images)
# outputs = torch.softmax(outputs,dim=1)
loss = criterion(outputs, masks)
dice1, dice2, dice3 = cal_dice(outputs,masks)
pbar.desc = "loss: {:.3f} ".format(loss.item())
running_loss += loss.item()
dice1_train += dice1.item()
dice2_train += dice2.item()
dice3_train += dice3.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = running_loss / len(train_loader)
dice1 = dice1_train / len(train_loader)
dice2 = dice2_train / len(train_loader)
dice3 = dice3_train / len(train_loader)
return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}
def val_loop(model,criterion,val_loader,device):
model.eval()
running_loss = 0
dice1_val = 0
dice2_val = 0
dice3_val = 0
pbar = tqdm(val_loader)
with torch.no_grad():
for images, masks in pbar:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
# outputs = torch.softmax(outputs,dim=1)
loss = criterion(outputs, masks)
dice1, dice2, dice3 = cal_dice(outputs, masks)
running_loss += loss.item()
dice1_val += dice1.item()
dice2_val += dice2.item()
dice3_val += dice3.item()
# pbar.desc = "loss:{:.3f} dice1:{:.3f} dice2:{:.3f} dice3:{:.3f} ".format(loss,dice1,dice2,dice3)
loss = running_loss / len(val_loader)
dice1 = dice1_val / len(val_loader)
dice2 = dice2_val / len(val_loader)
dice3 = dice3_val / len(val_loader)
return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}
def train(model,optimizer,scheduler,criterion,train_loader,
val_loader,epochs,device,train_log,valid_loss_min=999.0):
for e in range(epochs):
# train for epoch
train_metrics = train_loop(model,optimizer,scheduler,criterion,train_loader,device,e)
# eval for epoch
val_metrics = val_loop(model,criterion,val_loader,device)
info1 = "Epoch:[{}/{}] train_loss: {:.3f} valid_loss: {:.3f} ".format(e+1,epochs,train_metrics["loss"],val_metrics["loss"])
info2 = "Train--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(train_metrics['dice1'],train_metrics['dice2'],train_metrics['dice3'])
info3 = "Valid--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(val_metrics['dice1'],val_metrics['dice2'],val_metrics['dice3'])
print(info1)
print(info2)
print(info3)
with open(train_log,'a') as f:
f.write(info1 + '\n' + info2 + ' ' + info3 + '\n')
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict()}
if val_metrics['loss'] < valid_loss_min:
valid_loss_min = val_metrics['loss']
torch.save(save_file, 'results/UNet.pth')
else:
torch.save(save_file,os.path.join(args.save_path,'checkpoint{}.pth'.format(e+1)))
print("Finished Training!")
def main(args):
torch.manual_seed(args.seed) # 为CPU设置种子用于生成随机数,以使得结果是确定的
torch.cuda.manual_seed_all(args.seed) # 为所有的GPU设置种子,以使得结果是确定的
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# data info
patch_size = (160,160,128)
train_dataset = BraTS(args.data_path,args.train_txt,transform=transforms.Compose([
RandomRotFlip(),
RandomCrop(patch_size),
GaussianNoise(p=0.1),
ToTensor()
]))
val_dataset = BraTS(args.data_path,args.valid_txt,transform=transforms.Compose([
CenterCrop(patch_size),
ToTensor()
]))
test_dataset = BraTS(args.data_path,args.test_txt,transform=transforms.Compose([
CenterCrop(patch_size),
ToTensor()
]))
train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=12, # num_worker=4
shuffle=True, pin_memory=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
pin_memory=True)
print("using {} device.".format(device))
print("using {} images for training, {} images for validation.".format(len(train_dataset), len(val_dataset)))
# img,label = train_dataset[0]
# 1-坏疽(NT,necrotic tumor core),2-浮肿区域(ED,peritumoral edema),4-增强肿瘤区域(ET,enhancing tumor)
# 评价指标:ET(label4),TC(label1+label4),WT(label1+label2+label4)
model = UNet(in_channels=4,num_classes=4).to(device)
criterion = Loss(n_classes=4, weight=torch.tensor([0.2, 0.3, 0.25, 0.25])).to(device)
optimizer = optim.SGD(model.parameters(),momentum=0.9, lr=0, weight_decay=5e-4)
scheduler = cosine_scheduler(base_value=args.lr,final_value=args.min_lr,epochs=args.epochs,
niter_per_ep=len(train_loader),warmup_epochs=args.warmup_epochs,start_warmup_value=5e-4)
# 加载训练模型
if os.path.exists(args.weights):
weight_dict = torch.load(args.weights, map_location=device)
model.load_state_dict(weight_dict['model'])
optimizer.load_state_dict(weight_dict['optimizer'])
print('Successfully loading checkpoint.')
train(model,optimizer,scheduler,criterion,train_loader,val_loader,args.epochs,device,train_log=args.train_log)
# metrics1 = val_loop(model, criterion, train_loader, device)
metrics2 = val_loop(model, criterion, val_loader, device)
metrics3 = val_loop(model, criterion, test_loader, device)
# 最后再评价一遍所有数据,注意,这里使用的是训练结束的模型参数
# print("Train -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics1['loss'], metrics1['dice1'],metrics1['dice2'], metrics1['dice3']))
print("Valid -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics2['loss'], metrics2['dice1'], metrics2['dice2'], metrics2['dice3']))
print("Test -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics3['loss'], metrics3['dice1'], metrics3['dice2'], metrics3['dice3']))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=4)
parser.add_argument('--seed', type=int, default=21)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--warmup_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.004)
parser.add_argument('--min_lr', type=float, default=0.002)
parser.add_argument('--data_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/BraTS2021/dataset')
parser.add_argument('--train_txt', type=str, default='/data/omnisky/postgraduate/Yb/data_set/BraTS2021/train.txt')
parser.add_argument('--valid_txt', type=str, default='/data/omnisky/postgraduate/Yb/data_set/BraTS2021/valid.txt')
parser.add_argument('--test_txt', type=str, default='/data/omnisky/postgraduate/Yb/data_set/BraTS2021/test.txt')
parser.add_argument('--train_log', type=str, default='results/UNet.txt')
parser.add_argument('--weights', type=str, default='results/UNet.pth')
parser.add_argument('--save_path', type=str, default='checkpoint/UNet')
args = parser.parse_args()
main(args)