From f594081a3ec0cb18ed1a234379c590256f50da3d Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 20 May 2021 13:34:50 -0400 Subject: [PATCH 001/126] add callbacks --- configs/full.yaml | 10 ++++++---- nequip/train/early_stopping.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/configs/full.yaml b/configs/full.yaml index 7eea0f53..945bb070 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -93,14 +93,16 @@ ema_use_num_updates: true early_stopping_patiences: # stop early if a metric value stopped decreasing for n epochs Validation_loss: 50 # Training_loss: 100 # - e_mae: 100 # + e_mae: 100 # early_stopping_delta: # If delta is defined, a tiny decrease smaller than delta will not be considered as a decrease Training_loss: 0.005 # early_stopping_cumulative_delta: false # If True, the minimum value recorded will not be updated when the decrease is smaller than delta early_stopping_lower_bounds: # stop early if a metric value is lower than the bound - LR: 1.0e-10 # + LR: 1.0e-10 # early_stopping_upper_bounds: # stop early if a metric value is higher than the bound - wall: 1.0e+100 # + wall: 1.0e+100 # +end_of_epoch_callbacks: # call back functions to adjust hyper-parameters +- !!python/name:nequip.train.callbacks.cos_sin # two examples (equal_loss and cos_sin) are listed in nequip.train.callbacks # loss function loss_coeffs: # different weights to use in a weighted loss functions @@ -203,4 +205,4 @@ trainable_global_rescale_scale: false # Options for e3nn's set_optimization_defaults. A dict: # e3nn_optimization_defaults -# ... \ No newline at end of file +# ... diff --git a/nequip/train/early_stopping.py b/nequip/train/early_stopping.py index c1aee9e5..be5a67ef 100644 --- a/nequip/train/early_stopping.py +++ b/nequip/train/early_stopping.py @@ -86,11 +86,13 @@ def __call__(self, metrics) -> None: self.counters[key] = 0 for key, bound in self.lower_bounds.items(): + print(key, bound, type(bound), metrics[key], type(metrics[key])) if metrics[key] < bound: stop_args += f" {key} is smaller than {bound}" stop = True for key, bound in self.upper_bounds.items(): + print(key, bound, type(bound), metrics[key], type(metrics[key])) if metrics[key] > bound: stop_args += f" {key} is larger than {bound}" stop = True From 324015af684e61ce2ce19c23919abf06174fd831 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 20 May 2021 13:38:46 -0400 Subject: [PATCH 002/126] log change --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48bbf0b5..e91ab5a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +- examples of `end_of_epoch_callbacks` are listed in nequip.train.callbacks ## [0.3.1] ### Fixed From 95bd2810910de3dd87fd2a3949728bc4b93ea3ed Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 May 2021 09:21:16 -0400 Subject: [PATCH 003/126] add the call back file --- nequip/train/callbacks.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 nequip/train/callbacks.py diff --git a/nequip/train/callbacks.py b/nequip/train/callbacks.py new file mode 100644 index 00000000..d10e9528 --- /dev/null +++ b/nequip/train/callbacks.py @@ -0,0 +1,34 @@ +import torch +import logging + + +def equal_loss(self): + + loss_f = self.mae_dict["Validation_loss_f"] + loss_e = self.mae_dict["Validation_loss_e"] + coeff_e = loss_f / loss_e + self.loss.coeffs["forces"] = torch.as_tensor(1, dtype=torch.get_default_dtype()) + self.loss.coeffs["total_energy"] = torch.as_tensor( + loss_f / loss_e, dtype=torch.get_default_dtype() + ) + self.logger.info(f"# update loss coeffs to 1 and {loss_f/loss_e}") + + +def cos_sin(self): + + f = self.kwargs.get("loss_f_mag", 1) + e = self.kwargs.get("loss_e_mag", 1) + pi = self.kwargs.get("loss_coeff_pi", 20) + + dtype = torch.get_default_dtype() + + f = torch.as_tensor(f, dtype=dtype) + e = torch.as_tensor(e, dtype=dtype) + + f = f * torch.sin(torch.as_tensor(self.iepoch / pi, dtype=dtype)) + e = e * torch.cos(torch.as_tensor(self.iepoch / pi, dtype=dtype)) + + self.loss.coeffs["forces"] = f + self.loss.coeffs["total_energy"] = e + + self.logger.info(f"# update loss coeffs to {f} {e}") From d53cdeba34576f6352735876fe0b0835cd0063c0 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 May 2021 17:01:38 -0400 Subject: [PATCH 004/126] remove prints --- nequip/train/early_stopping.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nequip/train/early_stopping.py b/nequip/train/early_stopping.py index be5a67ef..c1aee9e5 100644 --- a/nequip/train/early_stopping.py +++ b/nequip/train/early_stopping.py @@ -86,13 +86,11 @@ def __call__(self, metrics) -> None: self.counters[key] = 0 for key, bound in self.lower_bounds.items(): - print(key, bound, type(bound), metrics[key], type(metrics[key])) if metrics[key] < bound: stop_args += f" {key} is smaller than {bound}" stop = True for key, bound in self.upper_bounds.items(): - print(key, bound, type(bound), metrics[key], type(metrics[key])) if metrics[key] > bound: stop_args += f" {key} is larger than {bound}" stop = True From 1ebd412c2f82a16def925c348258de902de5452b Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 May 2021 17:14:08 -0400 Subject: [PATCH 005/126] fix typos --- configs/minimal.yaml | 6 +----- nequip/train/callbacks.py | 6 +++++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/minimal.yaml b/configs/minimal.yaml index 58956081..e8cb2917 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -19,8 +19,6 @@ model_initializers: # data dataset: aspirin dataset_file_name: benchmark_data/aspirin_ccsd-train.npz -end_of_epoch_callbacks: # call back functions to adjust hyper-parameters -- !!python/name:nequip.train.callbacks.cos_sin # two examples (equal_loss and cos_sin) are listed in nequip.train.callbacks # logging wandb: false @@ -33,9 +31,7 @@ batch_size: 1 max_epochs: 10 # loss function -loss_coeffs: -- forces -- total_energy +loss_coeffs: forces # optimizer optimizer_name: Adam diff --git a/nequip/train/callbacks.py b/nequip/train/callbacks.py index 6e9be473..0dbfb5c3 100644 --- a/nequip/train/callbacks.py +++ b/nequip/train/callbacks.py @@ -21,7 +21,11 @@ def cos_sin(self): e = self.kwargs.get("loss_e_mag", 1) phi_f = self.kwargs.get("loss_f_phi", 0) phi_e = self.kwargs.get("loss_e_phi", 0) - pi = self.kwargs.get("loss_coeff_cycle", 20) + cycle = self.kwargs.get("loss_coeff_cycle", 20) + + if phi_f == phi_e: + + return dtype = torch.get_default_dtype() From 042e0e33bdf946b111375a9a9daca4d8e2e8878e Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 26 May 2021 14:16:11 -0400 Subject: [PATCH 006/126] move call back to utils --- nequip/{train => utils}/callbacks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename nequip/{train => utils}/callbacks.py (80%) diff --git a/nequip/train/callbacks.py b/nequip/utils/callbacks.py similarity index 80% rename from nequip/train/callbacks.py rename to nequip/utils/callbacks.py index 0dbfb5c3..49651a87 100644 --- a/nequip/train/callbacks.py +++ b/nequip/utils/callbacks.py @@ -20,7 +20,7 @@ def cos_sin(self): f = self.kwargs.get("loss_f_mag", 1) e = self.kwargs.get("loss_e_mag", 1) phi_f = self.kwargs.get("loss_f_phi", 0) - phi_e = self.kwargs.get("loss_e_phi", 0) + phi_e = self.kwargs.get("loss_e_phi", -10) cycle = self.kwargs.get("loss_coeff_cycle", 20) if phi_f == phi_e: @@ -29,8 +29,8 @@ def cos_sin(self): dtype = torch.get_default_dtype() - f = torch.as_tensor(f * cos((self.iepoch + phi_f) / cycle * pi), dtype=dtype) - e = torch.as_tensor(e * cos((self.iepoch + phi_e) / cycle * pi), dtype=dtype) + f = torch.as_tensor(f * (cos((self.iepoch + phi_f) / cycle * pi)+1), dtype=dtype) + e = torch.as_tensor(e * (cos((self.iepoch + phi_e) / cycle * pi)+1), dtype=dtype) self.loss.coeffs["forces"] = f self.loss.coeffs["total_energy"] = e From 91dec70ebc3cd0513fb81139dcc650475862053a Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 26 May 2021 14:24:05 -0400 Subject: [PATCH 007/126] update documentation --- configs/full.yaml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/configs/full.yaml b/configs/full.yaml index 95e75fb3..cc154c16 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -101,8 +101,15 @@ early_stopping_lower_bounds: LR: 1.0e-10 # early_stopping_upper_bounds: # stop early if a metric value is higher than the bound wall: 1.0e+100 # -end_of_epoch_callbacks: # call back functions to adjust hyper-parameters -- !!python/name:nequip.train.callbacks.cos_sin # two examples (equal_loss and cos_sin) are listed in nequip.train.callbacks +end_of_epoch_callbacks: # call back functions to adjust hyper-parameters +- !!python/name:nequip.utils.callbacks.cos_sin # two examples (equal_loss and cos_sin) are listed in nequip.train.callbacks + +# for cos_sin callback: the coefficient will be mag*(cos((iepoch+phi)/cycle*pi)+1) +loss_f_mag: 1 +loss_e_mag: 1 +loss_f_phi: 0 +loss_e_phi: -10 +loss_coeff_cycle: 20 # loss function loss_coeffs: # different weights to use in a weighted loss functions From 324d55df006a64eff50bfc61b54d53f7fb9648b3 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 26 May 2021 14:24:55 -0400 Subject: [PATCH 008/126] format --- nequip/utils/callbacks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py index 49651a87..3f57df0d 100644 --- a/nequip/utils/callbacks.py +++ b/nequip/utils/callbacks.py @@ -24,13 +24,12 @@ def cos_sin(self): cycle = self.kwargs.get("loss_coeff_cycle", 20) if phi_f == phi_e: - return dtype = torch.get_default_dtype() - f = torch.as_tensor(f * (cos((self.iepoch + phi_f) / cycle * pi)+1), dtype=dtype) - e = torch.as_tensor(e * (cos((self.iepoch + phi_e) / cycle * pi)+1), dtype=dtype) + f = torch.as_tensor(f * (cos((self.iepoch + phi_f) / cycle * pi) + 1), dtype=dtype) + e = torch.as_tensor(e * (cos((self.iepoch + phi_e) / cycle * pi) + 1), dtype=dtype) self.loss.coeffs["forces"] = f self.loss.coeffs["total_energy"] = e From 19c0c6e626d8cc12de04d711041f9bd9b3083950 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 26 May 2021 15:17:36 -0400 Subject: [PATCH 009/126] ad sanity check to undo the wandb type sanitization --- nequip/utils/wandb.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/nequip/utils/wandb.py b/nequip/utils/wandb.py index ff10914b..bd6caff0 100644 --- a/nequip/utils/wandb.py +++ b/nequip/utils/wandb.py @@ -1,5 +1,7 @@ import os import wandb +import logging +from wandb.util import json_friendly_val def init_n_update(config): @@ -13,9 +15,20 @@ def init_n_update(config): resume="allow", id=config.run_id, ) - # download from wandb set up - config.update(dict(wandb.config)) - wandb.config.update(dict(run_id=config.run_id), allow_val_change=True) + # # download from wandb set up + updated_parameters = dict(wandb.config) + for k, v_new in updated_parameters.items(): + skip = False + if k in config.keys(): + # double check the one sanitized by wandb + v_old = json_friendly_val(config[k]) + if repr(v_new) == repr(v_old): + skip = True + if skip: + logging.info(f"# skipping wandb update {k} from {v_old} to {v_new}") + else: + config.update({k: v_new}) + logging.info(f"# wandb update {k} from {v_old} to {v_new}") return config From 19719324bd7d9d84565737cd6e0b454bbb701b89 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 15:20:10 -0500 Subject: [PATCH 010/126] add stride argument --- configs/full.yaml | 1 + nequip/utils/regressor.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/configs/full.yaml b/configs/full.yaml index b7fc66b1..f19e6f5c 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -260,6 +260,7 @@ per_species_rescale_scales: dataset_forces_rms # total_energy: # alpha: 0.1 # max_iteration: 20 +# stride: 100 # keywords for GP decomposition of per specie energy. Optional. Defaults to 0.1 # per_species_rescale_arguments_in_dataset_units: True # if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values. diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py index 6de0ff1b..cb7a80a7 100644 --- a/nequip/utils/regressor.py +++ b/nequip/utils/regressor.py @@ -12,18 +12,19 @@ def solver( y, alpha: Optional[float] = 0.1, max_iteration: Optional[int] = 20, + stride: Optional[int] = None, regressor: Optional[str] = "NormalizedGaussianProcess", ): if regressor == "GaussianProcess": - return gp(X, y, alpha, max_iteration) + return gp(X, y, alpha, max_iteration, stride) elif regressor == "NormalizedGaussianProcess": - return normalized_gp(X, y, alpha, max_iteration) + return normalized_gp(X, y, alpha, max_iteration, stride) else: raise NotImplementedError(f"{regressor} is not implemented") -def normalized_gp(X, y, alpha, max_iteration): +def normalized_gp(X, y, alpha, max_iteration, stride): feature_rms = 1.0 / np.sqrt(np.average(X ** 2, axis=0)) feature_rms = np.nan_to_num(feature_rms, 1) y_mean = torch.sum(y) / torch.sum(X) @@ -34,11 +35,12 @@ def normalized_gp(X, y, alpha, max_iteration): {"diagonal_elements": feature_rms}, alpha, max_iteration, + stride, ) return mean + y_mean, std -def gp(X, y, alpha, max_iteration): +def gp(X, y, alpha, max_iteration, stride): return base_gp( X, y, @@ -46,14 +48,19 @@ def gp(X, y, alpha, max_iteration): {"sigma_0": 0, "sigma_0_bounds": "fixed"}, alpha, max_iteration, + stride, ) -def base_gp(X, y, kernel, kernel_kwargs, alpha, max_iteration:int): +def base_gp(X, y, kernel, kernel_kwargs, alpha, max_iteration: int, stride): if len(y.shape) == 1: y = y.reshape([-1, 1]) + if stride is not None: + X = X[::stride] + y = y[::stride] + not_fit = True iteration = 0 mean = None From d4a0aff6e31ac9a647a46cd0417b97e252f26c3c Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 15:26:04 -0500 Subject: [PATCH 011/126] update unit tests --- tests/unit/data/test_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 89bc20b1..7834942f 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -236,7 +236,11 @@ def test_per_graph_field( modes=["per_species_mean_std"], kwargs={ AtomicDataDict.TOTAL_ENERGY_KEY - + "per_species_mean_std": {"alpha": alpha, "regressor": regressor} + + "per_species_mean_std": { + "alpha": alpha, + "regressor": regressor, + "stride": 1, + } }, ) From 5b67be98d23da3a438299d75fc53e083b6f1cf7b Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 15:44:35 -0500 Subject: [PATCH 012/126] format --- nequip/utils/regressor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py index cb7a80a7..c11842a0 100644 --- a/nequip/utils/regressor.py +++ b/nequip/utils/regressor.py @@ -15,7 +15,6 @@ def solver( stride: Optional[int] = None, regressor: Optional[str] = "NormalizedGaussianProcess", ): - if regressor == "GaussianProcess": return gp(X, y, alpha, max_iteration, stride) elif regressor == "NormalizedGaussianProcess": From 0026309450cc2936bf16ee5016cd94e19086e19a Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 19:41:18 -0500 Subject: [PATCH 013/126] update kwargs --- nequip/utils/regressor.py | 41 +++++++++++++++------------------------ 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py index c11842a0..8da318a8 100644 --- a/nequip/utils/regressor.py +++ b/nequip/utils/regressor.py @@ -7,23 +7,16 @@ from sklearn.gaussian_process.kernels import DotProduct, Kernel, Hyperparameter -def solver( - X, - y, - alpha: Optional[float] = 0.1, - max_iteration: Optional[int] = 20, - stride: Optional[int] = None, - regressor: Optional[str] = "NormalizedGaussianProcess", -): +def solver(X, y, regressor: Optional[str] = "NormalizedGaussianProcess", **kwargs): if regressor == "GaussianProcess": - return gp(X, y, alpha, max_iteration, stride) + return gp(X, y, **kwargs) elif regressor == "NormalizedGaussianProcess": - return normalized_gp(X, y, alpha, max_iteration, stride) + return normalized_gp(X, y, **kwargs) else: raise NotImplementedError(f"{regressor} is not implemented") -def normalized_gp(X, y, alpha, max_iteration, stride): +def normalized_gp(X, y, **kwargs): feature_rms = 1.0 / np.sqrt(np.average(X ** 2, axis=0)) feature_rms = np.nan_to_num(feature_rms, 1) y_mean = torch.sum(y) / torch.sum(X) @@ -32,26 +25,26 @@ def normalized_gp(X, y, alpha, max_iteration, stride): y - (torch.sum(X, axis=1) * y_mean).reshape(y.shape), NormalizedDotProduct, {"diagonal_elements": feature_rms}, - alpha, - max_iteration, - stride, + **kwargs, ) return mean + y_mean, std -def gp(X, y, alpha, max_iteration, stride): +def gp(X, y, **kwargs): return base_gp( - X, - y, - DotProduct, - {"sigma_0": 0, "sigma_0_bounds": "fixed"}, - alpha, - max_iteration, - stride, + X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, **kwargs ) -def base_gp(X, y, kernel, kernel_kwargs, alpha, max_iteration: int, stride): +def base_gp( + X, + y, + kernel, + kernel_kwargs, + alpha: Optional[float] = 0.1, + max_iteration: int = 20, + stride: Optional[int] = None, +): if len(y.shape) == 1: y = y.reshape([-1, 1]) @@ -118,8 +111,6 @@ class NormalizedDotProduct(Kernel): r"""Dot-Product kernel. .. math:: k(x_i, x_j) = x_i \cdot A \cdot x_j - The DotProduct kernel is commonly combined with exponentiation. - """ def __init__(self, diagonal_elements): From be3bc7ec57f521b83596b66f8b4a19707f6dd3fe Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 19:42:32 -0500 Subject: [PATCH 014/126] add scalar field and per atom energies to ase one --- nequip/data/AtomicData.py | 20 ++++++++++++++++++++ nequip/data/_build.py | 1 + nequip/model/__init__.py | 2 +- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 81a2978d..a41489e6 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -24,6 +24,11 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] +_DEFAULT_SCALAR_FIELDS: Set[str] = { + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} _DEFAULT_NODE_FIELDS: Set[str] = { AtomicDataDict.POSITIONS_KEY, AtomicDataDict.WEIGHTS_KEY, @@ -48,12 +53,14 @@ _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) +_SCALAR_FIELDS: Set[str] = set(_DEFAULT_SCALAR_FIELDS) def register_fields( node_fields: Sequence[str] = [], edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], + scalar_fields: Sequence[str] = [], ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -64,11 +71,13 @@ def register_fields( node_fields: set = set(node_fields) edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) + scalar_fields: set = set(scalar_fields) allfields = node_fields.union(edge_fields, graph_fields) assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) + _SCALAR_FIELDS.update(scalar_fields) if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) ): @@ -169,6 +178,14 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): kwargs[k] = v.unsqueeze(-1) v = kwargs[k] + if ( + k in set.union(_NODE_FIELDS, _EDGE_FIELDS) + and k not in _SCALAR_FIELDS + and len(v.shape) == 1 + ): + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + if ( k in _NODE_FIELDS and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] @@ -425,6 +442,7 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: cell = getattr(self, AtomicDataDict.CELL_KEY, None) batch = getattr(self, AtomicDataDict.BATCH_KEY, None) energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None) + energies = getattr(self, AtomicDataDict.PER_ATOM_ENERGY_KEY, None) force = getattr(self, AtomicDataDict.FORCE_KEY, None) do_calc = energy is not None or force is not None @@ -456,6 +474,8 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: if do_calc: fields = {} + if energies is not None: + fields["energies"] = energies[mask].cpu().numpy() if energy is not None: fields["energy"] = energy[batch_idx].cpu().numpy() if force is not None: diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 88e7b7ce..76474ebb 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -76,6 +76,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: node_fields=config.get("node_fields", []), edge_fields=config.get("edge_fields", []), graph_fields=config.get("graph_fields", []), + scalar_fields=config.get("scalar_fields", []), ) instance, _ = instantiate( diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..6a99d905 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,7 +1,7 @@ from ._eng import EnergyModel from ._grads import ForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale -from ._weight_init import uniform_initialize_FCs +from ._weight_init import uniform_initialize_FCs, initialize_from_state from ._build import model_from_config From e3db9cbef23121f825aab97c3fbb23c4d9077228 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 12 Nov 2021 19:53:43 -0500 Subject: [PATCH 015/126] add long field --- nequip/data/AtomicData.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index a41489e6..0519a606 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -29,6 +29,12 @@ AtomicDataDict.ATOM_TYPE_KEY, AtomicDataDict.BATCH_KEY, } +_DEFAULT_LONG_FIELDS: Set[str] = { + AtomicDataDict.EDGE_INDEX_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} _DEFAULT_NODE_FIELDS: Set[str] = { AtomicDataDict.POSITIONS_KEY, AtomicDataDict.WEIGHTS_KEY, @@ -54,6 +60,7 @@ _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) _SCALAR_FIELDS: Set[str] = set(_DEFAULT_SCALAR_FIELDS) +_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) def register_fields( @@ -61,6 +68,7 @@ def register_fields( edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], scalar_fields: Sequence[str] = [], + long_fields: Sequence[str] = [], ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -78,6 +86,7 @@ def register_fields( _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) _SCALAR_FIELDS.update(scalar_fields) + _LONG_FIELDS.update(long_fields) if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) ): @@ -144,12 +153,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): AtomicDataDict.validate_keys(kwargs) # Deal with _some_ dtype issues for k, v in kwargs.items(): - if ( - k == AtomicDataDict.EDGE_INDEX_KEY - or k == AtomicDataDict.ATOMIC_NUMBERS_KEY - or k == AtomicDataDict.ATOM_TYPE_KEY - or k == AtomicDataDict.BATCH_KEY - ): + if k in _LONG_FIELDS: # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) # int32 would pass later checks, but is actually disallowed by torch kwargs[k] = torch.as_tensor(v, dtype=torch.long) From 6392b87f4865faeb9f94fffef557bb4af9aa23d4 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 17 Nov 2021 16:46:15 -0500 Subject: [PATCH 016/126] change abbreviation --- nequip/data/_build.py | 7 +------ nequip/train/_key.py | 3 ++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 76474ebb..c8bc24b2 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -72,12 +72,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) # Register fields: - register_fields( - node_fields=config.get("node_fields", []), - edge_fields=config.get("edge_fields", []), - graph_fields=config.get("graph_fields", []), - scalar_fields=config.get("scalar_fields", []), - ) + instantiate( register_fields, all_args=config) instance, _ = instantiate( class_name, diff --git a/nequip/train/_key.py b/nequip/train/_key.py index f3582ebd..b91ada75 100644 --- a/nequip/train/_key.py +++ b/nequip/train/_key.py @@ -11,7 +11,8 @@ TRAIN = "training" ABBREV = { - AtomicDataDict.TOTAL_ENERGY_KEY: "e", + AtomicDataDict.TOTAL_ENERGY_KEY: "E", + AtomicDataDict.PER_ATOM_ENERGY_KEY: "e", AtomicDataDict.FORCE_KEY: "f", LOSS_KEY: "loss", VALIDATION: "val", From 9d0b55b0e31e6b1d35da32a9f5ff7a89172ec75c Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 17 Nov 2021 17:25:32 -0500 Subject: [PATCH 017/126] flakes --- nequip/data/_build.py | 2 +- nequip/model/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 2af9afb8..7645bcb7 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -72,7 +72,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) # Register fields: - instantiate( register_fields, all_args=config) + instantiate(register_fields, all_args=config) instance, _ = instantiate( class_name, diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index 6a99d905..670004dc 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -11,5 +11,6 @@ "RescaleEnergyEtc", "PerSpeciesRescale", "uniform_initialize_FCs", + "initialize_from_state", "model_from_config", ] From e623eaf20b8a47d553e09e2f2bfa447fbd11ee0e Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Thu, 18 Nov 2021 08:28:50 -0800 Subject: [PATCH 018/126] Update nequip/train/_key.py --- nequip/train/_key.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/train/_key.py b/nequip/train/_key.py index b91ada75..17057b8b 100644 --- a/nequip/train/_key.py +++ b/nequip/train/_key.py @@ -11,8 +11,8 @@ TRAIN = "training" ABBREV = { - AtomicDataDict.TOTAL_ENERGY_KEY: "E", - AtomicDataDict.PER_ATOM_ENERGY_KEY: "e", + AtomicDataDict.TOTAL_ENERGY_KEY: "e", + AtomicDataDict.PER_ATOM_ENERGY_KEY: "Ei", AtomicDataDict.FORCE_KEY: "f", LOSS_KEY: "loss", VALIDATION: "val", From 6f0fd8af591cebee5dffc98fe5605c02f268b79a Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 11:33:59 -0500 Subject: [PATCH 019/126] update changelog.md to be consistent with develop --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26dd4488..adb348e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## [Unreleased] +## [Unreleased] - 0.5.0 +### Changed +- Allow e3nn 0.4.*, which changes the default normalization of `TensorProduct`s; this change _should_ not affect typical NequIP networks + +## [Unreleased] - 0.4.0 ### Added - Support for `e3nn`'s `soft_one_hot_linspace` as radial bases - Support for parallel dataloader workers with `dataloader_num_workers` @@ -22,6 +26,7 @@ Most recent change on the bottom. - Better error when instantiation fails - Rename `npz_keys` to `include_keys` - Allow user to register `graph_fields`, `node_fields`, and `edge_fields` via yaml +- Deployed models save the e3nn and torch versions they were created with ### Changed - Update example.yaml to use wandb by default, to only use 100 epochs of training, to set a very large batch logging frequency and to change Validation_loss to validation_loss From 603c82338d7eb9944ca4998c942badd0708076b5 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 11:48:25 -0500 Subject: [PATCH 020/126] move load_callable to utils --- nequip/model/_build.py | 22 ++-------------------- nequip/scripts/deploy.py | 8 ++++++-- nequip/scripts/evaluate.py | 4 +++- nequip/train/trainer.py | 11 ++++++----- nequip/utils/__init__.py | 2 +- nequip/utils/savenload.py | 19 +++++++++++++++++++ 6 files changed, 37 insertions(+), 29 deletions(-) diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 4f1ae7dd..ba2db55f 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -4,25 +4,7 @@ from nequip.data import AtomicDataset from nequip.nn import GraphModuleMixin - - -def _load_callable(obj: Union[str, Callable], prefix: Optional[str] = None) -> Callable: - """Load a callable from a name, or pass through a callable.""" - if callable(obj): - pass - elif isinstance(obj, str): - if "." not in obj: - # It's an unqualified name - if prefix is not None: - obj = prefix + "." + obj - else: - # You can't have an unqualified name without a prefix - raise ValueError(f"Cannot load unqualified name {obj}.") - obj = yaml.load(f"!!python/name:{obj}", Loader=yaml.Loader) - else: - raise TypeError - assert callable(obj), f"{obj} isn't callable" - return obj +from nequip.utils import load_callable def model_from_config( @@ -59,7 +41,7 @@ def model_from_config( # Build builders = [ - _load_callable(b, prefix="nequip.model") + load_callable(b, prefix="nequip.model") for b in config.get("model_builders", []) ] diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index c6a3caba..c1bfbb0d 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -101,7 +101,9 @@ def main(args=None): "info", help="Get information from a deployed model file" ) info_parser.add_argument( - "model_path", help="Path to a deployed model file.", type=pathlib.Path, + "model_path", + help="Path to a deployed model file.", + type=pathlib.Path, ) build_parser = subparsers.add_parser("build", help="Build a deployment model") @@ -111,7 +113,9 @@ def main(args=None): type=pathlib.Path, ) build_parser.add_argument( - "out_file", help="Output file for deployed model.", type=pathlib.Path, + "out_file", + help="Output file for deployed model.", + type=pathlib.Path, ) args = parser.parse_args(args=args) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 0d727b9c..641807a7 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -171,7 +171,9 @@ def main(args=None, running_as_script: bool = True): model.eval() # Load a config file - logger.info(f"Loading {'original ' if dataset_is_from_training else ''}dataset...",) + logger.info( + f"Loading {'original ' if dataset_is_from_training else ''}dataset...", + ) config = Config.from_file(str(args.dataset_config)) # set global options diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index e543d116..0345513f 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -36,6 +36,7 @@ instantiate, save_file, load_file, + load_callable, atomic_write, dtype_from_name, ) @@ -704,7 +705,7 @@ def train(self): ) for callback in self.init_callbacks: - callback(self) + load_callable(callback)(self) self.init_log() self.wall = perf_counter() @@ -720,7 +721,7 @@ def train(self): self.end_of_epoch_save() for callback in self.final_callbacks: - callback(self) + load_callable(callback)(self) self.final_log() @@ -852,13 +853,13 @@ def epoch_step(self): ) self.end_of_batch_log(batch_type=category) for callback in self.end_of_batch_callbacks: - callback(self) + load_callable(callback)(self) self.metrics_dict[category] = self.metrics.current_result() self.loss_dict[category] = self.loss_stat.current_result() if category == TRAIN: for callback in self.end_of_train_callbacks: - callback(self) + load_callable(callback)(self) self.iepoch += 1 @@ -868,7 +869,7 @@ def epoch_step(self): self.lr_sched.step(metrics=self.mae_dict[self.metrics_key]) for callback in self.end_of_epoch_callbacks: - callback(self) + load_callable(callback)(self) def end_of_batch_log(self, batch_type: str): """ diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index 16ad1ee6..1459273b 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -3,7 +3,7 @@ instantiate, get_w_prefix, ) -from .savenload import save_file, load_file, atomic_write +from .savenload import save_file, load_file, atomic_write, load_callable from .config import Config from .output import Output from .modules import find_first_of_type diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 0980fef2..7dae307d 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -10,6 +10,25 @@ from os.path import isfile, isdir, dirname, realpath +def load_callable(obj: Union[str, Callable], prefix: Optional[str] = None) -> Callable: + """Load a callable from a name, or pass through a callable.""" + if callable(obj): + pass + elif isinstance(obj, str): + if "." not in obj: + # It's an unqualified name + if prefix is not None: + obj = prefix + "." + obj + else: + # You can't have an unqualified name without a prefix + raise ValueError(f"Cannot load unqualified name {obj}.") + obj = yaml.load(f"!!python/name:{obj}", Loader=yaml.Loader) + else: + raise TypeError + assert callable(obj), f"{obj} isn't callable" + return obj + + @contextlib.contextmanager def atomic_write(filename: Union[Path, str]): filename = Path(filename) From d440ea009521bf712420e7deee3636a58f9d5b8f Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 17:04:54 -0500 Subject: [PATCH 021/126] move load callables method to utils --- nequip/model/_build.py | 3 +-- nequip/utils/__init__.py | 1 + nequip/utils/savenload.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nequip/model/_build.py b/nequip/model/_build.py index ba2db55f..99378111 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -1,6 +1,5 @@ -from typing import Optional, Union, Callable import inspect -import yaml +from typing import Optional from nequip.data import AtomicDataset from nequip.nn import GraphModuleMixin diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index 1459273b..28158b0d 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -15,6 +15,7 @@ get_w_prefix, save_file, load_file, + load_callable, atomic_write, Config, Output, diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 7dae307d..b927c185 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -1,10 +1,11 @@ """ utilities that involve file searching and operations (i.e. save/load) """ -from typing import Union +from typing import Union, Callable, Optional import sys import logging import contextlib +import yaml from pathlib import Path from os import makedirs from os.path import isfile, isdir, dirname, realpath From c884508278ffd5b120d7724ba961df8cb0051029 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 17:05:20 -0500 Subject: [PATCH 022/126] save num_types and allow gradient skip --- nequip/nn/_atomwise.py | 1 + nequip/nn/_grad_output.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 88718fb8..2503ba77 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -118,6 +118,7 @@ def __init__( irreps_in={}, ): super().__init__() + self.num_types = num_types self.field = field self.out_field = f"shifted_{field}" if out_field is None else out_field self._init_irreps( diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index b5ee9efc..ab785618 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -21,6 +21,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module): sign: either 1 or -1; the returned gradient is multiplied by this. """ sign: float + skip: bool def __init__( self, @@ -35,6 +36,8 @@ def __init__( assert sign in (1.0, -1.0) self.sign = sign self.of = of + self.skip = False + # TO DO: maybe better to force using list? if isinstance(wrt, str): wrt = [wrt] @@ -64,6 +67,10 @@ def __init__( ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + + if self.skip: + return self.func(data) + # set req grad wrt_tensors = [] old_requires_grad: List[bool] = [] From 5973cf27fe2ba2c9a2dd6edf902379c055723c3f Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 17:05:20 -0500 Subject: [PATCH 023/126] save num_types and allow gradient skip --- nequip/nn/_atomwise.py | 1 + nequip/nn/_grad_output.py | 7 ++ nequip/utils/callbacks.py | 132 ++++++++++++++++++++++++++++++++------ 3 files changed, 121 insertions(+), 19 deletions(-) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 88718fb8..2503ba77 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -118,6 +118,7 @@ def __init__( irreps_in={}, ): super().__init__() + self.num_types = num_types self.field = field self.out_field = f"shifted_{field}" if out_field is None else out_field self._init_irreps( diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index b5ee9efc..ab785618 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -21,6 +21,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module): sign: either 1 or -1; the returned gradient is multiplied by this. """ sign: float + skip: bool def __init__( self, @@ -35,6 +36,8 @@ def __init__( assert sign in (1.0, -1.0) self.sign = sign self.of = of + self.skip = False + # TO DO: maybe better to force using list? if isinstance(wrt, str): wrt = [wrt] @@ -64,6 +67,10 @@ def __init__( ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + + if self.skip: + return self.func(data) + # set req grad wrt_tensors = [] old_requires_grad: List[bool] = [] diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py index 3f57df0d..7cc69c96 100644 --- a/nequip/utils/callbacks.py +++ b/nequip/utils/callbacks.py @@ -1,37 +1,131 @@ +import sys import torch -import logging from math import cos, pi +from nequip.nn import PerSpeciesScaleShift, GradientOutput +from nequip.data import AtomicData, AtomicDataDict +from nequip.utils.batch_ops import bincount +from nequip.utils.regressor import solver +from nequip.utils import find_first_of_type +if sys.version_info[1] >= 7: + import contextlib +else: + # has backport of nullcontext + import contextlib2 as contextlib -def equal_loss(self): - loss_f = self.mae_dict["Validation_loss_f"] - loss_e = self.mae_dict["Validation_loss_e"] - coeff_e = loss_f / loss_e - self.loss.coeffs["forces"] = torch.as_tensor(1, dtype=torch.get_default_dtype()) - self.loss.coeffs["total_energy"] = torch.as_tensor( +def equal_loss(trainer): + + loss_f = trainer.mae_dict["Validation_loss_f"] + loss_e = trainer.mae_dict["Validation_loss_e"] + # coeff_e = loss_f / loss_e + trainer.loss.coeffs["forces"] = torch.as_tensor(1, dtype=torch.get_default_dtype()) + trainer.loss.coeffs["total_energy"] = torch.as_tensor( loss_f / loss_e, dtype=torch.get_default_dtype() ) - self.logger.info(f"# update loss coeffs to 1 and {loss_f/loss_e}") + trainer.logger.info(f"# update loss coeffs to 1 and {loss_f/loss_e}") -def cos_sin(self): +def cos_sin(trainer): - f = self.kwargs.get("loss_f_mag", 1) - e = self.kwargs.get("loss_e_mag", 1) - phi_f = self.kwargs.get("loss_f_phi", 0) - phi_e = self.kwargs.get("loss_e_phi", -10) - cycle = self.kwargs.get("loss_coeff_cycle", 20) + f = trainer.kwargs.get("loss_f_mag", 1) + e = trainer.kwargs.get("loss_e_mag", 1) + phi_f = trainer.kwargs.get("loss_f_phi", 0) + phi_e = trainer.kwargs.get("loss_e_phi", -10) + cycle = trainer.kwargs.get("loss_coeff_cycle", 20) if phi_f == phi_e: return dtype = torch.get_default_dtype() - f = torch.as_tensor(f * (cos((self.iepoch + phi_f) / cycle * pi) + 1), dtype=dtype) - e = torch.as_tensor(e * (cos((self.iepoch + phi_e) / cycle * pi) + 1), dtype=dtype) + f = torch.as_tensor( + f * (cos((trainer.iepoch + phi_f) / cycle * pi) + 1), dtype=dtype + ) + e = torch.as_tensor( + e * (cos((trainer.iepoch + phi_e) / cycle * pi) + 1), dtype=dtype + ) + + trainer.loss.coeffs["forces"] = f + trainer.loss.coeffs["total_energy"] = e + + trainer.logger.info(f"# update loss coeffs to {f} {e}") + + +def linear_regression(trainer): + """do a linear regration after training epoch""" + + per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) + if per_species_rescale is None: + return + + _key = AtomicDataDict.TOTAL_ENERGY_KEY + if trainer.use_ema: + cm = trainer.ema.average_parameters() + else: + cm = contextlib.nullcontext() + + num_types = per_species_rescale.num_types + force_module = find_first_of_type(trainer.model, GradientOutput) + if force_module is not None: + force_module.skip = True + + dataset = trainer.dl_train + trainer.n_batches = len(dataset) + trainer.model.train() + + X = [] + y = [] + with cm: - self.loss.coeffs["forces"] = f - self.loss.coeffs["total_energy"] = e + for trainer.ibatch, data in enumerate(dataset): + + # trainer.optim.zero_grad(set_to_none=True) + + # Do any target rescaling + data = data.to(trainer.torch_device) + data = AtomicData.to_AtomicDataDict(data) + if hasattr(trainer.model, "unscale"): + data_unscaled = trainer.model.unscale(data) + else: + data_unscaled = data + + input_data = data_unscaled.copy() + out = trainer.model(input_data) + + atom_types = input_data[AtomicDataDict.ATOM_TYPE_KEY] + N = bincount( + atom_types, + input_data[AtomicDataDict.BATCH_KEY], + minlength=num_types, + ) + + # N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes + N = N.type(torch.get_default_dtype()) + res = data_unscaled[_key] - out[_key] + + X += [N] + y += [res] + + with torch.no_grad(): + X = torch.cat(X, dim=0) + y = torch.cat(y, dim=0) + mean, _ = solver(X, y) + + trainer.logger.info(f"residue shifts {mean}") + trainer.delta_shifts = mean + + if force_module is not None: + force_module.skip = False + + return mean + + +def update_rescales(trainer): + per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) + if per_species_rescale is None or not per_species_rescale.has_shifts: + return - self.logger.info(f"# update loss coeffs to {f} {e}") + trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") + per_species_rescale.shifts = per_species_rescale.shifts + trainer.delta_shifts + trainer.logger.info(f" to {per_species_rescale.shifts} .") From 29b0cf4550794dc978eac64dc103c256ea560dda Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 18 Nov 2021 17:32:51 -0500 Subject: [PATCH 024/126] update restore method --- nequip/utils/callbacks.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py index 7cc69c96..982b2c11 100644 --- a/nequip/utils/callbacks.py +++ b/nequip/utils/callbacks.py @@ -122,10 +122,24 @@ def linear_regression(trainer): def update_rescales(trainer): - per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) - if per_species_rescale is None or not per_species_rescale.has_shifts: + + if not hasattr(itrainer, "delta_shifts"): return + per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") + per_species_rescale.shifts = per_species_rescale.shifts + trainer.delta_shifts + + trainer.logger.info(f" to {per_species_rescale.shifts} .") + +def recover_rescales(trainer): + + if not hasattr(itrainer, "delta_shifts"): + return + + per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) + + trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") + per_species_rescale.shifts = per_species_rescale.shifts - trainer.delta_shifts trainer.logger.info(f" to {per_species_rescale.shifts} .") From 881b5b45ca06689b017e8756e1d6b25b2bffbf61 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 19 Nov 2021 19:10:21 -0500 Subject: [PATCH 025/126] remove ema --- nequip/utils/callbacks.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py index 982b2c11..0338cb7b 100644 --- a/nequip/utils/callbacks.py +++ b/nequip/utils/callbacks.py @@ -60,10 +60,6 @@ def linear_regression(trainer): return _key = AtomicDataDict.TOTAL_ENERGY_KEY - if trainer.use_ema: - cm = trainer.ema.average_parameters() - else: - cm = contextlib.nullcontext() num_types = per_species_rescale.num_types force_module = find_first_of_type(trainer.model, GradientOutput) @@ -76,9 +72,10 @@ def linear_regression(trainer): X = [] y = [] - with cm: - for trainer.ibatch, data in enumerate(dataset): + with torch.no_grad(): + + for _, data in enumerate(dataset): # trainer.optim.zero_grad(set_to_none=True) @@ -107,7 +104,6 @@ def linear_regression(trainer): X += [N] y += [res] - with torch.no_grad(): X = torch.cat(X, dim=0) y = torch.cat(y, dim=0) mean, _ = solver(X, y) @@ -123,23 +119,24 @@ def linear_regression(trainer): def update_rescales(trainer): - if not hasattr(itrainer, "delta_shifts"): + if not hasattr(trainer, "delta_shifts"): return per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") - per_species_rescale.shifts = per_species_rescale.shifts + trainer.delta_shifts + per_species_rescale.shifts += trainer.delta_shifts trainer.logger.info(f" to {per_species_rescale.shifts} .") + def recover_rescales(trainer): - if not hasattr(itrainer, "delta_shifts"): + if not hasattr(trainer, "delta_shifts"): return per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") - per_species_rescale.shifts = per_species_rescale.shifts - trainer.delta_shifts + per_species_rescale.shifts -= trainer.delta_shifts trainer.logger.info(f" to {per_species_rescale.shifts} .") From a67aa988944c5861dd7539c28d896d10f237a812 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 19 Nov 2021 19:11:07 -0500 Subject: [PATCH 026/126] revise typo --- nequip/utils/callbacks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py index 982b2c11..98280b90 100644 --- a/nequip/utils/callbacks.py +++ b/nequip/utils/callbacks.py @@ -80,7 +80,7 @@ def linear_regression(trainer): for trainer.ibatch, data in enumerate(dataset): - # trainer.optim.zero_grad(set_to_none=True) + trainer.optim.zero_grad(set_to_none=True) # Do any target rescaling data = data.to(trainer.torch_device) @@ -110,10 +110,10 @@ def linear_regression(trainer): with torch.no_grad(): X = torch.cat(X, dim=0) y = torch.cat(y, dim=0) - mean, _ = solver(X, y) + mean, _ = solver(X.cpu(), y.cpu()) trainer.logger.info(f"residue shifts {mean}") - trainer.delta_shifts = mean + trainer.delta_shifts = mean.to(trainer.device) if force_module is not None: force_module.skip = False @@ -123,7 +123,7 @@ def linear_regression(trainer): def update_rescales(trainer): - if not hasattr(itrainer, "delta_shifts"): + if not hasattr(trainer, "delta_shifts"): return per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) @@ -135,11 +135,11 @@ def update_rescales(trainer): def recover_rescales(trainer): - if not hasattr(itrainer, "delta_shifts"): + if not hasattr(trainer, "delta_shifts"): return per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) - trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") + trainer.logger.info(f"recover shifts from {per_species_rescale.shifts} ...") per_species_rescale.shifts = per_species_rescale.shifts - trainer.delta_shifts trainer.logger.info(f" to {per_species_rescale.shifts} .") From 51982d33f551ca230294425ea40353d7448ada6f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 00:51:02 -0500 Subject: [PATCH 027/126] support config-only model builders --- CHANGELOG.md | 3 +++ nequip/model/_build.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2cb106..747d7912 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Added +- Model builders may now process only the configuration + ### Fixed - Equivariance testing no longer unintentionally skips translation - Correct cat dim for all registered per-graph fields diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 4f1ae7dd..dbb4c053 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -65,7 +65,7 @@ def model_from_config( model = None - for builder_i, builder in enumerate(builders): + for builder in builders: pnames = inspect.signature(builder).parameters params = {} if "initialize" in pnames: @@ -81,18 +81,18 @@ def model_from_config( ) params["dataset"] = dataset if "model" in pnames: - if builder_i == 0: + if model is None: raise RuntimeError( - f"Builder {builder.__name__} asked for the model as an input, but it's the first builder so there is no model to provide" + f"Builder {builder.__name__} asked for the model as an input, but no previous builder has returned a model" ) params["model"] = model else: - if builder_i > 0: + if model is not None: raise RuntimeError( - f"All model_builders but the first one must take the model as an argument; {builder.__name__} doesn't" + f"All model_builders after the first one that returns a model must take the model as an argument; {builder.__name__} doesn't" ) model = builder(**params) - if not isinstance(model, GraphModuleMixin): + if model is not None and not isinstance(model, GraphModuleMixin): raise TypeError( f"Builder {builder.__name__} didn't return a GraphModuleMixin, got {type(model)} instead" ) From ec1003c8d9868e59f0ace36696845bbd3c2d85d7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 00:51:17 -0500 Subject: [PATCH 028/126] add simplified irreps options --- CHANGELOG.md | 1 + configs/minimal.yaml | 6 ++--- nequip/model/__init__.py | 3 ++- nequip/model/_eng.py | 52 ++++++++++++++++++++++++++++++++++++---- nequip/scripts/train.py | 1 + 5 files changed, 55 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 747d7912..55f65e32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Most recent change on the bottom. ## [Unreleased] ### Added - Model builders may now process only the configuration +- Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` ### Fixed - Equivariance testing no longer unintentionally skips translation diff --git a/configs/minimal.yaml b/configs/minimal.yaml index fa05bbbb..13f01546 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -6,9 +6,9 @@ seed: 0 # network num_basis: 8 r_max: 4.0 -irreps_edge_sh: 0e + 1o -conv_to_output_hidden_irreps_out: 16x0e -feature_irreps_hidden: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e +l_max: 2 +parity: true +num_features: 16 # data set # the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..5de10c93 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,4 +1,4 @@ -from ._eng import EnergyModel +from ._eng import EnergyModel, SimpleIrrepsConfig from ._grads import ForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale from ._weight_init import uniform_initialize_FCs @@ -6,6 +6,7 @@ from ._build import model_from_config __all__ = [ + "SimpleIrrepsConfig", "EnergyModel", "ForceOutput", "RescaleEnergyEtc", diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index 314bef32..4f4bd06f 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -1,5 +1,7 @@ import logging +from e3nn import o3 + from nequip.data import AtomicDataDict from nequip.nn import ( SequentialGraphNetwork, @@ -14,6 +16,51 @@ ) +def SimpleIrrepsConfig(config): + """Builder that pre-processes options to allow "simple" configuration of irreps.""" + # We allow some simpler parameters to be provided, but if they are, + # they have to be correct and not overridden + simple_irreps_keys = ["l_max", "parity", "num_features"] + real_irreps_keys = [ + "chemical_embedding_irreps_out", + "feature_irreps_hidden", + "irreps_edge_sh", + "conv_to_output_hidden_irreps_out", + ] + # check for overlap + is_simple: bool = False + if any(k in config for k in simple_irreps_keys): + is_simple = True + if any(k in config for k in real_irreps_keys): + raise ValueError( + f"Cannot specify irreps using the simple and full option styles at the same time--- the sets of options {simple_irreps_keys} and {real_irreps_keys} are mutually exclusive." + ) + if is_simple: + # nothing to do if not + lmax = config.pop("l_max") + parity = config.pop("parity") + num_features = config.pop("num_features") + config["chemical_embedding_irreps_out"] = repr( + o3.Irreps([(num_features, (0, 1))]) # n scalars + ) + config["irreps_edge_sh"] = repr( + o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) + ) + config["feature_irreps_hidden"] = repr( + o3.Irreps( + [ + (num_features, (l, p)) + for p in ((1, -1) if parity else (1,)) + for l in range(lmax + 1) + ] + ) + ) + config["conv_to_output_hidden_irreps_out"] = repr( + # num_features // 2 scalars + o3.Irreps([(max(1, num_features // 2), (0, 1))]) + ) + + def EnergyModel(config) -> SequentialGraphNetwork: """Base default energy model archetecture. @@ -59,7 +106,4 @@ def EnergyModel(config) -> SequentialGraphNetwork: ), ) - return SequentialGraphNetwork.from_parameters( - shared_params=config, - layers=layers, - ) + return SequentialGraphNetwork.from_parameters(shared_params=config, layers=layers,) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 175979de..119c70b2 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -28,6 +28,7 @@ wandb_project="NequIP", compile_model=False, model_builders=[ + "SimpleIrrepsConfig", "EnergyModel", "PerSpeciesRescale", "ForceOutput", From 1b4ff04f3e7addfb98f0a360de0c2673a02bbc2d Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 13 Jan 2022 00:52:04 -0500 Subject: [PATCH 029/126] add view -1 to spd_idx --- nequip/train/_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 4db16274..70c063e1 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -130,7 +130,7 @@ def __call__( if len(reduce_dims) > 0: per_atom_loss = per_atom_loss.sum(dim=reduce_dims) - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY].view(-1) per_species_loss = scatter(per_atom_loss, spe_idx, dim=0) N = scatter(not_nan, spe_idx, dim=0) @@ -146,7 +146,7 @@ def __call__( per_atom_loss = per_atom_loss.mean(dim=reduce_dims) # offset species index by 1 to use 0 for nan - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY].view(-1) _, inverse_species_index = torch.unique(spe_idx, return_inverse=True) per_species_loss = scatter_mean(per_atom_loss, inverse_species_index, dim=0) From ee74956bf7855c91f372c9b717ed55297d35f060 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 13 Jan 2022 01:02:29 -0500 Subject: [PATCH 030/126] format --- nequip/data/AtomicData.py | 2 +- nequip/utils/savenload.py | 1 + tests/unit/model/test_eng_force.py | 14 +++++++++++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 9f5286fd..56cf6c31 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -24,6 +24,7 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] + _DEFAULT_LONG_FIELDS: Set[str] = { AtomicDataDict.EDGE_INDEX_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY, @@ -74,7 +75,6 @@ def register_fields( node_fields: set = set(node_fields) edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) - scalar_fields: set = set(scalar_fields) allfields = node_fields.union(edge_fields, graph_fields) assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) _NODE_FIELDS.update(node_fields) diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index f943a31e..1b21c271 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -15,6 +15,7 @@ # accumulate writes to group for renaming _MOVE_SET = contextvars.ContextVar("_move_set", default=None) + def _delete_files_if_exist(paths): # clean up # better for python 3.8 > diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index ced59987..524364eb 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -80,8 +80,14 @@ def config(request): @pytest.fixture( params=[ - (["EnergyModel", "ForceOutput"], AtomicDataDict.FORCE_KEY,), - (["EnergyModel"], AtomicDataDict.TOTAL_ENERGY_KEY,), + ( + ["EnergyModel", "ForceOutput"], + AtomicDataDict.FORCE_KEY, + ), + ( + ["EnergyModel"], + AtomicDataDict.TOTAL_ENERGY_KEY, + ), ] ) def model(request, config): @@ -133,7 +139,9 @@ def test_jit(self, model, atomic_batch, device): model_script = script(instance) assert torch.allclose( - instance(data)[out_field], model_script(data)[out_field], atol=1e-6, + instance(data)[out_field], + model_script(data)[out_field], + atol=1e-6, ) # - Try saving, loading in another process, and running - From ebef18cfa63df19d64c599b636bef4f591848c46 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 16:50:18 -0700 Subject: [PATCH 031/126] bump --- CHANGELOG.md | 2 +- nequip/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37371245..5c6a97aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## [Unreleased] +## [Unreleased] - 0.5.2 ## [0.5.1] - 2022-01-13 ### Added diff --git a/nequip/_version.py b/nequip/_version.py index 41f4a9f7..1d0ded98 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.5.1" +__version__ = "0.5.2" From 8816e8177ea85de59ab7a4567f0c5883052be450 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 17:48:59 -0700 Subject: [PATCH 032/126] allow disabling shifts/scales independently --- CHANGELOG.md | 2 ++ nequip/model/_scaling.py | 12 +++++++++--- nequip/nn/_atomwise.py | 6 ++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c6a97aa..c8d1995c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ Most recent change on the bottom. ## [Unreleased] - 0.5.2 +### Fixed +- Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc ## [0.5.1] - 2022-01-13 ### Added diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 30eee1ed..2811818a 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -156,9 +156,15 @@ def PerSpeciesRescale( # Both computed from dataset arguments_in_dataset_units = True elif len(str_names) == 1: - assert config[ - module_prefix + "arguments_in_dataset_units" - ], "Requested to set either the shifts or scales of the per_species_rescale using dataset values, but chose to provide the other in non-dataset units. Please give the explictly specified shifts/scales in dataset units and set per_species_rescale_arguments_in_dataset_units" + if None in [scales, shifts]: + # if the one that isnt str is null, it's just disabled + # that has no units + # so it's ok to have just one and to be in dataset units + arguments_in_dataset_units = True + else: + assert config[ + module_prefix + "_arguments_in_dataset_units" + ], "Requested to set either the shifts or scales of the per_species_rescale using dataset values, but chose to provide the other in non-dataset units. Please give the explictly specified shifts/scales in dataset units and set per_species_rescale_arguments_in_dataset_units" # = Compute shifts and scales = computed_stats = _compute_stats( diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 358bdb6c..b8e8e433 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -173,11 +173,13 @@ def update_for_rescale(self, rescale_module): if self.arguments_in_dataset_units and rescale_module.has_scale: logging.debug( f"PerSpeciesScaleShift's arguments were in dataset units; rescaling:\n" - f"Original scales {self.scales} shifts: {self.shifts}" + f"Original scales {self.scales if self.has_scales else 'n/a'} shifts: {self.shifts if self.has_shifts else 'n/a'}" ) with torch.no_grad(): if self.has_scales: self.scales.div_(rescale_module.scale_by) if self.has_shifts: self.shifts.div_(rescale_module.scale_by) - logging.debug(f"New scales {self.scales} shifts: {self.shifts}") + logging.debug( + f"New scales {self.scales if self.has_scales else 'n/a'} shifts: {self.shifts if self.has_shifts else 'n/a'}" + ) From eb33114be633de987201e188fb72764ac714914a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 18:37:14 -0700 Subject: [PATCH 033/126] allow calling `mae` `mean` in metrics --- nequip/train/metrics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 2f790bd9..c1085e2f 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -12,7 +12,11 @@ from ._loss import find_loss_function from ._key import ABBREV -metrics_to_reduction = {"mae": Reduction.MEAN, "rmse": Reduction.RMS} +metrics_to_reduction = { + "mae": Reduction.MEAN, + "mean": Reduction.MEAN, + "rmse": Reduction.RMS, +} class Metrics: From d83e364a92a292620591b9a9140c4e44a73289d9 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Tue, 18 Jan 2022 07:28:10 -0800 Subject: [PATCH 034/126] add conditions in converting data to torch.Tensor for AtomicData (#132) * add conditions for bool and list Co-authored-by: Simon Batzner --- nequip/data/AtomicData.py | 8 ++++++++ tests/unit/model/test_eng_force.py | 14 +++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 7ce7f12b..bfe66136 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -117,11 +117,19 @@ def _process_dict(kwargs, ignore_fields=[]): # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) # int32 would pass later checks, but is actually disallowed by torch kwargs[k] = torch.as_tensor(v, dtype=torch.long) + elif isinstance(v, bool): + kwargs[k] = torch.as_tensor(v) elif isinstance(v, np.ndarray): if np.issubdtype(v.dtype, np.floating): kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) else: kwargs[k] = torch.as_tensor(v) + elif isinstance(v, list): + ele_dtype = np.array(v).dtype + if np.issubdtype(ele_dtype, np.floating): + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + else: + kwargs[k] = torch.as_tensor(v) elif np.issubdtype(type(v), np.floating): # Force scalars to be tensors with a data dimension # This makes them play well with irreps diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index ced59987..524364eb 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -80,8 +80,14 @@ def config(request): @pytest.fixture( params=[ - (["EnergyModel", "ForceOutput"], AtomicDataDict.FORCE_KEY,), - (["EnergyModel"], AtomicDataDict.TOTAL_ENERGY_KEY,), + ( + ["EnergyModel", "ForceOutput"], + AtomicDataDict.FORCE_KEY, + ), + ( + ["EnergyModel"], + AtomicDataDict.TOTAL_ENERGY_KEY, + ), ] ) def model(request, config): @@ -133,7 +139,9 @@ def test_jit(self, model, atomic_batch, device): model_script = script(instance) assert torch.allclose( - instance(data)[out_field], model_script(data)[out_field], atol=1e-6, + instance(data)[out_field], + model_script(data)[out_field], + atol=1e-6, ) # - Try saving, loading in another process, and running - From 7b25d247c496e96dca1af4ef85806e4283e93221 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 18 Jan 2022 12:30:32 -0500 Subject: [PATCH 035/126] remove compile mode from trainer --- nequip/scripts/train.py | 6 +- nequip/train/trainer.py | 5 +- nequip/utils/callbacks.py | 140 -------------------------------------- 3 files changed, 5 insertions(+), 146 deletions(-) delete mode 100644 nequip/utils/callbacks.py diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index e9eed8ca..f2867bee 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -177,8 +177,10 @@ def fresh_start(config): logging.info("Successfully built the network...") if config.compile_model: - final_model = e3nn.util.jit.script(final_model) - logging.info("Successfully compiled model...") + raise ValueError("Compile_mode is not available for training") + # Warning("Compile_mode is not recommended for training") + # final_model = e3nn.util.jit.script(final_model) + # logging.info("Successfully compiled model...") # Equivar test if config.equivariance_test > 0: diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index ba94eee0..5bbe43df 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1003,10 +1003,7 @@ def save_ema_model(self, path, blocking: bool = True): def save_model(self, path, blocking: bool = True): with atomic_write(path, blocking=blocking, binary=True) as write_to: - if isinstance(self.model, torch.jit.ScriptModule): - torch.jit.save(self.model, write_to) - else: - torch.save(self.model.state_dict(), write_to) + torch.save(self.model.state_dict(), write_to) def init_log(self): if self.iepoch > 0: diff --git a/nequip/utils/callbacks.py b/nequip/utils/callbacks.py deleted file mode 100644 index addef7e5..00000000 --- a/nequip/utils/callbacks.py +++ /dev/null @@ -1,140 +0,0 @@ -import sys -import torch -from math import cos, pi -from nequip.nn import PerSpeciesScaleShift, GradientOutput -from nequip.data import AtomicData, AtomicDataDict -from nequip.utils.batch_ops import bincount -from nequip.utils.regressor import solver -from nequip.utils import find_first_of_type - -if sys.version_info[1] >= 7: - import contextlib -else: - # has backport of nullcontext - import contextlib2 as contextlib - - -def equal_loss(trainer): - - loss_f = trainer.mae_dict["Validation_loss_f"] - loss_e = trainer.mae_dict["Validation_loss_e"] - # coeff_e = loss_f / loss_e - trainer.loss.coeffs["forces"] = torch.as_tensor(1, dtype=torch.get_default_dtype()) - trainer.loss.coeffs["total_energy"] = torch.as_tensor( - loss_f / loss_e, dtype=torch.get_default_dtype() - ) - trainer.logger.info(f"# update loss coeffs to 1 and {loss_f/loss_e}") - - -def cos_sin(trainer): - - f = trainer.kwargs.get("loss_f_mag", 1) - e = trainer.kwargs.get("loss_e_mag", 1) - phi_f = trainer.kwargs.get("loss_f_phi", 0) - phi_e = trainer.kwargs.get("loss_e_phi", -10) - cycle = trainer.kwargs.get("loss_coeff_cycle", 20) - - if phi_f == phi_e: - return - - dtype = torch.get_default_dtype() - - f = torch.as_tensor( - f * (cos((trainer.iepoch + phi_f) / cycle * pi) + 1), dtype=dtype - ) - e = torch.as_tensor( - e * (cos((trainer.iepoch + phi_e) / cycle * pi) + 1), dtype=dtype - ) - - trainer.loss.coeffs["forces"] = f - trainer.loss.coeffs["total_energy"] = e - - trainer.logger.info(f"# update loss coeffs to {f} {e}") - - -def linear_regression(trainer): - """do a linear regration after training epoch""" - - per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) - if per_species_rescale is None: - return - - _key = AtomicDataDict.TOTAL_ENERGY_KEY - - num_types = per_species_rescale.num_types - force_module = find_first_of_type(trainer.model, GradientOutput) - if force_module is not None: - force_module.skip = True - - dataset = trainer.dl_train - trainer.n_batches = len(dataset) - trainer.model.train() - - X = [] - y = [] - - with torch.no_grad(): - - for _, data in enumerate(dataset): - - # Do any target rescaling - data = data.to(trainer.torch_device) - data = AtomicData.to_AtomicDataDict(data) - if hasattr(trainer.model, "unscale"): - data_unscaled = trainer.model.unscale(data) - else: - data_unscaled = data - - input_data = data_unscaled.copy() - out = trainer.model(input_data) - - atom_types = input_data[AtomicDataDict.ATOM_TYPE_KEY] - N = bincount( - atom_types, - input_data[AtomicDataDict.BATCH_KEY], - minlength=num_types, - ) - - # N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes - N = N.type(torch.get_default_dtype()) - res = data_unscaled[_key] - out[_key] - - X += [N] - y += [res] - - X = torch.cat(X, dim=0) - y = torch.cat(y, dim=0) - mean, _ = solver(X.cpu(), y.cpu()) - - trainer.logger.info(f"residue shifts {mean}") - trainer.delta_shifts = mean.to(trainer.device) - - if force_module is not None: - force_module.skip = False - - return mean - - -def update_rescales(trainer): - - if not hasattr(trainer, "delta_shifts"): - return - - per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) - trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") - - per_species_rescale.shifts += trainer.delta_shifts - - trainer.logger.info(f" to {per_species_rescale.shifts} .") - - -def recover_rescales(trainer): - - if not hasattr(trainer, "delta_shifts"): - return - - per_species_rescale = find_first_of_type(trainer.model, PerSpeciesScaleShift) - - trainer.logger.info(f"update shifts from {per_species_rescale.shifts} ...") - per_species_rescale.shifts -= trainer.delta_shifts - trainer.logger.info(f" to {per_species_rescale.shifts} .") From 002c6e81c2b77cb01d79a001608e33346e7a6106 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 18 Jan 2022 12:36:51 -0500 Subject: [PATCH 036/126] fix import errors from merge --- nequip/utils/savenload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 1b21c271..c30f7496 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -10,6 +10,7 @@ from pathlib import Path import shutil import os +import yaml # accumulate writes to group for renaming From eb5a4f4455ce075f75a88836edd65d551b967562 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 18 Jan 2022 12:42:57 -0500 Subject: [PATCH 037/126] fix import errors --- nequip/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index 1c3a435c..e7dd0912 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -22,6 +22,7 @@ get_w_prefix, save_file, load_file, + load_callable, atomic_write, finish_all_writes, atomic_write_group, From 1097f289447be5b2ee8f9b2598c6ef619443c9c5 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 12:06:25 -0500 Subject: [PATCH 038/126] remove unused wandb_resume option --- configs/example.yaml | 3 +-- configs/full.yaml | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/configs/example.yaml b/configs/example.yaml index e916b1e7..f92c1fb2 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -67,8 +67,7 @@ chemical_symbols: # logging wandb: true # we recommend using wandb for logging, we'll turn it off here as it's optional wandb_project: toluene-example # project name used in wandb -wandb_resume: true # if true and restart is true, wandb run data will be restarted and updated. - # if false, a new wandb run will be generated + verbose: info # the same as python logging, e.g. warning, info, debug, error. case insensitive log_batch_freq: 1000000 # batch frequency, how often to print training errors withinin the same epoch log_epoch_freq: 1 # epoch frequency, how often to print and save the model diff --git a/configs/full.yaml b/configs/full.yaml index a04e0fd4..4a89781f 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -121,8 +121,7 @@ chemical_symbols: # logging wandb: false # we recommend using wandb for logging, we'll turn it off here as it's optional wandb_project: toluene-example # project name used in wandb -wandb_resume: true # if true and restart is true, wandb run data will be restarted and updated. - # if false, a new wandb run will be generated + verbose: info # the same as python logging, e.g. warning, info, debug, error. case insensitive log_batch_freq: 1 # batch frequency, how often to print training errors withinin the same epoch log_epoch_freq: 1 # epoch frequency, how often to print and save the model From 9dda7adb3a9759fd4c916b1666bdf97d65e92096 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Wed, 19 Jan 2022 10:35:37 -0800 Subject: [PATCH 039/126] Update nequip/scripts/train.py Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/scripts/train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index f2867bee..81c1722d 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -177,10 +177,7 @@ def fresh_start(config): logging.info("Successfully built the network...") if config.compile_model: - raise ValueError("Compile_mode is not available for training") - # Warning("Compile_mode is not recommended for training") - # final_model = e3nn.util.jit.script(final_model) - # logging.info("Successfully compiled model...") + raise ValueError("the `compile_model` option has been removed") # Equivar test if config.equivariance_test > 0: From a7fbe1bc8bd8b347f1666ee731b91463e8b3ed19 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 19 Jan 2022 13:40:12 -0500 Subject: [PATCH 040/126] pre-load callbacks --- nequip/train/trainer.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 5bbe43df..f4c2955a 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -328,6 +328,12 @@ def __init__( if train_on_keys is not None: assert set(train_on_keys) == set(self.train_on_keys) + self._init_callbacks = [load_callable(callback) for callback in init_callbacks] + self._end_of_epoch_callbacks = [load_callable(callback) for callback in end_of_epoch_callbacks] + self._end_of_batch_callbacks = [load_callable(callback) for callback in end_of_batch_callbacks] + self._end_of_train_callbacks = [load_callable(callback) for callback in end_of_train_callbacks] + self._final_callbacks = [load_callable(callback) for callback in final_callbacks] + self.init() def init_objects(self): @@ -731,8 +737,8 @@ def train(self): ) ) - for callback in self.init_callbacks: - load_callable(callback)(self) + for callback in self._init_callbacks: + callback(self) self.init_log() self.wall = perf_counter() @@ -750,8 +756,8 @@ def train(self): self.epoch_step() self.end_of_epoch_save() - for callback in self.final_callbacks: - load_callable(callback)(self) + for callback in self._final_callbacks: + callback(self) self.final_log() @@ -883,14 +889,14 @@ def epoch_step(self): validation=(category == VALIDATION), ) self.end_of_batch_log(batch_type=category) - for callback in self.end_of_batch_callbacks: - load_callable(callback)(self) + for callback in self._end_of_batch_callbacks: + callback(self) self.metrics_dict[category] = self.metrics.current_result() self.loss_dict[category] = self.loss_stat.current_result() if category == TRAIN: - for callback in self.end_of_train_callbacks: - load_callable(callback)(self) + for callback in self._end_of_train_callbacks: + callback(self) self.iepoch += 1 @@ -899,8 +905,8 @@ def epoch_step(self): if self.lr_scheduler_name == "ReduceLROnPlateau": self.lr_sched.step(metrics=self.mae_dict[self.metrics_key]) - for callback in self.end_of_epoch_callbacks: - load_callable(callback)(self) + for callback in self._end_of_epoch_callbacks: + callback(self) def end_of_batch_log(self, batch_type: str): """ From 9f6cb5eed185cd560d58d56f72c1d55ee472a708 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:15:37 -0500 Subject: [PATCH 041/126] clearer about minimal.yaml --- README.md | 12 ++++++------ configs/minimal.yaml | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index dce25648..7e765d46 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentia NequIP requires: * Python >= 3.6 -* PyTorch >= 1.8, <=1.10.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. NequIP is also not currently compatible with PyTorch 1.10; PyTorch 1.9 can be specified with `pytorch==1.9` in the install command. +* PyTorch >= 1.8, <=1.10.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. To install: @@ -33,7 +33,7 @@ pip install . ### Installation Issues -The easiest way to check if your installation is working is to train a toy model: +The easiest way to check if your installation is working is to train a **toy** model: ```bash $ nequip-train configs/minimal.yaml ``` @@ -69,10 +69,10 @@ $ nequip-train configs/example.yaml ``` A number of example configuration files are provided: - - [`configs/minimal.yaml`](configs/minimal.yaml): A minimal example of training a toy model on force data. - - [`configs/minimal_eng.yaml`](configs/minimal_eng.yaml): The same, but for a toy model that predicts and trains on only energy labels. - - [`configs/example.yaml`](configs/example.yaml): Training a more realistic model on forces and energies. Start here for real models. - - [`configs/full.yaml`](configs/full.yaml): A complete configuration file containing all available options along with documenting comments. + - [`configs/minimal.yaml`](configs/minimal.yaml): A minimal example of training a **toy** model on force data. + - [`configs/minimal_eng.yaml`](configs/minimal_eng.yaml): The same, but for a **toy** model that predicts and trains on only energy labels. + - [`configs/example.yaml`](configs/example.yaml): Training a more realistic model on forces and energies. **Start here for real models!** + - [`configs/full.yaml`](configs/full.yaml): A complete configuration file containing all available options along with documenting comments. This file is **for reference**, `example.yaml` is the right starting point for a project. Training runs can be restarted using `nequip-restart`; training that starts fresh or restarts depending on the existance of the working directory can be launched using `nequip-requeue`. All `nequip-*` commands accept the `--help` option to show their call signatures and options. diff --git a/configs/minimal.yaml b/configs/minimal.yaml index 489fc7a6..baf6df4a 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -1,3 +1,9 @@ +# !! PLEASE NOTE: `minimal.yaml` is meant as a _minimal_ example of a tiny, fast +# training that can be used to verify your nequip install, +# the syntax of your configuration edits, etc. +# These are NOT recommended hyperparameters for real applications! +# Please see `example.yaml` for a reasonable starting point. + # general root: results/aspirin run_name: minimal From 68747374a862ef34f5672566f441a6ec0fb6bcef Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:39:56 -0500 Subject: [PATCH 042/126] fix restarts --- nequip/model/_eng.py | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index a24331fc..fc7b907b 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -30,26 +30,23 @@ def SimpleIrrepsConfig(config): "irreps_edge_sh", "conv_to_output_hidden_irreps_out", ] - # check for overlap - is_simple: bool = False - if any(k in config for k in simple_irreps_keys): - is_simple = True - if any(k in config for k in real_irreps_keys): - raise ValueError( - f"Cannot specify irreps using the simple and full option styles at the same time--- the sets of options {simple_irreps_keys} and {real_irreps_keys} are mutually exclusive." - ) - if is_simple: + has_simple: bool = any(k in config for k in simple_irreps_keys) + has_full: bool = any(k in config for k in real_irreps_keys) + assert has_simple or has_full + + update = {} + if has_simple: # nothing to do if not - lmax = config.pop("l_max") - parity = config.pop("parity") - num_features = config.pop("num_features") - config["chemical_embedding_irreps_out"] = repr( + lmax = config["l_max"] + parity = config["parity"] + num_features = config["num_features"] + update["chemical_embedding_irreps_out"] = repr( o3.Irreps([(num_features, (0, 1))]) # n scalars ) - config["irreps_edge_sh"] = repr( + update["irreps_edge_sh"] = repr( o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) ) - config["feature_irreps_hidden"] = repr( + update["feature_irreps_hidden"] = repr( o3.Irreps( [ (num_features, (l, p)) @@ -58,11 +55,25 @@ def SimpleIrrepsConfig(config): ] ) ) - config["conv_to_output_hidden_irreps_out"] = repr( + update["conv_to_output_hidden_irreps_out"] = repr( # num_features // 2 scalars o3.Irreps([(max(1, num_features // 2), (0, 1))]) ) + # check update is consistant with config + # (this is necessary since it is not possible + # to delete keys from config, so instead of + # making simple and full styles mutually + # exclusive, we just insist that if full + # and simple are provided, full must be + # consistant with simple) + for k, v in update.items(): + if k in config: + assert ( + config[k] == v + ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" + config[k] = v + def EnergyModel( config, initialize: bool, dataset: Optional[AtomicDataset] = None From 65d693e5ab30ae0d15a12d768990418dc4f31edd Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:40:04 -0500 Subject: [PATCH 043/126] update configs --- CHANGELOG.md | 9 +++++++-- configs/example.yaml | 7 +++---- configs/full.yaml | 18 ++++++++++++++---- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7126aed..3390c9d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ Most recent change on the bottom. ## [Unreleased] - 0.5.2 +### Added +- Model builders may now process only the configuration +- Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` + +### Changed +- `minimal.yaml` and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` + ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc @@ -20,8 +27,6 @@ Most recent change on the bottom. - The types may now be specified with a simpler `chemical_symbols` option - Equivariance testing reports per-field errors - `--equivariance-test n` tests equivariance on `n` frames from the training dataset -- Model builders may now process only the configuration -- Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` ### Changed - All fields now have consistant [N, dim] shaping diff --git a/configs/example.yaml b/configs/example.yaml index f92c1fb2..ba0e0b98 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -14,10 +14,9 @@ default_dtype: float32 # network r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan num_layers: 4 # number of interaction blocks, we find 4-6 to work best -chemical_embedding_irreps_out: 32x0e # irreps for the chemical embedding of species -feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster -irreps_edge_sh: 0e + 1o # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer -conv_to_output_hidden_irreps_out: 16x0e # irreps used in hidden layer of output block +l_max: 1 # the maximum irrep order (rotation order) for the network's features +parity: true # whether to include features with odd mirror parity +num_features: 32 # the multiplicity of the features nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended resnet: false # set true to make interaction block a resnet-style update diff --git a/configs/full.yaml b/configs/full.yaml index 4a89781f..f148a515 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -20,10 +20,20 @@ allow_tf32: false # network r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan num_layers: 4 # number of interaction blocks, we find 4-6 to work best -chemical_embedding_irreps_out: 32x0e # irreps for the chemical embedding of species -feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster -irreps_edge_sh: 0e + 1o # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer -conv_to_output_hidden_irreps_out: 16x0e # irreps used in hidden layer of output block + +l_max: 1 # the maximum irrep order (rotation order) for the network's features +parity: true # whether to include features with odd mirror parity +num_features: 32 # the multiplicity of the features + +# alternatively, the irreps of the features in various parts of the network can be specified directly: +# the following options use e3nn irreps notation +# either these four options, or the above three options, should be provided--- they cannot be mixed. +# chemical_embedding_irreps_out: 32x0e # irreps for the chemical embedding of species +# feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster +# irreps_edge_sh: 0e + 1o # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer +# conv_to_output_hidden_irreps_out: 16x0e # irreps used in hidden layer of output block + + nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended resnet: false # set true to make interaction block a resnet-style update From 83e94350695f584eb816dc58ef007787885b2d9b Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:42:10 -0500 Subject: [PATCH 044/126] update minimal_eng --- CHANGELOG.md | 2 +- configs/minimal_eng.yaml | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3390c9d3..66ddaef8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ Most recent change on the bottom. - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` ### Changed -- `minimal.yaml` and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` +- `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc diff --git a/configs/minimal_eng.yaml b/configs/minimal_eng.yaml index fe002fbc..2b48993e 100644 --- a/configs/minimal_eng.yaml +++ b/configs/minimal_eng.yaml @@ -5,14 +5,15 @@ seed: 0 # network model_builders: + - SimpleIrrepsConfig - EnergyModel - PerSpeciesRescale - RescaleEnergyEtc num_basis: 8 r_max: 4.0 -irreps_edge_sh: 0e + 1o -conv_to_output_hidden_irreps_out: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e -feature_irreps_hidden: 16x0o + 16x0e +l_max: 2 +parity: true +num_features: 16 # data set # the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys From b3f42688b76128817d886b9140c70ccc650458e7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:44:36 -0500 Subject: [PATCH 045/126] warning notice --- configs/minimal_eng.yaml | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/configs/minimal_eng.yaml b/configs/minimal_eng.yaml index 2b48993e..4fed9fae 100644 --- a/configs/minimal_eng.yaml +++ b/configs/minimal_eng.yaml @@ -1,14 +1,26 @@ +# !! PLEASE NOTE: `minimal_eng.yaml` is meant as a _minimal_ example of a tiny, fast +# training that can be used to verify your nequip install, +# the syntax of your configuration edits, etc. +# These are NOT recommended hyperparameters for real applications! +# Please see `example.yaml` for a reasonable starting point. + # general root: results/aspirin run_name: minimal_eng -seed: 0 +seed: 123 +dataset_seed: 456 # network +# The default is to build a model with forces, so we need to specify +# `model_builders` to get one without forces. This list is the default, +# except without the `ForceOutput` builder that makes a force+energy +# model out of an energy model: model_builders: - - SimpleIrrepsConfig - - EnergyModel - - PerSpeciesRescale - - RescaleEnergyEtc + - SimpleIrrepsConfig # make configuration easier + - EnergyModel # the core nequip model + - PerSpeciesRescale # per-species/per-atom shift & scaling + - RescaleEnergyEtc # global scaling +# options for the model: num_basis: 8 r_max: 4.0 l_max: 2 From acfcc51954d1c34db86cbfee000c52e52d91a979 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:56:02 -0500 Subject: [PATCH 046/126] remove old option --- configs/full.yaml | 1 - nequip/scripts/train.py | 1 - nequip/train/trainer.py | 48 +++++++++++++++++++++++------------------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/configs/full.yaml b/configs/full.yaml index a04e0fd4..8298d0ef 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -48,7 +48,6 @@ invariant_layers: 2 invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster avg_num_neighbors: auto # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement -compile_model: false # whether to compile the constructed model to TorchScript # to specify different parameters for each convolutional layer, try examples below # layer1_use_sc: true # use "layer{i}_" prefix to specify parameters for only one of the layer, diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 81c1722d..a628459e 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -26,7 +26,6 @@ run_name="NequIP", wandb=False, wandb_project="NequIP", - compile_model=False, model_builders=[ "EnergyModel", "PerSpeciesRescale", diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index f4c2955a..ee06a468 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -329,10 +329,18 @@ def __init__( assert set(train_on_keys) == set(self.train_on_keys) self._init_callbacks = [load_callable(callback) for callback in init_callbacks] - self._end_of_epoch_callbacks = [load_callable(callback) for callback in end_of_epoch_callbacks] - self._end_of_batch_callbacks = [load_callable(callback) for callback in end_of_batch_callbacks] - self._end_of_train_callbacks = [load_callable(callback) for callback in end_of_train_callbacks] - self._final_callbacks = [load_callable(callback) for callback in final_callbacks] + self._end_of_epoch_callbacks = [ + load_callable(callback) for callback in end_of_epoch_callbacks + ] + self._end_of_batch_callbacks = [ + load_callable(callback) for callback in end_of_batch_callbacks + ] + self._end_of_train_callbacks = [ + load_callable(callback) for callback in end_of_train_callbacks + ] + self._final_callbacks = [ + load_callable(callback) for callback in final_callbacks + ] self.init() @@ -669,24 +677,22 @@ def load_model_from_training_session( else: config = Config.from_file(traindir + "/config.yaml") - if config.get("compile_model", False): - model = torch.jit.load(traindir + "/" + model_name, map_location=device) - else: - model = model_from_config( - config=config, - initialize=False, + model = model_from_config( + config=config, + initialize=False, + ) + if model is not None: # TODO: why would it be? + # TODO: this is not exactly equivalent to building with + # this set as default dtype... does it matter? + model.to( + device=torch.device(device), + dtype=dtype_from_name(config.default_dtype), ) - if model is not None: - # TODO: this is not exactly equivalent to building with - # this set as default dtype... does it matter? - model.to( - device=torch.device(device), - dtype=dtype_from_name(config.default_dtype), - ) - model_state_dict = torch.load( - traindir + "/" + model_name, map_location=device - ) - model.load_state_dict(model_state_dict) + model_state_dict = torch.load( + traindir + "/" + model_name, map_location=device + ) + model.load_state_dict(model_state_dict) + return model, config def init(self): From fb85f4a8c1e0f802fdcc1290bc9c3cec8c244731 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 15:57:37 -0500 Subject: [PATCH 047/126] rename logger -> _logger --- nequip/scripts/{logger.py => _logger.py} | 0 nequip/scripts/evaluate.py | 2 +- nequip/scripts/train.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename nequip/scripts/{logger.py => _logger.py} (100%) diff --git a/nequip/scripts/logger.py b/nequip/scripts/_logger.py similarity index 100% rename from nequip/scripts/logger.py rename to nequip/scripts/_logger.py diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 733f62a3..894bf42a 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -18,7 +18,7 @@ from nequip.utils import load_file, instantiate from nequip.train.loss import Loss from nequip.train.metrics import Metrics -from nequip.scripts.logger import set_up_script_logger +from ._logger import set_up_script_logger def main(args=None, running_as_script: bool = True): diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index a628459e..5490a19f 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -19,7 +19,7 @@ from nequip.data import dataset_from_config from nequip.utils.test import assert_AtomicData_equivariant, set_irreps_debug from nequip.utils import load_file, dtype_from_name -from nequip.scripts.logger import set_up_script_logger +from ._logger import set_up_script_logger default_config = dict( root="./", From 7f163674b07073c14805644a31020930b5c7dd5d Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:00:28 -0500 Subject: [PATCH 048/126] refactor checking key --- nequip/scripts/train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 5490a19f..a5ca35dd 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -1,6 +1,7 @@ """ Train a network.""" import logging import argparse +import warnings # This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch. # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. @@ -175,8 +176,8 @@ def fresh_start(config): logging.info("Successfully built the network...") - if config.compile_model: - raise ValueError("the `compile_model` option has been removed") + # by doing this here we check also any keys custom builders may have added + _check_old_keys(config) # Equivar test if config.equivariance_test > 0: @@ -267,5 +268,16 @@ def restart(config): return trainer +def _check_old_keys(config) -> None: + """check ``config`` for old/depricated keys and emit corresponding errors/warnings""" + # compile_model + k = "compile_model" + if k in config: + if config[k]: + raise ValueError("the `compile_model` option has been removed") + else: + warnings.warn("the `compile_model` option has been removed") + + if __name__ == "__main__": main(running_as_script=True) From 9830462c418973a14a4742323a76ec1fe04e3c36 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:05:52 -0500 Subject: [PATCH 049/126] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8d1995c..8681545b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ Most recent change on the bottom. ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc +### Removed +- `compile_model` + ## [0.5.1] - 2022-01-13 ### Added - `NequIPCalculator` can now be built via a `nequip_calculator()` function. This adds a minimal compatibility with [vibes](https://gitlab.com/vibes-developers/vibes/) From 6766165e4c261433d4ba4198862e10c0391d5b5e Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:16:41 -0500 Subject: [PATCH 050/126] small updates --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7e765d46..1e343694 100644 --- a/README.md +++ b/README.md @@ -50,15 +50,15 @@ To run the full tests, including a set of longer/more intensive integration test pytest tests/ ``` -Note: the integration tests have hung in the past on certain systems that have GPUs. If this happens to you, please report it along with information on your software environment in the [Issues](https://github.com/mir-group/nequip/issues)! +If a GPU is present, the unit tests will use it. ## Tutorial -The best way to learn how to use NequIP is through the [Colab Tutorial](https://bit.ly/mrs-nequip). This will run entirely on Google Hardware, you will not need to install anything, but can instead simply run it in your browser. +The best way to learn how to use NequIP is through the [Colab Tutorial](https://bit.ly/mrs-nequip). This will run entirely on Google's cloud virtual machine; you do not need to install or run anything locally. ## Usage -**! PLEASE NOTE:** the first few training epochs/calls to a NequIP model can be painfully slow. This is expected behaviour as the [profile-guided optimization of TorchScript models](https://program-transformations.github.io/slides/pytorch_neurips.pdf) takes a number of calls to warm up before optimizing the model. This occurs regardless of whether the entire model is compiled because many core components from e3nn are compiled and optimized through TorchScript. +**! PLEASE NOTE:** the first few calls to a NequIP model can be painfully slow. This is expected behaviour as the [profile-guided optimization of TorchScript models](https://program-transformations.github.io/slides/pytorch_neurips.pdf) takes a number of calls to warm up before optimizing the model. (The `nequip-benchmark` script accounts for this.) ### Basic network training @@ -74,7 +74,9 @@ A number of example configuration files are provided: - [`configs/example.yaml`](configs/example.yaml): Training a more realistic model on forces and energies. **Start here for real models!** - [`configs/full.yaml`](configs/full.yaml): A complete configuration file containing all available options along with documenting comments. This file is **for reference**, `example.yaml` is the right starting point for a project. -Training runs can be restarted using `nequip-restart`; training that starts fresh or restarts depending on the existance of the working directory can be launched using `nequip-requeue`. All `nequip-*` commands accept the `--help` option to show their call signatures and options. +Training runs can also be restarted by running the same `nequip-train` command if the `append: True` option is specified in the original YAML. (Otherwise, a new training run with a different name can be started from the loaded state of the previous run.) + +All `nequip-*` commands accept the `--help` option to show their call signatures and options. ### Evaluating trained models (and their error) From b44ef02ce4973cc2cb64b51aa5bd5b40790c7271 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:20:16 -0500 Subject: [PATCH 051/126] default `resnet: False` --- CHANGELOG.md | 1 + configs/example.yaml | 1 - configs/full.yaml | 1 + docs/options/model.rst | 2 +- nequip/nn/_convnetlayer.py | 2 +- nequip/nn/_interaction_block.py | 2 +- tests/unit/model/test_eng_force.py | 3 --- 7 files changed, 5 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b27de39..c43d3956 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Most recent change on the bottom. ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` +- Default value for `resnet` is now `False` ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc diff --git a/configs/example.yaml b/configs/example.yaml index ba0e0b98..51bf5a3c 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -18,7 +18,6 @@ l_max: 1 parity: true # whether to include features with odd mirror parity num_features: 32 # the multiplicity of the features nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended -resnet: false # set true to make interaction block a resnet-style update # scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs. # Different nonlinearities are specified for e (even) and o (odd) parity; diff --git a/configs/full.yaml b/configs/full.yaml index 0b67b637..9f79ab09 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -36,6 +36,7 @@ num_features: 32 nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended resnet: false # set true to make interaction block a resnet-style update + # the resnet update will only be applied when the input and output irreps of the layer are the same # scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs. # Different nonlinearities are specified for e (even) and o (odd) parity; diff --git a/docs/options/model.rst b/docs/options/model.rst index a123bb80..a9ecb694 100644 --- a/docs/options/model.rst +++ b/docs/options/model.rst @@ -99,7 +99,7 @@ invariant_neurons resnet ^^^^^^ | Type: bool - | Default: ``True`` + | Default: ``False`` nonlinearity_type ^^^^^^^^^^^^^^^^^ diff --git a/nequip/nn/_convnetlayer.py b/nequip/nn/_convnetlayer.py index 98d12bcb..8d3d0dad 100644 --- a/nequip/nn/_convnetlayer.py +++ b/nequip/nn/_convnetlayer.py @@ -37,7 +37,7 @@ def __init__( convolution=InteractionBlock, convolution_kwargs: dict = {}, num_layers: int = 3, - resnet: bool = True, + resnet: bool = False, nonlinearity_type: str = "gate", nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"}, nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"}, diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index 575144f2..6f70af20 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -144,7 +144,7 @@ def __init__( def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: """ - Evaluate interaction Block with ResNet. + Evaluate interaction Block with ResNet (self-connection). :param node_input: :param node_attr: diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 524364eb..ff5033fb 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -28,7 +28,6 @@ irreps_edge_sh="0e + 1o", r_max=4, feature_irreps_hidden="4x0e + 4x1o", - resnet=True, num_layers=2, num_basis=8, PolynomialCutoff_p=6, @@ -47,7 +46,6 @@ irreps_edge_sh="0e + 1o", r_max=4, feature_irreps_hidden="4x0e + 4x1o", - resnet=True, num_layers=2, num_basis=8, PolynomialCutoff_p=6, @@ -58,7 +56,6 @@ irreps_edge_sh="0e + 1o + 2e", r_max=4, feature_irreps_hidden="2x0e + 2x1o + 2x2e", - resnet=False, num_layers=2, num_basis=3, PolynomialCutoff_p=6, From 545552a90af1ee835dc3b951beb869cfeba8e296 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 16:52:59 -0500 Subject: [PATCH 052/126] wandb.watch --- CHANGELOG.md | 1 + configs/full.yaml | 6 ++++++ nequip/train/trainer_wandb.py | 28 ++++++++++------------------ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c43d3956..99084c63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Most recent change on the bottom. ### Added - Model builders may now process only the configuration - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` +- `wandb.watch` via `wandb_watch` option ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/configs/full.yaml b/configs/full.yaml index 9f79ab09..f628e99a 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -131,6 +131,12 @@ chemical_symbols: # logging wandb: false # we recommend using wandb for logging, we'll turn it off here as it's optional wandb_project: toluene-example # project name used in wandb +wandb_watch: false +# see https://docs.wandb.ai/ref/python/watch +# wandb_watch: +# log: all +# log_freq: 1 +# log_graph: true verbose: info # the same as python logging, e.g. warning, info, debug, error. case insensitive log_batch_freq: 1 # batch frequency, how often to print training errors withinin the same epoch diff --git a/nequip/train/trainer_wandb.py b/nequip/train/trainer_wandb.py index 2c62493c..f5686046 100644 --- a/nequip/train/trainer_wandb.py +++ b/nequip/train/trainer_wandb.py @@ -1,31 +1,23 @@ -""" Nequip.train.trainer - -Todo: - -isolate the loss function from the training procedure -enable wandb resume -make an interface with ray - -""" - import wandb from .trainer import Trainer class TrainerWandB(Trainer): - """Class to train a model to minimize forces""" - - def __init__(self, **kwargs): - Trainer.__init__(self, **kwargs) + """Trainer class that adds WandB features""" def end_of_epoch_log(self): Trainer.end_of_epoch_log(self) wandb.log(self.mae_dict) - def init_model(self): + def init(self): + super().init() - Trainer.init_model(self) + if not self._initialized: + return - # TODO: test and re-enable this - # wandb.watch(self.model) + wandb_watch = self.kwargs.get("wandb_watch", False) + if wandb_watch is not False: + if wandb_watch is True: + wandb_watch = {} + wandb.watch(self.model, **wandb_watch) From 90bdb1735502150234d48e110cdc57162bd24180 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 17:04:35 -0500 Subject: [PATCH 053/126] Update nequip/train/trainer_wandb.py Co-authored-by: Lixin Sun --- nequip/train/trainer_wandb.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/nequip/train/trainer_wandb.py b/nequip/train/trainer_wandb.py index f5686046..deed5206 100644 --- a/nequip/train/trainer_wandb.py +++ b/nequip/train/trainer_wandb.py @@ -16,8 +16,6 @@ def init(self): if not self._initialized: return - wandb_watch = self.kwargs.get("wandb_watch", False) - if wandb_watch is not False: - if wandb_watch is True: - wandb_watch = {} - wandb.watch(self.model, **wandb_watch) + if self.kwargs.get("wandb_watch", False): + wandb_watch_kwargs = self.kwargs.get("wandb_watch_kwargs", {}) + wandb.watch(self.model, **wandb_watch_kwargs) From 0cc85275fff8866fa9d135fcca85d50377923d19 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 17:04:38 -0500 Subject: [PATCH 054/126] Update configs/full.yaml Co-authored-by: Lixin Sun --- configs/full.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/full.yaml b/configs/full.yaml index f628e99a..40544438 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -133,7 +133,7 @@ wandb: false wandb_project: toluene-example # project name used in wandb wandb_watch: false # see https://docs.wandb.ai/ref/python/watch -# wandb_watch: +# wandb_watch_kwargs: # log: all # log_freq: 1 # log_graph: true From 57f85ec71d4b3eec2bd70e112ddc7ef6633f7904 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 17:22:08 -0500 Subject: [PATCH 055/126] fix include_frames for ASE dataset --- CHANGELOG.md | 1 + nequip/data/dataset.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99084c63..23a9f953 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Most recent change on the bottom. ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc +- `include_frames` now works with ASE datasets ### Removed - `compile_model` diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 73e1fcec..a1a986dd 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -253,8 +253,8 @@ def process(self): num_examples = next(iter(num_examples)) include_frames = self.include_frames - if self.include_frames is None: - include_frames = list(range(num_examples)) + if include_frames is None: + include_frames = range(num_examples) # Make AtomicData from it: if AtomicDataDict.EDGE_INDEX_KEY in all_keys: @@ -850,6 +850,8 @@ def get_data(self): return ( [ AtomicData.from_ase(atoms=atoms_list[i], **kwargs) - for i in self.include_frames + if i in self.include_frames + else None # in-memory dataset will ignore this later, but needed for indexing to work out + for i in range(len(atoms_list)) ], ) From 695a39800a12d9d7e50900da873101a58099072c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 17:43:25 -0500 Subject: [PATCH 056/126] don't input training labels to model --- CHANGELOG.md | 1 + nequip/data/_keys.py | 8 ++++++++ nequip/model/_scaling.py | 10 +--------- nequip/train/trainer.py | 23 ++++++++++++++++++++++- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23a9f953..2cfa0449 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Most recent change on the bottom. ### Fixed - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc - `include_frames` now works with ASE datasets +- no training data labels in input_data ### Removed - `compile_model` diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index d87f52a7..724bca92 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -35,11 +35,19 @@ NODE_ATTRS_KEY: Final[str] = "node_attrs" ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" ATOM_TYPE_KEY: Final[str] = "atom_types" + PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" TOTAL_ENERGY_KEY: Final[str] = "total_energy" FORCE_KEY: Final[str] = "forces" PARTIAL_FORCE_KEY: Final[str] = "partial_forces" +ALL_ENERGY_KEYS: Final[List[str]] = [ + PER_ATOM_ENERGY_KEY, + TOTAL_ENERGY_KEY, + FORCE_KEY, + PARTIAL_FORCE_KEY, +] + BATCH_KEY: Final[str] = "batch" # Make a list of allowed keys diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 2811818a..b3d1f57c 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -88,15 +88,7 @@ def RescaleEnergyEtc( # == Build the model == return RescaleOutput( model=model, - scale_keys=[ - k - for k in ( - AtomicDataDict.TOTAL_ENERGY_KEY, - AtomicDataDict.PER_ATOM_ENERGY_KEY, - AtomicDataDict.FORCE_KEY, - ) - if k in model.irreps_out - ], + scale_keys=[k for k in AtomicDataDict.ALL_ENERGY_KEYS if k in model.irreps_out], scale_by=global_scale, shift_keys=[ k for k in (AtomicDataDict.TOTAL_ENERGY_KEY,) if k in model.irreps_out diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index ee06a468..6c7bbee0 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -324,10 +324,27 @@ def __init__( all_args=self.kwargs, ) self.loss_stat = LossStat(self.loss) + + # what do we train on? self.train_on_keys = self.loss.keys if train_on_keys is not None: assert set(train_on_keys) == set(self.train_on_keys) + self._remove_from_model_input = set(self.train_on_keys) + if ( + len( + self._remove_from_model_input.intersection( + AtomicDataDict.ALL_ENERGY_KEYS + ) + ) + > 0 + ): + # if we are training on _any_ of the energy quantities (energy, force, partials, stress, etc.) + # then none of them should be fed into the model + self._remove_from_model_input = self._remove_from_model_input.union( + AtomicDataDict.ALL_ENERGY_KEYS + ) + # load all callbacks self._init_callbacks = [load_callable(callback) for callback in init_callbacks] self._end_of_epoch_callbacks = [ load_callable(callback) for callback in end_of_epoch_callbacks @@ -794,7 +811,11 @@ def batch_step(self, data, validation=False): # Run model # We make a shallow copy of the input dict in case the model modifies it - input_data = data_unscaled.copy() + input_data = { + k: v + for k, v in data_unscaled.items() + if k not in self._remove_from_model_input + } out = self.model(input_data) del input_data From 41420ece5633bf01acdbc415408737ddd5b5d97b Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 17:44:22 -0500 Subject: [PATCH 057/126] remove non-existant WEIGHT_KEY --- nequip/data/AtomicData.py | 1 - nequip/data/_keys.py | 1 - nequip/train/_loss.py | 2 -- nequip/train/loss.py | 1 - tests/unit/trainer/test_metrics.py | 7 ------- 5 files changed, 12 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index bfe66136..bc40f873 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -33,7 +33,6 @@ } _DEFAULT_NODE_FIELDS: Set[str] = { AtomicDataDict.POSITIONS_KEY, - AtomicDataDict.WEIGHTS_KEY, AtomicDataDict.NODE_FEATURES_KEY, AtomicDataDict.NODE_ATTRS_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY, diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 724bca92..a4fbcb88 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -13,7 +13,6 @@ # == Define allowed keys as constants == # The positions of the atoms in the system POSITIONS_KEY: Final[str] = "pos" -WEIGHTS_KEY: Final[str] = "weights" # The [2, n_edge] index tensor giving center -> neighbor relations EDGE_INDEX_KEY: Final[str] = "edge_index" diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 7514a509..07ec8c80 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -10,8 +10,6 @@ class SimpleLoss: """wrapper to compute weighted loss function - if atomic_weight_on is True, the loss function will search for - AtomicDataDict.WEIGHTS_KEY+key in the reference data. Args: diff --git a/nequip/train/loss.py b/nequip/train/loss.py index e5cc8889..1420fc22 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -14,7 +14,6 @@ class Loss: Args: coeffs (dict, str): keys with coefficient and loss function name - weight (bool): if True, the results will be weighted with the key: AtomicDataDict.WEIGHTS_KEY+key Example input dictionaries diff --git a/tests/unit/trainer/test_metrics.py b/tests/unit/trainer/test_metrics.py index 17983e70..eb8425c0 100644 --- a/tests/unit/trainer/test_metrics.py +++ b/tests/unit/trainer/test_metrics.py @@ -99,10 +99,3 @@ def metrics(request): coeffs = request.param # noqa instance = Metrics(components=request.param) yield instance - - -# @pytest.fixture(scope="class") -# def w_loss(): -# """""" -# instance = Metrics(coeffs=metrics_tests[-1], atomic_weight_on=True) -# yield instance From 3824d3c943a12e70d7d628c5b943e6cfd5ae91bb Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 18:08:04 -0500 Subject: [PATCH 058/126] fix for testing --- nequip/train/trainer.py | 3 +++ tests/integration/test_train.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 6c7bbee0..920e2a77 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -343,6 +343,9 @@ def __init__( self._remove_from_model_input = self._remove_from_model_input.union( AtomicDataDict.ALL_ENERGY_KEYS ) + if kwargs.get("_override_allow_truth_label_inputs", False): + # needed for unit testing models + self._remove_from_model_input = set() # load all callbacks self._init_callbacks = [load_callable(callback) for callback in init_callbacks] diff --git a/tests/integration/test_train.py b/tests/integration/test_train.py index 72d7ecb1..1c1dc969 100644 --- a/tests/integration/test_train.py +++ b/tests/integration/test_train.py @@ -110,6 +110,8 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, builder): true_config["max_epochs"] = 2 # We just don't add rescaling: true_config["model_builders"] = [builder] + # We need truth labels as inputs for these fake testing models + true_config["_override_allow_truth_label_inputs"] = True config_path = tmpdir + "/conf.yaml" with open(config_path, "w+") as fp: @@ -226,6 +228,8 @@ def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): true_config["default_dtype"] = dtype # We just don't add rescaling: true_config["model_builders"] = [builder] + # We need truth labels as inputs for these fake testing models + true_config["_override_allow_truth_label_inputs"] = True for irun in range(3): From 3de51d0f243b50052dd4e83f3dd6ba91b4480ebb Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 19 Jan 2022 18:17:02 -0500 Subject: [PATCH 059/126] test fix --- tests/integration/test_evaluate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index be8f65e5..383dc9c8 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -52,6 +52,8 @@ def training_session(request, BENCHMARK_ROOT, conffile): true_config["default_dtype"] = dtype true_config["max_epochs"] = 2 true_config["model_builders"] = [builder] + # We need truth labels as inputs for these fake testing models + true_config["_override_allow_truth_label_inputs"] = True # to be a true identity, we can't have rescaling true_config["global_rescale_shift"] = None From b0f2b994fa7dc2cbdddfd3f9d947dc91790baca0 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 20 Jan 2022 13:02:41 -0500 Subject: [PATCH 060/126] better error message --- nequip/scripts/evaluate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 894bf42a..9733136b 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -158,6 +158,7 @@ def main(args=None, running_as_script: bool = True): # Load model: logger.info("Loading model... ") model_from_training: bool = False + loaded_deployed_model: bool = False try: model, _ = load_deployed_model( args.model, @@ -165,7 +166,14 @@ def main(args=None, running_as_script: bool = True): set_global_options=True, # don't warn that setting ) logger.info("loaded deployed model.") + loaded_deployed_model = True except ValueError: # its not a deployed model + loaded_deployed_model = False + # we don't do this in the `except:` block to avoid "during handing of this exception another exception" + # chains if there is an issue loading the training session model. This makes the error messages more + # comprehensible: + if not loaded_deployed_model: + # load a training session model model, _ = Trainer.load_model_from_training_session( traindir=args.model.parent, model_name=args.model.name ) From ae6b4258d1e1894cb02aea8b44e70e4e20b260a0 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 21 Jan 2022 14:46:35 -0500 Subject: [PATCH 061/126] cutoff p besides 6 --- CHANGELOG.md | 1 + nequip/nn/cutoffs.py | 9 ++------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23a9f953..897a1c6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Most recent change on the bottom. - Model builders may now process only the configuration - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` - `wandb.watch` via `wandb_watch` option +- Allow polynomial cutoff _p_ values besides 6.0 ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/nequip/nn/cutoffs.py b/nequip/nn/cutoffs.py index 99177323..7deb7a7a 100644 --- a/nequip/nn/cutoffs.py +++ b/nequip/nn/cutoffs.py @@ -2,8 +2,7 @@ @torch.jit.script -def _poly_cutoff(x: torch.Tensor, factor: float) -> torch.Tensor: - p: float = 6.0 +def _poly_cutoff(x: torch.Tensor, factor: float, p: float = 6.0) -> torch.Tensor: x = x * factor out = 1.0 @@ -31,10 +30,6 @@ def __init__(self, r_max: float, p: float = 6): Power used in envelope function """ super().__init__() - if p != 6: - raise NotImplementedError( - "p values other than 6 not currently supported for simplicity; if you need this please file an issue." - ) self.p = p self._factor = 1.0 / r_max @@ -44,4 +39,4 @@ def forward(self, x): x: torch.Tensor, input distance """ - return _poly_cutoff(x, self._factor) + return _poly_cutoff(x, self._factor, p=self.p) From c9209edb4b97dc2a4f98322173903f3834431b2c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 21 Jan 2022 15:07:51 -0500 Subject: [PATCH 062/126] bump allowed pytorch to 1.11 --- README.md | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1e343694..b8feca2d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentia NequIP requires: * Python >= 3.6 -* PyTorch >= 1.8, <=1.10.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. +* PyTorch >= 1.8, <=1.11.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. To install: diff --git a/setup.py b/setup.py index f206bba9..bf46b7de 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "numpy", "ase", "tqdm", - "torch>=1.8,<1.11", # torch.fx added in 1.8 + "torch>=1.8,<=1.11", # torch.fx added in 1.8 "e3nn>=0.3.5,<0.5.0", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext From c79f6ba3d6b6f6cde6358787c9cdd3eee46e6667 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 21 Jan 2022 15:08:16 -0500 Subject: [PATCH 063/126] freeze in benchmark --- nequip/scripts/benchmark.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/nequip/scripts/benchmark.py b/nequip/scripts/benchmark.py index 0f1685d5..609b54c8 100644 --- a/nequip/scripts/benchmark.py +++ b/nequip/scripts/benchmark.py @@ -97,20 +97,16 @@ def main(args=None): model.eval() model = script(model) - # OLD ---- OLD ---- OLD - # TODO!!: for now we just compile, but when - # https://github.com/pytorch/pytorch/issues/64957#issuecomment-918632252 - # is resolved, then should be deploying again - # print( - # "WARNING: this is currently not using deployed model, just scripted, because of PyTorch bugs" - # ) - # OLD ---- OLD ---- OLD - - model = _compile_for_deploy(model) # TODO make this an option + model = _compile_for_deploy(model) # save and reload to avoid bugs with tempfile.NamedTemporaryFile() as f: torch.jit.save(model, f.name) model = torch.jit.load(f.name, map_location=device) + # freeze like in the LAMMPS plugin + model = torch.jit.freeze(model) + # and reload again just to avoid bugs + torch.jit.save(model, f.name) + model = torch.jit.load(f.name, map_location=device) # Make sure we're warm past compilation warmup = config["_jit_bailout_depth"] + 4 # just to be safe... From 4d606298b29a5a2fd21c6a09d59cfa5e95c9fe12 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 11:57:58 -0500 Subject: [PATCH 064/126] add related keys and allow multiple layers of rescale --- nequip/nn/_rescale.py | 26 +++++++++++++++++++++----- nequip/train/trainer.py | 29 +++++++++++++++++++---------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index afd232f5..c5a80d0e 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -1,4 +1,5 @@ from typing import Sequence, List, Union +from rdflib import Graph import torch @@ -15,10 +16,11 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): Args: model : GraphModuleMixin The model whose outputs are to be rescaled. - scale : list of keys, default [] + scale_keys : list of keys, default [] Which fields to rescale. - shift : list of keys, default [] + shift_keys : list of keys, default [] Which fields to shift after rescaling. + related_keys: list of keys that could be contingent to this rescale scale_by : floating or Tensor, default 1. The scaling factor by which to multiply fields in ``scale``. shift_by : floating or Tensor, default 0. @@ -40,6 +42,7 @@ def __init__( model: GraphModuleMixin, scale_keys: Union[Sequence[str], str] = [], shift_keys: Union[Sequence[str], str] = [], + related_keys: Union[Sequence[str], str] = [], scale_by=None, shift_by=None, shift_trainable: bool = False, @@ -47,6 +50,7 @@ def __init__( irreps_in: dict = {}, ): super().__init__() + self.model = model scale_keys = [scale_keys] if isinstance(scale_keys, str) else scale_keys shift_keys = [shift_keys] if isinstance(shift_keys, str) else shift_keys @@ -74,6 +78,7 @@ def __init__( self.scale_keys = list(scale_keys) self.shift_keys = list(shift_keys) + self.related_keys = set(related_keys).union(set(scale_keys), set(shift_keys)) self.has_scale = scale_by is not None self.scale_trainble = scale_trainable @@ -110,16 +115,27 @@ def __init__( # Finally, we tell all the modules in the model that there is rescaling # This allows them to update parameters, like physical constants with units, # that need to be scaled - # + # Note that .modules() walks the full tree, including self - for mod in self.model.modules(): + for mod in self.inner_model.modules(): if isinstance(mod, GraphModuleMixin): callback = getattr(mod, "update_for_rescale", None) - if callable(callback): + contain_related_keys = False + for out_key in mod.irreps_out: + if out_key in self.related_keys: + contain_related_keys = True + if contain_related_keys and callable(callback): # It gets the `RescaleOutput` as an argument, # since that contains all relevant information callback(self) + @property + def inner_model(self): + inner_model = self.model + while isinstance(inner_model, RescaleOutput): + inner_model = inner_model.model + return inner_model + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = self.model(data) if self.training: diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index ee06a468..7bbfb1f6 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -44,6 +44,7 @@ ) from nequip.utils.git import get_commit from nequip.model import model_from_config +from nequip.nn import RescaleOutput from .loss import Loss, LossStat from .metrics import Metrics @@ -701,6 +702,13 @@ def init(self): return self.model.to(self.torch_device) + + self.rescale_layers = [] + outer_layer = self.model + while hasattr(outer_layer, "unscale"): + self.rescale_layers.append(outer_layer) + outer_layer = getattr(outer_layer, "model", None) + self.init_objects() self._initialized = True @@ -783,14 +791,13 @@ def batch_step(self, data, validation=False): data = data.to(self.torch_device) data = AtomicData.to_AtomicDataDict(data) - if hasattr(self.model, "unscale"): + data_unscaled = data + for layer in self.rescale_layers: # This means that self.model is RescaleOutputs # this will normalize the targets # in validation (eval mode), it does nothing # in train mode, if normalizes the targets - data_unscaled = self.model.unscale(data) - else: - data_unscaled = data + data_unscaled = layer.unscale(data) # Run model # We make a shallow copy of the input dict in case the model modifies it @@ -826,16 +833,18 @@ def batch_step(self, data, validation=False): self.lr_sched.step(self.iepoch + self.ibatch / self.n_batches) with torch.no_grad(): - if hasattr(self.model, "unscale"): + if len(self.rescale_layers) > 0: if validation: - # loss function always needs to be in normalized unit - scaled_out = self.model.unscale(out, force_process=True) - _data_unscaled = self.model.unscale(data, force_process=True) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + for layer in self.rescale_layers: + # loss function always needs to be in normalized unit + scaled_out = layer.unscale(out, force_process=True) + _data_unscaled = layer.unscale(data, force_process=True) + loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) else: # If we are in training mode, we need to bring the prediction # into real units - out = self.model.scale(out, force_process=True) + for layer in self.rescale_layers.reverse(): + out = layer.scale(out, force_process=True) elif validation: loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) From afac34290d6cd775c9f2db2f133eba446895035e Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 12:00:11 -0500 Subject: [PATCH 065/126] add related_key to the builder --- nequip/model/_scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index b3d1f57c..56162b8c 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -94,6 +94,7 @@ def RescaleEnergyEtc( k for k in (AtomicDataDict.TOTAL_ENERGY_KEY,) if k in model.irreps_out ], shift_by=global_shift, + related_keys=AtomicDataDict.ALL_ENERGY_KEYS+[AtomicDataDict.PER_ATOM_ENERGY_KEY], shift_trainable=config.get(f"{module_prefix}_shift_trainable", False), scale_trainable=config.get(f"{module_prefix}_scale_trainable", False), ) From 0cdbacaf854d9e14b3b1b9830f23e9b4887fc81f Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 15:56:45 -0500 Subject: [PATCH 066/126] fix bondary condition --- nequip/nn/_graph_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index 2f3ed396..bd974a35 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -295,7 +295,7 @@ def insert( assert AtomicDataDict._irreps_compatible( module_list[idx - 1].irreps_out, module.irreps_in ) - if len(module_list) > idx: + if len(module_list) - 1 > idx: assert AtomicDataDict._irreps_compatible( module_list[idx + 1].irreps_in, module.irreps_out ) From 3d56d6bd642e67b9c82531d62391705097989485 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 16:02:12 -0500 Subject: [PATCH 067/126] remove arbitrary import added by vscode... --- nequip/nn/_rescale.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index c5a80d0e..90efdc55 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -1,5 +1,4 @@ from typing import Sequence, List, Union -from rdflib import Graph import torch From b2dd09b731aebca5ab3bedf2961df2fe12241dfb Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 16:09:16 -0500 Subject: [PATCH 068/126] fix unused import --- nequip/train/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index ff6ad0de..be86d6d5 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -44,7 +44,6 @@ ) from nequip.utils.git import get_commit from nequip.model import model_from_config -from nequip.nn import RescaleOutput from .loss import Loss, LossStat from .metrics import Metrics From 162c61ee400c9b9721f17f91bb79749f731045f6 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 18:50:12 -0500 Subject: [PATCH 069/126] increase logging level from debug to info for string conversion --- nequip/model/_scaling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 56162b8c..7697aa60 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -63,11 +63,11 @@ def RescaleEnergyEtc( if isinstance(global_scale, str): s = global_scale global_scale = computed_stats[str_names.index(global_scale)] - logging.debug(f"Replace string {s} to {global_scale}") + logging.info(f"Replace string {s} to {global_scale}") if isinstance(global_shift, str): s = global_shift global_shift = computed_stats[str_names.index(global_shift)] - logging.debug(f"Replace string {s} to {global_shift}") + logging.info(f"Replace string {s} to {global_shift}") if isinstance(global_scale, float) and global_scale < RESCALE_THRESHOLD: raise ValueError( @@ -170,14 +170,14 @@ def PerSpeciesRescale( if isinstance(scales, str): s = scales scales = computed_stats[str_names.index(scales)] - logging.debug(f"Replace string {s} to {scales}") + logging.info(f"Replace string {s} to {scales}") elif isinstance(scales, (list, float)): scales = torch.as_tensor(scales) if isinstance(shifts, str): s = shifts shifts = computed_stats[str_names.index(shifts)] - logging.debug(f"Replace string {s} to {shifts}") + logging.info(f"Replace string {s} to {shifts}") elif isinstance(shifts, (list, float)): shifts = torch.as_tensor(shifts) @@ -216,7 +216,7 @@ def PerSpeciesRescale( params=params, ) - logging.debug(f"Atomic outputs are scaled by: {scales}, shifted by {shifts}.") + logging.info(f"Atomic outputs are scaled by: {scales}, shifted by {shifts}.") # == Build the model == return model From f5986f07b2174fc08f8c93a65d3cddfe7d7f97db Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 18:52:05 -0500 Subject: [PATCH 070/126] move the contain related keys condition into callable --- nequip/nn/_atomwise.py | 6 ++++++ nequip/nn/_rescale.py | 6 +----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 57103365..31ee7523 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -171,6 +171,12 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: return data def update_for_rescale(self, rescale_module): + if hasattr(rescale_module, "related_keys"): + if not ( + self.field in rescale_module.related_keys + or self.out_field in rescale_module.related_keys + ): + return if self.arguments_in_dataset_units and rescale_module.has_scale: logging.debug( f"PerSpeciesScaleShift's arguments were in dataset units; rescaling:\n" diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 90efdc55..205862a6 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -119,11 +119,7 @@ def __init__( for mod in self.inner_model.modules(): if isinstance(mod, GraphModuleMixin): callback = getattr(mod, "update_for_rescale", None) - contain_related_keys = False - for out_key in mod.irreps_out: - if out_key in self.related_keys: - contain_related_keys = True - if contain_related_keys and callable(callback): + if callable(callback): # It gets the `RescaleOutput` as an argument, # since that contains all relevant information callback(self) From b4e02ed470db71b78eb8169f108b70ff9cb1329e Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sat, 22 Jan 2022 22:38:35 -0500 Subject: [PATCH 071/126] use -1 instead of reverse --- nequip/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index be86d6d5..78403f79 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -866,7 +866,7 @@ def batch_step(self, data, validation=False): else: # If we are in training mode, we need to bring the prediction # into real units - for layer in self.rescale_layers.reverse(): + for layer in self.rescale_layers[::-1]: out = layer.scale(out, force_process=True) elif validation: loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) From 9af3cc5c6c83c6568e999262c2ba251bebf4c7ab Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sun, 23 Jan 2022 12:35:05 -0500 Subject: [PATCH 072/126] fix inner_model error. property tag should not be used --- nequip/nn/_rescale.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 205862a6..078706b0 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -116,7 +116,7 @@ def __init__( # that need to be scaled # Note that .modules() walks the full tree, including self - for mod in self.inner_model.modules(): + for mod in self.inner_model().modules(): if isinstance(mod, GraphModuleMixin): callback = getattr(mod, "update_for_rescale", None) if callable(callback): @@ -124,12 +124,12 @@ def __init__( # since that contains all relevant information callback(self) - @property def inner_model(self): - inner_model = self.model - while isinstance(inner_model, RescaleOutput): - inner_model = inner_model.model - return inner_model + model = self.model + while isinstance(model, RescaleOutput) and hasattr(model, "model"): + print(type(model)) + model = model.model + return model def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = self.model(data) From e363d2cc04d1fc77c04f364d6274c92a6f9bb1a1 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sun, 23 Jan 2022 12:37:11 -0500 Subject: [PATCH 073/126] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd817e43..33d2a495 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Most recent change on the bottom. - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` - `wandb.watch` via `wandb_watch` option - Allow polynomial cutoff _p_ values besides 6.0 +- Support multiple rescale layers in trainer ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` From 690e2be736d527f4892123e70628d18be0af7665 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sun, 23 Jan 2022 13:53:33 -0500 Subject: [PATCH 074/126] modulize rescale builder for genral purpose --- nequip/model/_scaling.py | 47 +++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 7697aa60..cc942e23 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import List, Optional, Union import torch @@ -11,25 +11,44 @@ def RescaleEnergyEtc( + model: GraphModuleMixin, config, dataset: AtomicDataset, initialize: bool +): + + return GeneralRescale( + model=model, + config=config, + dataset=dataset, + initialize=initialize, + module_prefix="global_rescale", + default_global_scale=f"dataset_{AtomicDataDict.FORCE_KEY}_rms" + if AtomicDataDict.FORCE_KEY in model.irreps_out + else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", + default_global_shift=None, + default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS, + default_shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, + default_related_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], + ) + + +def GeneralRescale( model: GraphModuleMixin, config, dataset: AtomicDataset, initialize: bool, + module_prefix: str, + default_scale: Union[str, float, list], + default_shift: Union[str, float, list], + default_scale_keys: list, + default_shift_keys: list, + default_related_keys: list, ): """Add global rescaling for energy(-based quantities). If ``initialize`` is false, doesn't compute statistics. """ - module_prefix = "global_rescale" - - global_scale = config.get( - f"{module_prefix}_scale", - f"dataset_{AtomicDataDict.FORCE_KEY}_rms" - if AtomicDataDict.FORCE_KEY in model.irreps_out - else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", - ) - global_shift = config.get(f"{module_prefix}_shift", None) + global_scale = config.get(f"{module_prefix}_scale", default_scale) + global_shift = config.get(f"{module_prefix}_shift", default_shift) if global_shift is not None: logging.warning( @@ -88,13 +107,11 @@ def RescaleEnergyEtc( # == Build the model == return RescaleOutput( model=model, - scale_keys=[k for k in AtomicDataDict.ALL_ENERGY_KEYS if k in model.irreps_out], + scale_keys=[k for k in default_scale_keys if k in model.irreps_out], scale_by=global_scale, - shift_keys=[ - k for k in (AtomicDataDict.TOTAL_ENERGY_KEY,) if k in model.irreps_out - ], + shift_keys=[k for k in default_shift_keys if k in model.irreps_out], shift_by=global_shift, - related_keys=AtomicDataDict.ALL_ENERGY_KEYS+[AtomicDataDict.PER_ATOM_ENERGY_KEY], + related_keys=default_related_keys, shift_trainable=config.get(f"{module_prefix}_shift_trainable", False), scale_trainable=config.get(f"{module_prefix}_scale_trainable", False), ) From 03cd0d08d056ced525e824928b378b38d6116cbc Mon Sep 17 00:00:00 2001 From: nw13slx Date: Sun, 23 Jan 2022 13:55:33 -0500 Subject: [PATCH 075/126] fix wrong names --- nequip/model/_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index cc942e23..5f8e1f25 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -20,10 +20,10 @@ def RescaleEnergyEtc( dataset=dataset, initialize=initialize, module_prefix="global_rescale", - default_global_scale=f"dataset_{AtomicDataDict.FORCE_KEY}_rms" + default_scale=f"dataset_{AtomicDataDict.FORCE_KEY}_rms" if AtomicDataDict.FORCE_KEY in model.irreps_out else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", - default_global_shift=None, + default_shift=None, default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS, default_shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, default_related_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], From fff137b4c8ba42923be6ef8ec3f4808452b4dd37 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Mon, 24 Jan 2022 14:54:50 -0500 Subject: [PATCH 076/126] Update nequip/nn/_rescale.py --- nequip/nn/_rescale.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 078706b0..6c35e7df 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -127,7 +127,6 @@ def __init__( def inner_model(self): model = self.model while isinstance(model, RescaleOutput) and hasattr(model, "model"): - print(type(model)) model = model.model return model From 85733d9a893e1947511cfbb9c90ae183577f6952 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Mon, 24 Jan 2022 14:56:15 -0500 Subject: [PATCH 077/126] Update nequip/nn/_rescale.py Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/nn/_rescale.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 6c35e7df..23f92b25 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -124,7 +124,8 @@ def __init__( # since that contains all relevant information callback(self) - def inner_model(self): + def get_inner_model(self): + """Get the outermost child module that is not another ``RescaleOutput``""" model = self.model while isinstance(model, RescaleOutput) and hasattr(model, "model"): model = model.model From 6fdc06c884b013c44bb101d0c7bd0b28c8b421c0 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Mon, 24 Jan 2022 14:56:22 -0500 Subject: [PATCH 078/126] Update nequip/nn/_rescale.py Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/nn/_rescale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 23f92b25..645851e2 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -116,7 +116,7 @@ def __init__( # that need to be scaled # Note that .modules() walks the full tree, including self - for mod in self.inner_model().modules(): + for mod in self.get_inner_model().modules(): if isinstance(mod, GraphModuleMixin): callback = getattr(mod, "update_for_rescale", None) if callable(callback): From fb88575316b12d84a53a62dd458947d1ec5d4066 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 24 Jan 2022 15:41:21 -0500 Subject: [PATCH 079/126] r_max default for evaluate --- CHANGELOG.md | 1 + nequip/scripts/evaluate.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd817e43..c150ba87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Most recent change on the bottom. - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features` - `wandb.watch` via `wandb_watch` option - Allow polynomial cutoff _p_ values besides 6.0 +- `nequip-evaluate` now sets a default `r_max` taken from the model for the dataset config ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 9733136b..92a0454c 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -13,7 +13,7 @@ from nequip.utils import Config from nequip.data import AtomicData, Collater, dataset_from_config from nequip.train import Trainer -from nequip.scripts.deploy import load_deployed_model +from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY from nequip.scripts.train import default_config, _set_global_options from nequip.utils import load_file, instantiate from nequip.train.loss import Loss @@ -159,13 +159,15 @@ def main(args=None, running_as_script: bool = True): logger.info("Loading model... ") model_from_training: bool = False loaded_deployed_model: bool = False + model_r_max = None try: - model, _ = load_deployed_model( + model, metadata = load_deployed_model( args.model, device=device, set_global_options=True, # don't warn that setting ) logger.info("loaded deployed model.") + model_r_max = float(metadata[R_MAX_KEY]) loaded_deployed_model = True except ValueError: # its not a deployed model loaded_deployed_model = False @@ -174,19 +176,26 @@ def main(args=None, running_as_script: bool = True): # comprehensible: if not loaded_deployed_model: # load a training session model - model, _ = Trainer.load_model_from_training_session( + model, model_config = Trainer.load_model_from_training_session( traindir=args.model.parent, model_name=args.model.name ) model_from_training = True model = model.to(device) logger.info("loaded model from training session") + model_r_max = model_config["r_max"] model.eval() # Load a config file logger.info( f"Loading {'original ' if dataset_is_from_training else ''}dataset...", ) - config = Config.from_file(str(args.dataset_config)) + dataset_config = Config.from_file( + str(args.dataset_config), defaults={"r_max": model_r_max} + ) + if dataset_config["r_max"] != model_r_max: + logger.warn( + f"Dataset config has r_max={dataset_config['r_max']}, but model has r_max={model_r_max}!" + ) # set global options if model_from_training: @@ -208,11 +217,11 @@ def main(args=None, running_as_script: bool = True): with contextlib.redirect_stdout(sys.stderr): try: # Try to get validation dataset - dataset = dataset_from_config(config, prefix="validation_dataset") + dataset = dataset_from_config(dataset_config, prefix="validation_dataset") dataset_is_validation = True except KeyError: # Get shared train + validation dataset - dataset = dataset_from_config(config) + dataset = dataset_from_config(dataset_config) logger.info( f"Loaded {'validation_' if dataset_is_validation else ''}dataset specified in {args.dataset_config.name}.", ) From de716eb1372778c7ff8a9b48d4706ad1a37f3332 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 16:24:34 -0500 Subject: [PATCH 080/126] rename to global rescale --- nequip/model/_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 5f8e1f25..11a456eb 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -14,7 +14,7 @@ def RescaleEnergyEtc( model: GraphModuleMixin, config, dataset: AtomicDataset, initialize: bool ): - return GeneralRescale( + return GlobalRescale( model=model, config=config, dataset=dataset, @@ -30,7 +30,7 @@ def RescaleEnergyEtc( ) -def GeneralRescale( +def GlobalRescale( model: GraphModuleMixin, config, dataset: AtomicDataset, From 5e6f60cb4b1fa3a9e8f7607dc33e354607017955 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Mon, 24 Jan 2022 16:30:59 -0500 Subject: [PATCH 081/126] fix unscale in validation mode Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/train/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 78403f79..cdaac4f7 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -858,11 +858,13 @@ def batch_step(self, data, validation=False): with torch.no_grad(): if len(self.rescale_layers) > 0: if validation: + scaled_out = out + _data_unscaled = data for layer in self.rescale_layers: # loss function always needs to be in normalized unit - scaled_out = layer.unscale(out, force_process=True) - _data_unscaled = layer.unscale(data, force_process=True) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + scaled_out = layer.unscale(scaled_out, force_process=True) + _data_unscaled = layer.unscale(_data_unscaled, force_process=True) + loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) else: # If we are in training mode, we need to bring the prediction # into real units From 762491732b23429773f3a13fe3569ab3bae69885 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 16:38:23 -0500 Subject: [PATCH 082/126] format --- nequip/train/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index cdaac4f7..3ed6fb92 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -863,7 +863,9 @@ def batch_step(self, data, validation=False): for layer in self.rescale_layers: # loss function always needs to be in normalized unit scaled_out = layer.unscale(scaled_out, force_process=True) - _data_unscaled = layer.unscale(_data_unscaled, force_process=True) + _data_unscaled = layer.unscale( + _data_unscaled, force_process=True + ) loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) else: # If we are in training mode, we need to bring the prediction From ac3177deba6aabc78c59c45d11969be9ce6699af Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 17:18:13 -0500 Subject: [PATCH 083/126] add related_shift keys --- nequip/nn/_atomwise.py | 6 +++--- nequip/nn/_rescale.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 31ee7523..26a19df3 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -171,10 +171,10 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: return data def update_for_rescale(self, rescale_module): - if hasattr(rescale_module, "related_keys"): + if hasattr(rescale_module, "related_scale_keys"): if not ( - self.field in rescale_module.related_keys - or self.out_field in rescale_module.related_keys + self.field in rescale_module.related_scale_keys + or self.out_field in rescale_module.related_scale_keys ): return if self.arguments_in_dataset_units and rescale_module.has_scale: diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 645851e2..f7f755c1 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -19,7 +19,8 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): Which fields to rescale. shift_keys : list of keys, default [] Which fields to shift after rescaling. - related_keys: list of keys that could be contingent to this rescale + related_scale_keys: list of keys that could be contingent to this rescale + related_shift_keys: list of keys that could be contingent to this rescale scale_by : floating or Tensor, default 1. The scaling factor by which to multiply fields in ``scale``. shift_by : floating or Tensor, default 0. @@ -30,6 +31,8 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): scale_keys: List[str] shift_keys: List[str] + related_scale_keys: List[str] + related_shift_keys: List[str] scale_trainble: bool rescale_trainable: bool @@ -41,7 +44,8 @@ def __init__( model: GraphModuleMixin, scale_keys: Union[Sequence[str], str] = [], shift_keys: Union[Sequence[str], str] = [], - related_keys: Union[Sequence[str], str] = [], + related_shift_keys: Union[Sequence[str], str] = [], + related_scale_keys: Union[Sequence[str], str] = [], scale_by=None, shift_by=None, shift_trainable: bool = False, @@ -77,7 +81,8 @@ def __init__( self.scale_keys = list(scale_keys) self.shift_keys = list(shift_keys) - self.related_keys = set(related_keys).union(set(scale_keys), set(shift_keys)) + self.related_scale_keys = set(related_scale_keys).union(scale_keys) + self.related_shift_keys = set(related_shift_keys).union(shift_keys) self.has_scale = scale_by is not None self.scale_trainble = scale_trainable From 937249cd9a508a00e6a29b51bafc2c4c03ac0095 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 17:30:04 -0500 Subject: [PATCH 084/126] update call arguments for the rescale function --- nequip/model/_scaling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 11a456eb..aeac1476 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -26,7 +26,8 @@ def RescaleEnergyEtc( default_shift=None, default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS, default_shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, - default_related_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], + default_related_scale_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], + default_related_shift_keys=[], ) @@ -40,7 +41,8 @@ def GlobalRescale( default_shift: Union[str, float, list], default_scale_keys: list, default_shift_keys: list, - default_related_keys: list, + default_related_scale_keys: list, + default_related_shift_keys: list, ): """Add global rescaling for energy(-based quantities). @@ -111,7 +113,8 @@ def GlobalRescale( scale_by=global_scale, shift_keys=[k for k in default_shift_keys if k in model.irreps_out], shift_by=global_shift, - related_keys=default_related_keys, + related_scale_keys=default_related_scale_keys, + related_shift_keys=default_related_shift_keys, shift_trainable=config.get(f"{module_prefix}_shift_trainable", False), scale_trainable=config.get(f"{module_prefix}_scale_trainable", False), ) From 84630cc7901bfe53d63b1cf080abf438a23c8046 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 20:24:01 -0500 Subject: [PATCH 085/126] add other prefix as an option --- nequip/nn/_graph_mixin.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index bd974a35..254ec6db 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -190,7 +190,7 @@ def from_parameters( instance, _ = instantiate( builder=builder, - prefix=name, + prefix=[name] + params.get("other_prefix", []), positional_args=( dict( irreps_in=( @@ -244,7 +244,7 @@ def append_from_parameters( """ instance, _ = instantiate( builder=builder, - prefix=name, + prefix=[name] + params.get("other_prefix", []), positional_args=(dict(irreps_in=self[-1].irreps_out)), optional_args=params, all_args=shared_params, @@ -317,6 +317,7 @@ def insert_from_parameters( params: Dict[str, Any] = {}, after: Optional[str] = None, before: Optional[str] = None, + prefix: Optional[Sequence[str]] = [], ) -> None: r"""Build a module from parameters and insert it after ``after``. @@ -339,7 +340,7 @@ def insert_from_parameters( idx += 1 instance, _ = instantiate( builder=builder, - prefix=name, + prefix=[name] + params.get("other_prefix", []), positional_args=(dict(irreps_in=self[idx].irreps_out)), optional_args=params, all_args=shared_params, From a1c9b6f0751d4ef34804d2ee59228949fcb48404 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Mon, 24 Jan 2022 20:55:16 -0500 Subject: [PATCH 086/126] reverse previous commit --- nequip/nn/_graph_mixin.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index 254ec6db..2f3ed396 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -190,7 +190,7 @@ def from_parameters( instance, _ = instantiate( builder=builder, - prefix=[name] + params.get("other_prefix", []), + prefix=name, positional_args=( dict( irreps_in=( @@ -244,7 +244,7 @@ def append_from_parameters( """ instance, _ = instantiate( builder=builder, - prefix=[name] + params.get("other_prefix", []), + prefix=name, positional_args=(dict(irreps_in=self[-1].irreps_out)), optional_args=params, all_args=shared_params, @@ -295,7 +295,7 @@ def insert( assert AtomicDataDict._irreps_compatible( module_list[idx - 1].irreps_out, module.irreps_in ) - if len(module_list) - 1 > idx: + if len(module_list) > idx: assert AtomicDataDict._irreps_compatible( module_list[idx + 1].irreps_in, module.irreps_out ) @@ -317,7 +317,6 @@ def insert_from_parameters( params: Dict[str, Any] = {}, after: Optional[str] = None, before: Optional[str] = None, - prefix: Optional[Sequence[str]] = [], ) -> None: r"""Build a module from parameters and insert it after ``after``. @@ -340,7 +339,7 @@ def insert_from_parameters( idx += 1 instance, _ = instantiate( builder=builder, - prefix=[name] + params.get("other_prefix", []), + prefix=name, positional_args=(dict(irreps_in=self[idx].irreps_out)), optional_args=params, all_args=shared_params, From 858fa0ac3858ded8aab986fd1a9cb15b183c9e23 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 11:36:41 -0500 Subject: [PATCH 087/126] fix update_for_rescale condition --- nequip/nn/_atomwise.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 26a19df3..e5815fe2 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -172,10 +172,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: def update_for_rescale(self, rescale_module): if hasattr(rescale_module, "related_scale_keys"): - if not ( - self.field in rescale_module.related_scale_keys - or self.out_field in rescale_module.related_scale_keys - ): + if self.out_field not in rescale_module.related_scale_keys: return if self.arguments_in_dataset_units and rescale_module.has_scale: logging.debug( From b68e7bcf581534f8ed77cfab7a9ffefb54cf65b7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 11:37:40 -0500 Subject: [PATCH 088/126] type correct --- nequip/nn/_rescale.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index f7f755c1..b0ca3f83 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -81,8 +81,8 @@ def __init__( self.scale_keys = list(scale_keys) self.shift_keys = list(shift_keys) - self.related_scale_keys = set(related_scale_keys).union(scale_keys) - self.related_shift_keys = set(related_shift_keys).union(shift_keys) + self.related_scale_keys = list(set(related_scale_keys).union(scale_keys)) + self.related_shift_keys = list(set(related_shift_keys).union(shift_keys)) self.has_scale = scale_by is not None self.scale_trainble = scale_trainable From 6fde41464f2d2aaaeff5f591f7e96183e989423b Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 11:39:01 -0500 Subject: [PATCH 089/126] no silent failure --- nequip/nn/_rescale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index b0ca3f83..8bea7096 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -132,7 +132,7 @@ def __init__( def get_inner_model(self): """Get the outermost child module that is not another ``RescaleOutput``""" model = self.model - while isinstance(model, RescaleOutput) and hasattr(model, "model"): + while isinstance(model, RescaleOutput): model = model.model return model From 3b97cce439b7b996bce42a1ee12b92e21887bf33 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 14:28:05 -0500 Subject: [PATCH 090/126] change the prefix to autodetection --- nequip/model/_eng.py | 98 +++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index fc7b907b..f89a7448 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -21,6 +21,7 @@ def SimpleIrrepsConfig(config): """Builder that pre-processes options to allow "simple" configuration of irreps.""" + # We allow some simpler parameters to be provided, but if they are, # they have to be correct and not overridden simple_irreps_keys = ["l_max", "parity", "num_features"] @@ -30,49 +31,62 @@ def SimpleIrrepsConfig(config): "irreps_edge_sh", "conv_to_output_hidden_irreps_out", ] - has_simple: bool = any(k in config for k in simple_irreps_keys) - has_full: bool = any(k in config for k in real_irreps_keys) - assert has_simple or has_full - - update = {} - if has_simple: - # nothing to do if not - lmax = config["l_max"] - parity = config["parity"] - num_features = config["num_features"] - update["chemical_embedding_irreps_out"] = repr( - o3.Irreps([(num_features, (0, 1))]) # n scalars - ) - update["irreps_edge_sh"] = repr( - o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) - ) - update["feature_irreps_hidden"] = repr( - o3.Irreps( - [ - (num_features, (l, p)) - for p in ((1, -1) if parity else (1,)) - for l in range(lmax + 1) - ] + + # search for prefix + prefixes = set() + for key in config.keys(): + for simple_key in simple_irreps_keys: + if key.endswith(simple_key): + prefixes.update((key[: -len(simple_key)],)) + for real_key in real_irreps_keys: + if key.endswith(real_key): + prefixes.update((key[: -len(real_key)],)) + + for prefix in prefixes: + + has_simple: bool = any(f"{prefix}{k}" in config for k in simple_irreps_keys) + has_full: bool = any(f"{prefix}{k}" in config for k in real_irreps_keys) + assert has_simple or has_full + + update = {} + if has_simple: + # nothing to do if not + lmax = config.get(f"{prefix}l_max", config["l_max"]) + parity = config.get(f"{prefix}parity", config["parity"]) + num_features = config.get(f"{prefix}num_features", config["num_features"]) + update[f"{prefix}chemical_embedding_irreps_out"] = repr( + o3.Irreps([(num_features, (0, 1))]) # n scalars + ) + update[f"{prefix}irreps_edge_sh"] = repr( + o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) ) - ) - update["conv_to_output_hidden_irreps_out"] = repr( - # num_features // 2 scalars - o3.Irreps([(max(1, num_features // 2), (0, 1))]) - ) - - # check update is consistant with config - # (this is necessary since it is not possible - # to delete keys from config, so instead of - # making simple and full styles mutually - # exclusive, we just insist that if full - # and simple are provided, full must be - # consistant with simple) - for k, v in update.items(): - if k in config: - assert ( - config[k] == v - ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" - config[k] = v + update[f"{prefix}feature_irreps_hidden"] = repr( + o3.Irreps( + [ + (num_features, (l, p)) + for p in ((1, -1) if parity else (1,)) + for l in range(lmax + 1) + ] + ) + ) + update[f"{prefix}conv_to_output_hidden_irreps_out"] = repr( + # num_features // 2 scalars + o3.Irreps([(max(1, num_features // 2), (0, 1))]) + ) + + # check update is consistant with config + # (this is necessary since it is not possible + # to delete keys from config, so instead of + # making simple and full styles mutually + # exclusive, we just insist that if full + # and simple are provided, full must be + # consistant with simple) + for k, v in update.items(): + if k in config: + assert ( + config[k] == v + ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" + config[k] = v def EnergyModel( From 44431ef3e007db70c53080c63b6f960830e6c37e Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 14:44:39 -0500 Subject: [PATCH 091/126] revert back to argument style --- nequip/model/_eng.py | 92 ++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index f89a7448..cd6b62d3 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -19,7 +19,7 @@ from . import builder_utils -def SimpleIrrepsConfig(config): +def SimpleIrrepsConfig(config, prefix: Optional[str] = None): """Builder that pre-processes options to allow "simple" configuration of irreps.""" # We allow some simpler parameters to be provided, but if they are, @@ -32,61 +32,61 @@ def SimpleIrrepsConfig(config): "conv_to_output_hidden_irreps_out", ] - # search for prefix - prefixes = set() - for key in config.keys(): - for simple_key in simple_irreps_keys: - if key.endswith(simple_key): - prefixes.update((key[: -len(simple_key)],)) - for real_key in real_irreps_keys: - if key.endswith(real_key): - prefixes.update((key[: -len(real_key)],)) + prefix = "" if prefix is None else f"{prefix}_" - for prefix in prefixes: - - has_simple: bool = any(f"{prefix}{k}" in config for k in simple_irreps_keys) - has_full: bool = any(f"{prefix}{k}" in config for k in real_irreps_keys) + has_simple: bool = any(f"{prefix}{k}" in config for k in simple_irreps_keys) + has_full: bool = any(f"{prefix}{k}" in config for k in real_irreps_keys) + if prefix == "": assert has_simple or has_full - update = {} - if has_simple: - # nothing to do if not - lmax = config.get(f"{prefix}l_max", config["l_max"]) - parity = config.get(f"{prefix}parity", config["parity"]) - num_features = config.get(f"{prefix}num_features", config["num_features"]) - update[f"{prefix}chemical_embedding_irreps_out"] = repr( - o3.Irreps([(num_features, (0, 1))]) # n scalars - ) - update[f"{prefix}irreps_edge_sh"] = repr( - o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) - ) - update[f"{prefix}feature_irreps_hidden"] = repr( - o3.Irreps( - [ - (num_features, (l, p)) - for p in ((1, -1) if parity else (1,)) - for l in range(lmax + 1) - ] - ) + update = {} + if has_simple: + # nothing to do if not + lmax = config.get(f"{prefix}l_max", config["l_max"]) + parity = config.get(f"{prefix}parity", config["parity"]) + num_features = config.get(f"{prefix}num_features", config["num_features"]) + update[f"{prefix}chemical_embedding_irreps_out"] = repr( + o3.Irreps([(num_features, (0, 1))]) # n scalars + ) + update[f"{prefix}irreps_edge_sh"] = repr( + o3.Irreps.spherical_harmonics(lmax=lmax, p=-1 if parity else 1) + ) + update[f"{prefix}feature_irreps_hidden"] = repr( + o3.Irreps( + [ + (num_features, (l, p)) + for p in ((1, -1) if parity else (1,)) + for l in range(lmax + 1) + ] ) - update[f"{prefix}conv_to_output_hidden_irreps_out"] = repr( - # num_features // 2 scalars - o3.Irreps([(max(1, num_features // 2), (0, 1))]) - ) - - # check update is consistant with config - # (this is necessary since it is not possible - # to delete keys from config, so instead of - # making simple and full styles mutually - # exclusive, we just insist that if full - # and simple are provided, full must be - # consistant with simple) + ) + update[f"{prefix}conv_to_output_hidden_irreps_out"] = repr( + # num_features // 2 scalars + o3.Irreps([(max(1, num_features // 2), (0, 1))]) + ) + + # check update is consistant with config + # (this is necessary since it is not possible + # to delete keys from config, so instead of + # making simple and full styles mutually + # exclusive, we just insist that if full + # and simple are provided, full must be + # consistant with simple) + if len(prefix) == "": for k, v in update.items(): if k in config: assert ( config[k] == v ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" config[k] = v + else: + for k, v in update.items(): + if k in config and config[k] != v: + Warning( + f"For key {k}, the simple irreps options {v} is overrode by full irreps `{config[k]}`" + ) + else: + config[k] = v def EnergyModel( From 63bdad6193e6eac528c6d5dc10275b2a9a469a10 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 14:50:47 -0500 Subject: [PATCH 092/126] remove redundant condition --- nequip/model/_eng.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index cd6b62d3..a426ee4a 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -36,8 +36,7 @@ def SimpleIrrepsConfig(config, prefix: Optional[str] = None): has_simple: bool = any(f"{prefix}{k}" in config for k in simple_irreps_keys) has_full: bool = any(f"{prefix}{k}" in config for k in real_irreps_keys) - if prefix == "": - assert has_simple or has_full + assert has_simple or has_full update = {} if has_simple: From 515b5eaa19213794ef4e9aa4f82860c8b3a6ccab Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 14:56:13 -0500 Subject: [PATCH 093/126] remove separate conditioning for w/w.o prefix --- nequip/model/_eng.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index a426ee4a..034add54 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -41,9 +41,9 @@ def SimpleIrrepsConfig(config, prefix: Optional[str] = None): update = {} if has_simple: # nothing to do if not - lmax = config.get(f"{prefix}l_max", config["l_max"]) - parity = config.get(f"{prefix}parity", config["parity"]) - num_features = config.get(f"{prefix}num_features", config["num_features"]) + lmax = config[f"{prefix}l_max"] + parity = config[f"{prefix}parity"] + num_features = config[f"{prefix}num_features"] update[f"{prefix}chemical_embedding_irreps_out"] = repr( o3.Irreps([(num_features, (0, 1))]) # n scalars ) @@ -71,21 +71,12 @@ def SimpleIrrepsConfig(config, prefix: Optional[str] = None): # exclusive, we just insist that if full # and simple are provided, full must be # consistant with simple) - if len(prefix) == "": - for k, v in update.items(): - if k in config: - assert ( - config[k] == v - ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" - config[k] = v - else: - for k, v in update.items(): - if k in config and config[k] != v: - Warning( - f"For key {k}, the simple irreps options {v} is overrode by full irreps `{config[k]}`" - ) - else: - config[k] = v + for k, v in update.items(): + if k in config: + assert ( + config[k] == v + ), f"For key {k}, the full irreps options had value `{config[k]}` inconsistant with the value derived from the simple irreps options `{v}`" + config[k] = v def EnergyModel( From 6b70ca338fccd335cb31c60ac7c63deec5ae334d Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 16:26:27 -0500 Subject: [PATCH 094/126] allow default in prefix mode --- nequip/model/_eng.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index 034add54..421e9112 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -34,16 +34,20 @@ def SimpleIrrepsConfig(config, prefix: Optional[str] = None): prefix = "" if prefix is None else f"{prefix}_" - has_simple: bool = any(f"{prefix}{k}" in config for k in simple_irreps_keys) - has_full: bool = any(f"{prefix}{k}" in config for k in real_irreps_keys) + has_simple: bool = any( + (f"{prefix}{k}" in config) or (k in config) for k in simple_irreps_keys + ) + has_full: bool = any( + (f"{prefix}{k}" in config) or (k in config) for k in real_irreps_keys + ) assert has_simple or has_full update = {} if has_simple: # nothing to do if not - lmax = config[f"{prefix}l_max"] - parity = config[f"{prefix}parity"] - num_features = config[f"{prefix}num_features"] + lmax = config.get(f"{prefix}l_max", config["l_max"]) + parity = config.get(f"{prefix}parity", config["parity"]) + num_features = config.get(f"{prefix}num_features", config["num_features"]) update[f"{prefix}chemical_embedding_irreps_out"] = repr( o3.Irreps([(num_features, (0, 1))]) # n scalars ) From 135ab66a71c470e719d6e519ea8209d26e7f3a11 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 16:36:41 -0500 Subject: [PATCH 095/126] fix small cutoffs bug --- CHANGELOG.md | 1 + nequip/model/builder_utils.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e898525f..3b0e7875 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Most recent change on the bottom. - Handle one of `per_species_shifts`/`scales` being `null` when the other is a dataset statistc - `include_frames` now works with ASE datasets - no training data labels in input_data +- Average number of neighbors no longer crashes sometimes when not all nodes have neighbors (small cutoffs) ### Removed - `compile_model` diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index f2a402f9..bc78c510 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -24,9 +24,11 @@ def add_avg_num_neighbors( ann = dataset.statistics( fields=[ lambda data: ( - torch.unique( - data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True - )[1], + torch.bincount( + data[AtomicDataDict.EDGE_INDEX_KEY][0], + # make sure we have the right number of counts even if some nodes have no neighbors + minlength=len(data[AtomicDataDict.POSITIONS_KEY]), + ), "node", ) ], From 8d7c44287e40b72b729bb5558c923a48a7e23ffb Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 17:54:26 -0500 Subject: [PATCH 096/126] organize keys --- nequip/data/_keys.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index a4fbcb88..fa0cee3c 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -13,11 +13,29 @@ # == Define allowed keys as constants == # The positions of the atoms in the system POSITIONS_KEY: Final[str] = "pos" - # The [2, n_edge] index tensor giving center -> neighbor relations EDGE_INDEX_KEY: Final[str] = "edge_index" # A [n_edge, 3] tensor of how many periodic cells each edge crosses in each cell vector EDGE_CELL_SHIFT_KEY: Final[str] = "edge_cell_shift" +# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors +CELL_KEY: Final[str] = "cell" +# [n_batch, 3] bool tensor +PBC_KEY: Final[str] = "pbc" +# [n_atom, 1] long tensor +ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" +# [n_atom, 1] long tensor +ATOM_TYPE_KEY: Final[str] = "atom_types" + +BASIC_STRUCTURE_KEYS: Final[List[str]] = [ + POSITIONS_KEY, + EDGE_INDEX_KEY, + EDGE_CELL_SHIFT_KEY, + CELL_KEY, + PBC_KEY, + ATOM_TYPE_KEY, + ATOMIC_NUMBERS_KEY, +] + # A [n_edge, 3] tensor of displacement vectors associated to edges EDGE_VECTORS_KEY: Final[str] = "edge_vectors" # A [n_edge] tensor of the lengths of EDGE_VECTORS @@ -27,13 +45,8 @@ # [n_edge, dim] invariant embedding of the edges EDGE_EMBEDDING_KEY: Final[str] = "edge_embedding" -CELL_KEY: Final[str] = "cell" -PBC_KEY: Final[str] = "pbc" - NODE_FEATURES_KEY: Final[str] = "node_features" NODE_ATTRS_KEY: Final[str] = "node_attrs" -ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" -ATOM_TYPE_KEY: Final[str] = "atom_types" PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" TOTAL_ENERGY_KEY: Final[str] = "total_energy" From cdd0600487bd080139d6523593ada3c2adb88899 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 20:20:51 -0500 Subject: [PATCH 097/126] to_ase extra_fields --- CHANGELOG.md | 1 + nequip/data/AtomicData.py | 46 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b0e7875..fc04a517 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Most recent change on the bottom. - Allow polynomial cutoff _p_ values besides 6.0 - `nequip-evaluate` now sets a default `r_max` taken from the model for the dataset config - Support multiple rescale layers in trainer +- `AtomicData.to_ase` supports arbitrary fields ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index bc40f873..a3744b44 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -426,7 +426,11 @@ def from_ase( **add_fields, ) - def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: + def to_ase( + self, + type_mapper=None, + extra_fields: List[str] = [], + ) -> Union[List[ase.Atoms], ase.Atoms]: """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object. For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``, @@ -436,12 +440,18 @@ def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: Args: type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into elements, if the configuration of the ``type_mapper`` allows. + extra_fields: fields other than those handled explicitly (currently + those defining the structure as well as energy, per-atom energy, + and forces) to include in the output object. Per-atom (per-node) + quantities will be included in ``arrays``; per-graph and per-edge + quantities will be included in ``info``. Returns: A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self and is not None. Otherwise, a single ``ase.Atoms`` object is returned. """ positions = self.pos + edge_index = self[AtomicDataDict.EDGE_INDEX_KEY] if positions.device != torch.device("cpu"): raise TypeError( "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`." @@ -463,6 +473,21 @@ def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: force = getattr(self, AtomicDataDict.FORCE_KEY, None) do_calc = energy is not None or force is not None + assert ( + len( + set(extra_fields).intersection( + [ # exclude those that are special for ASE and that we process seperately + AtomicDataDict.POSITIONS_KEY, + AtomicDataDict.CELL_KEY, + AtomicDataDict.PBC_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + ] + + AtomicDataDict.ALL_ENERGY_KEYS + ) + ) + == 0 + ), "Cannot specify typical keys as `extra_fields` for atoms output" + if cell is not None: cell = cell.view(-1, 3, 3) if pbc is not None: @@ -480,8 +505,11 @@ def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: if batch is not None: mask = batch == batch_idx mask = mask.view(-1) + # if both ends of the edge are in the batch, the edge is in the batch + edge_mask = mask[edge_index[0]] & mask[edge_index[1]] else: mask = slice(None) + edge_mask = slice(None) mol = ase.Atoms( numbers=atomic_nums[mask].view(-1), # must be flat for ASE @@ -500,6 +528,22 @@ def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: fields["forces"] = force[mask].cpu().numpy() mol.calc = SinglePointCalculator(mol, **fields) + # add other information + for key in extra_fields: + if key in _NODE_FIELDS: + # mask it + mol.arrays[key] = self[key][mask].cpu().numpy() + elif key in _EDGE_FIELDS: + mol.info[key] = self[key][edge_mask].cpu().numpy() + elif key == AtomicDataDict.EDGE_INDEX_KEY: + mol.info[key] = self[key][:, edge_mask].cpu().numpy() + elif key in _GRAPH_FIELDS: + mol.info[key] = self[key][batch_idx].cpu().numpy() + else: + raise RuntimeError( + f"Extra field `{key}` isn't registered as node/edge/graph" + ) + batch_atoms.append(mol) if batch is not None: From bf29cd58f7f5976706d121d7c11320690628c34a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 25 Jan 2022 20:25:11 -0500 Subject: [PATCH 098/126] nequip-evaluate output arb fields --- CHANGELOG.md | 1 + nequip/scripts/evaluate.py | 27 +++++++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc04a517..228874b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Most recent change on the bottom. - `nequip-evaluate` now sets a default `r_max` taken from the model for the dataset config - Support multiple rescale layers in trainer - `AtomicData.to_ase` supports arbitrary fields +- `nequip-evaluate` can now output arbitrary fields to an XYZ file ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 92a0454c..c5babaa9 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -1,3 +1,4 @@ +from typing import Optional import sys import argparse import logging @@ -83,10 +84,16 @@ def main(args=None, running_as_script: bool = True): ) parser.add_argument( "--output", - help="XYZ file to write out the test set and model predicted forces, energies, etc. to.", + help="ExtXYZ (.xyz) file to write out the test set and model predictions to.", type=Path, default=None, ) + parser.add_argument( + "--output-fields", + help="Extra fields to write to the `--output`.", + type=str, + default="", + ) parser.add_argument( "--log", help="log file to store all the metrics and screen logging.debug", @@ -135,9 +142,13 @@ def main(args=None, running_as_script: bool = True): ) if args.model is None: raise ValueError("--model or --train-dir must be provided") + output_type: Optional[str] = None if args.output is not None: if args.output.suffix != ".xyz": - raise ValueError("Only extxyz format for `--output` is supported.") + raise ValueError("Only .xyz format for `--output` is supported.") + args.output_fields = args.output_fields.split(",") + assert len(args.output_fields) > 0 + output_type = "xyz" if args.device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -320,7 +331,7 @@ def main(args=None, running_as_script: bool = True): ) ) - if args.output is not None: + if output_type is not None: output = context_stack.enter_context(open(args.output, "w")) else: output = None @@ -338,16 +349,20 @@ def main(args=None, running_as_script: bool = True): with torch.no_grad(): # Write output - # TODO: make sure don't keep appending to existing file - if output is not None: + if output_type == "xyz": + # append to the file ase.io.write( output, AtomicData.from_AtomicDataDict(out) .to(device="cpu") - .to_ase(type_mapper=dataset.type_mapper), + .to_ase( + type_mapper=dataset.type_mapper, + extra_fields=args.output_fields, + ), format="extxyz", append=True, ) + # Accumulate metrics if do_metrics: metrics(out, batch) From 9f33aab7eb6e8e8822417daf439f2400af2bb1ab Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 21:16:51 -0500 Subject: [PATCH 099/126] fix bug of wrong variable name. simplify condition --- nequip/train/trainer.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 3ed6fb92..41b792fe 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -816,7 +816,7 @@ def batch_step(self, data, validation=False): # this will normalize the targets # in validation (eval mode), it does nothing # in train mode, if normalizes the targets - data_unscaled = layer.unscale(data) + data_unscaled = layer.unscale(data_unscaled) # Run model # We make a shallow copy of the input dict in case the model modifies it @@ -856,24 +856,19 @@ def batch_step(self, data, validation=False): self.lr_sched.step(self.iepoch + self.ibatch / self.n_batches) with torch.no_grad(): - if len(self.rescale_layers) > 0: - if validation: - scaled_out = out - _data_unscaled = data - for layer in self.rescale_layers: - # loss function always needs to be in normalized unit - scaled_out = layer.unscale(scaled_out, force_process=True) - _data_unscaled = layer.unscale( - _data_unscaled, force_process=True - ) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) - else: - # If we are in training mode, we need to bring the prediction - # into real units - for layer in self.rescale_layers[::-1]: - out = layer.scale(out, force_process=True) - elif validation: - loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) + if validation: + scaled_out = out + _data_unscaled = data + for layer in self.rescale_layers: + # loss function always needs to be in normalized unit + scaled_out = layer.unscale(scaled_out, force_process=True) + _data_unscaled = layer.unscale(_data_unscaled, force_process=True) + loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + else: + # If we are in training mode, we need to bring the prediction + # into real units + for layer in self.rescale_layers[::-1]: + out = layer.scale(out, force_process=True) # save metrics stats self.batch_losses = self.loss_stat(loss, loss_contrib) From 9591f1f20f021fa44acaee38251414294e64a849 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Tue, 25 Jan 2022 21:24:21 -0500 Subject: [PATCH 100/126] reverse simplication --- nequip/train/trainer.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 41b792fe..8c8bec55 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -856,19 +856,24 @@ def batch_step(self, data, validation=False): self.lr_sched.step(self.iepoch + self.ibatch / self.n_batches) with torch.no_grad(): - if validation: - scaled_out = out - _data_unscaled = data - for layer in self.rescale_layers: - # loss function always needs to be in normalized unit - scaled_out = layer.unscale(scaled_out, force_process=True) - _data_unscaled = layer.unscale(_data_unscaled, force_process=True) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) - else: - # If we are in training mode, we need to bring the prediction - # into real units - for layer in self.rescale_layers[::-1]: - out = layer.scale(out, force_process=True) + if len(self.rescale_layers) > 0: + if validation: + scaled_out = out + _data_unscaled = data + for layer in self.rescale_layers: + # loss function always needs to be in normalized unit + scaled_out = layer.unscale(scaled_out, force_process=True) + _data_unscaled = layer.unscale( + _data_unscaled, force_process=True + ) + loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + else: + # If we are in training mode, we need to bring the prediction + # into real units + for layer in self.rescale_layers[::-1]: + out = layer.scale(out, force_process=True) + elif validation: + loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) # save metrics stats self.batch_losses = self.loss_stat(loss, loss_contrib) From 757a5ef0fc19eb85c98059df7b5979d007d0d0d6 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 16:07:57 -0500 Subject: [PATCH 101/126] better message --- nequip/scripts/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index c5babaa9..e292aaf1 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -90,7 +90,7 @@ def main(args=None, running_as_script: bool = True): ) parser.add_argument( "--output-fields", - help="Extra fields to write to the `--output`.", + help="Extra fields (names comma separated with no spaces) to write to the `--output`.", type=str, default="", ) From 939769a79c4d1ac49f5d9f790ea490e41d47104c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 16:26:55 -0500 Subject: [PATCH 102/126] register fields in _set_global_options --- CHANGELOG.md | 1 + nequip/data/_build.py | 1 + nequip/scripts/train.py | 7 +++++-- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 228874b7..1317af33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Most recent change on the bottom. - `include_frames` now works with ASE datasets - no training data labels in input_data - Average number of neighbors no longer crashes sometimes when not all nodes have neighbors (small cutoffs) +- Handle field registrations correctly in `nequip-evaluate` ### Removed - `compile_model` diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 7645bcb7..8757198f 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -72,6 +72,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) # Register fields: + # This might reregister fields, but that's OK: instantiate(register_fields, all_args=config) instance, _ = instantiate( diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 130fd18f..108d39f7 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -16,8 +16,8 @@ import e3nn.util.jit from nequip.model import model_from_config -from nequip.utils import Config -from nequip.data import dataset_from_config +from nequip.utils import Config, instantiate +from nequip.data import dataset_from_config, register_fields from nequip.utils.test import assert_AtomicData_equivariant, set_irreps_debug from nequip.utils import load_file, dtype_from_name from ._logger import set_up_script_logger @@ -130,6 +130,9 @@ def _set_global_options(config): e3nn.set_optimization_defaults(**config.get("e3nn_optimization_defaults", {})) + # Register fields: + instantiate(register_fields, all_args=config) + def fresh_start(config): From 9d44022a69f97b9ea236ca5fddc72518e566b826 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 16:27:24 -0500 Subject: [PATCH 103/126] make inconsistant r_max an error --- nequip/scripts/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index e292aaf1..fccb13cc 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -204,7 +204,7 @@ def main(args=None, running_as_script: bool = True): str(args.dataset_config), defaults={"r_max": model_r_max} ) if dataset_config["r_max"] != model_r_max: - logger.warn( + raise RuntimeError( f"Dataset config has r_max={dataset_config['r_max']}, but model has r_max={model_r_max}!" ) From 7732bd75ce63e711e31a4d476b12ed8af8cd9c28 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 16:43:19 -0500 Subject: [PATCH 104/126] fix bug --- nequip/scripts/evaluate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index fccb13cc..3d2fec45 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -146,9 +146,11 @@ def main(args=None, running_as_script: bool = True): if args.output is not None: if args.output.suffix != ".xyz": raise ValueError("Only .xyz format for `--output` is supported.") - args.output_fields = args.output_fields.split(",") - assert len(args.output_fields) > 0 + args.output_fields = [e for e in args.output_fields.split(",") if e != ""] output_type = "xyz" + else: + assert args.output_fields == "" + args.output_fields = [] if args.device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") From 4ecb6555bba065f723b755185da804986dda0149 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 26 Jan 2022 16:58:03 -0500 Subject: [PATCH 105/126] change import so the script can be run from anywhere --- nequip/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 130fd18f..1cd7afbe 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -20,7 +20,7 @@ from nequip.data import dataset_from_config from nequip.utils.test import assert_AtomicData_equivariant, set_irreps_debug from nequip.utils import load_file, dtype_from_name -from ._logger import set_up_script_logger +from nequip.scripts._logger import set_up_script_logger default_config = dict( root="./", From e026695e0a4ed007a35047dec71486d1aa2bd6d9 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 17:06:27 -0500 Subject: [PATCH 106/126] fix --- nequip/scripts/evaluate.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 3d2fec45..a2672ddc 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -170,7 +170,6 @@ def main(args=None, running_as_script: bool = True): # Load model: logger.info("Loading model... ") - model_from_training: bool = False loaded_deployed_model: bool = False model_r_max = None try: @@ -180,6 +179,9 @@ def main(args=None, running_as_script: bool = True): set_global_options=True, # don't warn that setting ) logger.info("loaded deployed model.") + # the global settings for a deployed model are set by + # set_global_options in the call to load_deployed_model + # above model_r_max = float(metadata[R_MAX_KEY]) loaded_deployed_model = True except ValueError: # its not a deployed model @@ -188,11 +190,15 @@ def main(args=None, running_as_script: bool = True): # chains if there is an issue loading the training session model. This makes the error messages more # comprehensible: if not loaded_deployed_model: + # Use the model config, regardless of dataset config + global_config = args.model.parent / "config.yaml" + global_config = Config.from_file(str(global_config), defaults=default_config) + _set_global_options(global_config) + del global_config # load a training session model model, model_config = Trainer.load_model_from_training_session( traindir=args.model.parent, model_name=args.model.name ) - model_from_training = True model = model.to(device) logger.info("loaded model from training session") model_r_max = model_config["r_max"] @@ -210,19 +216,6 @@ def main(args=None, running_as_script: bool = True): f"Dataset config has r_max={dataset_config['r_max']}, but model has r_max={model_r_max}!" ) - # set global options - if model_from_training: - # Use the model config, regardless of dataset config - global_config = args.model.parent / "config.yaml" - global_config = Config.from_file(str(global_config), defaults=default_config) - _set_global_options(global_config) - del global_config - else: - # the global settings for a deployed model are set by - # set_global_options in the call to load_deployed_model - # above - pass - dataset_is_validation: bool = False # Currently, pytorch_geometric prints some status messages to stdout while loading the dataset # TODO: fix may come soon: https://github.com/rusty1s/pytorch_geometric/pull/2950 From 85bb0ab1712b7265c23429bae2498c97911d6e43 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 17:32:32 -0500 Subject: [PATCH 107/126] don't validate from_AtomicDataDict --- nequip/data/AtomicData.py | 81 +++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index a3744b44..6c017d54 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -206,7 +206,9 @@ class AtomicData(Data): **kwargs: other data, optional. """ - def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): + def __init__( + self, irreps: Dict[str, e3nn.o3.Irreps] = {}, _validate: bool = True, **kwargs + ): # empty init needed by get_example if len(kwargs) == 0 and len(irreps) == 0: @@ -214,47 +216,49 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): return # Check the keys - AtomicDataDict.validate_keys(kwargs) - _process_dict(kwargs) + if _validate: + AtomicDataDict.validate_keys(kwargs) + _process_dict(kwargs) super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) - # Validate shapes - assert self.pos.dim() == 2 and self.pos.shape[1] == 3 - assert self.edge_index.dim() == 2 and self.edge_index.shape[0] == 2 - if "edge_cell_shift" in self and self.edge_cell_shift is not None: - assert self.edge_cell_shift.shape == (self.num_edges, 3) - assert self.edge_cell_shift.dtype == self.pos.dtype - if "cell" in self and self.cell is not None: - assert (self.cell.shape == (3, 3)) or ( - self.cell.dim() == 3 and self.cell.shape[1:] == (3, 3) - ) - assert self.cell.dtype == self.pos.dtype - if "node_features" in self and self.node_features is not None: - assert self.node_features.shape[0] == self.num_nodes - assert self.node_features.dtype == self.pos.dtype - if "node_attrs" in self and self.node_attrs is not None: - assert self.node_attrs.shape[0] == self.num_nodes - assert self.node_attrs.dtype == self.pos.dtype - - if ( - AtomicDataDict.ATOMIC_NUMBERS_KEY in self - and self.atomic_numbers is not None - ): - assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES - if "batch" in self and self.batch is not None: - assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes - # Check that there are the right number of cells + if _validate: + # Validate shapes + assert self.pos.dim() == 2 and self.pos.shape[1] == 3 + assert self.edge_index.dim() == 2 and self.edge_index.shape[0] == 2 + if "edge_cell_shift" in self and self.edge_cell_shift is not None: + assert self.edge_cell_shift.shape == (self.num_edges, 3) + assert self.edge_cell_shift.dtype == self.pos.dtype if "cell" in self and self.cell is not None: - cell = self.cell.view(-1, 3, 3) - assert cell.shape[0] == self.batch.max() + 1 + assert (self.cell.shape == (3, 3)) or ( + self.cell.dim() == 3 and self.cell.shape[1:] == (3, 3) + ) + assert self.cell.dtype == self.pos.dtype + if "node_features" in self and self.node_features is not None: + assert self.node_features.shape[0] == self.num_nodes + assert self.node_features.dtype == self.pos.dtype + if "node_attrs" in self and self.node_attrs is not None: + assert self.node_attrs.shape[0] == self.num_nodes + assert self.node_attrs.dtype == self.pos.dtype - # Validate irreps - # __*__ is the only way to hide from torch_geometric - self.__irreps__ = AtomicDataDict._fix_irreps_dict(irreps) - for field, irreps in self.__irreps__: - if irreps is not None: - assert self[field].shape[-1] == irreps.dim + if ( + AtomicDataDict.ATOMIC_NUMBERS_KEY in self + and self.atomic_numbers is not None + ): + assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES + if "batch" in self and self.batch is not None: + assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes + # Check that there are the right number of cells + if "cell" in self and self.cell is not None: + cell = self.cell.view(-1, 3, 3) + assert cell.shape[0] == self.batch.max() + 1 + + # Validate irreps + # __*__ is the only way to hide from torch_geometric + self.__irreps__ = AtomicDataDict._fix_irreps_dict(irreps) + for field, irreps in self.__irreps__: + if irreps is not None: + assert self[field].shape[-1] == irreps.dim @classmethod def from_points( @@ -579,7 +583,8 @@ def to_AtomicDataDict( @classmethod def from_AtomicDataDict(cls, data: AtomicDataDict.Type): - return cls(**data) + # it's an AtomicDataDict, so don't validate-- assume valid: + return cls(_validate=False, **data) @property def irreps(self): From b6185afc2f1b23824265890e21c09a882df2f70f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 26 Jan 2022 17:39:06 -0500 Subject: [PATCH 108/126] flatten fields for ASE --- nequip/data/AtomicData.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 6c017d54..bb23923e 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -536,13 +536,17 @@ def to_ase( for key in extra_fields: if key in _NODE_FIELDS: # mask it - mol.arrays[key] = self[key][mask].cpu().numpy() + mol.arrays[key] = ( + self[key][mask].cpu().numpy().reshape(mask.sum(), -1) + ) elif key in _EDGE_FIELDS: - mol.info[key] = self[key][edge_mask].cpu().numpy() + mol.info[key] = ( + self[key][edge_mask].cpu().numpy().reshape(edge_mask.sum(), -1) + ) elif key == AtomicDataDict.EDGE_INDEX_KEY: mol.info[key] = self[key][:, edge_mask].cpu().numpy() elif key in _GRAPH_FIELDS: - mol.info[key] = self[key][batch_idx].cpu().numpy() + mol.info[key] = self[key][batch_idx].cpu().numpy().reshape(-1) else: raise RuntimeError( f"Extra field `{key}` isn't registered as node/edge/graph" From c4d4dd3c7cb417650a3f9390f2c8227f6047861d Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 27 Jan 2022 10:31:55 -0500 Subject: [PATCH 109/126] update deploy global setup --- nequip/scripts/deploy.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 8bbea037..f0fd9d21 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -21,7 +21,9 @@ from e3nn.util.jit import script +from nequip.scripts.train import _set_global_options from nequip.train import Trainer +from nequip.utils import Config CONFIG_KEY: Final[str] = "config" NEQUIP_VERSION_KEY: Final[str] = "nequip_version" @@ -171,6 +173,11 @@ def main(args=None): raise ValueError( f"{args.out_dir} is a directory, but a path to a file for the deployed model must be given" ) + + # load config + config = Config.from_file(str(args.train_dir / "config.yaml")) + _set_global_options(config) + # -- load model -- model, _ = Trainer.load_model_from_training_session( args.train_dir, model_name="best_model.pth", device="cpu" @@ -180,10 +187,6 @@ def main(args=None): model = _compile_for_deploy(model) logging.info("Compiled & optimized model.") - # load config - config_str = (args.train_dir / "config.yaml").read_text() - config = yaml.load(config_str, Loader=yaml.Loader) - # Deploy metadata: dict = {} for code in ["e3nn", "nequip", "torch"]: From 09c6568b82d861aab4f8be6e35b448e292b2aff3 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 27 Jan 2022 11:40:33 -0500 Subject: [PATCH 110/126] add config_str back --- nequip/scripts/deploy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index f0fd9d21..745ab9c8 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -209,7 +209,7 @@ def main(args=None): metadata[JIT_BAILOUT_KEY] = str(config["_jit_bailout_depth"]) metadata[TF32_KEY] = str(int(config["allow_tf32"])) - metadata[CONFIG_KEY] = config_str + metadata[CONFIG_KEY] = (args.train_dir / "config.yaml").read_text() metadata = {k: v.encode("ascii") for k, v in metadata.items()} torch.jit.save(model, args.out_file, _extra_files=metadata) From 673d19d87eff985227f3a72b5ced8cf7d500c9ed Mon Sep 17 00:00:00 2001 From: nw13slx Date: Thu, 27 Jan 2022 12:16:34 -0500 Subject: [PATCH 111/126] remove unused import --- nequip/scripts/deploy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 745ab9c8..d9c91ce0 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -9,7 +9,6 @@ import pathlib import logging import warnings -import yaml # This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch. # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. From 43f4ef514a1b477f76d3709b73433a29f33392a5 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 17:09:26 -0500 Subject: [PATCH 112/126] fix avg_num_neighbors auto with subdatasets and small cutoffs --- nequip/model/builder_utils.py | 25 +++++++++++++--------- tests/unit/model/test_builder_utils.py | 29 +++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index bc78c510..5c0ec04a 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -6,6 +6,20 @@ from nequip.data import AtomicDataset, AtomicDataDict +def _add_avg_num_neighbors_helper(data): + counts = torch.unique( + data[AtomicDataDict.EDGE_INDEX_KEY][0], + sorted=True, + return_counts=True, + )[1] + # in case the cutoff is small and some nodes have no neighbors, + # we need to pad `counts` up to the right length + counts = torch.nn.functional.pad( + counts, pad=(0, len(data[AtomicDataDict.POSITIONS_KEY]) - len(counts)) + ) + return (counts, "node") + + def add_avg_num_neighbors( config: Config, initialize: bool, @@ -22,16 +36,7 @@ def add_avg_num_neighbors( "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" ) ann = dataset.statistics( - fields=[ - lambda data: ( - torch.bincount( - data[AtomicDataDict.EDGE_INDEX_KEY][0], - # make sure we have the right number of counts even if some nodes have no neighbors - minlength=len(data[AtomicDataDict.POSITIONS_KEY]), - ), - "node", - ) - ], + fields=[_add_avg_num_neighbors_helper], modes=["mean_std"], stride=config.get("dataset_statistics_stride", 1), )[0][0].item() diff --git a/tests/unit/model/test_builder_utils.py b/tests/unit/model/test_builder_utils.py index dcb45d4d..4cfc6002 100644 --- a/tests/unit/model/test_builder_utils.py +++ b/tests/unit/model/test_builder_utils.py @@ -1,13 +1,33 @@ import pytest +import tempfile import torch -from nequip.data import AtomicDataDict +import ase.io + +from nequip.data import AtomicDataDict, ASEDataset +from nequip.data.transforms import TypeMapper from nequip.model.builder_utils import add_avg_num_neighbors -def test_avg_num_neighbors(nequip_dataset): +@pytest.mark.parametrize("r_max", [3.0, 2.0, 1.1]) +def test_avg_num_neighbors(molecules, temp_data, r_max): + with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: + for atoms in molecules: + # Reverse the atoms so the one without neighbors ends up at the end + # to test the minlength style padding logic + # this is specific to the current contents and ordering of `molcules`! + ase.io.write( + fp.name, ase.Atoms(list(atoms)[::-1]), format="extxyz", append=True + ) + nequip_dataset = ASEDataset( + file_name=fp.name, + root=temp_data, + extra_fixed_fields={"r_max": r_max}, + ase_args=dict(format="extxyz"), + type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), + ) # test basic options annkey = "avg_num_neighbors" config = {annkey: 3} @@ -28,7 +48,10 @@ def test_avg_num_neighbors(nequip_dataset): for i in range(len(nequip_dataset)): frame = nequip_dataset[i] num_neigh.append( - torch.bincount(frame[AtomicDataDict.EDGE_INDEX_KEY][0]).float() + torch.bincount( + frame[AtomicDataDict.EDGE_INDEX_KEY][0], + minlength=len(frame[AtomicDataDict.POSITIONS_KEY]), + ).float() ) avg_num_neighbor_truth = torch.mean(torch.cat(num_neigh, dim=0)) From c6b3d76b44fcffd3aa8ab21c658a916e6dc18d67 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 17:17:18 -0500 Subject: [PATCH 113/126] test subsets too --- tests/unit/model/test_builder_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit/model/test_builder_utils.py b/tests/unit/model/test_builder_utils.py index 4cfc6002..30298b44 100644 --- a/tests/unit/model/test_builder_utils.py +++ b/tests/unit/model/test_builder_utils.py @@ -12,7 +12,8 @@ @pytest.mark.parametrize("r_max", [3.0, 2.0, 1.1]) -def test_avg_num_neighbors(molecules, temp_data, r_max): +@pytest.mark.parametrize("subset", [False, True]) +def test_avg_num_neighbors(molecules, temp_data, r_max, subset): with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: for atoms in molecules: # Reverse the atoms so the one without neighbors ends up at the end @@ -28,6 +29,13 @@ def test_avg_num_neighbors(molecules, temp_data, r_max): ase_args=dict(format="extxyz"), type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), ) + + if subset: + old_nequip_dataset = nequip_dataset # noqa + nequip_dataset = nequip_dataset.index_select( + torch.randperm(len(nequip_dataset))[: len(nequip_dataset) // 2] + ) + # test basic options annkey = "avg_num_neighbors" config = {annkey: 3} From e0818b36ecb88340cc558f574e47a805e16131aa Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 17:29:22 -0500 Subject: [PATCH 114/126] print total number of frames --- nequip/scripts/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index a2672ddc..6fed5b2f 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -247,13 +247,13 @@ def main(args=None, running_as_script: bool = True): if dataset_is_validation: test_idcs = list(all_idcs - val_idcs) logger.info( - f"Using origial validation dataset minus validation set frames, yielding a test set size of {len(test_idcs)} frames.", + f"Using origial validation dataset ({len(dataset)} frames) minus validation set frames ({len(val_idcs)} frames), yielding a test set size of {len(test_idcs)} frames.", ) else: test_idcs = list(all_idcs - train_idcs - val_idcs) assert set(test_idcs).isdisjoint(train_idcs) logger.info( - f"Using origial training dataset minus training and validation frames, yielding a test set size of {len(test_idcs)} frames.", + f"Using origial training dataset ({len(dataset)} frames) minus training ({len(train_idcs)} frames) and validation frames ({len(val_idcs)} frames), yielding a test set size of {len(test_idcs)} frames.", ) # No matter what it should be disjoint from validation: assert set(test_idcs).isdisjoint(val_idcs) From 3cffe6052923eb86810d63b73ca9fac88e5202bd Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 17:44:21 -0500 Subject: [PATCH 115/126] save dataset index in evaluate --- CHANGELOG.md | 1 + nequip/scripts/evaluate.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1317af33..ecc4a7b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Most recent change on the bottom. - Support multiple rescale layers in trainer - `AtomicData.to_ase` supports arbitrary fields - `nequip-evaluate` can now output arbitrary fields to an XYZ file +- `nequip-evaluate` reports which frame in the original dataset was used as input for each output frame ### Changed - `minimal.yaml`, `minimal_eng.yaml`, and `example.yaml` now use the simplified irreps options `l_max`, `parity`, and `num_features` diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 6fed5b2f..a7d65397 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -12,7 +12,7 @@ import torch from nequip.utils import Config -from nequip.data import AtomicData, Collater, dataset_from_config +from nequip.data import AtomicData, Collater, dataset_from_config, register_fields from nequip.train import Trainer from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY from nequip.scripts.train import default_config, _set_global_options @@ -22,6 +22,10 @@ from ._logger import set_up_script_logger +ORIGINAL_DATASET_INDEX_KEY: str = "original_dataset_index" +register_fields(graph_fields=[ORIGINAL_DATASET_INDEX_KEY]) + + def main(args=None, running_as_script: bool = True): # in results dir, do: nequip-deploy build . deployed.pth parser = argparse.ArgumentParser( @@ -146,7 +150,9 @@ def main(args=None, running_as_script: bool = True): if args.output is not None: if args.output.suffix != ".xyz": raise ValueError("Only .xyz format for `--output` is supported.") - args.output_fields = [e for e in args.output_fields.split(",") if e != ""] + args.output_fields = [e for e in args.output_fields.split(",") if e != ""] + [ + ORIGINAL_DATASET_INDEX_KEY + ] output_type = "xyz" else: assert args.output_fields == "" @@ -332,10 +338,10 @@ def main(args=None, running_as_script: bool = True): output = None while True: - datas = [ - dataset[int(idex)] - for idex in test_idcs[batch_i * batch_size : (batch_i + 1) * batch_size] + this_batch_test_indexes = test_idcs[ + batch_i * batch_size : (batch_i + 1) * batch_size ] + datas = [dataset[int(idex)] for idex in this_batch_test_indexes] if len(datas) == 0: break batch = c.collate(datas) @@ -345,6 +351,10 @@ def main(args=None, running_as_script: bool = True): with torch.no_grad(): # Write output if output_type == "xyz": + # add test frame to the output: + out[ORIGINAL_DATASET_INDEX_KEY] = torch.LongTensor( + this_batch_test_indexes + ) # append to the file ase.io.write( output, From 03d8a5b47d8379fd772b52d96abf3976138024ff Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 18:40:28 -0500 Subject: [PATCH 116/126] test --output-fields --- tests/integration/test_evaluate.py | 33 +++++++++++++++++++++++++++--- tests/integration/test_train.py | 9 ++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 383dc9c8..7af11a62 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -12,6 +12,8 @@ import torch +from e3nn import o3 + from nequip.data import AtomicDataDict from test_train import ConstFactorModel, IdentityModel # noqa @@ -76,7 +78,8 @@ def training_session(request, BENCHMARK_ROOT, conffile): @pytest.mark.parametrize("do_test_idcs", [True, False]) @pytest.mark.parametrize("do_metrics", [True, False]) -def test_metrics(training_session, do_test_idcs, do_metrics): +@pytest.mark.parametrize("do_output_fields", [True, False]) +def test_metrics(training_session, do_test_idcs, do_metrics, do_output_fields): builder, true_config, tmpdir, env = training_session # == Run test error == @@ -152,15 +155,23 @@ def runit(params: dict): expect_metrics = {"f_mae", "f_rmse"} default_params["metrics-config"] = metrics_yaml - # First run + if do_output_fields: + output_fields = [AtomicDataDict.NODE_FEATURES_KEY] + default_params["output-fields"] = ",".join(output_fields) + else: + output_fields = None + + # -- First run -- metrics = runit({"train-dir": outdir, "batch-size": 200, "device": "cpu"}) # move out.xyz to out-orig.xyz shutil.move(tmpdir + "/out.xyz", tmpdir + "/out-orig.xyz") # Load it orig_atoms = ase.io.read(tmpdir + "/out-orig.xyz", index=":", format="extxyz") + # check that we have the metrics assert set(metrics.keys()) == expect_metrics + # check metrics if builder == IdentityModel: for metric, err in metrics.items(): assert np.allclose(err, 0.0), f"Metric `{metric}` wasn't zero!" @@ -168,7 +179,19 @@ def runit(params: dict): # TODO: check comperable to naive numpy compute pass - # Check insensitive to batch size + # check we got output fields + if output_fields is not None: + for a in orig_atoms: + for key in output_fields: + if key == AtomicDataDict.NODE_FEATURES_KEY: + assert a.arrays[AtomicDataDict.NODE_FEATURES_KEY].shape == ( + len(a), + 3, # THIS IS SPECIFIC TO THE HACK IN ConstFactorModel and friends + ) + else: + raise RuntimeError + + # -- Check insensitive to batch size -- for batch_size in (13, 1000): metrics2 = runit( { @@ -191,6 +214,10 @@ def runit(params: dict): ) assert np.array_equal(origframe.get_pbc(), newframe.get_pbc()) assert np.array_equal(origframe.get_cell(), newframe.get_cell()) + if output_fields is not None: + for key in output_fields: + # TODO handle info fields too + assert np.allclose(origframe.arrays[key], newframe.arrays[key]) # Check GPU if torch.cuda.is_available(): diff --git a/tests/integration/test_train.py b/tests/integration/test_train.py index 1c1dc969..36597a98 100644 --- a/tests/integration/test_train.py +++ b/tests/integration/test_train.py @@ -25,6 +25,9 @@ def __init__(self, **kwargs): def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[AtomicDataDict.FORCE_KEY] = self.one * data[AtomicDataDict.FORCE_KEY] + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( self.one * data[AtomicDataDict.TOTAL_ENERGY_KEY] ) @@ -48,6 +51,9 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[AtomicDataDict.FORCE_KEY] = ( self.factor * data[AtomicDataDict.FORCE_KEY] + 0.0 * self.dummy ) + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] + 0.0 * self.dummy ) @@ -70,6 +76,9 @@ def __init__(self, **kwargs): def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[AtomicDataDict.FORCE_KEY] = self.factor * data[AtomicDataDict.FORCE_KEY] + data[AtomicDataDict.NODE_FEATURES_KEY] = ( + 0.77 * data[AtomicDataDict.FORCE_KEY].tanh() + ) # some BS data[AtomicDataDict.TOTAL_ENERGY_KEY] = ( self.factor * data[AtomicDataDict.TOTAL_ENERGY_KEY] ) From a43e3f343f90440dda4ad5307dceb6e587c60f2e Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 31 Jan 2022 18:42:10 -0500 Subject: [PATCH 117/126] lint --- tests/integration/test_evaluate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 7af11a62..5cd6b1d3 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -12,8 +12,6 @@ import torch -from e3nn import o3 - from nequip.data import AtomicDataDict from test_train import ConstFactorModel, IdentityModel # noqa From 62996aca6fbb93c85f29e9fff4762218f9ee84a4 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 1 Feb 2022 23:31:37 -0500 Subject: [PATCH 118/126] correct type annotation --- nequip/nn/_atomwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index e5815fe2..a4a1dba1 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -109,8 +109,8 @@ def __init__( self, field: str, num_types: int, - shifts: List[float], - scales: List[float], + shifts: Optional[List[float]], + scales: Optional[List[float]], arguments_in_dataset_units: bool, out_field: Optional[str] = None, scales_trainable: bool = False, From a2d6ff82a3af23773f9b82f8c5fa86ac4b155732 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 2 Feb 2022 09:43:42 -0500 Subject: [PATCH 119/126] sort the graph_selector --- nequip/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index a1a986dd..b9e01483 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -352,7 +352,8 @@ def statistics( return [] if self._indices is not None: - graph_selector = torch.as_tensor(self._indices)[::stride] + graph_selector = torch.as_tensor(self._indices)[::stride] + graph_selector, _ = torch.sort(graph_selector) else: graph_selector = torch.arange(0, self.len(), stride) num_graphs = len(graph_selector) From 32d6ae6b065e603c489f890c1388a04aa603adc6 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 2 Feb 2022 09:52:36 -0500 Subject: [PATCH 120/126] fix for flake8 --- nequip/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index b9e01483..3c6a0434 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -352,8 +352,8 @@ def statistics( return [] if self._indices is not None: - graph_selector = torch.as_tensor(self._indices)[::stride] - graph_selector, _ = torch.sort(graph_selector) + graph_selector = torch.as_tensor(self._indices)[::stride] + graph_selector, _ = torch.sort(graph_selector) else: graph_selector = torch.arange(0, self.len(), stride) num_graphs = len(graph_selector) From 6ef8b01d84620c7121a3d5008faa8b18da34ef08 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 2 Feb 2022 13:02:42 -0500 Subject: [PATCH 121/126] update unit tests --- tests/unit/data/test_dataset.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 01b89e2d..20efa5c0 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -2,7 +2,6 @@ import pytest import tempfile import torch - from os.path import isdir, isfile from ase.data import chemical_symbols @@ -192,9 +191,12 @@ def test_edgewise_stats(self, npz_dataset): class TestPerSpeciesStatistics: @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("mode", ["mean_std", "rms"]) - def test_per_node_field(self, npz_dataset, fixed_field, mode): + @pytest.mark.parametrize("subset", [True, False]) + def test_per_node_field(self, npz_dataset, fixed_field, mode, subset): # set up the transformer - npz_dataset = set_up_transformer(npz_dataset, not fixed_field, fixed_field) + npz_dataset = set_up_transformer( + npz_dataset, not fixed_field, fixed_field, subset + ) (result,) = npz_dataset.statistics( [AtomicDataDict.BATCH_KEY], @@ -205,14 +207,15 @@ def test_per_node_field(self, npz_dataset, fixed_field, mode): @pytest.mark.parametrize("alpha", [1e-10, 1e-6, 0.1, 0.5, 1]) @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("full_rank", [True, False]) + @pytest.mark.parametrize("subset", [True, False]) @pytest.mark.parametrize( "regressor", ["NormalizedGaussianProcess", "GaussianProcess"] ) def test_per_graph_field( - self, npz_dataset, alpha, fixed_field, full_rank, regressor + self, npz_dataset, alpha, fixed_field, full_rank, regressor, subset ): - npz_dataset = set_up_transformer(npz_dataset, full_rank, fixed_field) + npz_dataset = set_up_transformer(npz_dataset, full_rank, fixed_field, subset) if npz_dataset is None: return @@ -234,7 +237,11 @@ def test_per_graph_field( else: ref_mean, ref_std, E = generate_E(N, 100, 0.5) - npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E + E_orig_order = torch.zeros_like( + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] + ) + E_orig_order[npz_dataset._indices] = E.unsqueeze(-1) + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E_orig_order ref_res2 = torch.square( torch.matmul(N, ref_mean.reshape([-1, 1])) - E.reshape([-1, 1]) @@ -367,7 +374,7 @@ def generate_E(N, mean, std): return ref_mean, ref_std, (N * E).sum(axis=-1) -def set_up_transformer(npz_dataset, full_rank, fixed_field): +def set_up_transformer(npz_dataset, full_rank, fixed_field, subset): if full_rank: @@ -405,4 +412,7 @@ def set_up_transformer(npz_dataset, full_rank, fixed_field): chemical_symbols[n]: i for i, n in enumerate([1, ntype + 1]) } ) - return npz_dataset + if subset: + return npz_dataset.index_select(torch.randperm(len(npz_dataset))) + else: + return npz_dataset From 7184b22afb41214efd76d2fd0318345fc4fb2397 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Wed, 2 Feb 2022 13:40:11 -0500 Subject: [PATCH 122/126] add condition when it is full set --- tests/unit/data/test_dataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 20efa5c0..e580a49a 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -237,11 +237,14 @@ def test_per_graph_field( else: ref_mean, ref_std, E = generate_E(N, 100, 0.5) - E_orig_order = torch.zeros_like( - npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] - ) - E_orig_order[npz_dataset._indices] = E.unsqueeze(-1) - npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E_orig_order + if subset: + E_orig_order = torch.zeros_like( + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] + ) + E_orig_order[npz_dataset._indices] = E.unsqueeze(-1) + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E_orig_order + else: + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E ref_res2 = torch.square( torch.matmul(N, ref_mean.reshape([-1, 1])) - E.reshape([-1, 1]) From 686c07b6fb67384c50984e3722089d291aa64c30 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 4 Feb 2022 13:17:06 -0500 Subject: [PATCH 123/126] fix the wrong list statement --- nequip/model/_scaling.py | 2 +- nequip/train/trainer.py | 31 +++++++++++++------------------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index aeac1476..37708088 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -25,7 +25,7 @@ def RescaleEnergyEtc( else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", default_shift=None, default_scale_keys=AtomicDataDict.ALL_ENERGY_KEYS, - default_shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, + default_shift_keys=[AtomicDataDict.TOTAL_ENERGY_KEY], default_related_scale_keys=[AtomicDataDict.PER_ATOM_ENERGY_KEY], default_related_shift_keys=[], ) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 8c8bec55..41b792fe 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -856,24 +856,19 @@ def batch_step(self, data, validation=False): self.lr_sched.step(self.iepoch + self.ibatch / self.n_batches) with torch.no_grad(): - if len(self.rescale_layers) > 0: - if validation: - scaled_out = out - _data_unscaled = data - for layer in self.rescale_layers: - # loss function always needs to be in normalized unit - scaled_out = layer.unscale(scaled_out, force_process=True) - _data_unscaled = layer.unscale( - _data_unscaled, force_process=True - ) - loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) - else: - # If we are in training mode, we need to bring the prediction - # into real units - for layer in self.rescale_layers[::-1]: - out = layer.scale(out, force_process=True) - elif validation: - loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) + if validation: + scaled_out = out + _data_unscaled = data + for layer in self.rescale_layers: + # loss function always needs to be in normalized unit + scaled_out = layer.unscale(scaled_out, force_process=True) + _data_unscaled = layer.unscale(_data_unscaled, force_process=True) + loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled) + else: + # If we are in training mode, we need to bring the prediction + # into real units + for layer in self.rescale_layers[::-1]: + out = layer.scale(out, force_process=True) # save metrics stats self.batch_losses = self.loss_stat(loss, loss_contrib) From f78b8e8b3431be3e682c0bf23fcf3754df6d0ab7 Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 4 Feb 2022 13:23:51 -0500 Subject: [PATCH 124/126] add assert --- nequip/model/_scaling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 37708088..bcd31e1e 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -106,6 +106,12 @@ def GlobalRescale( if global_scale is not None: global_scale = 1.0 # same, + error_string = "keys need to be a list" + assert isinstance(default_scale_keys, list), error_string + assert isinstance(default_shift_keys, list), error_string + assert isinstance(default_related_scale_keys, list), error_string + assert isinstance(default_related_shift_keys, list), error_string + # == Build the model == return RescaleOutput( model=model, From 202ebe5a7cd18d9abff5f337f426ee0baf04a841 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 4 Feb 2022 17:31:09 -0500 Subject: [PATCH 125/126] comment --- nequip/data/dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 3c6a0434..44afe560 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -353,6 +353,16 @@ def statistics( if self._indices is not None: graph_selector = torch.as_tensor(self._indices)[::stride] + # note that self._indices is _not_ necessarily in order, + # while self.data --- which we take our arrays from --- + # is always in the original order. + # In particular, the values of `self.data.batch` + # are indexes in the ORIGINAL order + # thus we need graph level properties to also be in the original order + # so that batch values index into them correctly + # since self.data.batch is always sorted & contiguous + # (because of Batch.from_data_list) + # we sort it: graph_selector, _ = torch.sort(graph_selector) else: graph_selector = torch.arange(0, self.len(), stride) From 1cbfd333140e07e93c2301bb49032ac52b517cf5 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 4 Feb 2022 17:46:34 -0500 Subject: [PATCH 126/126] changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecc4a7b2..8ad45d6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## [Unreleased] - 0.5.2 +## [Unreleased] + +## [0.5.2] - 2022-02-04 ### Added - Model builders may now process only the configuration - Allow irreps to optionally be specified through the simplified keys `l_max`, `parity`, and `num_features`