Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to pass --ckpt.resume and start from scratch if no ckpt f… #26

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions open_diloco/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from torchdata.stateful_dataloader import StatefulDataLoader
from fsspec.generic import GenericFileSystem
from hivemind.optim.optimizer import logger


GLOBAL_STATE_FILE = "global_state_dict.pt"
Expand All @@ -18,21 +19,30 @@ class CkptConfig(BaseConfig):
path: str = "outputs"
topk: int | None = None # how many checkpoints to keep

def get_resume_path(self):
if self.resume is None:
raise ValueError("Resume path is not set")
elif isinstance(self.resume, bool):
# Using fsspec to list directory contents
fs = GenericFileSystem()
ckpt_files = [f for f in fs.ls(self.path, detail=False) if filter_ckpt_files(f)]

if len(ckpt_files) == 0:
raise ValueError(f"No checkpoints found in {self.path}")
def get_resume_info(ckpt_config: CkptConfig) -> tuple[bool, str | None]:
"""
check if we should resume from a checkpoint, if yes return the path to the checkpoint, otherwise return None
"""
if ckpt_config.resume is None:
return False, None
elif isinstance(ckpt_config.resume, bool):
# Using fsspec to list directory contents
fs = GenericFileSystem()
try:
ckpt_files = [f for f in fs.ls(ckpt_config.path, detail=False) if filter_ckpt_files(f)]
except FileNotFoundError:
logger.info(f"Checkpoint path {ckpt_config.path} not found, starting from scratch")
return False, None

latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
return latest_ckpt
if len(ckpt_files) == 0:
logger.info(f"No checkpoints found in {ckpt_config.path}, starting from scratch")
return False, None

return self.resume
latest_ckpt = max(ckpt_files, key=lambda f: int(f.split("_")[-1]))
return True, latest_ckpt
else:
return True, ckpt_config.resume


def save_checkpoint(
Expand Down
25 changes: 13 additions & 12 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
check_checkpoint_path_access,
delete_old_checkpoints,
get_diloco_rank_dir_name,
get_resume_info,
load_checkpoint,
save_checkpoint,
)
Expand Down Expand Up @@ -192,9 +193,11 @@ def train(config: Config):
sharding_strategy = ShardingStrategy.NO_SHARD
log("Hivemind is used, ShardingStrategy.NO_SHARD is used")

resume_from_ckpt, resume_path = get_resume_info(config.ckpt)

if rank == 0:
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger
metric_logger = logger_cls(project=config.project, config=config.model_dump())
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=resume_from_ckpt)

if config.hv is not None:
log("hivemind diloco enabled")
Expand Down Expand Up @@ -257,27 +260,24 @@ def scheduler_fn(opt):
)

if config.hv is not None:
if config.ckpt.resume:
if resume_from_ckpt:
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
# This is because the DiLoCoOptimizer makes a copy of the model parameters for the state averager which is hard to update later
# We also need to do this on follower workers so that the world_messenger has friends to talk to when it does its two loads
# Otherwise the world messenger will get lonely and hang
fake_optimizer = inner_optimizer(model.parameters())
last_loss = load_checkpoint(
checkpoint_path=os.path.join(
config.ckpt.get_resume_path(), get_diloco_rank_dir_name(config.hv.world_rank)
),
checkpoint_path=os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank)),
model=model,
optimizer=fake_optimizer,
)
del fake_optimizer

if config.ckpt.resume:
base_path = config.ckpt.get_resume_path()
if resume_from_ckpt:
if config.hv is not None:
ckpt_path = os.path.join(base_path, get_diloco_rank_dir_name(config.hv.world_rank))
ckpt_path = os.path.join(resume_path, get_diloco_rank_dir_name(config.hv.world_rank))
else:
ckpt_path = base_path
ckpt_path = resume_path

if world_messenger_hv:
diloco_args = dict(
Expand Down Expand Up @@ -310,7 +310,7 @@ def scheduler_fn(opt):
optimizer.inner_optimizer
) # scheduler(optimizer) should work but better to make it explicit here

if config.ckpt.resume:
if resume_from_ckpt:
last_loss = load_checkpoint(
checkpoint_path=ckpt_path,
model=model,
Expand All @@ -327,7 +327,7 @@ def scheduler_fn(opt):
else:
optimizer = inner_optimizer(model.parameters())
scheduler = scheduler_fn(optimizer)
if config.ckpt.resume:
if resume_from_ckpt:
last_loss = load_checkpoint(
checkpoint_path=ckpt_path,
model=model,
Expand All @@ -340,7 +340,7 @@ def scheduler_fn(opt):
else:
start_step = 0

if config.ckpt.resume:
if resume_from_ckpt:
log(f"Resumed from checkpoint at step {start_step} with loss {last_loss}")

model.train()
Expand Down Expand Up @@ -510,6 +510,7 @@ def scheduler_fn(opt):

if config.max_steps is not None and real_step >= config.max_steps:
break

log("Training completed.")
if rank == 0:
metric_logger.finish()
Expand Down
8 changes: 5 additions & 3 deletions open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def finish(self): ...


class WandbLogger:
def __init__(self, project, config):
wandb.init(project=project, config=config)
def __init__(self, project, config, resume: bool):
wandb.init(
project=project, config=config, resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
wandb.log(metrics)
Expand All @@ -187,7 +189,7 @@ def finish(self):


class DummyLogger:
def __init__(self, project, config):
def __init__(self, project, config, *args, **kwargs):
self.project = project
self.config = config
open(project, "a").close() # Create an empty file at the project path
Expand Down
Loading