-
Notifications
You must be signed in to change notification settings - Fork 5
/
eval.py
125 lines (104 loc) · 5.09 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import random
from pathlib import Path
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
import datasets
from datasets import build_dataset
import util.misc as utils
from engine import evaluate
from models import build_model
def get_args_parser():
parser = argparse.ArgumentParser('Set Point Query Transformer', add_help=False)
# model parameters
# - backbone
parser.add_argument('--backbone', default='vgg16_bn', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned', 'fourier'),
help="Type of positional embedding to use on top of the image features")
# - transformer
parser.add_argument('--dec_layers', default=2, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=512, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.0, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
# loss parameters
# - matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_point', default=0.05, type=float,
help="SmoothL1 point coefficient in the matching cost")
# - loss coefficients
parser.add_argument('--ce_loss_coef', default=1.0, type=float) # classification loss coefficient
parser.add_argument('--point_loss_coef', default=5.0, type=float) # regression loss coefficient
parser.add_argument('--eos_coef', default=0.5, type=float,
help="Relative classification weight of the no-object class") # cross-entropy weights
# dataset parameters
parser.add_argument('--dataset_file', default="SHA")
parser.add_argument('--data_path', default="./data/ShanghaiTech/PartA", type=str)
# misc parameters
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--vis_dir', default="")
parser.add_argument('--num_workers', default=2, type=int)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# build model
model, criterion = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('params:', n_parameters/1e6)
# build dataset
val_image_set = 'val'
dataset_val = build_dataset(image_set=val_image_set, args=args)
if args.distributed:
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = DataLoader(dataset_val, 1, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
# load pretrained model
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
cur_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
# evaluation
vis_dir = None if args.vis_dir == "" else args.vis_dir
test_stats = evaluate(model, data_loader_val, device, vis_dir=vis_dir)
mae, mse = test_stats['mae'], test_stats['mse']
line = f'\nepoch: {cur_epoch}, mae: {mae}, mse: {mse}'
print(line)
if __name__ == '__main__':
parser = argparse.ArgumentParser('PET evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
main(args)