From e1f2fbd05c44dbfa98ade952d9132be60576e4db Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 23 Aug 2024 10:07:51 +0000 Subject: [PATCH 1/3] feat: allow to pass --ckpt.resume and start from scratch if no ckpt files are present --- open_diloco/ckpt_utils.py | 34 ++++++++++++++++++++-------------- open_diloco/train_fsdp.py | 22 +++++++++++----------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/open_diloco/ckpt_utils.py b/open_diloco/ckpt_utils.py index 261b7c4..c978f76 100644 --- a/open_diloco/ckpt_utils.py +++ b/open_diloco/ckpt_utils.py @@ -6,6 +6,7 @@ import os from torchdata.stateful_dataloader import StatefulDataLoader from fsspec.generic import GenericFileSystem +from hivemind.optim.optimizer import logger GLOBAL_STATE_FILE = "global_state_dict.pt" @@ -18,21 +19,26 @@ class CkptConfig(BaseConfig): path: str = "outputs" topk: int | None = None # how many checkpoints to keep - def get_resume_path(self): - if self.resume is None: - raise ValueError("Resume path is not set") - elif isinstance(self.resume, bool): - # Using fsspec to list directory contents - fs = GenericFileSystem() - ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)] - if len(ckpt_files) == 0: - raise ValueError(f"No checkpoints found in {self.path}") - - latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1])) - return latest_ckpt - - return self.resume +def get_resume_info(ckpt_config: CkptConfig) -> tuple[bool, str | None]: + """ + check if we should resume from a checkpoint, if yes return the path to the checkpoint, otherwise return None + """ + if ckpt_config.resume is None: + return False, None + elif isinstance(ckpt_config.resume, bool): + # Using fsspec to list directory contents + fs = GenericFileSystem() + ckpt_files = [f for f in fs.ls(ckpt_config.path, detail=False) if filter_ckpt_files(f)] + + if len(ckpt_files) == 0: + logger.info(f"No checkpoints found in {ckpt_config.path}, starting from scratch") + return False, None + + latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1])) + return True, latest_ckpt + else: + return True, ckpt_config.resume def save_checkpoint( diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 3c2b031..4fb3676 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -41,6 +41,7 @@ check_checkpoint_path_access, delete_old_checkpoints, get_diloco_rank_dir_name, + get_resume_info, load_checkpoint, save_checkpoint, ) @@ -256,28 +257,27 @@ def scheduler_fn(opt): num_training_steps=config.total_steps, ) + resume_from_ckpt, resume_path = get_resume_info(config.ckpt) + if config.hv is not None: - if config.ckpt.resume: + if resume_from_ckpt: # We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer # This is because the DiLoCoOptimizer makes a copy of the model parameters for the state averager which is hard to update later # We also need to do this on follower workers so that the world_messenger has friends to talk to when it does its two loads # Otherwise the world messenger will get lonely and hang fake_optimizer = inner_optimizer(model.parameters()) last_loss = load_checkpoint( - checkpoint_path=os.path.join( - config.ckpt.get_resume_path(), get_diloco_rank_dir_name(config.hv.world_rank) - ), + checkpoint_path=os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank)), model=model, optimizer=fake_optimizer, ) del fake_optimizer - if config.ckpt.resume: - base_path = config.ckpt.get_resume_path() + if resume_from_ckpt: if config.hv is not None: - ckpt_path = os.path.join(base_path, get_diloco_rank_dir_name(config.hv.world_rank)) + ckpt_path = os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank)) else: - ckpt_path = base_path + ckpt_path = resume_path if world_messenger_hv: diloco_args = dict( @@ -310,7 +310,7 @@ def scheduler_fn(opt): optimizer.inner_optimizer ) # scheduler(optimizer) should work but better to make it explicit here - if config.ckpt.resume: + if resume_from_ckpt: last_loss = load_checkpoint( checkpoint_path=ckpt_path, model=model, @@ -327,7 +327,7 @@ def scheduler_fn(opt): else: optimizer = inner_optimizer(model.parameters()) scheduler = scheduler_fn(optimizer) - if config.ckpt.resume: + if resume_from_ckpt: last_loss = load_checkpoint( checkpoint_path=ckpt_path, model=model, @@ -340,7 +340,7 @@ def scheduler_fn(opt): else: start_step = 0 - if config.ckpt.resume: + if resume_from_ckpt: log(f"Resumed from checkpoint at step {start_step} with loss {last_loss}") model.train() From 2acc597b31eb4142b7cf9180eb0d5489970d01cf Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 23 Aug 2024 10:29:32 +0000 Subject: [PATCH 2/3] feat: not failing if the ckpt folder does not exist yet --- open_diloco/ckpt_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/open_diloco/ckpt_utils.py b/open_diloco/ckpt_utils.py index c978f76..cc917cd 100644 --- a/open_diloco/ckpt_utils.py +++ b/open_diloco/ckpt_utils.py @@ -29,7 +29,11 @@ def get_resume_info(ckpt_config: CkptConfig) -> tuple[bool, str | None]: elif isinstance(ckpt_config.resume, bool): # Using fsspec to list directory contents fs = GenericFileSystem() - ckpt_files = [f for f in fs.ls(ckpt_config.path, detail=False) if filter_ckpt_files(f)] + try: + ckpt_files = [f for f in fs.ls(ckpt_config.path, detail=False) if filter_ckpt_files(f)] + except FileNotFoundError: + logger.info(f"Checkpoint path {ckpt_config.path} not found, starting from scratch") + return False, None if len(ckpt_files) == 0: logger.info(f"No checkpoints found in {ckpt_config.path}, starting from scratch") From 383df3417eb6d982286d8bf9598f24ac93741170 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 23 Aug 2024 10:48:28 +0000 Subject: [PATCH 3/3] feat: use wandb auto resuming feature --- open_diloco/train_fsdp.py | 7 ++++--- open_diloco/utils.py | 8 +++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 4fb3676..4d5ef3e 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -193,9 +193,11 @@ def train(config: Config): sharding_strategy = ShardingStrategy.NO_SHARD log("Hivemind is used, ShardingStrategy.NO_SHARD is used") + resume_from_ckpt, resume_path = get_resume_info(config.ckpt) + if rank == 0: logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger - metric_logger = logger_cls(project=config.project, config=config.model_dump()) + metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=resume_from_ckpt) if config.hv is not None: log("hivemind diloco enabled") @@ -257,8 +259,6 @@ def scheduler_fn(opt): num_training_steps=config.total_steps, ) - resume_from_ckpt, resume_path = get_resume_info(config.ckpt) - if config.hv is not None: if resume_from_ckpt: # We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer @@ -510,6 +510,7 @@ def scheduler_fn(opt): if config.max_steps is not None and real_step >= config.max_steps: break + log("Training completed.") if rank == 0: metric_logger.finish() diff --git a/open_diloco/utils.py b/open_diloco/utils.py index 9940747..0ced218 100644 --- a/open_diloco/utils.py +++ b/open_diloco/utils.py @@ -176,8 +176,10 @@ def finish(self): ... class WandbLogger: - def __init__(self, project, config): - wandb.init(project=project, config=config) + def __init__(self, project, config, resume: bool): + wandb.init( + project=project, config=config, resume="auto" if resume else None + ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): wandb.log(metrics) @@ -187,7 +189,7 @@ def finish(self): class DummyLogger: - def __init__(self, project, config): + def __init__(self, project, config, *args, **kwargs): self.project = project self.config = config open(project, "a").close() # Create an empty file at the project path