-
Notifications
You must be signed in to change notification settings - Fork 5
/
validation.py
44 lines (34 loc) · 1.47 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from core.utils import AverageMeter, process_data_item, run_model, calculate_accuracy
import os
import time
import torch
def val_epoch(epoch, data_loader, model, criterion, opt, writer, optimizer):
print("# ---------------------------------------------------------------------- #")
print('Validation at epoch {}'.format(epoch))
model.eval()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
accuracies = AverageMeter()
end_time = time.time()
for i, data_item in enumerate(data_loader):
visual, target, audio, visualization_item, batch_size = process_data_item(opt, data_item)
data_time.update(time.time() - end_time)
with torch.no_grad():
output, loss = run_model(opt, [visual, target, audio], model, criterion, i)
acc = calculate_accuracy(output, target)
losses.update(loss.item(), batch_size)
accuracies.update(acc, batch_size)
batch_time.update(time.time() - end_time)
end_time = time.time()
writer.add_scalar('val/loss', losses.avg, epoch)
writer.add_scalar('val/acc', accuracies.avg, epoch)
print("Val loss: {:.4f}".format(losses.avg))
print("Val acc: {:.4f}".format(accuracies.avg))
save_file_path = os.path.join(opt.ckpt_path, 'save_{}.pt'.format(epoch))
states = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(states, save_file_path)