diff --git a/tools/visualize_skel_movement.py b/tools/visualize_skel_movement.py index 6242f26..dbd85e6 100644 --- a/tools/visualize_skel_movement.py +++ b/tools/visualize_skel_movement.py @@ -13,6 +13,8 @@ argparser.add_argument('--dataset', type=str, required=True, help='Path to the dataset.') argparser.add_argument('--sample', type=int, required=True, help='Sample index to visualize.') argparser.add_argument('--cache-file', type=str, required=False, help='Path to the cache file.') +argparser.add_argument('--checkpoint', type=str, required=False, help='Path to the checkpoint file.', + default='/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') args = argparser.parse_args() @@ -30,7 +32,7 @@ # checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') # model_state_dict = checkpoint['model_state_dict'] # model.load_state_dict(model_state_dict) - model.from_pretrained('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') + model.from_pretrained(args.checkpoint) model = model.to('cpu') skeleton, label = dataset[args.sample] seq_len, n_bodies, n_joints, n_dims = skeleton.shape