Skip to content

Commit

Permalink
remove ddp autograd monkeypatch and impose minimum torchmetrics ver…
Browse files Browse the repository at this point in the history
…sion of 1.6.0
  • Loading branch information
cw-tan committed Nov 14, 2024
1 parent 5e210e2 commit e90c610
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 74 deletions.
7 changes: 7 additions & 0 deletions nequip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import packaging.version

import torch
import torchmetrics

# torch version checks
torch_version = packaging.version.parse(torch.__version__)
Expand All @@ -12,6 +13,12 @@
"1.13"
), f"NequIP supports 1.13.* or later, but {torch_version} found"

# torchmetrics >= 1.6.0 for ddp autograd
# https://github.com/Lightning-AI/torchmetrics/releases/tag/v1.6.0
torchmetrics_version = packaging.version.parse(torchmetrics.__version__)
assert torchmetrics_version >= packaging.version.parse(
"1.6.0"
), f"NequIP requires torchmetrics>=1.6.0 for ddp training but {torchmetrics_version} found"

# Load all installed nequip extension packages
# This allows installed extensions to register themselves in
Expand Down
69 changes: 0 additions & 69 deletions nequip/train/_metrics_utils.py

This file was deleted.

4 changes: 0 additions & 4 deletions nequip/train/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from hydra.utils import instantiate
from nequip.data import AtomicDataDict
from nequip.utils import RankedLogger
from ._metrics_utils import gather_all_tensors
import warnings
from typing import Optional, Dict

Expand Down Expand Up @@ -99,9 +98,6 @@ def __init__(
for metric_dict in loss["metrics"]:
# silently ensure that dist_sync_on_step is true for loss metrics
metric_dict["metric"]["dist_sync_on_step"] = True
# TODO: remove following once torchmetrics PR is merged
# https://github.com/Lightning-AI/torchmetrics/pull/2754
metric_dict["metric"]["dist_sync_fn"] = gather_all_tensors

# == instantiate loss ==
self.loss = instantiate(loss, type_names=self.model.type_names)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"importlib_metadata; python_version<'3.10'",
"hydra-core",
"lightning",
"torchmetrics",
"torchmetrics>=1.6.0",
]

[project.urls]
Expand Down

0 comments on commit e90c610

Please sign in to comment.