-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
67 lines (58 loc) · 2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Entry filename: main.py
Code is forked from https://github.com/hugochan/IDGL as it provides a flexible way
to configure hyper-parameters and evaluate model performance. Great thanks to the authors.
"""
import argparse
import yaml
import numpy as np
from model import ClfHandler
from utils.func import args_grid, print_config
def main(handler, config):
model = handler(config)
if config['test']:
metrics = model.exec_test()
else:
metrics = model.exec()
print('[INFO] Metrics:', metrics)
def multi_run_main(handler, config):
hyperparams = []
for k, v in config.items():
if isinstance(v, list):
hyperparams.append(k)
configs = args_grid(config)
for cnf in configs:
print('\n')
for k in hyperparams:
cnf['save_path'] += '-{}_{}'.format(k, cnf[k])
model = handler(cnf)
if cnf['test']:
print(cnf['test_save_path'])
metrics = model.exec_test()
else:
print(cnf['save_path'])
metrics = model.exec()
print('[INFO] Metrics:', metrics)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-f', required=True, type=str, default='config/cfg_clf_mix.yml', help='path to the config file')
parser.add_argument('--handler', '-d', required=True, type=str, default='clf', help='model handler (clf or others)')
parser.add_argument('--multi_run', action='store_true', help='if execute multi-runs')
args = vars(parser.parse_args())
return args
def get_config(config_path="config/config.yml"):
with open(config_path, "r") as setting:
config = yaml.load(setting, Loader=yaml.FullLoader)
return config
if __name__ == '__main__':
cfg = get_args()
config = get_config(cfg['config'])
print_config(config)
if cfg['handler'] == 'clf':
handler = ClfHandler
else:
handler = None
if cfg['multi_run']:
multi_run_main(handler, config)
else:
main(handler, config)