Skip to content

Commit

Permalink
diloco ckpt check fix
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 8, 2024
1 parent ba63895 commit a53b294
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def log(message):
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")


def check_checkpoint_path_access(checkpoint_path: str, rank: int):
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None):
if world_rank_hv:
dummy_file_path = os.path.join(
checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt"
)
else:
dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt")
with fsspec.open(dummy_file_path, "w") as f:
f.write("This is a dummy file for testing access.")
gfs = GenericFileSystem()
Expand Down Expand Up @@ -221,7 +226,7 @@ def train(config: Config):
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False)

if local_rank == 0:
check_checkpoint_path_access(config.checkpoint_path, rank)
check_checkpoint_path_access(config.checkpoint_path, rank, config.hv.world_rank if config.hv else None)

# DataLoader preparation
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
Expand Down

0 comments on commit a53b294

Please sign in to comment.