-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
84 lines (69 loc) · 2.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import os
import torch
from dataset import create_dataset
from metrics import compute_model_metrics
from models import create_model
from utils.config import read_config_from_file
from utils.plot import display_confusion_matrix
from utils.torch import test_model
os.environ["TORCH_HOME"] = "./.cache"
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_folder", type=str, required=True)
parser.add_argument("--val_samples_file", type=str, required=True)
parser.add_argument("--test_samples_file", type=str, required=True)
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--experiment_cfg", type=str, required=True)
parser.add_argument(
"--checkpoint_dir",
type=str,
help="path to save model checkpoints",
default="checkpoints/",
)
args = parser.parse_args()
experiment_config = read_config_from_file(args.experiment_cfg)
return args, experiment_config
if __name__ == "__main__":
args, experiment_config = parse_arguments()
val_dataset = create_dataset(
args.val_samples_file,
args.dataset_folder,
experiment_config.data_kwargs.batch_size,
experiment_config.data_kwargs.many_to_one_setting,
experiment_config.data_kwargs.image_size,
upsample=False,
split="val",
)
test_dataset = create_dataset(
args.test_samples_file,
args.dataset_folder,
experiment_config.data_kwargs.batch_size,
experiment_config.data_kwargs.many_to_one_setting,
experiment_config.data_kwargs.image_size,
upsample=False,
split="test",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(experiment_config).to(device)
pretrained_weights = torch.load(args.checkpoint_path, map_location=device)
model.load_state_dict(pretrained_weights)
print('loaded weights')
test_model(
experiment_config,
val_dataset,
model,
device,
f"validation_{experiment_config.model_kwargs.encoder.name}_{experiment_config.model_kwargs.temporal.name}",
compute_model_metrics,
display_confusion_matrix,
)
test_model(
experiment_config,
test_dataset,
model,
device,
f"test_{experiment_config.model_kwargs.encoder.name}_{experiment_config.model_kwargs.temporal.name}",
compute_model_metrics,
display_confusion_matrix,
)