Skip to content

Commit

Permalink
make lightning module get world size from trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Nov 14, 2024
1 parent e90c610 commit da321ec
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
3 changes: 3 additions & 0 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def training_data_stats(stat_name: str):
info_dict=info_dict,
)

# pass world size from trainer to NequIPLightningModule
nequip_module.world_size = trainer.world_size

# === loop of run types ===
# restart behavior is such that
# - train from ckpt uses the correct ckpt file to restore training state (so it is given a specific `ckpt_path`)
Expand Down
11 changes: 4 additions & 7 deletions nequip/train/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,8 @@ def __init__(
# == DDP concerns for loss ==

# to account for loss contributions from multiple ranks later on
self.world_size = (
torch.distributed.get_world_size(torch.distributed.group.WORLD)
if torch.distributed.is_initialized()
else 1
)
# NOTE: this must be updated externally by the script that sets up the training run
self.world_size = 1

# add dist_sync_on_step for loss metrics
for metric_dict in loss["metrics"]:
Expand Down Expand Up @@ -171,9 +168,9 @@ def training_step(
)
self.log_dict(loss_dict)
# In DDP training, because gradients are averaged rather than summed over nodes,
# we get an effective factor of 1/n_rank applied to the loss. Because our loss already
# we get an effective factor of 1/n_rank applied to the loss. Because our loss already
# manages correct accumulation of the metric over ranks, we want to cancel out this
# unnecessary 1/n_rank term. If DDP is disabled, this is 1 and has no effect.
# unnecessary 1/n_rank term. If DDP is disabled, this is 1 and has no effect.
loss = (
loss_dict[f"train_loss_step{self.logging_delimiter}weighted_sum"]
* self.world_size
Expand Down

0 comments on commit da321ec

Please sign in to comment.