-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
40 lines (31 loc) · 1.09 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import warnings
from pathlib import Path
from omegaconf import OmegaConf
from dask.distributed import Client
# filter some warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*Detected KeyboardInterrupt, attempting graceful shutdown....*")
def main():
import sys
from uq.configs.config import get_config
from uq import utils
from uq.runner import run_all
config = OmegaConf.from_cli(sys.argv)
config = get_config(config)
config.clean_previous = True
OmegaConf.resolve(config)
Path(config.log_dir).mkdir(parents=True, exist_ok=True)
print(config.log_dir)
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=True)
# Set parallelization
manager = 'joblib'
if config.nb_workers == 1:
manager = 'sequential'
if manager == 'dask':
Client(n_workers=config.nb_workers, threads_per_worker=1, memory_limit=None)
# Train model
return run_all(config, manager=manager)
if __name__ == "__main__":
main()