diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index 43df77b..ccba643 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -1,6 +1,7 @@ import os from contextlib import nullcontext import datetime +from typing import Literal import torch import torch.distributed as dist @@ -21,11 +22,14 @@ ) from torch.distributed.device_mesh import init_device_mesh from hivemind.optim.optimizer import logger - from open_diloco.utils import ( FakeTokenizedDataset, get_sharding_strategy, + WandbLogger, + DummyLogger, ) +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) TEST_VOCAB_SIZE = 1024 @@ -61,10 +65,31 @@ class Config(BaseConfig): warmup_steps: int = 1000 total_steps: int = 88_000 sharding_strategy: str = "FULL_SHARD" + project: str = "debug" + metric_logger_type: Literal["wandb", "dummy"] = "wandb" + fake_data: bool = False + + +def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader: + if config.fake_data: + train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) + else: + ds = load_dataset(config.dataset_name_or_path, "en", streaming=True) + def tokenize_function(data): + outputs = tokenizer( + data["text"], + truncation=True, + max_length=config.seq_length, + padding="max_length", + ) + return outputs -def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: - train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) + tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[ + "train" + ] + + train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) @@ -141,6 +166,10 @@ def train(config: Config): model.train() + if rank == 0: + logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger + metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False) + loss_batch = 0 train_dataloader_iterator = iter(train_dataloader) @@ -169,9 +198,18 @@ def train(config: Config): inner_optimizer.zero_grad() if rank == 0: - log( - f"step: {outer_step} inner: {inner_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in inner_optimizer.param_groups][0]}" - ) + real_step = outer_step * config.diloco.local_steps + inner_step + 1 + inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] + + metrics = { + "Loss": loss_batch.item(), + "step": real_step, + "inner_lr": inner_lr, + } + + metric_logger.log(metrics) + + log(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}") loss_batch = 0 @@ -194,6 +232,9 @@ def train(config: Config): outer_step += 1 + if rank == 0: + metric_logger.finish() + if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die