From 51e2aeae66ea8b83962079b2b664be7687ade87b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 22 Aug 2024 17:37:45 +0000 Subject: [PATCH] fix multi gpu hivemind ckpt resume --- open_diloco/train_fsdp.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index ab4efe2..3c2b031 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -272,6 +272,13 @@ def scheduler_fn(opt): ) del fake_optimizer + if config.ckpt.resume: + base_path = config.ckpt.get_resume_path() + if config.hv is not None: + ckpt_path = os.path.join(base_path, get_diloco_rank_dir_name(config.hv.world_rank)) + else: + ckpt_path = base_path + if world_messenger_hv: diloco_args = dict( dht=dht, @@ -305,9 +312,7 @@ def scheduler_fn(opt): if config.ckpt.resume: last_loss = load_checkpoint( - checkpoint_path=os.path.join( - config.ckpt.get_resume_path(), get_diloco_rank_dir_name(config.hv.world_rank) - ), + checkpoint_path=ckpt_path, model=model, optimizer=optimizer.inner_optimizer, scheduler=scheduler, @@ -324,7 +329,7 @@ def scheduler_fn(opt): scheduler = scheduler_fn(optimizer) if config.ckpt.resume: last_loss = load_checkpoint( - checkpoint_path=config.ckpt.get_resume_path(), + checkpoint_path=ckpt_path, model=model, optimizer=optimizer, scheduler=scheduler,