Skip to content

Commit

Permalink
add log_dir and weight_prefix option flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Kazuhiro Terao committed Aug 30, 2024
1 parent 0aea949 commit 03b6c0b
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions bin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from spine.main import run


def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir, weight_prefix):
def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir, weight_prefix, weight_path):
"""Main driver for training/validation/inference/analysis.
Performs these basic functions:
Expand All @@ -48,6 +48,12 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
Number of iterations to skip
detect_anomaly : bool
Whether to turn on anomaly detection in torch
log_dir : str
Path to the directory for storing the training log
weight_prefix : str
Path to the directory for storing the training weights
weight_path : str
Path string a weight file or pattern for multiple weight files to load the model weights
"""
# Try to find configuration file using the absolute path or under
# the 'config' directory of the parent SPINE repository
Expand Down Expand Up @@ -109,6 +115,9 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
raise KeyError('--weight_prefix flag provided: must specify `train` in the `base` block.')
cfg['base']['train']['weight_prefix']=weight_prefix

if weight_path is not None:
cfg['model']['weight_path']=weight_path

# Turn on PyTorch anomaly detection, if requested
if detect_anomaly is not None:
assert 'model' in cfg, (
Expand Down Expand Up @@ -165,8 +174,12 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
help='Prefix for weight files',
type=str, default=None)

parser.add_argument('--weight_path',
help='Path to a weight file (or pattern to multiple weight files)',
type=str, default=None)

args = parser.parse_args()

# Execute the main function
main(args.config, args.source, args.source_list, args.output, args.n,
args.nskip, args.detect_anomaly, args.log_dir, args.weight_prefix)
args.nskip, args.detect_anomaly, args.log_dir, args.weight_prefix, args.weight_path)

0 comments on commit 03b6c0b

Please sign in to comment.