Skip to content

Commit

Permalink
For distributed training make batch size be the global batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Nov 12, 2024
1 parent 6a07e88 commit 9be9887
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import logging
import math
from pathlib import Path
from warnings import warn

Expand Down Expand Up @@ -75,8 +76,29 @@ def train_session(
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]

set_up_default_logger()

sessions_directory_path = Path(f'sessions')
session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}'
sessions_directory_path.mkdir(exist_ok=True, parents=True)
loggers = [
CSVLogger(save_dir=sessions_directory_path, name=session_name),
WandbLogger(save_dir=sessions_directory_path, name=session_name)]
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
log_every_n_steps=0,
accelerator=system_configuration.accelerator,
logger=loggers,
)

train_dataset = InterleavedDataset.new(*train_datasets)
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process

local_batch_size = round(hyperparameter_configuration.batch_size / trainer.world_size)
if local_batch_size == 0:
local_batch_size = 1

if workers_per_dataloader == 0:
prefetch_factor = None
persistent_workers = False
Expand All @@ -85,7 +107,7 @@ def train_session(
persistent_workers = True
train_dataloader = DataLoader(
train_dataset,
batch_size=hyperparameter_configuration.batch_size,
batch_size=local_batch_size,
pin_memory=True,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
Expand All @@ -95,7 +117,7 @@ def train_session(
for validation_dataset in validation_datasets:
validation_dataloader = DataLoader(
validation_dataset,
batch_size=hyperparameter_configuration.batch_size,
batch_size=local_batch_size,
pin_memory=True,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
Expand All @@ -105,18 +127,4 @@ def train_session(

lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric,
logging_metrics=logging_metrics)
sessions_directory_path = Path(f'sessions')
session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}'
sessions_directory_path.mkdir(exist_ok=True, parents=True)
loggers = [
CSVLogger(save_dir=sessions_directory_path, name=session_name),
WandbLogger(save_dir=sessions_directory_path, name=session_name)]
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
log_every_n_steps=0,
accelerator=system_configuration.accelerator,
logger=loggers,
)
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)

0 comments on commit 9be9887

Please sign in to comment.