-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_ofa_net.py
79 lines (69 loc) · 2.35 KB
/
eval_ofa_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import torch
import argparse
from ofa.imagenet_classification.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_classification.run_manager import ImagenetRunConfig, RunManager
from ofa.model_zoo import ofa_net
parser = argparse.ArgumentParser()
parser.add_argument(
'-p',
'--path',
help='The path of imagenet',
type=str,
default='/dataset/imagenet')
parser.add_argument(
'-g',
'--gpu',
help='The gpu(s) to use',
type=str,
default='all')
parser.add_argument(
'-b',
'--batch-size',
help='The batch on every device for validation',
type=int,
default=100)
parser.add_argument(
'-j',
'--workers',
help='Number of workers',
type=int,
default=20)
parser.add_argument(
'-n',
'--net',
metavar='OFANET',
default='ofa_resnet50',
choices=['ofa_mbv3_d234_e346_k357_w1.0', 'ofa_mbv3_d234_e346_k357_w1.2', 'ofa_proxyless_d234_e346_k357_w1.3',
'ofa_resnet50'],
help='OFA networks')
args = parser.parse_args()
if args.gpu == 'all':
device_list = range(torch.cuda.device_count())
args.gpu = ','.join(str(_) for _ in device_list)
else:
device_list = [int(_) for _ in args.gpu.split(',')]
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.batch_size = args.batch_size * max(len(device_list), 1)
ImagenetDataProvider.DEFAULT_PATH = args.path
ofa_network = ofa_net(args.net, pretrained=True)
run_config = ImagenetRunConfig(test_batch_size=args.batch_size, n_worker=args.workers)
""" Randomly sample a sub-network,
you can also manually set the sub-network using:
ofa_network.set_active_subnet(ks=7, e=6, d=4)
"""
ofa_network.sample_active_subnet()
subnet = ofa_network.get_active_subnet(preserve_weight=True)
""" Test sampled subnet
"""
run_manager = RunManager('.tmp/eval_subnet', subnet, run_config, init=False)
# assign image size: 128, 132, ..., 224
run_config.data_provider.assign_active_img_size(224)
run_manager.reset_running_statistics(net=subnet)
print('Test random subnet:')
print(subnet.module_str)
loss, (top1, top5) = run_manager.validate(net=subnet)
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (loss, top1, top5))