Skip to content

Commit

Permalink
add wandb and real data
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 11, 2024
1 parent ee89d33 commit c217dd4
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from contextlib import nullcontext
import datetime
from typing import Literal

import torch
import torch.distributed as dist
Expand All @@ -21,11 +22,14 @@
)
from torch.distributed.device_mesh import init_device_mesh
from hivemind.optim.optimizer import logger

from open_diloco.utils import (
FakeTokenizedDataset,
get_sharding_strategy,
WandbLogger,
DummyLogger,
)
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
TEST_VOCAB_SIZE = 1024
Expand Down Expand Up @@ -61,10 +65,31 @@ class Config(BaseConfig):
warmup_steps: int = 1000
total_steps: int = 88_000
sharding_strategy: str = "FULL_SHARD"
project: str = "debug"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
fake_data: bool = False


def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader:
if config.fake_data:
train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE)
else:
ds = load_dataset(config.dataset_name_or_path, "en", streaming=True)

def tokenize_function(data):
outputs = tokenizer(
data["text"],
truncation=True,
max_length=config.seq_length,
padding="max_length",
)
return outputs

def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader:
train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE)
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[
"train"
]

train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Expand Down Expand Up @@ -141,6 +166,10 @@ def train(config: Config):

model.train()

if rank == 0:
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)

loss_batch = 0

train_dataloader_iterator = iter(train_dataloader)
Expand Down Expand Up @@ -169,9 +198,18 @@ def train(config: Config):
inner_optimizer.zero_grad()

if rank == 0:
log(
f"step: {outer_step} inner: {inner_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in inner_optimizer.param_groups][0]}"
)
real_step = outer_step * config.diloco.local_steps + inner_step + 1
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

metrics = {
"Loss": loss_batch.item(),
"step": real_step,
"inner_lr": inner_lr,
}

metric_logger.log(metrics)

log(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}")

loss_batch = 0

Expand All @@ -194,6 +232,9 @@ def train(config: Config):

outer_step += 1

if rank == 0:
metric_logger.finish()


if __name__ == "__main__":
# Allow eager fallback during production so that that the training runs dont die
Expand Down

0 comments on commit c217dd4

Please sign in to comment.