Skip to content

Commit

Permalink
Fix model loading in pruning script
Browse files Browse the repository at this point in the history
  • Loading branch information
andreysher committed Aug 16, 2023
1 parent e213c3e commit cecec57
Showing 1 changed file with 1 addition and 18 deletions.
19 changes: 1 addition & 18 deletions prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,30 +129,13 @@ def measure_flops(model):

train_loader = get_train_loader(cfg)
resume_epoch = 0
# resume now work as model ckpt
if cfg.model_ckpt is not None:
net = torch.load(cfg.model_ckpt, map_location="cpu")["model_ckpt"]
if cfg.resume is not None:
dist_print("==> Resume model from " + cfg.resume)
resume_dict = torch.load(cfg.resume, map_location="cpu")
net = resume_dict["model"]
resume_epoch = 0
else:
net = get_model(cfg)
raise ValueError("Pruning works only for model_ckpt")

if distributed:
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank])
optimizer = get_optimizer(net, cfg)

if cfg.finetune is not None:
dist_print("finetune from ", cfg.finetune)
state_all = torch.load(cfg.finetune, map_location="cpu")["model"]
state_clip = {} # only use backbone parameters
for k, v in state_all.items():
if "model" in k:
state_clip[k] = v
net.load_state_dict(state_clip, strict=False)

scheduler = get_scheduler(optimizer, cfg, len(train_loader))
dist_print(len(train_loader))
metric_dict = get_metric_dict(cfg)
Expand Down

0 comments on commit cecec57

Please sign in to comment.