-
Notifications
You must be signed in to change notification settings - Fork 30
/
inference.py
62 lines (48 loc) · 2.1 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import argparse
import torch
import data as Dataset
from config import Config
from util.distributed import init_dist
from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer
def parse_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--config', default='./config/fashion_512.yaml')
parser.add_argument('--name', default=None)
parser.add_argument('--checkpoints_dir', default='result',
help='Dir for saving logs and models.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--which_iter', type=int, default=None)
parser.add_argument('--no_resume', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--single_gpu', action='store_true')
parser.add_argument('--output_dir', type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
# get training options
args = parse_args()
set_random_seed(args.seed)
opt = Config(args.config, args, is_train=False)
if not args.single_gpu:
opt.local_rank = args.local_rank
init_dist(opt.local_rank)
opt.device = torch.cuda.current_device()
opt.logdir = os.path.join(opt.checkpoints_dir, opt.name)
# create a dataset
test_dataset, train_dataset = Dataset.get_train_val_dataloader(opt.data)
# create a model
net_G, net_D, net_G_ema, opt_G, opt_D, sch_G, sch_D \
= get_model_optimizer_and_scheduler(opt)
trainer = get_trainer(opt, net_G, net_D, net_G_ema, \
opt_G, opt_D, sch_G, sch_D, \
test_dataset)
current_epoch, current_iteration = trainer.load_checkpoint(
opt, args.which_iter)
if not args.single_gpu:
test_dataset.sampler.set_epoch(current_epoch)
output_dir = os.path.join(
args.output_dir,
'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration)
)
trainer.test(test_dataset, output_dir, current_iteration)