-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
252 lines (203 loc) · 7.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""Train module.
This module contains the training code for the segmentation task.
(c) 2023 Bhimraj Yadav. All rights reserved.
"""
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm
from src.config import (
CHANNELS,
DEVICE,
LEARNING_RATE,
MODEL_PATH,
NUM_EPOCHS,
OUT_CHANNELS,
)
from src.dataset import train_dataloader, val_dataloader
from src.model import UNet
torchvision.disable_beta_transforms_warning()
model = UNet(channels=CHANNELS, out_channels=OUT_CHANNELS)
model.to(DEVICE)
def unet_loss(outputs, targets, alpha=0.5, beta=1.5):
"""
U-Net loss function with per-pixel weights to balance the classes and an extra term
to penalize joining two bits of the segmentation.
"""
weights = alpha * targets + beta * (1 - targets)
loss = F.binary_cross_entropy_with_logits(
outputs, targets, weights, reduction="none"
)
intersection = torch.sum(outputs * targets * weights)
union = torch.sum(outputs * weights) + torch.sum(targets * weights)
loss += 1 - 2 * (intersection + 1) / (union + 1)
return torch.mean(loss)
def accuracy(outputs, targets):
"""
U-Net accuracy function.
"""
outputs = torch.sigmoid(outputs)
outputs = (outputs > 0.5).float()
return torch.mean((outputs == targets).float())
def train_step(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
) -> Tuple[float, float]:
"""Trains a PyTorch model for a single epoch.
Turns a target PyTorch model to training mode and then
runs through all of the required training steps (forward
pass, loss calculation, optimizer step).
Args:
model: A PyTorch model to be trained.
dataloader: A DataLoader instance for the model to be trained on.
loss_fn: A PyTorch loss function to minimize.
optimizer: A PyTorch optimizer to help minimize the loss function.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A tuple of training loss and training accuracy metrics.
In the form (train_loss, train_accuracy). For example:
(0.1112, 0.8743)
"""
# set model to training mode
model.train()
# Setup train loss and train accuracy values
train_loss, train_acc = 0, 0
# Loop through data loader data batches
for batch, (X, y) in enumerate(dataloader):
# Send data to target device
X, y = X.to(DEVICE), y.to(DEVICE)
# 1. Forward pass
y_pred = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(y_pred, y)
train_loss += loss.item()
# 3. Optimizer zero grad
optimizer.zero_grad()
# 4. Loss backward
loss.backward()
# 5. Optimizer step
optimizer.step()
# Calculate and accumulate accuracy metric across all batches
train_acc += accuracy(y_pred, y).item()
# Adjust metrics to get average loss and accuracy per batch
train_loss = train_loss / len(dataloader)
train_acc = train_acc / len(dataloader)
return train_loss, train_acc
def val_step(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
) -> Tuple[float, float]:
"""Tests a PyTorch model for a single epoch.
Turns a target PyTorch model to "eval" mode and then performs
a forward pass on a testing dataset.
Args:
model: A PyTorch model to be tested.
dataloader: A DataLoader instance for the model to be tested on.
loss_fn: A PyTorch loss function to calculate loss on the test data.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A tuple of testing loss and testing accuracy metrics.
In the form (val_loss, val_accuracy). For example:
(0.0223, 0.8985)
"""
# set model in eval mode
model.eval()
# Setup test loss and test accuracy values
val_loss, val_acc = 0, 0
# Turn on inference context manager
with torch.inference_mode():
# Loop through DataLoader batches
for batch, (X, y) in enumerate(dataloader):
# Send data to target device
X, y = X.to(DEVICE), y.to(DEVICE)
# 1. Forward pass
val_pred_logits = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(val_pred_logits, y)
val_loss += loss.item()
# Calculate and accumulate accuracy
val_acc += accuracy(val_pred_logits, y).item()
# Adjust metrics to get average loss and accuracy per batch
val_loss = val_loss / len(dataloader)
val_acc = val_acc / len(dataloader)
return val_loss, val_acc
def train(
model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
val_dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
epochs: int,
) -> Dict[str, List]:
"""Trains and tests a PyTorch model.
Passes a target PyTorch models through train_step() and val_step()
functions for a number of epochs, training and testing the model
in the same epoch loop.
Calculates, prints and stores evaluation metrics throughout.
Args:
model: A PyTorch model to be trained and tested.
train_dataloader: A DataLoader instance for the model to be trained on.
val_dataloader: A DataLoader instance for the model to be tested on.
optimizer: A PyTorch optimizer to help minimize the loss function.
loss_fn: A PyTorch loss function to calculate loss on both datasets.
epochs: An integer indicating how many epochs to train for.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
A dictionary of training and testing loss as well as training and
testing accuracy metrics. Each metric has a value in a list for
each epoch.
In the form: {train_loss: [...],
train_acc: [...],
val_loss: [...],
val_acc: [...]}
For example if training for epochs=2:
{train_loss: [2.0616, 1.0537],
train_acc: [0.3945, 0.3945],
val_loss: [1.2641, 1.5706],
val_acc: [0.3400, 0.2973]}
"""
# Create empty results dictionary
results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
# Loop through training and testing steps for a number of epochs
for epoch in tqdm(range(epochs), desc="Epochs"):
train_loss, train_acc = train_step(
model=model,
dataloader=train_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
)
val_loss, val_acc = val_step(
model=model, dataloader=val_dataloader, loss_fn=loss_fn
)
# Print out what's happening
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"train_acc: {train_acc:.4f} | "
f"val_loss: {val_loss:.4f} | "
f"val_acc: {val_acc:.4f}"
)
# Update results dictionary
results["train_loss"].append(train_loss)
results["train_acc"].append(train_acc)
results["val_loss"].append(val_loss)
results["val_acc"].append(val_acc)
# Return the filled results at the end of the epochs
torch.save(model.state_dict(), MODEL_PATH)
print(f"Saved model to {MODEL_PATH}")
return results
if __name__ == "__main__":
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
train(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=NUM_EPOCHS,
)