forked from JamesTonG321/FCN-Depth-Semantic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
147 lines (122 loc) · 5.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import random
import os
import PIL
from PIL import Image
from model import *
from torch.autograd import Variable
from weights import load_weights
from scipy import misc
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision.transforms as transforms
from torchvision import utils
import flow_transforms
import torch
from nyu_dataset_loader import *
import warnings
from torch.optim.optimizer import Optimizer
import shutil
weights_file = "NYU_ResNet-UpProj.npy"
color = np.array([(0,0,0),(0,0,255),(255,0,0),(0,255,0),(255,255,0),(255,0,255), #magenta
(192,192,192), #silver
(128,128,128), #gray
(128,0,0) ,#maroon
(128,128,0) ,#olive
(0,128,0) ,#green
(128,0,128), # purple
(0,128,128) , # teal
(65,105,225) , #royal blue
(255,250,205) , #lemon chiffon
(255,20,147) , #deep pink
(218,112,214) , #orchid]
(135,206,250) , #light sky blue
(127,255,212), #aqua marine
(0,255,127) , #spring green
(255,215,0) , #gold
(165,42,42) , #brown
(148,0,211) , #violet
(210,105,30) , # chocolate
(244,164,96), # sandy brown
(240,255,240), #honeydew
(112,128,144), (64,224,208) ,(100,149,237) ,(30,144,255),(221,160,221),
(205,133,63),(255,240,245),(255,255,240),(255,165,0),(255,160,122),(205,92,92),(240,248,255)])
def run_epoch(model, loss_fn, loader, optimizer, dtype):
"""
Train the model for one epoch.
"""
# Set the model to training mode
#model.train()
running_loss = 0
count = 0
for x,z,y in loader:
x_var = Variable(x.type(dtype))
z_var = Variable(z.type(dtype))
y_var = Variable(y.type(dtype).long())
m = nn.LogSoftmax(dim=1)
pred_depth,pred_labels = model(x_var,z_var)
y_var = y_var.squeeze()
loss = loss_fn(m(pred_labels), y_var)
running_loss += loss.data.cpu().numpy()
count += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
return running_loss/count
def check_accuracy(model, loader,epoch, dtype, visualize = False):
"""
Check the accuracy of the model.
"""
num_correct, num_samples = 0, 0
for x,z,y in loader:
x_var = Variable(x.type(dtype),volatile=True)
z_var = Variable(z.type(dtype),volatile=True)
pred_depth,pred_labels = model(x_var,z_var)
_,preds = pred_labels.data.cpu().max(1)
if visualize == True:
#Save the input RGB image, Ground truth depth map, Ground Truth Coloured Semantic Segmentation Map,
#Predicted Coloured Semantic Segmentation Map, Predicted Depth Map for one image in the current batch
input_rgb_image = x_var[0].data.permute(1,2,0).cpu().numpy().astype(np.uint8)
plt.imsave('input_rgb_epoch_{}.png'.format(epoch),input_rgb_image)
input_gt_depth_image = z_var[0].data.permute(1,2,0).cpu().numpy().astype(np.uint8)
plt.imsave('input_gt_depth_epoch_{}.png'.format(epoch),input_gt_depth_image)
colored_gt_label = color[y[0].squeeze().cpu().numpy().astype(int)].astype(np.uint8)
plt.imsave('gt_label_epoch_{}.png'.format(epoch),colored_gt_label)
colored_pred_label = color[preds[0].squeeze().cpu().numpy().astype(int)].astype(np.uint8)
plt.imsave('pred_label_epoch_{}.png'.format(epoch),colored_pred_label)
pred_depth_image = pred_depth[0].data.squeeze().cpu().numpy().astype(np.uint8)
plt.imsave('pred_depth_epoch_{}.png'.format(epoch),pred_depth_image,cmap = "gray")
# Computing pixel-wise accuracy
num_correct += (preds.long() == y.long()).sum()
num_samples += preds.numel()
acc = float(num_correct) / num_samples
return acc
def plot_performance_curves(loss_history,train_acc_history,val_acc_history,epoch_history,train_on,batch_size,num_epochs,resumed_file):
plt.figure()
plt.plot(np.array(epoch_history),np.array(loss_history))
plt.ylabel('Loss')
plt.xlabel('Number of Epochs')
plt.title('Loss history for training model on {} examples with batch size of {}'.format(train_on,batch_size))
if resumed_file == False:
plt.savefig('loss_plot_train_on_{}_batch_size_{}.png'.format(train_on,batch_size))
else:
plt.savefig('loss_plot_train_on_{}_batch_size_{}_resumed.png'.format(train_on,batch_size))
plt.figure()
plt.plot(np.array(epoch_history),np.array(train_acc_history),label = 'Training accuracy')
plt.plot(np.array(epoch_history),np.array(val_acc_history), label = 'Validation accuracy')
plt.title('Accuracy history for training model on {} examples with batch size of {}'.format(train_on,batch_size))
plt.ylabel('Accuracy')
plt.xlabel('Number of Epochs')
plt.legend()
if resumed_file == False:
plt.savefig('acc_plots_train_on_{}_batch_size_{}.png'.format(train_on,batch_size))
else:
plt.savefig('acc_plots_train_on_{}_batch_size_{}_resumed.png'.format(train_on,batch_size))
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')