Skip to content

Commit

Permalink
Merge branch 'DeepLearnPhysics:develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-drielsma authored Sep 12, 2024
2 parents 7841df6 + 2e147c6 commit b002d8b
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions bin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from spine.main import run


def main(config, source, source_list, output, n, nskip, detect_anomaly):
def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir, weight_prefix, weight_path):
"""Main driver for training/validation/inference/analysis.
Performs these basic functions:
Expand All @@ -48,6 +48,12 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly):
Number of iterations to skip
detect_anomaly : bool
Whether to turn on anomaly detection in torch
log_dir : str
Path to the directory for storing the training log
weight_prefix : str
Path to the directory for storing the training weights
weight_path : str
Path string a weight file or pattern for multiple weight files to load the model weights
"""
# Try to find configuration file using the absolute path or under
# the 'config' directory of the parent SPINE repository
Expand Down Expand Up @@ -101,6 +107,17 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly):
if output is not None and 'writer' in cfg['io']:
cfg['io']['writer']['file_name'] = output

if log_dir is not None:
cfg['base']['log_dir'] = log_dir

if weight_prefix is not None:
if not 'train' in cfg['base']:
raise KeyError('--weight_prefix flag provided: must specify `train` in the `base` block.')
cfg['base']['train']['weight_prefix']=weight_prefix

if weight_path is not None:
cfg['model']['weight_path']=weight_path

# Turn on PyTorch anomaly detection, if requested
if detect_anomaly is not None:
assert 'model' in cfg, (
Expand Down Expand Up @@ -149,8 +166,20 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly):
help='Turns on autograd.detect_anomaly for debugging',
action='store_const', const=True)

parser.add_argument('--log_dir',
help='Log directory',
type=str, default=None)

parser.add_argument('--weight_prefix',
help='Prefix for weight files',
type=str, default=None)

parser.add_argument('--weight_path',
help='Path to a weight file (or pattern to multiple weight files)',
type=str, default=None)

args = parser.parse_args()

# Execute the main function
main(args.config, args.source, args.source_list, args.output, args.n,
args.nskip, args.detect_anomaly)
args.nskip, args.detect_anomaly, args.log_dir, args.weight_prefix, args.weight_path)

0 comments on commit b002d8b

Please sign in to comment.