Skip to content

Commit

Permalink
fix data
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 11, 2024
1 parent c217dd4 commit 94bc0ae
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Config(BaseConfig):
project: str = "debug"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
fake_data: bool = False
dataset_name_or_path: str = "allenai/c4"


def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader:
Expand Down Expand Up @@ -127,7 +128,7 @@ def train(config: Config):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
tokenizer.pad_token = "</s>" # Ensure pad token is set for models that need it

train_dataloader = get_dataloader(tokenizer, world_size, rank, local_rank, config)
train_dataloader = get_dataloader(tokenizer, world_size, rank, config)

model = get_model(config)
model = model.to(local_rank)
Expand Down

0 comments on commit 94bc0ae

Please sign in to comment.