diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md index f90b4eed..c6a6a531 100644 --- a/.github/ISSUE_TEMPLATE/question.md +++ b/.github/ISSUE_TEMPLATE/question.md @@ -7,4 +7,4 @@ assignees: '' --- - +If this isn't an issue with the code or a request, please use our [GitHub Discussions](https://github.com/mir-group/nequip/discussions) instead. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 09a5e48f..1f6e8b38 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -24,8 +24,9 @@ Resolves: #??? - [ ] My code follows the code style of this project and has been formatted using `black`. -- [ ] I have updated the documentation (if relevant). -- [ ] I have added tests that cover my changes (if relevant). - [ ] All new and existing tests passed. -- [ ] `example.yaml` (and other relevant `configs/`) have been updated with new or changed options. -- [ ] I have updated `CHANGELOG.md`. \ No newline at end of file +- [ ] I have added tests that cover my changes (if relevant). +- [ ] The option documentation (`docs/options`) has been updated with new or changed options. +- [ ] I have updated `CHANGELOG.md`. +- [ ] I have updated the documentation (if relevant). + diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..41e389be --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,48 @@ +name: Check Syntax and Run Tests + +on: + push: + branches: + - main + + pull_request: + branches: + - main + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.9] + torch-version: [1.8.0, 1.9.0] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install flake8 + run: | + pip install flake8 + - name: Lint with flake8 + run: | + flake8 . --count --show-source --statistics + - name: Install dependencies + env: + TORCH: "${{ matrix.torch-version }}" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + python -m pip install --upgrade pip + pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install . + - name: Install pytest + run: | + pip install pytest + pip install pytest-xdist[psutil] + - name: Test with pytest + run: | + # See https://github.com/pytest-dev/pytest/issues/1075 + PYTHONHASHSEED=0 pytest -n auto --ignore=docs/ . diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml new file mode 100644 index 00000000..a69a728d --- /dev/null +++ b/.github/workflows/tests_develop.yml @@ -0,0 +1,48 @@ +name: Check Syntax and Run Tests + +on: + push: + branches: + - develop + + pull_request: + branches: + - develop + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + torch-version: [1.9.0] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install flake8 + run: | + pip install flake8 + - name: Lint with flake8 + run: | + flake8 . --count --show-source --statistics + - name: Install dependencies + env: + TORCH: "${{ matrix.torch-version }}" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + python -m pip install --upgrade pip + pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install . + - name: Install pytest + run: | + pip install pytest + pip install pytest-xdist[psutil] + - name: Test with pytest + run: | + # See https://github.com/pytest-dev/pytest/issues/1075 + PYTHONHASHSEED=0 pytest -n auto --ignore=docs/ . diff --git a/CHANGELOG.md b/CHANGELOG.md index b4859880..e7040c1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,63 @@ Most recent change on the bottom. ## [Unreleased] +## [0.5.0] - 2021-11-24 +### Changed +- Allow e3nn 0.4.*, which changes the default normalization of `TensorProduct`s; this change _should_ not affect typical NequIP networks +- Deployed are now frozen on load, rather than compile + +### Fixed +- `load_deployed_model` respects global JIT settings + +## [0.4.0] - not released +### Added +- Support for `e3nn`'s `soft_one_hot_linspace` as radial bases +- Support for parallel dataloader workers with `dataloader_num_workers` +- Optionally independently configure validation and training datasets +- Save dataset parameters along with processed data +- Gradient clipping +- Arbitrary atom type support +- Unified, modular model building and initialization architecture +- Added `nequip-benchmark` script for benchmarking and profiling models +- Add before option to SequentialGraphNetwork.insert +- Normalize total energy loss by the number of atoms via PerAtomLoss +- Model builder to initialize training from previous checkpoint +- Better error when instantiation fails +- Rename `npz_keys` to `include_keys` +- Allow user to register `graph_fields`, `node_fields`, and `edge_fields` via yaml +- Deployed models save the e3nn and torch versions they were created with + +### Changed +- Update example.yaml to use wandb by default, to only use 100 epochs of training, to set a very large batch logging frequency and to change Validation_loss to validation_loss +- Name processed datasets based on a hash of their parameters to ensure only valid cached data is used +- Do not use TensorFloat32 by default on Ampere GPUs until we understand it better +- No atomic numbers in networks +- `dataset_energy_std`/`dataset_energy_mean` to `dataset_total_energy_*` +- `nequip.dynamics` -> `nequip.ase` +- update example.yaml and full.yaml with better defaults, new loss function, and switched to toluene-ccsd(t) as example +data +- `use_sc` defaults to `True` +- `register_fields` is now in `nequip.data` +- Default total energy scaling is changed from global mode to per species mode. +- Renamed `trainable_global_rescale_scale` to `global_rescale_scale_trainble` +- Renamed `trainable_global_rescale_shift` to `global_rescale_shift_trainble` +- Renamed `PerSpeciesScaleShift_` to `per_species_rescale` +- Change default and allowed values of `metrics_key` from `loss` to `validation_loss`. The old default `loss` will no longer be accepted. +- Renamed `per_species_rescale_trainable` to `per_species_rescale_scales_trainable` and `per_species_rescale_shifts_trainable` + +### Fixed +- The first 20 epochs/calls of inference are no longer painfully slow for recompilation +- Set global options like TF32, dtype in `nequip-evaluate` +- Avoid possilbe race condition in caching of processed datasets across multiple training runs + +### Removed +- Removed `allowed_species` +- Removed `--update-config`; start a new training and load old state instead +- Removed dependency on `pytorch_geometric` +- `nequip-train` no longer prints the full config, which can be found in the training dir as `config.yaml`. +- `nequip.datasets.AspirinDataset` & `nequip.datasets.WaterDataset` +- Dependency on `pytorch_scatter` + ## [0.3.3] - 2021-08-11 ### Added - `to_ase` method in `AtomicData.py` to convert `AtomicData` object to (list of) `ase.Atoms` object(s) diff --git a/README.md b/README.md index b41ce9ec..3b82dc52 100644 --- a/README.md +++ b/README.md @@ -13,17 +13,10 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentia NequIP requires: * Python >= 3.6 -* PyTorch >= 1.8, <1.10 (PyTorch 1.10 support is in the works on `develop`.) +* PyTorch >= 1.8, <=1.10.*. PyTorch can be installed following the [instructions from their documentation](https://pytorch.org/get-started/locally/). Note that neither `torchvision` nor `torchaudio`, included in the default install command, are needed for NequIP. NequIP is also not currently compatible with PyTorch 1.10; PyTorch 1.9 can be specified with `pytorch==1.9` in the install command. To install: -* Install [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric), following [their installation instructions](https://pytorch-geometric.readthedocs.io/en/1.7.2/notes/installation.html) and making sure to install with the correct version of CUDA. Please note that `torch_geometric==1.7.2` is required. - -* Install our fork of [`pytorch_ema`](https://github.com/Linux-cpp-lisp/pytorch_ema) for using an Exponential Moving Average on the weights: -```bash -$ pip install "git+https://github.com/Linux-cpp-lisp/pytorch_ema@context_manager#egg=torch_ema" -``` - * We use [Weights&Biases](https://wandb.ai) to keep track of experiments. This is not a strict requirement — you can use our package without it — but it may make your life easier. If you want to use it, create an account [here](https://wandb.ai) and install the Python package: ``` @@ -40,14 +33,24 @@ pip install . ### Installation Issues -We recommend running the tests using ```pytest```: +The easiest way to check if your installation is working is to train a toy model: +```bash +$ nequip-train configs/minimal.yaml +``` + +If you suspect something is wrong, encounter errors, or just want to confirm that everything is in working order, you can also run the unit tests: ``` pip install pytest -pytest ./tests/ +pytest tests/unit/ ``` -While the tests are somewhat compute intensive, we've known them to hang on certain systems that have GPUs. If this happens to you, please report it along with information on your software environment in the [Issues](https://github.com/mir-group/nequip/issues)! +To run the full tests, including a set of longer/more intensive integration tests, run: +``` +pytest tests/ +``` + +Note: the integration tests have hung in the past on certain systems that have GPUs. If this happens to you, please report it along with information on your software environment in the [Issues](https://github.com/mir-group/nequip/issues)! ## Usage @@ -64,7 +67,7 @@ $ nequip-train configs/example.yaml A number of example configuration files are provided: - [`configs/minimal.yaml`](configs/minimal.yaml): A minimal example of training a toy model on force data. - [`configs/minimal_eng.yaml`](configs/minimal_eng.yaml): The same, but for a toy model that predicts and trains on only energy labels. - - [`configs/example.yaml`](configs/example.yaml): Training a more realistic model on forces and energies. + - [`configs/example.yaml`](configs/example.yaml): Training a more realistic model on forces and energies. Start here for real models. - [`configs/full.yaml`](configs/full.yaml): A complete configuration file containing all available options along with documenting comments. Training runs can be restarted using `nequip-restart`; training that starts fresh or restarts depending on the existance of the working directory can be launched using `nequip-requeue`. All `nequip-*` commands accept the `--help` option to show their call signatures and options. @@ -87,14 +90,12 @@ The `nequip-deploy` command is used to deploy the result of a training session i It compiles a NequIP model trained in Python to [TorchScript](https://pytorch.org/docs/stable/jit.html). The result is an optimized model file that has no dependency on the `nequip` Python library, or even on Python itself: ```bash -nequip-deploy build path/to/training/session/ path/to/deployed.pth +nequip-deploy build path/to/training/session/ where/to/put/deployed_model.pth ``` For more details on this command, please run `nequip-deploy --help`. ### Using models in Python -Both deployed and undeployed models can be used in Python code; for details, see the end of the [Developer's tutorial](https://deepnote.com/project/2412ca93-7ad1-4458-972c-5d5add5a667e) mentioned again below. - An ASE calculator is also provided in `nequip.dynamics`. ### LAMMPS Integration @@ -113,18 +114,12 @@ The result is an optimized model file that has no Python dependency and can be u ``` pair_style nequip -pair_coeff * * deployed.pth +pair_coeff * * deployed.pth ... ``` For installation instructions, please see the [`pair_nequip` repository](https://github.com/mir-group/pair_nequip). -## Developer's tutorial - -A more in-depth introduction to the internals of NequIP can be found in the [tutorial notebook](https://deepnote.com/project/2412ca93-7ad1-4458-972c-5d5add5a667e). This notebook discusses theoretical background as well as the Python interfaces that can be used to train and call models. - -Please note that for most common usecases, including customized models, the `nequip-*` commands should be prefered for training models. - ## References & citing The theory behind NequIP is described in our preprint (1). NequIP's backend builds on e3nn, a general framework for building E(3)-equivariant neural networks (2). If you use this repository in your work, please consider citing NequIP (1) and e3nn (3): diff --git a/configs/example.yaml b/configs/example.yaml index 0096269e..70dfdd98 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -1,49 +1,55 @@ -# an example yaml file -# for a full yaml file containing all possible features check out full.yaml +# a simple example config file # Two folders will be used during the training: 'root'/process and 'root'/'run_name' # run_name contains logfiles and saved models # process contains processed data sets # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. -root: results/aspirin -run_name: example-run +root: results/toluene +run_name: example-run-toluene seed: 0 # random number seed for numpy and torch -restart: false # set True for a restarted run -append: false # set True if a restarted run should append to the previous log file +append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 # network -r_max: 4.0 # cutoff radius in length units - -num_layers: 6 # number of interaction blocks, we found 5-6 to work best +r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan +num_layers: 4 # number of interaction blocks, we find 4-6 to work best chemical_embedding_irreps_out: 32x0e # irreps for the chemical embedding of species -feature_irreps_hidden: 32x0o + 32x0e + 16x1o + 16x1e + 8x2o + 8x2e # irreps used for hidden features, here we go up to lmax=2, with even and odd parities -irreps_edge_sh: 0e + 1o + 2e # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer +feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster +irreps_edge_sh: 0e + 1o # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer conv_to_output_hidden_irreps_out: 16x0e # irreps used in hidden layer of output block - nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended resnet: false # set true to make interaction block a resnet-style update + +# scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs. +# Different nonlinearities are specified for e (even) and o (odd) parity; +# note that only tanh and abs are correct for o (odd parity). +nonlinearity_scalars: + e: silu + o: tanh + +nonlinearity_gates: + e: silu + o: tanh + +# radial network basis num_basis: 8 # number of basis functions used in the radial basis +BesselBasis_trainable: true # set true to train the bessel weights +PolynomialCutoff_p: 6 # p-exponent used in polynomial cutoff function # radial network -invariant_layers: 2 # number of radial layers, we found it important to keep this small, 1 or 2 +invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, None => no normalization. +avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement -# to specify different parameters for each convolutional layer, try examples below -# layer1_use_sc: true # use "layer{i}_" prefix to specify parameters for only one of the layer, -# priority for different definition: -# invariant_neurons < InteractionBlock_invariant_neurons < layer{i}_invariant_neurons - # data set # the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys # key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py) # all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields # note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key dataset: npz # type of data set, can be npz or ase -dataset_url: http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip # url to download the npz. optional -dataset_file_name: ./benchmark_data/aspirin_ccsd-train.npz # path to data set file +dataset_url: http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip # url to download the npz. optional +dataset_file_name: ./benchmark_data/toluene_ccsd_t-train.npz # path to data set file key_mapping: z: atomic_numbers # atomic species, integers E: total_energy # total potential eneriges to train to @@ -52,57 +58,81 @@ key_mapping: npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# As an alternative option to npz, you can also pass data ase ASE Atoms-objects -# This can often be easier to work with, simply make sure the ASE Atoms object -# has a calculator for which atoms.get_potential_energy() and atoms.get_forces() are defined -# dataset: ase -# dataset_file_name: xxx.xyz # need to be a format accepted by ase.io.read -# ase_args: # any arguments needed by ase.io.read -# format: extxyz +# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. +chemical_symbol_to_type: + H: 0 + C: 1 # logging -wandb: true # we recommend using wandb for logging, we'll turn it off here as it's optional -wandb_project: aspirin # project name used in wandb +wandb: true # we recommend using wandb for logging, we'll turn it off here as it's optional +wandb_project: toluene-example # project name used in wandb wandb_resume: true # if true and restart is true, wandb run data will be restarted and updated. # if false, a new wandb run will be generated verbose: info # the same as python logging, e.g. warning, info, debug, error. case insensitive -log_batch_freq: 1 # batch frequency, how often to print training errors withinin the same epoch +log_batch_freq: 1000000 # batch frequency, how often to print training errors withinin the same epoch log_epoch_freq: 1 # epoch frequency, how often to print and save the model +save_checkpoint_freq: -1 # frequency to save the intermediate checkpoint. no saving when the value is not positive. +save_ema_checkpoint_freq: -1 # frequency to save the intermediate ema checkpoint. no saving when the value is not positive. # training n_train: 100 # number of training data n_val: 50 # number of validation data -learning_rate: 0.01 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune -batch_size: 5 # batch size, we found it important to keep this small for most applications (1-5) -max_epochs: 1000000 # stop training after _ number of epochs -metrics_key: loss # metrics used for scheduling and saving best model. Options: loss, or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse -use_ema: false # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors -ema_decay: 0.999 # ema weight, commonly set to 0.999 +learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune +batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better +max_epochs: 100 # stop training after _ number of epochs, we set a very large number here, it won't take this long in practice and we will use early stopping instead +train_val_split: random # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice +shuffle: true # If true, the data loader will shuffle the data, usually a good idea +metrics_key: validation_loss # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse +use_ema: true # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors +ema_decay: 0.99 # ema weight, typically set to 0.99 or 0.999 +ema_use_num_updates: true # whether to use number of updates when computing averages + +# early stopping based on metrics values. +early_stopping_patiences: # stop early if a metric value stopped decreasing for n epochs + validation_loss: 50 # loss function loss_coeffs: # different weights to use in a weighted loss functions - forces: 100 # for MD applications, we recommed a force weight of 100 and an energy weight of 1 - total_energy: 1 # alternatively, if energies are not of importance, a force weight 1 and an energy weight of 0 also works. + forces: 1 # for MD applications, we recommed a force weight of 100 and an energy weight of 1 + total_energy: # alternatively, if energies are not of importance, a force weight 1 and an energy weight of 0 also works. + - 1 + - PerAtomMSELoss # output metrics metrics_components: - - - forces # key - - rmse # "rmse" or "mse" - - PerSpecies: True # if true, per species contribution is counted separately - report_per_component: False # if true, statistics on each component (i.e. fx, fy, fz) will be counted separately - - forces - mae - PerSpecies: True report_per_component: False - - total_energy - mae + - PerAtom: True # if true, energy is normalized by the number of atoms # optimizer, may be any optimizer defined in torch.optim # the name `optimizer_name`is case sensitive -optimizer_name: Adam +optimizer_name: Adam # default optimizer is Adam in the amsgrad mode optimizer_amsgrad: true -# lr scheduler, on plateau +# lr scheduler, currently only supports the two options listed below, if you need more please file an issue +# first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch lr_scheduler_name: ReduceLROnPlateau lr_scheduler_patience: 100 lr_scheduler_factor: 0.5 + +# we provide a series of options to shift and scale the data +# these are for advanced use and usually the defaults work very well +# the default is to scale the atomic energy and forces by scaling them by the force standard deviation and to shift the energy by the mean atomic energy +# in certain cases, it can be useful to have a trainable shift/scale and to also have species-dependent shifts/scales for each atom + +per_species_rescale_shifts_trainable: false +per_species_rescale_scales_trainable: false + +# whether the shifts and scales are trainable. Defaults to False. Optional +per_species_rescale_shifts: dataset_per_atom_total_energy_mean +# initial atomic energy shift for each species. default to the mean of per atom energy. Optional +# the value can be a constant float value, an array for each species, or a string that defines a statistics over the training dataset +per_species_rescale_scales: dataset_forces_rms +# initial atomic energy scale for each species. Optional. +# the value can be a constant float value, an array for each species, or a string +# per_species_rescale_arguments_in_dataset_units: True +# if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values. diff --git a/configs/full.yaml b/configs/full.yaml index 7ae8d632..57d99094 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -1,66 +1,67 @@ -# a full yaml file with all nequip options -# this is primarily intested to serve as documentation/reference for all options -# for a simpler yaml file containing all necessary feature to get you started check out example.yaml +# IMPORTANT: READ THIS + +# This is a full yaml file with all nequip options. +# It is primarily intented to serve as documentation/reference for all options +# For a simpler yaml file containing all necessary feature to get you started check out configs/example.yaml # Two folders will be used during the training: 'root'/process and 'root'/'run_name' # run_name contains logfiles and saved models # process contains processed data sets # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. -root: results/aspirin -run_name: example-run-full +root: results/toluene +run_name: example-run-toluene seed: 0 # random number seed for numpy and torch -restart: false # set True for a restarted run -append: false # set True if a restarted run should append to the previous log file +append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 -allow_tf32: True # whether to use TensorFloat32 if it is available +allow_tf32: false # whether to use TensorFloat32 if it is available +device: cuda # which device to use. Default: automatically detected cuda or "cpu" # network -r_max: 4.0 # cutoff radius in length units - -num_layers: 6 # number of interaction blocks, we found 5-6 to work best +r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan +num_layers: 4 # number of interaction blocks, we find 4-6 to work best chemical_embedding_irreps_out: 32x0e # irreps for the chemical embedding of species -feature_irreps_hidden: 32x0o + 32x0e + 16x1o + 16x1e + 8x2o + 8x2e # irreps used for hidden features, here we go up to lmax=2, with even and odd parities -irreps_edge_sh: 0e + 1o + 2e # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer +feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster +irreps_edge_sh: 0e + 1o # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer conv_to_output_hidden_irreps_out: 16x0e # irreps used in hidden layer of output block - nonlinearity_type: gate # may be 'gate' or 'norm', 'gate' is recommended +resnet: false # set true to make interaction block a resnet-style update # scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs. # Different nonlinearities are specified for e (even) and o (odd) parity; # note that only tanh and abs are correct for o (odd parity). nonlinearity_scalars: - e: ssp + e: silu o: tanh -nonlinearity_gates: - e: ssp - o: abs -resnet: false # set true to make interaction block a resnet-style update +nonlinearity_gates: + e: silu + o: tanh +# radial network basis num_basis: 8 # number of basis functions used in the radial basis BesselBasis_trainable: true # set true to train the bessel weights PolynomialCutoff_p: 6 # p-exponent used in polynomial cutoff function # radial network -invariant_layers: 2 # number of radial layers, we found it important to keep this small, 1 or 2 +invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, None => no normalization. +avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement compile_model: false # whether to compile the constructed model to TorchScript # to specify different parameters for each convolutional layer, try examples below # layer1_use_sc: true # use "layer{i}_" prefix to specify parameters for only one of the layer, -# priority for different definition: +# priority for different definitions: # invariant_neurons < InteractionBlock_invariant_neurons < layer{i}_invariant_neurons # data set -# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys +# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or include_keys # key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py) # all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields # note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key dataset: npz # type of data set, can be npz or ase -dataset_url: http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip # url to download the npz. optional -dataset_file_name: ./benchmark_data/aspirin_ccsd-train.npz # path to data set file +dataset_url: http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip # url to download the npz. optional +dataset_file_name: ./benchmark_data/toluene_ccsd_t-train.npz # path to data set file key_mapping: z: atomic_numbers # atomic species, integers E: total_energy # total potential eneriges to train to @@ -69,6 +70,35 @@ key_mapping: npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers +# # for extxyz file +# dataset: ase +# dataset_file_name: H2.extxyz +# ase_args: +# format: extxyz +# include_keys: +# - user_label +# key_mapping: +# user_label: label0 +# +# # for VASP OUTCAR, the yaml input should be +# dataset: ase +# dataset_file_name: OUTCAR +# ase_args: +# format: vasp-out +# key_mapping: +# free_energy: total_energy + +# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. +chemical_symbol_to_type: + H: 0 + C: 1 + +# Alternatively, if the dataset has type indicess, the total number of types is all that is required: +# type_names: +# 0: my_type +# 1: atom +# 2: thing + # As an alternative option to npz, you can also pass data ase ASE Atoms-objects # This can often be easier to work with, simply make sure the ASE Atoms object # has a calculator for which atoms.get_potential_energy() and atoms.get_forces() are defined @@ -77,9 +107,14 @@ npz_fixed_field_keys: # ase_args: # any arguments needed by ase.io.read # format: extxyz +# If you want to use a different dataset for validation, you can specify +# the same types of options using a `validation_` prefix: +# validation_dataset: ase +# validation_dataset_file_name: xxx.xyz # need to be a format accepted by ase.io.read + # logging -wandb: false # we recommend using wandb for logging, we'll turn it off here as it's optional -wandb_project: aspirin # project name used in wandb +wandb: false # we recommend using wandb for logging, we'll turn it off here as it's optional +wandb_project: toluene-example # project name used in wandb wandb_resume: true # if true and restart is true, wandb run data will be restarted and updated. # if false, a new wandb run will be generated verbose: info # the same as python logging, e.g. warning, info, debug, error. case insensitive @@ -91,35 +126,40 @@ save_ema_checkpoint_freq: -1 # training n_train: 100 # number of training data n_val: 50 # number of validation data -learning_rate: 0.01 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune -batch_size: 5 # batch size, we found it important to keep this small for most applications (1-5) -max_epochs: 1000000 # stop training after _ number of epochs +learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune +batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better +max_epochs: 100000 # stop training after _ number of epochs, we set a very large number here, it won't take this long in practice and we will use early stopping instead train_val_split: random # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice shuffle: true # If true, the data loader will shuffle the data, usually a good idea -metrics_key: loss # metrics used for scheduling and saving best model. Options: loss, or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse -use_ema: false # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors -ema_decay: 0.999 # ema weight, commonly set to 0.999 +metrics_key: validation_loss # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse +use_ema: true # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors +ema_decay: 0.99 # ema weight, typically set to 0.99 or 0.999 ema_use_num_updates: true # whether to use number of updates when computing averages +report_init_validation: false # if True, report the validation error for just initialized model # early stopping based on metrics values. # LR, wall and any keys printed in the log file can be used. # The key can start with Training or Validation. If not defined, the validation value will be used. early_stopping_patiences: # stop early if a metric value stopped decreasing for n epochs - Validation_loss: 50 # - Training_loss: 100 # - e_mae: 100 # -early_stopping_delta: # If delta is defined, a tiny decrease smaller than delta will not be considered as a decrease - Training_loss: 0.005 # + Validation_loss: 50 + +early_stopping_delta: # If delta is defined, a decrease smaller than delta will not be considered as a decrease + Validation_loss: 0.005 + early_stopping_cumulative_delta: false # If True, the minimum value recorded will not be updated when the decrease is smaller than delta + early_stopping_lower_bounds: # stop early if a metric value is lower than the bound - LR: 1.0e-10 # + LR: 1.0e-6 + early_stopping_upper_bounds: # stop early if a metric value is higher than the bound - wall: 1.0e+100 # + wall: 1.0e+100 # loss function loss_coeffs: # different weights to use in a weighted loss functions - forces: 100 # for MD applications, we recommed a force weight of 100 and an energy weight of 1 - total_energy: 1 # alternatively, if energies are not of importance, a force weight 1 and an energy weight of 0 also works. + forces: 1 # for MD applications, we recommed a force weight of 100 and an energy weight of 1 + total_energy: # alternatively, if energies are not of importance, a force weight 1 and an energy weight of 0 also works. + - 1 + - PerAtomMSELoss # # default loss function is MSELoss, the name has to be exactly the same as those in torch.nn. # the only supprted targets are forces and total_energy @@ -134,6 +174,11 @@ loss_coeffs: # - MSELoss # # loss_coeffs: +# total_energy: +# - 1.0 +# - PerAtomMSELoss +# +# loss_coeffs: # forces: # - 1.0 # - PerSpeciesL1Loss @@ -158,6 +203,7 @@ metrics_components: report_per_component: False - - total_energy - mae + - PerAtom: True # if true, energy is normalized by the number of atoms # optimizer, may be any optimizer defined in torch.optim # the name `optimizer_name`is case sensitive @@ -169,17 +215,10 @@ optimizer_betas: !!python/tuple optimizer_eps: 1.0e-08 optimizer_weight_decay: 0 -# weight initialization -# this can be the importable name of any function that can be `model.apply`ed to initialize some weights in the model. NequIP provides a number of useful initializers: -# For more details please see the docstrings of the individual initializers -#model_initializers: -# - nequip.utils.initialization.uniform_initialize_fcs -# - nequip.utils.initialization.uniform_initialize_equivariant_linears -# - nequip.utils.initialization.uniform_initialize_tp_internal_weights -# - nequip.utils.initialization.xavier_initialize_fcs -# - nequip.utils.initialization.(unit_)orthogonal_initialize_equivariant_linears -# - nequip.utils.initialization.(unit_)orthogonal_initialize_fcs -# - nequip.utils.initialization.(unit_)orthogonal_initialize_e3nn_fcs +# gradient clipping using torch.nn.utils.clip_grad_norm_ +# see https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html#torch.nn.utils.clip_grad_norm_ +# setting to inf or null disables it +max_gradient_norm: null # lr scheduler, currently only supports the two options listed below, if you need more please file an issue # first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch @@ -187,7 +226,7 @@ lr_scheduler_name: ReduceLROnPlateau lr_scheduler_patience: 100 lr_scheduler_factor: 0.5 -# second, consine annealing with warm restart +# second, cosine annealing with warm restart # lr_scheduler_name: CosineAnnealingWarmRestarts # lr_scheduler_T_0: 10000 # lr_scheduler_T_mult: 2 @@ -196,28 +235,68 @@ lr_scheduler_factor: 0.5 # we provide a series of options to shift and scale the data # these are for advanced use and usually the defaults work very well -# the deafult is to scale the energies and forces by scaling them by the force standard deviation and to shift the energy by its mean +# the default is to scale the energies and forces by scaling them by the force standard deviation and to shift the energy by its mean # in certain cases, it can be useful to have a trainable shift/scale and to also have species-dependent shifts/scales for each atom -# whether to apply a shift and scale, defined per-species, to the atomic energies -PerSpeciesScaleShift_enable: false -# if the PerSpeciesScaleShift is enabled, whether the shifts and scales are trainable -PerSpeciesScaleShift_trainable: true -# optional initial atomic energy shift for each species. order should be the same as the allowed_species used in train.py. Defaults to zeros. -PerSpeciesScaleShift_shifts: [0.0, 0.0, 0.0] -# optional initial atomic energy scale for each species. order should be the same as the allowed_species used in train.py. Defaults to ones. -PerSpeciesScaleShift_scales: [1.0, 1.0, 1.0] - -# global energy shift. When "dataset_energy_mean" (the default), the mean energy of the dataset. When None, disables the global shift. When a number, used directly. -global_rescale_shift: dataset_energy_mean -# global energy scale. When "dataset_force_rms", the RMS of force components in the dataset. When "dataset_energy_std", the stdev of energies in the dataset. When None, disables the global scale. When a number, used directly. -# If not provided, defaults to either dataset_force_rms or dataset_energy_std, depending on whether forces are being trained. -global_rescale_scale: dataset_force_rms +per_species_rescale_scales_trainable: false +# whether the scales are trainable. Defaults to False. Optional +per_species_rescale_shifts_trainable: false +# whether the shifts are trainable. Defaults to False. Optional +per_species_rescale_shifts: dataset_per_atom_total_energy_mean +# initial atomic energy shift for each species. default to the mean of per atom energy. Optional +# the value can be a constant float value, an array for each species, or a string +# string option include: +# * "dataset_per_atom_total_energy_mean", which computes the per atom average +# * "dataset_per_species_total_energy_mean", which automatically compute the per atom energy mean using a GP model +per_species_rescale_scales: dataset_forces_rms +# initial atomic energy scale for each species. Optional. +# the value can be a constant float value, an array for each species, or a string +# string option include: +# * "dataset_per_atom_total_energy_std", which computes the per atom energy std +# * "dataset_per_species_total_energy_std", which uses the GP model uncertainty +# * "dataset_per_species_forces_rms", which compute the force rms for each species +# If not provided, defaults to dataset_per_species_force_rms or dataset_per_atom_total_energy_std, depending on whether forces are being trained. +# per_species_rescale_kwargs: +# total_energy: +# alpha: 0.1 +# max_iteration: 20 +# stride: 100 +# keywords for GP decomposition of per specie energy. Optional. Defaults to 0.1 +# per_species_rescale_arguments_in_dataset_units: True +# if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values. + +# global energy shift and scale +# When "dataset_total_energy_mean", the mean energy of the dataset. When None, disables the global shift. When a number, used directly. +# Warning: if this value is not None, the model is no longer size extensive +global_rescale_shift: null + +# global energy scale. When "dataset_force_rms", the RMS of force components in the dataset. When "dataset_total_energy_std", the stdev of energies in the dataset. When null, disables the global scale. When a number, used directly. +# If not provided, defaults to either dataset_force_rms or dataset_total_energy_std, depending on whether forces are being trained. +global_rescale_scale: dataset_forces_rms + # whether the shift of the final global energy rescaling should be trainable -trainable_global_rescale_shift: false +global_rescale_shift_trainable: false + # whether the scale of the final global energy rescaling should be trainable -trainable_global_rescale_scale: false +global_rescale_scale_trainable: false + +# # full block needed for per specie rescale +# global_rescale_shift: null +# global_rescale_shift_trainable: false +# global_rescale_scale: dataset_forces_rms +# global_rescale_scale_trainable: false +# per_species_rescale_trainable: true +# per_species_rescale_shifts: dataset_per_atom_total_energy_mean +# per_species_rescale_scales: dataset_per_atom_total_energy_std +# # full block needed for global rescale +# global_rescale_shift: dataset_total_energy_mean +# global_rescale_shift_trainable: false +# global_rescale_scale: dataset_forces_rms +# global_rescale_scale_trainable: false +# per_species_rescale_trainable: false +# per_species_rescale_shifts: null +# per_species_rescale_scales: null # Options for e3nn's set_optimization_defaults. A dict: # e3nn_optimization_defaults: diff --git a/configs/minimal.yaml b/configs/minimal.yaml index a45810f9..fa05bbbb 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -9,11 +9,32 @@ r_max: 4.0 irreps_edge_sh: 0e + 1o conv_to_output_hidden_irreps_out: 16x0e feature_irreps_hidden: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e -model_uniform_init: false -# data -dataset: aspirin -dataset_file_name: benchmark_data/aspirin_ccsd-train.npz +# data set +# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys +# key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py) +# all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields +# note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key +dataset: npz # type of data set, can be npz or ase +dataset_url: http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip # url to download the npz. optional +dataset_file_name: ./benchmark_data/aspirin_ccsd-train.npz # path to data set file +key_mapping: + z: atomic_numbers # atomic species, integers + E: total_energy # total potential eneriges to train to + F: forces # atomic forces to train to + R: pos # raw atomic positions +npz_fixed_field_keys: # fields that are repeated across different examples + - atomic_numbers +# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. +chemical_symbol_to_type: + H: 0 + C: 1 + O: 2 +# Alternatively, if the dataset has type indexes, the total number of types is all that is required: +# type_names: +# 0: my_type +# 1: atom +# 2: thing # logging wandb: false diff --git a/configs/minimal_eng.yaml b/configs/minimal_eng.yaml index be1a27c7..0b4ad0d3 100644 --- a/configs/minimal_eng.yaml +++ b/configs/minimal_eng.yaml @@ -4,15 +4,41 @@ run_name: minimal_eng seed: 0 # network -model_builder: nequip.models.EnergyModel +model_builders: + - EnergyModel + - PerSpeciesRescale + - RescaleEnergyEtc num_basis: 8 r_max: 4.0 irreps_edge_sh: 0e + 1o conv_to_output_hidden_irreps_out: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e feature_irreps_hidden: 16x0o + 16x0e -# data -dataset: aspirin +# data set +# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys +# key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py) +# all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields +# note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key +dataset: npz # type of data set, can be npz or ase +dataset_url: http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip # url to download the npz. optional +dataset_file_name: ./benchmark_data/aspirin_ccsd-train.npz # path to data set file +key_mapping: + z: atomic_numbers # atomic species, integers + E: total_energy # total potential eneriges to train to + F: forces # atomic forces to train to + R: pos # raw atomic positions +npz_fixed_field_keys: # fields that are repeated across different examples + - atomic_numbers +# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. +chemical_symbol_to_type: + H: 0 + C: 1 + O: 2 +# Alternatively, if the dataset has type indexes, the total number of types is all that is required: +# type_names: +# 0: my_type +# 1: atom +# 2: thing # logging wandb: false diff --git a/configs/requeue.yaml b/configs/requeue.yaml deleted file mode 100644 index 446157c3..00000000 --- a/configs/requeue.yaml +++ /dev/null @@ -1,34 +0,0 @@ -# general -root: results/ -wandb_project: aspirin -workdir: results/requeue -run_name: minimal_requeue -requeue: True - -# network -num_basis: 8 -r_max: 4.0 -irreps_edge_sh: 0e + 1o -conv_to_output_hidden_irreps_out: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e -feature_irreps_hidden: 16x0o + 16x0e - -# data -dataset: aspirin -dataset_file_name: benchmark_data/aspirin_ccsd-train.npz - -# logging -wandb: false -wandb_project: aspirin -# verbose: debug - -# training -n_train: 5 -n_val: 5 -batch_size: 1 -max_epochs: 10 - -# loss function -loss_coeffs: forces - -# optimizer -optimizer_name: Adam diff --git a/configs/restart.yaml b/configs/restart.yaml deleted file mode 100644 index 2e669eaa..00000000 --- a/configs/restart.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# # python script/restart.py xxx/trainer.pth config/restart.yaml(optional) - -# general -root: results/aspirin -wandb_project: aspirin -run_name: new_minimal -seed: 0 -restart: true -append: false -verbose: debug diff --git a/docs/api/nequip.rst b/docs/api/nequip.rst index 6b44348c..13bc37ca 100644 --- a/docs/api/nequip.rst +++ b/docs/api/nequip.rst @@ -1,4 +1,4 @@ -nequip API +Python API ========== .. toctree:: diff --git a/docs/guide/FAQ.rst b/docs/guide/FAQ.rst new file mode 100644 index 00000000..92ac758e --- /dev/null +++ b/docs/guide/FAQ.rst @@ -0,0 +1,21 @@ +FAQ +=== + +How do I... +----------- + +... continue to train a model that reached a stopping condition? + There will be an answer here. + +1. Reload the model trained with version 0.3.3 to the code in 0.4. + check out the migration note at :ref:`migration_note`. + +Common errors +------------- + +Various shape errors + Check the sanity of the shapes in your dataset. + +Out-of-memory errors with `nequip-evaluate` + Choose a lower ``--batch-size``; while the highest value that fits in your GPU memory is good for performance, + lowering this does *not* affect the final results (beyond numerics). diff --git a/docs/guide/conventions.rst b/docs/guide/conventions.rst index fe609116..f4679a76 100644 --- a/docs/guide/conventions.rst +++ b/docs/guide/conventions.rst @@ -1,4 +1,5 @@ Conventions =========== - - Cells vectors are given in ASE style as the **rows** of the cell matrix \ No newline at end of file + - Cells vectors are given in ASE style as the **rows** of the cell matrix + - The first index in an edge tuple (``edge_index[0]``) is the center atom, and the second (``edge_index[1]``) is the neighbor \ No newline at end of file diff --git a/docs/guide/guide.rst b/docs/guide/guide.rst index e8a19a77..6def3859 100644 --- a/docs/guide/guide.rst +++ b/docs/guide/guide.rst @@ -4,4 +4,6 @@ NequIP User Guide .. toctree:: intro - conventions \ No newline at end of file + irreps + conventions + FAQ \ No newline at end of file diff --git a/docs/guide/irreps.rst b/docs/guide/irreps.rst new file mode 100644 index 00000000..5f9b2735 --- /dev/null +++ b/docs/guide/irreps.rst @@ -0,0 +1,9 @@ +Irreps +====== + +.. _Irreps: + +Syntax to specify irreps +------------------------ + +TODO: descripe irreps syntax here \ No newline at end of file diff --git a/docs/guide/migrate.rst b/docs/guide/migrate.rst new file mode 100644 index 00000000..b0225dee --- /dev/null +++ b/docs/guide/migrate.rst @@ -0,0 +1,168 @@ +.. _migration_note: + +How to migrate to newer NequIP versions +======================================= + +(Written for migration from 0.3.3 to 0.4. Nov. 3. 2021) + +If the model are mostly the same and there is only some internal +variable changes, it is possible to migrate your NequIP model from the +older version to the newer version. + +Upgrade NequIP +-------------- + +1. Record the old version +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Go to the code folder in your virtual environment, find out the last +commit that you are using + +.. code:: bash + + # bash code + NEQUIP_FOLDER=$(python -c "import nequip; print(\"/\".join(nequip.__file__.split(\"/\")[:-1]))") + cd ${NEQUIP_FOLDER} + pwd + git show --oneline -s + OLD_COMMIT=$(git show --oneline -s|awk '{print $1}') + +2. Update your main nequip repo +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + # bash code + git pull origin main + pip install -e ./ + +Obtain the state_dict from the old ver +-------------------------------------- + +For version before 0.3.3, the ``last_model.pth`` stores the whole pickle +model. So you need to save the ``state_dict()``; otherwise, skip this +section. + +1. Back up the old version +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Git clone the old commit to a new folder + +.. code:: bash + + # bash code + git clone git@github.com:mir-group/nequip.git -n old_nequip + cd old_nequip + git checkout ${OLD_COMMIT} + +2. Save the state_dict from the old verion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Go to the old_nequip folder, make sure that your current nequip is +overloaded by local nequip folder. The result of the code below should +show the old_nequip folder instead of the one usually used in the +virtualenv. + +.. code:: python + + #python + import nequip + print(nequip.__file__) + +Load the old model with the old verion in python. + +.. code:: python + + # save_state_dict.py + import torch + import sys + model_folder = sys.argv[1] + old_model=torch.load( + f"{model_folder}/last_model.pth", + map_location=torch.device('cpu') # if it operates on CPU + ) + torch.save(old_model.state_dict(), f"{model_folder}/new_last_model.pth") + +Load the state_dict in the new version +-------------------------------------- + +Go to any other directorys that are not in the old version nequip +folder. + +Double check now the ``nequip.__file__`` should locate at the +``${NEQUIP_FOLDER}`` + +Then try to load the old ``state_dict()`` to the new model. + +.. code:: python + + # in new nequip + import torch + from nequip.utils import Config + from nequip.model import model_from_config + + config = Config.from_file("config_final.yaml") + + # only needed for version 0.3.3 + config["train_on_keys"]=["forces", "total_energy"] + config["model_builders"] = ["EnergyModel", "PerSpeciesRescale", "ForceOutput", "RescaleEnergyEtc"] + + model = model_from_config(config, initialize=False) + + d = torch.load("new.pth") + # load the state dict to the new model + model.load_state_dict(d) + +The code will likely to fail. Render some outputs like below: + +.. code:: bash + + RuntimeError: Error(s) in loading state_dict for RescaleOutput: + Missing key(s) in state_dict: "model.func.per_species_rescale.shifts", "model.func.per_species_rescale.scales". + Unexpected key(s) in state_dict: "model.func.per_species_scale_shift.shifts", "model.func.per_species_scale_shift.scales", "model.func.radial_basis.cutoff.p", "model.func.radial_basis.cutoff.r_max" + +According to this output and the CHANGELOG.md file, we can revise the +dictionary by renaming or removing variables. + +.. code:: python + + # rename all parameters listed in the change log as changed. + d["model.func.per_species_rescale.shifts"]=d.pop("model.func.per_species_scale_shift.shifts") + d["model.func.per_species_rescale.scales"]=d.pop("model.func.per_species_scale_shift.scales") + d.pop("model.func.radial_basis.cutoff.p") + d.pop("model.func.radial_basis.cutoff.r_max") + + # load the state dict to the new model + model.load_state_dict(d) + + # save the new state dict + import nequip + torch.save(model.state_dict(), f"new_last_model_{nequip.__version__}.pth') + +Validate the result using nequip-evaluate +----------------------------------------- + +Old model +~~~~~~~~~ + +.. code:: bash + + python nequip/script/evaluate.py + +New model +~~~~~~~~~ + +.. code:: bash + + nequip-evaluate --train-dir new_model/ --dataset-config data.yaml --output new.xyz + +.. code:: yaml + + root: ./ + r_max: 4 + validation_dataset: ase + validation_dataset_file_name: validate.xyz + chemical_symbol_to_type: + H: 0 + C: 1 + O: 2 diff --git a/docs/index.rst b/docs/index.rst index 3256117e..dc6ecd43 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ NequIP is an open-source package for creating, training, and using E(3)-equivari :caption: Contents: guide/guide + options/options api/nequip diff --git a/docs/options/HOWTO.md b/docs/options/HOWTO.md new file mode 100644 index 00000000..44bc5508 --- /dev/null +++ b/docs/options/HOWTO.md @@ -0,0 +1,32 @@ +Add this code to `auto_init.py`: + +```python +f = open("auto_all_options.rst", "w") + + +def print_option(builder, file): + print(f"!! {builder.__name__}", file=f) + if inspect.isclass(builder): + builder = builder.__init__ + sig = inspect.signature(builder) + for k, v in sig.parameters.items(): + if k == "self": + continue + print(k, file=f) + print(len(k) * "^", file=f) + if v.default == inspect.Parameter.empty: + print(f" | Type:", file=f) + print( + f" | Default: n/a\n", + file=f, + ) + else: + typestr = type(v.default).__name__ + print(f" | Type: {typestr}", file=f) + print( + f" | Default: ``{str(v.default)}``\n", + file=f, + ) +``` + +and call the function in every `instantiate`. \ No newline at end of file diff --git a/docs/options/dataset.rst b/docs/options/dataset.rst new file mode 100644 index 00000000..22cd701e --- /dev/null +++ b/docs/options/dataset.rst @@ -0,0 +1,67 @@ +Dataset +======= + +Basic +----- + +r_max +^^^^^ + See :ref:`r_max_option`. + +type_names +^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + +chemical_symbol_to_type +^^^^^^^^^^^^^^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + +avg_num_neighbors +^^^^^^^^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + +key_mapping +^^^^^^^^^^^ + | Type: dict + | Default: ``{'positions': 'pos', 'energy': 'total_energy', 'force': 'forces', 'forces': 'forces', 'Z': 'atomic_numbers', 'atomic_number': 'atomic_numbers'}`` + +npz_keys +^^^^^^^^ + | Type: list + | Default: ``[]`` + +npz_fixed_field_keys +^^^^^^^^^^^^^^^^^^^^ + | Type: list + | Default: ``[]`` + +file_name +^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + +url +^^^ + | Type: NoneType + | Default: ``None`` + +force_fixed_keys +^^^^^^^^^^^^^^^^ + | Type: list + | Default: ``[]`` + +extra_fixed_fields +^^^^^^^^^^^^^^^^^^ + | Type: dict + | Default: ``{}`` + +include_frames +^^^^^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + +Advanced +-------- \ No newline at end of file diff --git a/docs/options/general.rst b/docs/options/general.rst new file mode 100644 index 00000000..1b75b6d9 --- /dev/null +++ b/docs/options/general.rst @@ -0,0 +1,28 @@ +General +======= + +Basic +----- + +root +^^^^ + | Type: + | Default: n/a + +run_name +^^^^^^^^ + | Type: path + | Default: n/a + + ``run_name`` specifies something about whatever + +Advanced +-------- + +allow_tf32 +^^^^^^^^^^ + | Type: bool + | Default: ``False`` + + If ``False``, the use of NVIDIA's TensorFloat32 on Tensor Cores (Ampere architecture and later) will be disabled. + If ``True``, the PyTorch defaults (use anywhere possible) will remain. \ No newline at end of file diff --git a/docs/options/logging.rst b/docs/options/logging.rst new file mode 100644 index 00000000..675cdc45 --- /dev/null +++ b/docs/options/logging.rst @@ -0,0 +1,8 @@ +Logging +======= + +Basic +----- + +Advanced +-------- \ No newline at end of file diff --git a/docs/options/model.rst b/docs/options/model.rst new file mode 100644 index 00000000..a123bb80 --- /dev/null +++ b/docs/options/model.rst @@ -0,0 +1,149 @@ +Model +===== + +Edge Basis +********** + +Basic +----- + +.. _r_max_option: + +r_max +^^^^^ + | Type: float + | Default: n/a + + The cutoff radius within which an atom is considered a neighbor. + +irreps_edge_sh +^^^^^^^^^^^^^^ + | Type: :ref:`Irreps` or int + | Default: n/a + + The irreps to use for the spherical harmonic projection of the edges. + If an integer, specifies all spherical harmonics up to and including that integer as :math:`\ell_{\text{max}}`. + If provided as explicit irreps, all multiplicities should be 1. + +num_basis +^^^^^^^^^ + | Type: int + | Default: ``8`` + + The number of radial basis functions to use. + +chemical_embedding_irreps_out +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Type: :ref:`Irreps` + | Default: n/a + + The size of the linear embedding of the chemistry of an atom. + +Advanced +-------- + +BesselBasis_trainable +^^^^^^^^^^^^^^^^^^^^^ + | Type: bool + | Default: ``True`` + + Whether the Bessel radial basis should be trainable. + +basis +^^^^^ + | Type: type + | Default: ```` + + The radial basis to use. + +Convolution +*********** + +Basic +----- + +num_layers +^^^^^^^^^^ + | Type: int + | Default: ``3`` + + The number of convolution layers. + + +feature_irreps_hidden +^^^^^^^^^^^^^^^^^^^^^ + | Type: :ref:`Irreps` + | Default: n/a + + Specifies the irreps and multiplicities of the hidden features. + Typically, include irreps with all :math:`\ell` values up to :math:`\ell_{\text{max}}` (see `irreps_edge_sh`_), each with both even and odd parity. + For example, for ``irreps_edge_sh: 1``, one might provide: ``feature_irreps_hidden: 16x0e + 16x0o + 16x1e + 16x1o``. + +Advanced +-------- + +invariant_layers +^^^^^^^^^^^^^^^^ + | Type: int + | Default: ``1`` + + The number of hidden layers in the radial neural network. + +invariant_neurons +^^^^^^^^^^^^^^^^^ + | Type: int + | Default: ``8`` + + The width of the hidden layers of the radial neural network. + +resnet +^^^^^^ + | Type: bool + | Default: ``True`` + +nonlinearity_type +^^^^^^^^^^^^^^^^^ + | Type: str + | Default: ``gate`` + +nonlinearity_scalars +^^^^^^^^^^^^^^^^^^^^ + | Type: dict + | Default: ``{'e': 'ssp', 'o': 'tanh'}`` + +nonlinearity_gates +^^^^^^^^^^^^^^^^^^ + | Type: dict + | Default: ``{'e': 'ssp', 'o': 'abs'}`` + +use_sc +^^^^^^ + | Type: bool + | Default: ``True`` + +Output block +************ + +Basic +----- + +conv_to_output_hidden_irreps_out +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Type: :ref:`Irreps` + | Default: n/a + + The middle (hidden) irreps of the output block. Should only contain irreps that are contained in the output of the network (``0e`` for potentials). + +Advanced +-------- + + + + + + + + + + + diff --git a/docs/options/options.rst b/docs/options/options.rst new file mode 100644 index 00000000..95ab66ea --- /dev/null +++ b/docs/options/options.rst @@ -0,0 +1,10 @@ +All Options +=========== + + .. toctree:: + + general + dataset + model + training + logging diff --git a/docs/options/training.rst b/docs/options/training.rst new file mode 100644 index 00000000..b8c1711b --- /dev/null +++ b/docs/options/training.rst @@ -0,0 +1,8 @@ +Training +======== + +Basic +----- + +Advanced +-------- \ No newline at end of file diff --git a/examples/monkeypatch.py b/examples/monkeypatch.py new file mode 100644 index 00000000..1234054f --- /dev/null +++ b/examples/monkeypatch.py @@ -0,0 +1,56 @@ +"""Example of patching a model after training to analyze it. + +This file shows how to load a pickled Python model after training +and modify it to save and output the features after the first +convolution for later analysis. +""" + +import torch + +from nequip.utils import Config, find_first_of_type +from nequip.data import AtomicDataDict, AtomicData, dataset_from_config +from nequip.nn import SequentialGraphNetwork, SaveForOutput + +# The path to the original training session +path = "../results/aspirin/minimal" + +# Load the model +model = torch.load(path + "/best_model.pth") + + +# Find the SequentialGraphNetwork, which contains the +# sequential bulk of the NequIP GNN model. To see the +# structure of the GNN models, see +# nequip/models/_eng.py +sgn = find_first_of_type(model, SequentialGraphNetwork) + +# Now insert a SaveForOutput +insert_after = "layer1_convnet" # change this, again see nequip/models/_eng.py +# `insert_from_parameters` builds the module for us from `shared_parameters` +# You could also build it manually and use `insert`, but `insert_from_parameters` +# has the advantage of constructing it with the correct input irreps +# based on whatever comes before: +sgn.insert_from_parameters( + after=insert_after, + name="feature_extractor", + shared_params=dict( + field=AtomicDataDict.NODE_FEATURES_KEY, + out_field="saved", + ), + builder=SaveForOutput, +) + +# Now, we can test our patched model: +# Load the original config file --- this could be a new one too: +config = Config.from_file(path + "/config.yaml") +# Load the dataset: +# (Note that this loads the training dataset if there are separate training and validation datasets defined.) +dataset = dataset_from_config(config) + +# Evaluate the model on a configuration: +data = dataset[0] +out = sgn(AtomicData.to_AtomicDataDict(data)) + +# Check that our extracted data is there: +assert "saved" in out +print(out["saved"].shape) diff --git a/nequip/_version.py b/nequip/_version.py index 93a83158..edbb9d1b 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.3.3" +__version__ = "0.5.0" diff --git a/nequip/ase/__init__.py b/nequip/ase/__init__.py new file mode 100644 index 00000000..df5a0666 --- /dev/null +++ b/nequip/ase/__init__.py @@ -0,0 +1,4 @@ +from .nequip_calculator import NequIPCalculator +from .nosehoover import NoseHoover + +__all__ = [NequIPCalculator, NoseHoover] diff --git a/nequip/dynamics/nequip_calculator.py b/nequip/ase/nequip_calculator.py similarity index 50% rename from nequip/dynamics/nequip_calculator.py rename to nequip/ase/nequip_calculator.py index 3e58ccfd..39723dfe 100644 --- a/nequip/dynamics/nequip_calculator.py +++ b/nequip/ase/nequip_calculator.py @@ -1,14 +1,23 @@ -from typing import Union +from typing import Union, Optional, Callable, Dict +import warnings import torch +import ase.data from ase.calculators.calculator import Calculator, all_changes from nequip.data import AtomicData, AtomicDataDict +from nequip.data.transforms import TypeMapper import nequip.scripts.deploy class NequIPCalculator(Calculator): - """NequIP ASE Calculator.""" + """NequIP ASE Calculator. + + .. warning:: + + If you are running MD with custom species, please make sure to set the correct masses for ASE. + + """ implemented_properties = ["energy", "forces"] @@ -19,6 +28,7 @@ def __init__( device: Union[str, torch.device], energy_units_to_eV: float = 1.0, length_units_to_A: float = 1.0, + transform: Callable = lambda x: x, **kwargs ): Calculator.__init__(self, **kwargs) @@ -28,19 +38,51 @@ def __init__( self.device = device self.energy_units_to_eV = energy_units_to_eV self.length_units_to_A = length_units_to_A + self.transform = transform @classmethod def from_deployed_model( - cls, model_path, device: Union[str, torch.device] = "cpu", **kwargs + cls, + model_path, + device: Union[str, torch.device] = "cpu", + species_to_type_name: Optional[Dict[str, str]] = None, + set_global_options: Union[str, bool] = "warn", + **kwargs ): # load model model, metadata = nequip.scripts.deploy.load_deployed_model( - model_path=model_path, device=device + model_path=model_path, + device=device, + set_global_options=set_global_options, ) r_max = float(metadata[nequip.scripts.deploy.R_MAX_KEY]) + # build typemapper + type_names = metadata[nequip.scripts.deploy.TYPE_NAMES_KEY].split(" ") + if species_to_type_name is None: + # Default to species names + warnings.warn( + "Trying to use chemical symbols as NequIP type names; this may not be correct for your model! To avoid this warning, please provide `species_to_type_name` explicitly." + ) + species_to_type_name = {s: s for s in ase.data.chemical_symbols} + type_name_to_index = {n: i for i, n in enumerate(type_names)} + chemical_symbol_to_type = { + sym: type_name_to_index[species_to_type_name[sym]] + for sym in ase.data.chemical_symbols + if sym in type_name_to_index + } + if len(chemical_symbol_to_type) != len(type_names): + raise ValueError( + "The default mapping of chemical symbols as type names didn't make sense; please provide an explicit mapping in `species_to_type_name`" + ) + transform = TypeMapper(chemical_symbol_to_type=chemical_symbol_to_type) + # build nequip calculator - return cls(model=model, r_max=r_max, device=device, **kwargs) + if "transform" in kwargs: + raise TypeError("`transform` not allowed here") + return cls( + model=model, r_max=r_max, device=device, transform=transform, **kwargs + ) def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): """ @@ -56,6 +98,7 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change # prepare data data = AtomicData.from_ase(atoms=atoms, r_max=self.r_max) + data = self.transform(data) data = data.to(self.device) diff --git a/nequip/dynamics/nosehoover.py b/nequip/ase/nosehoover.py similarity index 100% rename from nequip/dynamics/nosehoover.py rename to nequip/ase/nosehoover.py diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 252e5b72..81a2978d 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -5,24 +5,94 @@ import warnings from copy import deepcopy -from typing import Union, Tuple, Dict, Optional, List +from typing import Union, Tuple, Dict, Optional, List, Set, Sequence from collections.abc import Mapping import numpy as np import ase.neighborlist import ase from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator +from ase.calculators.calculator import all_properties as ase_all_properties import torch -from torch_geometric.data import Data import e3nn.o3 from . import AtomicDataDict from ._util import _TORCH_INTEGER_DTYPES +from nequip.utils.torch_geometric import Data # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] +_DEFAULT_NODE_FIELDS: Set[str] = { + AtomicDataDict.POSITIONS_KEY, + AtomicDataDict.WEIGHTS_KEY, + AtomicDataDict.NODE_FEATURES_KEY, + AtomicDataDict.NODE_ATTRS_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.FORCE_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.BATCH_KEY, +} +_DEFAULT_EDGE_FIELDS: Set[str] = { + AtomicDataDict.EDGE_CELL_SHIFT_KEY, + AtomicDataDict.EDGE_VECTORS_KEY, + AtomicDataDict.EDGE_LENGTH_KEY, + AtomicDataDict.EDGE_ATTRS_KEY, + AtomicDataDict.EDGE_EMBEDDING_KEY, +} +_DEFAULT_GRAPH_FIELDS: Set[str] = { + AtomicDataDict.TOTAL_ENERGY_KEY, +} +_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) +_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) +_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) + + +def register_fields( + node_fields: Sequence[str] = [], + edge_fields: Sequence[str] = [], + graph_fields: Sequence[str] = [], +) -> None: + r"""Register fields as being per-atom, per-edge, or per-frame. + + Args: + node_permute_fields: fields that are equivariant to node permutations. + edge_permute_fields: fields that are equivariant to edge permutations. + """ + node_fields: set = set(node_fields) + edge_fields: set = set(edge_fields) + graph_fields: set = set(graph_fields) + allfields = node_fields.union(edge_fields, graph_fields) + assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) + _NODE_FIELDS.update(node_fields) + _EDGE_FIELDS.update(edge_fields) + _GRAPH_FIELDS.update(graph_fields) + if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( + len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) + ): + raise ValueError( + "At least one key was registered as more than one of node, edge, or graph!" + ) + + +def deregister_fields(*fields: Sequence[str]) -> None: + r"""Deregister a field registered with ``register_fields``. + + Silently ignores fields that were never registered to begin with. + + Args: + *fields: fields to deregister. + """ + for f in fields: + assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field" + _NODE_FIELDS.discard(f) + _EDGE_FIELDS.discard(f) + _GRAPH_FIELDS.discard(f) + class AtomicData(Data): """A neighbor graph for points in (periodic triclinic) real space. @@ -50,7 +120,7 @@ class AtomicData(Data): node_attrs (Tensor [n_atom, ...]): the attributes of the nodes, for instance the atom type, optional batch (Tensor [n_atom]): the graph to which the node belongs, optional atomic_numbers (Tensor [n_atom]): optional. - species_index (Tensor [n_atom]): optional. + atom_type (Tensor [n_atom]): optional. **kwargs: other data, optional. """ @@ -68,7 +138,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): if ( k == AtomicDataDict.EDGE_INDEX_KEY or k == AtomicDataDict.ATOMIC_NUMBERS_KEY - or k == AtomicDataDict.SPECIES_INDEX_KEY + or k == AtomicDataDict.ATOM_TYPE_KEY or k == AtomicDataDict.BATCH_KEY ): # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) @@ -82,13 +152,42 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): elif np.issubdtype(type(v), np.floating): # Force scalars to be tensors with a data dimension # This makes them play well with irreps - kwargs[k] = torch.as_tensor( - v, dtype=torch.get_default_dtype() - ).unsqueeze(-1) + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) elif isinstance(v, torch.Tensor) and len(v.shape) == 0: # ^ this tensor is a scalar; we need to give it # a data dimension to play nice with irreps + kwargs[k] = v + + if AtomicDataDict.BATCH_KEY in kwargs: + num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 + else: + num_frames = 1 + + for k, v in kwargs.items(): + + if len(kwargs[k].shape) == 0: kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if ( + k in _NODE_FIELDS + and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] + ): + raise ValueError( + f"{k} is a node field but has the wrong dimension {v.shape}" + ) + elif ( + k in _EDGE_FIELDS + and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a edge field but has the wrong dimension {v.shape}" + ) + elif k in _GRAPH_FIELDS: + if num_frames == 1 and v.shape[0] != 1: + kwargs[k] = v.unsqueeze(0) + elif v.shape[0] != num_frames: + raise ValueError(f"Wrong shape for graph property {k}") super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) @@ -154,8 +253,7 @@ def from_points( strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two instances of the atom are in different periodic images. Defaults to True, should be True for most applications. - **kwargs (optional): other information to pass to the ``torch_geometric.data.Data`` constructor for - inclusion in the object. Keys listed in ``AtomicDataDict.*_KEY` will be treated specially. + **kwargs (optional): other fields to add. Keys listed in ``AtomicDataDict.*_KEY` will be treated specially. """ if pos is None or r_max is None: raise ValueError("pos and r_max must be given.") @@ -196,7 +294,14 @@ def from_points( return cls(edge_index=edge_index, pos=torch.as_tensor(pos), **kwargs) @classmethod - def from_ase(cls, atoms, r_max, **kwargs): + def from_ase( + cls, + atoms, + r_max, + key_mapping: Optional[Dict[str, str]] = {}, + include_keys: Optional[list] = [], + **kwargs, + ): """Build a ``AtomicData`` from an ``ase.Atoms`` object. Respects ``atoms``'s ``pbc`` and ``cell``. @@ -211,51 +316,84 @@ def from_ase(cls, atoms, r_max, **kwargs): r_max (float): neighbor cutoff radius. features (torch.Tensor shape [N, M], optional): per-atom M-dimensional feature vectors. If ``None`` (the default), uses a one-hot encoding of the species present in ``atoms``. + include_keys (list): list of additional keys to include in AtomicData aside from the ones defined in + ase.calculators.calculator.all_properties. Optional + key_mapping (dict): rename ase property name to a new string name. Optional **kwargs (optional): other arguments for the ``AtomicData`` constructor. Returns: A ``AtomicData``. """ + from nequip.ase import NequIPCalculator + + assert "pos" not in kwargs + + default_args = set( + [ + "numbers", + "positions", + ] # ase internal names for position and atomic_numbers + + ["pbc", "cell", "pos", "r_max"] # arguments for from_points method + + list(kwargs.keys()) + ) + # the keys that are duplicated in kwargs are removed from the include_keys + include_keys = list(set(include_keys + ase_all_properties) - default_args) + + km = { + "forces": AtomicDataDict.FORCE_KEY, + "energy": AtomicDataDict.TOTAL_ENERGY_KEY, + } + km.update(key_mapping) + key_mapping = km + add_fields = {} + + # Get info from atoms.arrays; lowest priority. copy first + add_fields = { + key_mapping.get(k, k): v + for k, v in atoms.arrays.items() + if k in include_keys + } + + # Get info from atoms.info; second lowest priority. + add_fields.update( + { + key_mapping.get(k, k): v + for k, v in atoms.info.items() + if k in include_keys + } + ) + if atoms.calc is not None: + if isinstance( atoms.calc, (SinglePointCalculator, SinglePointDFTCalculator) ): - add_fields = deepcopy(atoms.calc.results) - if "forces" in add_fields: - add_fields.pop("forces") - add_fields[AtomicDataDict.FORCE_KEY] = atoms.get_forces() - - if "free_energy" in add_fields and "energy" not in add_fields: - add_fields[AtomicDataDict.TOTAL_ENERGY_KEY] = add_fields.pop( - "free_energy" - ) - elif "energy" in add_fields: - add_fields[AtomicDataDict.TOTAL_ENERGY_KEY] = add_fields.pop( - "energy" - ) - - if AtomicDataDict.FORCE_KEY not in add_fields: - # Get it from arrays - for k in ("force", "forces"): - if k in atoms.arrays: - add_fields[AtomicDataDict.FORCE_KEY] = atoms.arrays[k] - break - - if AtomicDataDict.TOTAL_ENERGY_KEY not in add_fields: - # Get it from arrays - for k in ("energy", "energies"): - if k in atoms.arrays: - add_fields[AtomicDataDict.TOTAL_ENERGY_KEY] = atoms.arrays[k] - break + add_fields.update( + { + key_mapping.get(k, k): deepcopy(v) + for k, v in atoms.calc.results.items() + if k in include_keys + } + ) + elif isinstance(atoms.calc, NequIPCalculator): + pass # otherwise the calculator breaks + else: + raise NotImplementedError( + f"`from_ase` does not support calculator {atoms.calc}" + ) add_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = atoms.get_atomic_numbers() + # cell and pbc in kwargs can override the ones stored in atoms + cell = kwargs.pop("cell", atoms.get_cell()) + pbc = kwargs.pop("pbc", atoms.pbc) + return cls.from_points( pos=atoms.positions, r_max=r_max, - cell=atoms.get_cell(), - pbc=atoms.pbc, + cell=cell, + pbc=pbc, **kwargs, **add_fields, ) @@ -276,7 +414,13 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: raise TypeError( "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`." ) - atomic_nums = self.atomic_numbers + if AtomicDataDict.ATOMIC_NUMBERS_KEY in self: + atomic_nums = self.atomic_numbers + else: + warnings.warn( + "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong" + ) + atomic_nums = self[AtomicDataDict.ATOM_TYPE_KEY] pbc = getattr(self, AtomicDataDict.PBC_KEY, None) cell = getattr(self, AtomicDataDict.CELL_KEY, None) batch = getattr(self, AtomicDataDict.BATCH_KEY, None) diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index 598b3ff1..069f8cff 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -30,11 +30,6 @@ def validate_keys(keys, graph_required=True): raise KeyError("At least pos and edge_index must be supplied") if _keys.EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys: raise ValueError("If `edge_cell_shift` given, `cell` must be given.") - # This is in flux; TODO - # if _keys.ATOMIC_NUMBERS_KEY in keys and _keys.SPECIES_INDEX_KEY in keys: - # raise ValueError( - # "'atomic_numbers' and 'species_index' cannot be simultaneously provided" - # ) _SPECIAL_IRREPS = [None] diff --git a/nequip/data/__init__.py b/nequip/data/__init__.py index 3fc8a7af..37dd219d 100644 --- a/nequip/data/__init__.py +++ b/nequip/data/__init__.py @@ -1,3 +1,29 @@ -from .AtomicData import AtomicData, PBC +from .AtomicData import ( + AtomicData, + PBC, + register_fields, + deregister_fields, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, +) from .dataset import AtomicDataset, AtomicInMemoryDataset, NpzDataset, ASEDataset from .dataloader import DataLoader, Collater +from ._build import dataset_from_config + +__all__ = [ + AtomicData, + PBC, + register_fields, + deregister_fields, + AtomicDataset, + AtomicInMemoryDataset, + NpzDataset, + ASEDataset, + DataLoader, + Collater, + dataset_from_config, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, +] diff --git a/nequip/data/_build.py b/nequip/data/_build.py new file mode 100644 index 00000000..670b6156 --- /dev/null +++ b/nequip/data/_build.py @@ -0,0 +1,88 @@ +import inspect +from importlib import import_module + +from nequip import data +from nequip.data.transforms import TypeMapper +from nequip.data import AtomicDataset, register_fields +from nequip.utils import instantiate, get_w_prefix + + +def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: + """initialize database based on a config instance + + It needs dataset type name (case insensitive), + and all the parameters needed in the constructor. + + Examples see tests/data/test_dataset.py TestFromConfig + and tests/datasets/test_simplest.py + + Args: + + config (dict, nequip.utils.Config): dict/object that store all the parameters + prefix (str): Optional. The prefix of all dataset parameters + + Return: + + dataset (nequip.data.AtomicDataset) + """ + + config_dataset = config.get(prefix, None) + if config_dataset is None: + raise KeyError(f"Dataset with prefix `{prefix}` isn't present in this config!") + + if inspect.isclass(config_dataset): + # user define class + class_name = config_dataset + else: + try: + module_name = ".".join(config_dataset.split(".")[:-1]) + class_name = ".".join(config_dataset.split(".")[-1:]) + class_name = getattr(import_module(module_name), class_name) + except Exception: + # ^ TODO: don't catch all Exception + # default class defined in nequip.data or nequip.dataset + dataset_name = config_dataset.lower() + + class_name = None + for k, v in inspect.getmembers(data, inspect.isclass): + if k.endswith("Dataset"): + if k.lower() == dataset_name: + class_name = v + if k[:-7].lower() == dataset_name: + class_name = v + elif k.lower() == dataset_name: + class_name = v + + if class_name is None: + raise NameError(f"dataset type {dataset_name} does not exists") + + # if dataset r_max is not found, use the universal r_max + eff_key = "extra_fixed_fields" + prefixed_eff_key = f"{prefix}_{eff_key}" + config[prefixed_eff_key] = get_w_prefix( + eff_key, {}, prefix=prefix, arg_dicts=config + ) + config[prefixed_eff_key]["r_max"] = get_w_prefix( + "r_max", + prefix=prefix, + arg_dicts=[config[prefixed_eff_key], config], + ) + + # Build a TypeMapper from the config + type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) + + # Register fields: + register_fields( + node_fields=config.get("node_fields", []), + edge_fields=config.get("edge_fields", []), + graph_fields=config.get("graph_fields", []), + ) + + instance, _ = instantiate( + class_name, + prefix=prefix, + positional_args={"type_mapper": type_mapper}, + optional_args=config, + ) + + return instance diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 13ff7c82..58918657 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -34,7 +34,7 @@ NODE_FEATURES_KEY: Final[str] = "node_features" NODE_ATTRS_KEY: Final[str] = "node_attrs" ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers" -SPECIES_INDEX_KEY: Final[str] = "species_index" +ATOM_TYPE_KEY: Final[str] = "atom_types" PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" TOTAL_ENERGY_KEY: Final[str] = "total_energy" FORCE_KEY: Final[str] = "forces" diff --git a/nequip/data/dataloader.py b/nequip/data/dataloader.py index 6d1debb7..a6c16670 100644 --- a/nequip/data/dataloader.py +++ b/nequip/data/dataloader.py @@ -2,23 +2,42 @@ import torch -from torch_geometric.data import Batch, Data +from nequip.utils.torch_geometric import Batch, Data class Collater(object): - def __init__(self, fixed_fields=[], exclude_keys=[]): + """Collate a list of ``AtomicData``. + + Args: + fixed_fields: which fields are fixed fields + exclude_keys: keys to ignore in the input, not copying to the output + """ + + def __init__( + self, + fixed_fields: List[str] = [], + exclude_keys: List[str] = [], + ): self.fixed_fields = fixed_fields - self.exclude_keys = exclude_keys self._exclude_keys = set(exclude_keys) @classmethod - def for_dataset(cls, dataset, exclude_keys=[]): + def for_dataset( + cls, + dataset, + exclude_keys: List[str] = [], + ): + """Construct a collater appropriate to ``dataset``. + + All kwargs besides ``fixed_fields`` are passed through to the constructor. + """ return cls( fixed_fields=list(getattr(dataset, "fixed_fields", {}).keys()), exclude_keys=exclude_keys, ) - def collate(self, batch: List[Data]): + def collate(self, batch: List[Data]) -> Batch: + """Collate a list of data""" # For fixed fields, we need to batch those that are per-node or # per-edge, since they need to be repeated in order to have the same # number of nodes/edges as the full batch graph. @@ -39,12 +58,24 @@ def collate(self, batch: List[Data]): out[f] = batch[0][f] return out - def __call__(self, batch): + def __call__(self, batch: List[Data]) -> Batch: + """Collate a list of data""" return self.collate(batch) + @property + def exclude_keys(self): + return list(self._exclude_keys) + class DataLoader(torch.utils.data.DataLoader): - def __init__(self, dataset, batch_size=1, shuffle=False, exclude_keys=[], **kwargs): + def __init__( + self, + dataset, + batch_size: int = 1, + shuffle: bool = False, + exclude_keys: List[str] = [], + **kwargs, + ): if "collate_fn" in kwargs: del kwargs["collate_fn"] @@ -53,5 +84,5 @@ def __init__(self, dataset, batch_size=1, shuffle=False, exclude_keys=[], **kwar batch_size, shuffle, collate_fn=Collater.for_dataset(dataset, exclude_keys=exclude_keys), - **kwargs + **kwargs, ) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index a0bcc446..a83d65c9 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -1,22 +1,33 @@ -""" -Dataset classes that parse array of positions, cells to AtomicData object - -This module requre the torch_geometric to catch up with the github main branch from Jan. 18, 2021 - -""" import numpy as np import logging import tempfile +import inspect +import yaml +import hashlib from os.path import dirname, basename, abspath from typing import Tuple, Dict, Any, List, Callable, Union, Optional, Sequence import ase import torch -from torch_geometric.data import Batch, Dataset, download_url, extract_zip -from nequip.data import AtomicData, AtomicDataDict -from ._util import _TORCH_INTEGER_DTYPES +from torch_runstats.scatter import scatter_std, scatter_mean + +from nequip.utils.torch_geometric import Batch, Dataset +from nequip.utils.torch_geometric.utils import download_url, extract_zip + +import nequip +from nequip.data import ( + AtomicData, + AtomicDataDict, + _NODE_FIELDS, + _EDGE_FIELDS, + _GRAPH_FIELDS, +) +from nequip.utils.batch_ops import bincount +from nequip.utils.regressor import solver +from nequip.utils.savenload import atomic_write +from .transforms import TypeMapper class AtomicDataset(Dataset): @@ -26,32 +37,23 @@ class AtomicDataset(Dataset): root: str def statistics( - self, fields: List[Union[str, Callable]], stride: int = 1, unbiased: bool = True + self, + fields: List[Union[str, Callable]], + modes: List[str], + stride: int = 1, + unbiased: bool = True, + kwargs: Optional[Dict[str, dict]] = {}, ) -> List[tuple]: - """Compute the statistics of ``fields`` in the dataset. - - If the values at the fields are vectors/multidimensional, they must be of fixed shape and elementwise statistics will be computed. - - Args: - fields: the names of the fields to compute statistics for. - Instead of a field name, a callable can also be given that reuturns a quantity to compute the statisics for. - - If a callable is given, it will be called with a (possibly batched) ``Data`` object and must return a sequence of points to add to the set over which the statistics will be computed. - - For example, to compute the overall statistics of the x,y, and z components of a per-node vector ``force`` field: - - data.statistics([lambda data: data.force.flatten()]) - - The above computes the statistics over a set of size 3N, where N is the total number of nodes in the dataset. - - Returns: - List of statistics. For fields of floating dtype the statistics are the two-tuple (mean, std); for fields of integer dtype the statistics are a one-tuple (bincounts,) - """ # TODO: If needed, this can eventually be implimented for general AtomicDataset by computing an online running mean and using Welford's method for a stable running standard deviation: https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/ # That would be needed if we have lazy loading datasets. # TODO: When lazy-loading datasets are implimented, how to deal with statistics, sampling, and subsets? raise NotImplementedError("not implimented for general AtomicDataset yet") + @property + def type_mapper(self) -> Optional[TypeMapper]: + # self.transform is always a TypeMapper + return self.transform + class AtomicInMemoryDataset(AtomicDataset): r"""Base class for all datasets that fit in memory. @@ -74,6 +76,7 @@ class AtomicInMemoryDataset(AtomicDataset): force_fixed_keys (list, optional): keys to move from AtomicData to fixed_fields dictionary extra_fixed_fields (dict, optional): extra key that are not stored in data but needed for AtomicData initialization include_frames (list, optional): the frames to process with the constructor. + type_mapper (TypeMapper): the transformation to map atomic information to species index. Optional """ def __init__( @@ -84,8 +87,9 @@ def __init__( force_fixed_keys: List[str] = [], extra_fixed_fields: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, ): - # TO DO, this may be symplified + # TO DO, this may be simplified # See if a subclass defines some inputs self.file_name = ( getattr(type(self), "FILE_NAME", None) if file_name is None else file_name @@ -115,7 +119,7 @@ def __init__( # Initialize the InMemoryDataset, which runs download and process # See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets # Then pre-process the data if disk files are not found - super().__init__(root=root) + super().__init__(root=root, transform=type_mapper) if self.data is None: self.data, self.fixed_fields, include_frames = torch.load( self.processed_paths[0] @@ -135,14 +139,40 @@ def len(self): def raw_file_names(self): raise NotImplementedError() + def _get_parameters(self) -> Dict[str, Any]: + """Get a dict of the parameters used to build this dataset.""" + pnames = list(inspect.signature(self.__init__).parameters) + IGNORE_KEYS = { + # the type mapper is applied after saving, not before, so doesn't matter to cache validity + "type_mapper" + } + params = { + k: getattr(self, k) + for k in pnames + if k not in IGNORE_KEYS and hasattr(self, k) + } + # Add other relevant metadata: + params["dtype"] = str(torch.get_default_dtype()) + params["nequip_version"] = nequip.__version__ + return params + @property - def processed_file_names(self): - # TO DO, can be updated to hash all simple terms in extra_fixed_fields - r_max = self.extra_fixed_fields["r_max"] - dtype = str(torch.get_default_dtype()) - if dtype.startswith("torch."): - dtype = dtype[len("torch.") :] - return [f"{r_max}_{dtype}_data.pt"] + def processed_dir(self) -> str: + # We want the file name to change when the parameters change + # So, first we get all parameters: + params = self._get_parameters() + # Make some kind of string of them: + # we don't care about this possibly changing between python versions, + # since a change in python version almost certainly means a change in + # versions of other things too, and is a good reason to recompute + buffer = yaml.dump(params).encode("ascii") + # And hash it: + param_hash = hashlib.sha1(buffer).hexdigest() + return f"{self.root}/processed_dataset_{param_hash}" + + @property + def processed_file_names(self) -> List[str]: + return ["data.pth", "params.yaml"] def get_data( self, @@ -265,7 +295,19 @@ def process(self): logging.info(f"Loaded data: {data}") - torch.save((data, fixed_fields, self.include_frames), self.processed_paths[0]) + # use atomic writes to avoid race conditions between + # different trainings that use the same dataset + # since those separate trainings should all produce the same results, + # it doesn't matter if they overwrite each others cached' + # datasets. It only matters that they don't simultaneously try + # to write the _same_ file, corrupting it. + with atomic_write(self.processed_paths[0]) as tmppth: + torch.save((data, fixed_fields, self.include_frames), tmppth) + with atomic_write(self.processed_paths[1]) as tmppth: + with open(tmppth, "w") as f: + yaml.dump(self._get_parameters(), f) + + logging.info("Cached processed data to disk") self.data = data self.fixed_fields = fixed_fields @@ -280,99 +322,321 @@ def get(self, idx): def statistics( self, fields: List[Union[str, Callable]], + modes: List[str], stride: int = 1, unbiased: bool = True, - modes: Optional[List[Union[str]]] = None, + kwargs: Optional[Dict[str, dict]] = {}, ) -> List[tuple]: + """Compute the statistics of ``fields`` in the dataset. + + If the values at the fields are vectors/multidimensional, they must be of fixed shape and elementwise statistics will be computed. + + Args: + fields: the names of the fields to compute statistics for. + Instead of a field name, a callable can also be given that reuturns a quantity to compute the statisics for. + + If a callable is given, it will be called with a (possibly batched) ``Data``-like object and must return a sequence of points to add to the set over which the statistics will be computed. + The callable must also return a string, one of ``"node"`` or ``"graph"``, indicating whether the value it returns is a per-node or per-graph quantity. + PLEASE NOTE: the argument to the callable may be "batched", and it may not be batched "contiguously": ``batch`` and ``edge_index`` may have "gaps" in their values. + + For example, to compute the overall statistics of the x,y, and z components of a per-node vector ``force`` field: + + data.statistics([lambda data: (data.force.flatten(), "node")]) + + The above computes the statistics over a set of size 3N, where N is the total number of nodes in the dataset. + + modes: the statistic to compute for each field. Valid options are TODO. + + stride: the stride over the dataset while computing statistcs. + + unbiased: whether to use unbiased for standard deviations. + + kwargs: other options for individual statistics modes. + + Returns: + List of statistics. For fields of floating dtype the statistics are the two-tuple (mean, std); for fields of integer dtype the statistics are a one-tuple (bincounts,) + """ + + # Short circut: + assert len(modes) == len(fields) + if len(fields) == 0: + return [] + if self._indices is not None: - selector = torch.as_tensor(self._indices)[::stride] + graph_selector = torch.as_tensor(self._indices)[::stride] else: - selector = torch.arange(0, self.len(), stride) + graph_selector = torch.arange(0, self.len(), stride) + num_graphs = len(graph_selector) node_selector = torch.as_tensor( - np.in1d(self.data.batch.numpy(), selector.numpy()) + np.in1d(self.data.batch.numpy(), graph_selector.numpy()) ) - # the pure PyTorch alternative to ^ is: - # hack for in1d: https://github.com/pytorch/pytorch/issues/3025#issuecomment-392601780 - # node_selector = (self.data.batch[..., None] == selector).any(-1) - # but this is unnecessary because no backward is done through statistics + num_nodes = node_selector.sum() - if modes is not None: - assert len(modes) == len(fields) + edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY] + edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]] + num_edges = edge_selector.sum() + del edge_index - out = [] - for ifield, field in enumerate(fields): - - if field in self.fixed_fields: - obj = self.fixed_fields - else: - obj = self.data + if self.transform is not None: + # pre-transform the fixed fields and data so that statistics process transformed data + ff_transformed = self.transform(self.fixed_fields, types_required=False) + data_transformed = self.transform(self.data.to_dict(), types_required=False) + else: + ff_transformed = self.fixed_fields + data_transformed = self.data.to_dict() + # pre-select arrays + # this ensures that all following computations use the right data + selectors = {} + for k in list(ff_transformed.keys()) + list(data_transformed.keys()): + if k in _NODE_FIELDS: + selectors[k] = node_selector + elif k in _GRAPH_FIELDS: + selectors[k] = graph_selector + elif k == AtomicDataDict.EDGE_INDEX_KEY: + selectors[k] = (slice(None, None, None), edge_selector) + elif k in _EDGE_FIELDS: + selectors[k] = edge_selector + # TODO: do the batch indexes, edge_indexes, etc. after selection need to be + # "compacted" to subtract out their offsets? For now, we just punt this + # onto the writer of the callable field. + # do not actually select on fixed fields, since they are constant + # but still only select fields that are correctly registered + ff_transformed = {k: v for k, v in ff_transformed.items() if k in selectors} + # apply selector to actual data + data_transformed = { + k: data_transformed[k][selectors[k]] + for k in data_transformed.keys() + if k in selectors + } + atom_types: Optional[torch.Tensor] = None + out: list = [] + for ifield, field in enumerate(fields): if callable(field): - arr = field(obj) - else: - arr = obj[field] - # Apply selector - # TODO: this might be quite expensive if the dataset is large. - # Better to impliment the general running average case in AtomicDataset, - # and just call super here in AtomicInMemoryDataset? - # - # TODO: !!! this is a terrible shape-based hack that needs to be fixed !!! - if len(self.data.batch) == self.data.num_graphs: - raise NotImplementedError( - "AtomicDataset.statistics cannot currently handle datasets whose number of examples is the same as their number of nodes" - ) - if obj is self.fixed_fields: - # arr is fixed, nothing to select. - pass - elif len(arr) == self.data.num_graphs: - # arr is per example (probably) - arr = arr[selector] - elif len(arr) == len(self.data.batch): - # arr is per-node (probably) - arr = arr[node_selector] + # make a joined thing? so it includes fixed fields + arr, arr_is_per = field(data_transformed) + arr = arr.to( + torch.get_default_dtype() + ) # all statistics must be on floating + assert arr_is_per in ("node", "graph", "edge") else: - raise NotImplementedError( - "Statistics of properties that are not per-graph or per-node are not yet implimented" - ) + # Give a better error + if field not in selectors: + # this means field is not selected and so not available + raise RuntimeError( + f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`" + ) + if field in ff_transformed: + arr = ff_transformed[field] + else: + arr = data_transformed[field] + if field in _NODE_FIELDS: + arr_is_per = "node" + elif field in _GRAPH_FIELDS: + arr_is_per = "graph" + elif field in _EDGE_FIELDS: + arr_is_per = "edge" + else: + raise RuntimeError - ana_mode = None if modes is None else modes[ifield] + # Check arr + if arr is None: + raise ValueError( + f"Cannot compute statistics over field `{field}` whose value is None!" + ) if not isinstance(arr, torch.Tensor): if np.issubdtype(arr.dtype, np.floating): arr = torch.as_tensor(arr, dtype=torch.get_default_dtype()) else: arr = torch.as_tensor(arr) - if ana_mode is None: - ana_mode = "count" if arr.dtype in _TORCH_INTEGER_DTYPES else "mean_std" - + if arr_is_per == "node": + arr = arr.view(num_nodes, -1) + elif arr_is_per == "graph": + arr = arr.view(num_graphs, -1) + elif arr_is_per == "edge": + arr = arr.view(num_edges, -1) + + ana_mode = modes[ifield] + # compute statistics if ana_mode == "count": + # count integers uniq, counts = torch.unique( torch.flatten(arr), return_counts=True, sorted=True ) out.append((uniq, counts)) elif ana_mode == "rms": - + # root-mean-square out.append((torch.sqrt(torch.mean(arr * arr)),)) elif ana_mode == "mean_std": - + # mean and std mean = torch.mean(arr, dim=0) std = torch.std(arr, dim=0, unbiased=unbiased) out.append((mean, std)) + elif ana_mode.startswith("per_species_"): + # per-species + algorithm_kwargs = kwargs.pop(field + ana_mode, {}) + + ana_mode = ana_mode[len("per_species_") :] + + if atom_types is None: + if AtomicDataDict.ATOM_TYPE_KEY in data_transformed: + atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY] + elif AtomicDataDict.ATOM_TYPE_KEY in ff_transformed: + atom_types = ff_transformed[AtomicDataDict.ATOM_TYPE_KEY] + atom_types = ( + atom_types.unsqueeze(0) + .expand((num_graphs,) + atom_types.shape) + .reshape(-1) + ) + + results = self._per_species_statistics( + ana_mode, + arr, + arr_is_per=arr_is_per, + batch=data_transformed[AtomicDataDict.BATCH_KEY], + atom_types=atom_types, + unbiased=unbiased, + algorithm_kwargs=algorithm_kwargs, + ) + out.append(results) + + elif ana_mode.startswith("per_atom_"): + # per-atom + # only makes sense for a per-graph quantity + if arr_is_per != "graph": + raise ValueError( + f"It doesn't make sense to ask for `{ana_mode}` since `{field}` is not per-graph" + ) + ana_mode = ana_mode[len("per_atom_") :] + results = self._per_atom_statistics( + ana_mode=ana_mode, + arr=arr, + batch=data_transformed[AtomicDataDict.BATCH_KEY], + unbiased=unbiased, + ) + out.append(results) + + else: + raise NotImplementedError(f"Cannot handle statistics mode {ana_mode}") + return out + @staticmethod + def _per_atom_statistics( + ana_mode: str, + arr: torch.Tensor, + batch: torch.Tensor, + unbiased: bool = True, + ): + """Compute "per-atom" statistics that are normalized by the number of atoms in the system. + + Only makes sense for a graph-level quantity (checked by .statistics). + """ + # using unique_consecutive handles the non-contiguous selected batch index + _, N = torch.unique_consecutive(batch, return_counts=True) + if ana_mode == "mean_std": + arr = arr / N + mean = torch.mean(arr) + std = torch.std(arr, unbiased=unbiased) + return mean, std + elif ana_mode == "rms": + arr = arr / N + return (torch.sqrt(torch.mean(arr.square())),) + else: + raise NotImplementedError( + f"{ana_mode} for per-atom analysis is not implemented" + ) + + @staticmethod + def _per_species_statistics( + ana_mode: str, + arr: torch.Tensor, + arr_is_per: str, + atom_types: torch.Tensor, + batch: torch.Tensor, + unbiased: bool = True, + algorithm_kwargs: Optional[dict] = {}, + ): + """Compute "per-species" statistics. + + For a graph-level quantity, models it as a linear combintation of the number of atoms of different types in the graph. + + For a per-node quantity, computes the expected statistic but for each type instead of over all nodes. + """ + N = bincount(atom_types, batch) + N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes + + if arr_is_per == "graph": + + if ana_mode != "mean_std": + raise NotImplementedError( + f"{ana_mode} for per species analysis is not implemented for shape {arr.shape}" + ) + + N = N.type(torch.get_default_dtype()) + + return solver(N, arr, **algorithm_kwargs) + + elif arr_is_per == "node": + arr = arr.type(torch.get_default_dtype()) + + if ana_mode == "mean_std": + mean = scatter_mean(arr, atom_types, dim=0) + std = scatter_std(arr, atom_types, dim=0, unbiased=unbiased) + return mean, std + elif ana_mode == "rms": + square = scatter_mean(arr.square(), atom_types, dim=0) + dims = len(square.shape) - 1 + for i in range(dims): + square = square.mean(axis=-1) + return (torch.sqrt(square),) + + else: + raise NotImplementedError + # TODO: document fixed field mapped key behavior more clearly class NpzDataset(AtomicInMemoryDataset): """Load data from an npz file. - To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``npz_keys``, or ``npz_fixed_fields``. + To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``, + ``npz_fixed_fields`` or ``extra_fixed_fields``. Args: - file_name (str): file name of the npz file - key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys - force_fixed_keys (list): keys in the npz to treat as fixed quantities that don't change across examples. For example: cell, atomic_numbers + key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional + include_keys (list): the attributes to be processed and stored. Optional + npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional + + Example: Given a npz file with 10 configurations, each with 14 atoms. + + position: (10, 14, 3) + force: (10, 14, 3) + energy: (10,) + Z: (14) + user_label1: (10) # per config + user_label2: (10, 14, 3) # per atom + + The input yaml should be + + ```yaml + dataset: npz + dataset_file_name: example.npz + include_keys: + - user_label1 + - user_label2 + npz_fixed_field_keys: + - cell + - atomic_numbers + key_mapping: + position: pos + force: forces + energy: total_energy + Z: atomic_numbers + ``` + """ def __init__( @@ -386,17 +650,18 @@ def __init__( "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY, "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY, }, - npz_keys: List[str] = [], + include_keys: List[str] = [], npz_fixed_field_keys: List[str] = [], file_name: Optional[str] = None, url: Optional[str] = None, force_fixed_keys: List[str] = [], extra_fixed_fields: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, ): self.key_mapping = key_mapping self.npz_fixed_field_keys = npz_fixed_field_keys - self.npz_keys = npz_keys + self.include_keys = include_keys super().__init__( file_name=file_name, @@ -405,6 +670,7 @@ def __init__( force_fixed_keys=force_fixed_keys, extra_fixed_fields=extra_fixed_fields, include_frames=include_frames, + type_mapper=type_mapper, ) @property @@ -417,17 +683,22 @@ def raw_dir(self): # TODO: fixed fields? def get_data(self): + data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True) + # only the keys explicitly mentioned in the yaml file will be parsed keys = set(list(self.key_mapping.keys())) keys.update(self.npz_fixed_field_keys) - keys.update(self.npz_keys) + keys.update(self.include_keys) + keys.update(list(self.extra_fixed_fields.keys())) keys = keys.intersection(set(list(data.keys()))) + mapped = {self.key_mapping.get(k, k): data[k] for k in keys} + # TODO: generalize this? for intkey in ( AtomicDataDict.ATOMIC_NUMBERS_KEY, - AtomicDataDict.SPECIES_INDEX_KEY, + AtomicDataDict.ATOM_TYPE_KEY, AtomicDataDict.EDGE_INDEX_KEY, ): if intkey in mapped: @@ -441,9 +712,53 @@ def get_data(self): class ASEDataset(AtomicInMemoryDataset): - """TODO + """ + + Args: + ase_args (dict): arguments for ase.io.read + include_keys (list): in addition to forces and energy, the keys that needs to + be parsed into dataset + The data stored in ase.atoms.Atoms.array has the lowest priority, + and it will be overrided by data in ase.atoms.Atoms.info + and ase.atoms.Atoms.calc.results. Optional + key_mapping (dict): rename some of the keys to the value str. Optional + + Example: Given an atomic data stored in "H2.extxyz" that looks like below: + + ```H2.extxyz + 2 + Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F" + H 0.00000000 0.00000000 0.00000000 + H 0.00000000 0.00000000 1.02000000 + ``` + + The yaml input should be + + ``` + dataset: ase + dataset_file_name: H2.extxyz + ase_args: + format: extxyz + include_keys: + - user_label + key_mapping: + user_label: label0 + chemical_symbol_to_type: + H: 0 + ``` + + for VASP parser, the yaml input should be + ``` + dataset: ase + dataset_file_name: OUTCAR + ase_args: + format: vasp-out + key_mapping: + free_energy: total_energy + chemical_symbol_to_type: + H: 0 + ``` - r_max and an override PBC can be specified in extra_fixed_fields """ def __init__( @@ -455,12 +770,18 @@ def __init__( force_fixed_keys: List[str] = [], extra_fixed_fields: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, + type_mapper: TypeMapper = None, + key_mapping: Optional[dict] = None, + include_keys: Optional[List[str]] = None, ): self.ase_args = dict(index=":") self.ase_args.update(getattr(type(self), "ASE_ARGS", dict())) self.ase_args.update(ase_args) + self.include_keys = include_keys + self.key_mapping = key_mapping + super().__init__( file_name=file_name, url=url, @@ -468,6 +789,7 @@ def __init__( force_fixed_keys=force_fixed_keys, extra_fixed_fields=extra_fixed_fields, include_frames=include_frames, + type_mapper=type_mapper, ) @classmethod @@ -519,19 +841,26 @@ def get_atoms(self): return aseread(self.raw_dir + "/" + self.raw_file_names[0], **self.ase_args) def get_data(self): + # Get our data atoms_list = self.get_atoms() + + # skip the None arguments + kwargs = dict( + include_keys=self.include_keys, + key_mapping=self.key_mapping, + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + kwargs.update(self.extra_fixed_fields) + if self.include_frames is None: return ( - [ - AtomicData.from_ase(atoms=atoms, **self.extra_fixed_fields) - for atoms in atoms_list - ], + [AtomicData.from_ase(atoms=atoms, **kwargs) for atoms in atoms_list], ) else: return ( [ - AtomicData.from_ase(atoms=atoms_list[i], **self.extra_fixed_fields) + AtomicData.from_ase(atoms=atoms_list[i], **kwargs) for i in self.include_frames ], ) diff --git a/nequip/data/transforms.py b/nequip/data/transforms.py new file mode 100644 index 00000000..7ac9d724 --- /dev/null +++ b/nequip/data/transforms.py @@ -0,0 +1,103 @@ +from typing import Dict, Optional, Union, List +import warnings + +import torch + +import ase.data + +from nequip.data import AtomicData, AtomicDataDict + + +class TypeMapper: + """Based on a configuration, map atomic numbers to types.""" + + num_types: int + chemical_symbol_to_type: Optional[Dict[str, int]] + type_names: List[str] + _min_Z: int + + def __init__( + self, + type_names: Optional[List[str]] = None, + chemical_symbol_to_type: Optional[Dict[str, int]] = None, + ): + # Build from chem->type mapping, if provided + self.chemical_symbol_to_type = chemical_symbol_to_type + if self.chemical_symbol_to_type is not None: + # Validate + for sym, type in self.chemical_symbol_to_type.items(): + assert sym in ase.data.atomic_numbers, f"Invalid chemical symbol {sym}" + assert 0 <= type, f"Invalid type number {type}" + assert set(self.chemical_symbol_to_type.values()) == set( + range(len(self.chemical_symbol_to_type)) + ) + if type_names is None: + # Make type_names + type_names = [None] * len(self.chemical_symbol_to_type) + for sym, type in self.chemical_symbol_to_type.items(): + type_names[type] = sym + else: + # Make sure they agree on types + # We already checked that chem->type is contiguous, + # so enough to check length since type_names is a list + assert len(type_names) == len(self.chemical_symbol_to_type) + # Make mapper array + valid_atomic_numbers = [ + ase.data.atomic_numbers[sym] for sym in self.chemical_symbol_to_type + ] + self._min_Z = min(valid_atomic_numbers) + self._max_Z = max(valid_atomic_numbers) + Z_to_index = torch.full( + size=(1 + self._max_Z - self._min_Z,), fill_value=-1, dtype=torch.long + ) + for sym, type in self.chemical_symbol_to_type.items(): + Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type + self._Z_to_index = Z_to_index + self._valid_set = set(valid_atomic_numbers) + # check + if type_names is None: + raise ValueError( + "Neither chemical_symbol_to_type nor type_names was provided; one or the other is required" + ) + # validate type names + assert all( + n.isalnum() for n in type_names + ), "Type names must contain only alphanumeric characters" + # Set to however many maps specified -- we already checked contiguous + self.num_types = len(type_names) + # Check type_names + self.type_names = type_names + + def __call__( + self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True + ) -> Union[AtomicDataDict.Type, AtomicData]: + if AtomicDataDict.ATOM_TYPE_KEY in data: + if AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + warnings.warn( + "Data contained both ATOM_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY" + ) + elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: + assert ( + self.chemical_symbol_to_type is not None + ), "Atomic numbers provided but there is no chemical_symbol_to_type mapping!" + atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + del data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + + data[AtomicDataDict.ATOM_TYPE_KEY] = self.transform(atomic_numbers) + else: + if types_required: + raise KeyError( + "Data doesn't contain any atom type information (ATOM_TYPE_KEY or ATOMIC_NUMBERS_KEY)" + ) + return data + + def transform(self, atomic_numbers): + """core function to transform an array to specie index list""" + + if atomic_numbers.min() < self._min_Z or atomic_numbers.max() > self._max_Z: + bad_set = set(torch.unique(atomic_numbers).cpu().tolist()) - self._valid_set + raise ValueError( + f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!" + ) + + return self._Z_to_index[atomic_numbers - self._min_Z] diff --git a/nequip/datasets/__init__.py b/nequip/datasets/__init__.py deleted file mode 100644 index 45f2b476..00000000 --- a/nequip/datasets/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .aspirin import AspirinDataset -from .water import ChengWaterDataset diff --git a/nequip/datasets/aspirin.py b/nequip/datasets/aspirin.py deleted file mode 100644 index 01557e0d..00000000 --- a/nequip/datasets/aspirin.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - -from os.path import dirname, basename, abspath - -from nequip.data import AtomicDataDict, AtomicInMemoryDataset - - -class AspirinDataset(AtomicInMemoryDataset): - """Aspirin DFT/CCSD(T) data""" - - URL = "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip" - FILE_NAME = "benchmark_data/aspirin_ccsd-train.npz" - - @property - def raw_file_names(self): - return [basename(self.file_name)] - - @property - def raw_dir(self): - return dirname(abspath(self.file_name)) - - def get_data(self): - data = np.load(self.raw_dir + "/" + self.raw_file_names[0]) - arrays = { - AtomicDataDict.POSITIONS_KEY: data["R"], - AtomicDataDict.FORCE_KEY: data["F"], - AtomicDataDict.TOTAL_ENERGY_KEY: data["E"].reshape([-1, 1]), - } - fixed_fields = { - AtomicDataDict.ATOMIC_NUMBERS_KEY: np.asarray(data["z"], dtype=int), - AtomicDataDict.PBC_KEY: np.array([False, False, False]), - } - return arrays, fixed_fields diff --git a/nequip/datasets/water.py b/nequip/datasets/water.py deleted file mode 100644 index 9ed2d8fc..00000000 --- a/nequip/datasets/water.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy as np - -from ase import units -from ase.io import read - -from nequip.data import ASEDataset, AtomicDataDict - - -class ChengWaterDataset(ASEDataset): - """Cheng Water Dataset - - TODO: energies - """ - - URL = None - FILE_NAME = "benchmark_data/dataset_1593.xyz" - FORCE_FIXED_KEYS = [AtomicDataDict.PBC_KEY] - - def download(self): - pass diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py new file mode 100644 index 00000000..b849efed --- /dev/null +++ b/nequip/model/__init__.py @@ -0,0 +1,15 @@ +from ._eng import EnergyModel +from ._grads import ForceOutput +from ._scaling import RescaleEnergyEtc, PerSpeciesRescale +from ._weight_init import uniform_initialize_FCs + +from ._build import model_from_config + +__all__ = [ + "EnergyModel", + "ForceOutput", + "RescaleEnergyEtc", + "PerSpeciesRescale", + "uniform_initialize_FCs", + "model_from_config", +] diff --git a/nequip/model/_build.py b/nequip/model/_build.py new file mode 100644 index 00000000..4f1ae7dd --- /dev/null +++ b/nequip/model/_build.py @@ -0,0 +1,100 @@ +from typing import Optional, Union, Callable +import inspect +import yaml + +from nequip.data import AtomicDataset +from nequip.nn import GraphModuleMixin + + +def _load_callable(obj: Union[str, Callable], prefix: Optional[str] = None) -> Callable: + """Load a callable from a name, or pass through a callable.""" + if callable(obj): + pass + elif isinstance(obj, str): + if "." not in obj: + # It's an unqualified name + if prefix is not None: + obj = prefix + "." + obj + else: + # You can't have an unqualified name without a prefix + raise ValueError(f"Cannot load unqualified name {obj}.") + obj = yaml.load(f"!!python/name:{obj}", Loader=yaml.Loader) + else: + raise TypeError + assert callable(obj), f"{obj} isn't callable" + return obj + + +def model_from_config( + config, initialize: bool = False, dataset: Optional[AtomicDataset] = None +) -> GraphModuleMixin: + """Build a model based on `config`. + + Model builders (`model_builders`) can have arguments: + - ``config``: the config. Always present. + - ``model``: the model produced by the previous builder. Cannot be requested by the first builder, must be requested by subsequent ones. + - ``initialize``: whether to initialize the model + - ``dataset``: if ``initialize`` is True, the dataset + + Args: + config + initialize (bool): if True (default False), ``model_initializers`` will also be run. + dataset: dataset for initializers if ``initialize`` is True. + + Returns: + The build model. + """ + # Pre-process config + if initialize and dataset is not None: + if "num_types" in config: + assert ( + config["num_types"] == dataset.type_mapper.num_types + ), "inconsistant config & dataset" + if "type_names" in config: + assert ( + config["type_names"] == dataset.type_mapper.type_names + ), "inconsistant config & dataset" + config["num_types"] = dataset.type_mapper.num_types + config["type_names"] = dataset.type_mapper.type_names + + # Build + builders = [ + _load_callable(b, prefix="nequip.model") + for b in config.get("model_builders", []) + ] + + model = None + + for builder_i, builder in enumerate(builders): + pnames = inspect.signature(builder).parameters + params = {} + if "initialize" in pnames: + params["initialize"] = initialize + if "config" in pnames: + params["config"] = config + if "dataset" in pnames: + if "initialize" not in pnames: + raise ValueError("Cannot request dataset without requesting initialize") + if initialize and dataset is None: + raise RuntimeError( + f"Builder {builder.__name__} asked for the dataset, initialize is true, but no dataset was provided to `model_from_config`." + ) + params["dataset"] = dataset + if "model" in pnames: + if builder_i == 0: + raise RuntimeError( + f"Builder {builder.__name__} asked for the model as an input, but it's the first builder so there is no model to provide" + ) + params["model"] = model + else: + if builder_i > 0: + raise RuntimeError( + f"All model_builders but the first one must take the model as an argument; {builder.__name__} doesn't" + ) + model = builder(**params) + if not isinstance(model, GraphModuleMixin): + raise TypeError( + f"Builder {builder.__name__} didn't return a GraphModuleMixin, got {type(model)} instead" + ) + + return model diff --git a/nequip/models/_eng.py b/nequip/model/_eng.py similarity index 62% rename from nequip/models/_eng.py rename to nequip/model/_eng.py index ec17a7b6..314bef32 100644 --- a/nequip/models/_eng.py +++ b/nequip/model/_eng.py @@ -2,12 +2,9 @@ from nequip.data import AtomicDataDict from nequip.nn import ( - GraphModuleMixin, SequentialGraphNetwork, AtomwiseLinear, AtomwiseReduce, - ForceOutput, - PerSpeciesScaleShift, ConvNetLayer, ) from nequip.nn.embedding import ( @@ -17,15 +14,14 @@ ) -def EnergyModel(**shared_params) -> SequentialGraphNetwork: +def EnergyModel(config) -> SequentialGraphNetwork: """Base default energy model archetecture. For minimal and full configuration option listings, see ``minimal.yaml`` and ``example.yaml``. """ logging.debug("Start building the network model") - num_layers = shared_params.pop("num_layers", 3) - add_per_species_shift = shared_params.pop("PerSpeciesScaleShift_enable", False) + num_layers = config.get("num_layers", 3) layers = { # -- Encode -- @@ -54,15 +50,6 @@ def EnergyModel(**shared_params) -> SequentialGraphNetwork: } ) - if add_per_species_shift: - layers["per_species_scale_shift"] = ( - PerSpeciesScaleShift, - dict( - field=AtomicDataDict.PER_ATOM_ENERGY_KEY, - out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, - ), - ) - layers["total_energy_sum"] = ( AtomwiseReduce, dict( @@ -73,17 +60,6 @@ def EnergyModel(**shared_params) -> SequentialGraphNetwork: ) return SequentialGraphNetwork.from_parameters( - shared_params=shared_params, + shared_params=config, layers=layers, ) - - -def ForceModel(**shared_params) -> GraphModuleMixin: - """Base default energy and force model archetecture. - - For minimal and full configuration option listings, see ``minimal.yaml`` and ``example.yaml``. - - A convinience method, equivalent to constructing ``EnergyModel`` and passing it to ``nequip.nn.ForceOutput``. - """ - energy_model = EnergyModel(**shared_params) - return ForceOutput(energy_model=energy_model) diff --git a/nequip/model/_grads.py b/nequip/model/_grads.py new file mode 100644 index 00000000..dcf85f9b --- /dev/null +++ b/nequip/model/_grads.py @@ -0,0 +1,22 @@ +from nequip.nn import GraphModuleMixin, GradientOutput +from nequip.data import AtomicDataDict + + +def ForceOutput(model: GraphModuleMixin) -> GradientOutput: + r"""Add forces to a model that predicts energy. + + Args: + energy_model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output. + + Returns: + A ``GradientOutput`` wrapping ``energy_model``. + """ + if AtomicDataDict.FORCE_KEY in model.irreps_out: + raise ValueError("This model already has force outputs.") + return GradientOutput( + func=model, + of=AtomicDataDict.TOTAL_ENERGY_KEY, + wrt=AtomicDataDict.POSITIONS_KEY, + out_field=AtomicDataDict.FORCE_KEY, + sign=-1, # force is the negative gradient + ) diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py new file mode 100644 index 00000000..30eee1ed --- /dev/null +++ b/nequip/model/_scaling.py @@ -0,0 +1,292 @@ +import logging +from typing import List, Optional + +import torch + +from nequip.nn import RescaleOutput, GraphModuleMixin, PerSpeciesScaleShift +from nequip.data import AtomicDataDict, AtomicDataset + + +RESCALE_THRESHOLD = 1e-6 + + +def RescaleEnergyEtc( + model: GraphModuleMixin, + config, + dataset: AtomicDataset, + initialize: bool, +): + """Add global rescaling for energy(-based quantities). + + If ``initialize`` is false, doesn't compute statistics. + """ + + module_prefix = "global_rescale" + + global_scale = config.get( + f"{module_prefix}_scale", + f"dataset_{AtomicDataDict.FORCE_KEY}_rms" + if AtomicDataDict.FORCE_KEY in model.irreps_out + else f"dataset_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", + ) + global_shift = config.get(f"{module_prefix}_shift", None) + + if global_shift is not None: + logging.warning( + f"!!!! Careful global_shift is set to {global_shift}." + f"The energy model will no longer be extensive" + ) + + # = Get statistics of training dataset = + if initialize: + str_names = [] + for value in [global_scale, global_shift]: + if isinstance(value, str): + str_names += [value] + elif ( + value is None + or isinstance(value, float) + or isinstance(value, torch.Tensor) + ): + # valid values + pass + else: + raise ValueError(f"Invalid global scale `{value}`") + + # = Compute shifts and scales = + computed_stats = _compute_stats( + str_names=str_names, + dataset=dataset, + stride=config.dataset_statistics_stride, + ) + + if isinstance(global_scale, str): + s = global_scale + global_scale = computed_stats[str_names.index(global_scale)] + logging.debug(f"Replace string {s} to {global_scale}") + if isinstance(global_shift, str): + s = global_shift + global_shift = computed_stats[str_names.index(global_shift)] + logging.debug(f"Replace string {s} to {global_shift}") + + if isinstance(global_scale, float) and global_scale < RESCALE_THRESHOLD: + raise ValueError( + f"Global energy scaling was very low: {global_scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None." + ) + + logging.debug( + f"Initially outputs are globally scaled by: {global_scale}, total_energy are globally shifted by {global_shift}." + ) + + else: + # Put dummy values + if global_shift is not None: + global_shift = 0.0 # it has some kind of value + if global_scale is not None: + global_scale = 1.0 # same, + + # == Build the model == + return RescaleOutput( + model=model, + scale_keys=[ + k + for k in ( + AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.PER_ATOM_ENERGY_KEY, + AtomicDataDict.FORCE_KEY, + ) + if k in model.irreps_out + ], + scale_by=global_scale, + shift_keys=[ + k for k in (AtomicDataDict.TOTAL_ENERGY_KEY,) if k in model.irreps_out + ], + shift_by=global_shift, + shift_trainable=config.get(f"{module_prefix}_shift_trainable", False), + scale_trainable=config.get(f"{module_prefix}_scale_trainable", False), + ) + + +def PerSpeciesRescale( + model: GraphModuleMixin, + config, + dataset: AtomicDataset, + initialize: bool, +): + """Add global rescaling for energy(-based quantities). + + If ``initialize`` is false, doesn't compute statistics. + """ + module_prefix = "per_species_rescale" + + # = Determine energy rescale type = + scales = config.get( + module_prefix + "_scales", + f"dataset_{AtomicDataDict.FORCE_KEY}_rms" + # if `train_on_keys` isn't provided, assume conservatively + # that we aren't "training" on anything (i.e. take the + # most general defaults) + if AtomicDataDict.FORCE_KEY in config.get("train_on_keys", []) + else f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_std", + ) + shifts = config.get( + module_prefix + "_shifts", + f"dataset_per_atom_{AtomicDataDict.TOTAL_ENERGY_KEY}_mean", + ) + + # = Determine what statistics need to be compute =\ + arguments_in_dataset_units = None + if initialize: + str_names = [] + for value in [scales, shifts]: + if isinstance(value, str): + str_names += [value] + elif ( + value is None + or isinstance(value, float) + or isinstance(value, list) + or isinstance(value, torch.Tensor) + ): + # valid values + pass + else: + raise ValueError(f"Invalid value `{value}` of type {type(value)}") + + if len(str_names) == 2: + # Both computed from dataset + arguments_in_dataset_units = True + elif len(str_names) == 1: + assert config[ + module_prefix + "arguments_in_dataset_units" + ], "Requested to set either the shifts or scales of the per_species_rescale using dataset values, but chose to provide the other in non-dataset units. Please give the explictly specified shifts/scales in dataset units and set per_species_rescale_arguments_in_dataset_units" + + # = Compute shifts and scales = + computed_stats = _compute_stats( + str_names=str_names, + dataset=dataset, + stride=config.dataset_statistics_stride, + kwargs=config.get(module_prefix + "_kwargs", {}), + ) + + if isinstance(scales, str): + s = scales + scales = computed_stats[str_names.index(scales)] + logging.debug(f"Replace string {s} to {scales}") + elif isinstance(scales, (list, float)): + scales = torch.as_tensor(scales) + + if isinstance(shifts, str): + s = shifts + shifts = computed_stats[str_names.index(shifts)] + logging.debug(f"Replace string {s} to {shifts}") + elif isinstance(shifts, (list, float)): + shifts = torch.as_tensor(shifts) + + if scales is not None and torch.min(scales) < RESCALE_THRESHOLD: + raise ValueError( + f"Per species energy scaling was very low: {scales}. Maybe try setting {module_prefix}_scales = 1." + ) + + else: + + # Put dummy values + # the real ones will be loaded from the state dict later + # note that the state dict includes buffers, + # so this is fine regardless of whether its trainable. + scales = 1.0 + shifts = 0.0 + # values correctly scaled according to where the come from + # will be brought from the state dict later, + # so what you set this to doesnt matter: + arguments_in_dataset_units = False + + # insert in per species shift + params = dict( + field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, + shifts=shifts, + scales=scales, + ) + + params["arguments_in_dataset_units"] = arguments_in_dataset_units + model.insert_from_parameters( + before="total_energy_sum", + name=module_prefix, + shared_params=config, + builder=PerSpeciesScaleShift, + params=params, + ) + + logging.debug(f"Atomic outputs are scaled by: {scales}, shifted by {shifts}.") + + # == Build the model == + return model + + +def _compute_stats( + str_names: List[str], dataset, stride: int, kwargs: Optional[dict] = {} +): + """return the values of statistics over dataset + quantity name should be dataset_key_stat, where key can be any key + that exists in the dataset, stat can be mean, std + + Args: + + str_names: list of strings that define the quantity to compute + dataset: dataset object to run the stats over + stride: # frames to skip for every one frame to include + """ + + # parse the list of string to field, mode + # and record which quantity correspond to which computed_item + stat_modes = [] + stat_fields = [] + stat_strs = [] + ids = [] + tuple_ids = [] + tuple_id_map = {"mean": 0, "std": 1, "rms": 0} + input_kwargs = {} + for name in str_names: + + # remove dataset prefix + if name.startswith("dataset_"): + name = name[len("dataset_") :] + # identify per_species and per_atom modes + prefix = "" + if name.startswith("per_species_"): + name = name[len("per_species_") :] + prefix = "per_species_" + elif name.startswith("per_atom_"): + name = name[len("per_atom_") :] + prefix = "per_atom_" + + stat = name.split("_")[-1] + field = "_".join(name.split("_")[:-1]) + if stat in ["mean", "std"]: + stat_mode = prefix + "mean_std" + stat_str = field + prefix + "mean_std" + elif stat in ["rms"]: + stat_mode = prefix + "rms" + stat_str = field + prefix + "rms" + else: + raise ValueError(f"Cannot handle {stat} type quantity") + + if stat_str in stat_strs: + ids += [stat_strs.index(stat_str)] + else: + ids += [len(stat_strs)] + stat_strs += [stat_str] + stat_modes += [stat_mode] + stat_fields += [field] + if stat_mode.startswith("per_species_"): + if field in kwargs: + input_kwargs[field + stat_mode] = kwargs[field] + tuple_ids += [tuple_id_map[stat]] + + values = dataset.statistics( + fields=stat_fields, + modes=stat_modes, + stride=stride, + kwargs=input_kwargs, + ) + return [values[idx][tuple_ids[i]] for i, idx in enumerate(ids)] diff --git a/nequip/model/_weight_init.py b/nequip/model/_weight_init.py new file mode 100644 index 00000000..60783be3 --- /dev/null +++ b/nequip/model/_weight_init.py @@ -0,0 +1,52 @@ +import math + +import torch + +import e3nn.o3 +import e3nn.nn + +from nequip.nn import GraphModuleMixin +from nequip.utils import Config + + +# == Load old state == +def initialize_from_state(config: Config, model: GraphModuleMixin, initialize: bool): + """Initialize the model from the state dict file given by the config options `initial_model_state`.""" + if not initialize: + return model # do nothing + key = "initial_model_state" + if key not in config: + raise KeyError( + f"initialize_from_state requires the `{key}` option specifying the state to initialize from" + ) + state = torch.load(config[key]) + model.load_state_dict(state) + return model + + +# == Init functions == +def unit_uniform_init_(t: torch.Tensor): + """Uniform initialization with = 1""" + t.uniform_(-math.sqrt(3), math.sqrt(3)) + + +# TODO: does this normalization make any sense +# def unit_orthogonal_init_(t: torch.Tensor): +# """Orthogonal init with = 1""" +# assert t.ndim == 2 +# torch.nn.init.orthogonal_(t, gain=math.sqrt(max(t.shape))) + + +def uniform_initialize_FCs(model: GraphModuleMixin, initialize: bool): + """Initialize ``e3nn.nn.FullyConnectedNet``s with unit uniform initialization""" + if initialize: + + def _uniform_initialize_fcs(mod: torch.nn.Module): + if isinstance(mod, e3nn.nn.FullyConnectedNet): + for layer in mod: + # in FC, normalization is expected + unit_uniform_init_(layer.weight) + + with torch.no_grad(): + model.apply(_uniform_initialize_fcs) + return model diff --git a/nequip/models/__init__.py b/nequip/models/__init__.py deleted file mode 100644 index 32e2721a..00000000 --- a/nequip/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._eng import EnergyModel, ForceModel # noqa: F401 diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index f86102b8..383f619b 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,8 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput, ForceOutput # noqa: F401 +from ._grad_output import GradientOutput # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 from ._util import SaveForOutput # noqa: F401 +from ._concat import Concat # noqa: F401 diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 57331b5a..88718fb8 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -1,13 +1,13 @@ +import logging from typing import Optional, List import torch import torch.nn.functional -from torch_scatter import scatter +from torch_runstats.scatter import scatter from e3nn.o3 import Linear from nequip.data import AtomicDataDict -from nequip.utils.batch_ops import bincount from ._graph_mixin import GraphModuleMixin @@ -61,7 +61,7 @@ def __init__( self, field: str, out_field: Optional[str] = None, reduce="sum", irreps_in={} ): super().__init__() - assert reduce in ("sum", "mean", "min", "max") + assert reduce in ("sum", "mean") self.reduce = reduce self.field = field self.out_field = f"{reduce}_{field}" if out_field is None else out_field @@ -81,18 +81,40 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: class PerSpeciesScaleShift(GraphModuleMixin, torch.nn.Module): + """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. + + Args: + field: the per-atom field to scale/shift. + num_types: the number of types in the model. + shifts: the initial shifts to use, one per atom type. + scales: the initial scales to use, one per atom type. + arguments_in_dataset_units: if ``True``, says that the provided shifts/scales are in dataset + units (in which case they will be rescaled appropriately by any global rescaling later + applied to the model); if ``False``, the provided shifts/scales will be used without modification. + + For example, if identity shifts/scales of zeros and ones are provided, this should be ``False``. + But if scales/shifts computed from the training data are used, and are thus in dataset units, + this should be ``True``. + out_field: the output field; defaults to ``field``. + """ + field: str out_field: str - trainable: bool + scales_trainble: bool + shifts_trainable: bool + has_scales: bool + has_shifts: bool def __init__( self, field: str, - allowed_species: List[int], + num_types: int, + shifts: List[float], + scales: List[float], + arguments_in_dataset_units: bool, out_field: Optional[str] = None, - shifts: Optional[list] = None, - scales: Optional[list] = None, - trainable: bool = False, + scales_trainable: bool = False, + shifts_trainable: bool = False, irreps_in={}, ): super().__init__() @@ -104,39 +126,57 @@ def __init__( irreps_out={self.out_field: irreps_in[self.field]}, ) - shifts = ( - torch.zeros(len(allowed_species)) - if shifts is None - else torch.as_tensor(shifts, dtype=torch.get_default_dtype()) - ) - assert shifts.shape == ( - len(allowed_species), - ), f"Invalid shape of shifts {shifts}" - scales = ( - torch.ones(len(allowed_species)) - if scales is None - else torch.as_tensor(scales, dtype=torch.get_default_dtype()) - ) - assert scales.shape == ( - len(allowed_species), - ), f"Invalid shape of scales {scales}" - - self.trainable = trainable - if trainable: - self.shifts = torch.nn.Parameter(shifts) - self.scales = torch.nn.Parameter(scales) - else: - self.register_buffer("shifts", shifts) - self.register_buffer("scales", scales) + self.has_shifts = shifts is not None + if shifts is not None: + shifts = torch.as_tensor(shifts, dtype=torch.get_default_dtype()) + if len(shifts.reshape([-1])) == 1: + shifts = torch.ones(num_types) * shifts + assert shifts.shape == (num_types,), f"Invalid shape of shifts {shifts}" + self.shifts_trainable = shifts_trainable + if shifts_trainable: + self.shifts = torch.nn.Parameter(shifts) + else: + self.register_buffer("shifts", shifts) + + self.has_scales = scales is not None + if scales is not None: + scales = torch.as_tensor(scales, dtype=torch.get_default_dtype()) + if len(scales.reshape([-1])) == 1: + scales = torch.ones(num_types) * scales + assert scales.shape == (num_types,), f"Invalid shape of scales {scales}" + self.scales_trainable = scales_trainable + if scales_trainable: + self.scales = torch.nn.Parameter(scales) + else: + self.register_buffer("scales", scales) + + self.arguments_in_dataset_units = arguments_in_dataset_units def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: - species_idx = data[AtomicDataDict.SPECIES_INDEX_KEY] + + if not (self.has_scales or self.has_shifts): + return data + + species_idx = data[AtomicDataDict.ATOM_TYPE_KEY] in_field = data[self.field] assert len(in_field) == len( species_idx ), "in_field doesnt seem to have correct per-atom shape" - data[self.out_field] = ( - self.shifts[species_idx].view(-1, 1) - + self.scales[species_idx].view(-1, 1) * in_field - ) + if self.has_scales: + in_field = self.scales[species_idx].view(-1, 1) * in_field + if self.has_shifts: + data[self.out_field] = self.shifts[species_idx].view(-1, 1) + in_field return data + + def update_for_rescale(self, rescale_module): + if self.arguments_in_dataset_units and rescale_module.has_scale: + logging.debug( + f"PerSpeciesScaleShift's arguments were in dataset units; rescaling:\n" + f"Original scales {self.scales} shifts: {self.shifts}" + ) + with torch.no_grad(): + if self.has_scales: + self.scales.div_(rescale_module.scale_by) + if self.has_shifts: + self.shifts.div_(rescale_module.scale_by) + logging.debug(f"New scales {self.scales} shifts: {self.shifts}") diff --git a/nequip/nn/_concat.py b/nequip/nn/_concat.py new file mode 100644 index 00000000..f3b59a39 --- /dev/null +++ b/nequip/nn/_concat.py @@ -0,0 +1,25 @@ +from typing import List + +import torch + +from e3nn import o3 + +from nequip.data import AtomicDataDict +from nequip.nn import GraphModuleMixin + + +class Concat(GraphModuleMixin, torch.nn.Module): + """Concatenate multiple fields into one.""" + + def __init__(self, in_fields: List[str], out_field: str, irreps_in={}): + super().__init__() + self.in_fields = list(in_fields) + self.out_field = out_field + self._init_irreps(irreps_in=irreps_in, required_irreps_in=self.in_fields) + self.irreps_out[self.out_field] = sum( + (self.irreps_in[k] for k in self.in_fields), o3.Irreps() + ) + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data[self.out_field] = torch.cat([data[k] for k in self.in_fields], dim=-1) + return data diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index 15cdfc2d..b5ee9efc 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -97,21 +97,3 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[k].requires_grad_(req_grad) return data - - -def ForceOutput(energy_model: GraphModuleMixin) -> GradientOutput: - r"""Convinience constructor for ``GradientOutput`` with settings for forces. - - Args: - energy_model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output. - - Returns: - A ``GradientOutput`` wrapping ``energy_model``. - """ - return GradientOutput( - func=energy_model, - of=AtomicDataDict.TOTAL_ENERGY_KEY, - wrt=AtomicDataDict.POSITIONS_KEY, - out_field=AtomicDataDict.FORCE_KEY, - sign=-1, # force is the negative gradient - ) diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index 9498201b..2f3ed396 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -78,6 +78,25 @@ def _init_irreps( new_out.update(irreps_out) self.irreps_out = new_out + def _add_independent_irreps(self, irreps: Dict[str, Any]): + """ + Insert some independent irreps that need to be exposed to the self.irreps_in and self.irreps_out. + The terms that have already appeared in the irreps_in will be removed. + + Args: + irreps (dict): maps names of all new fields + """ + + irreps = { + key: irrep for key, irrep in irreps.items() if key not in self.irreps_in + } + irreps_in = AtomicDataDict._fix_irreps_dict(irreps) + irreps_out = AtomicDataDict._fix_irreps_dict( + {key: irrep for key, irrep in irreps.items() if key not in self.irreps_out} + ) + self.irreps_in.update(irreps_in) + self.irreps_out.update(irreps_out) + def _make_tracing_inputs(self, n): # We impliment this to be able to trace graph modules out = [] @@ -89,7 +108,11 @@ def _make_tracing_inputs(self, n): out.append( { "forward": ( - {k: i.randn(batch, -1) for k, i in self.irreps_in.items()}, + { + k: i.randn(batch, -1) + for k, i in self.irreps_in.items() + if i is not None + }, ) } ) @@ -229,43 +252,91 @@ def append_from_parameters( self.append(name, instance) return - def insert(self, after: str, name: str, module: GraphModuleMixin) -> None: + def insert( + self, + name: str, + module: GraphModuleMixin, + after: Optional[str] = None, + before: Optional[str] = None, + ) -> None: """Insert a module after the module with name ``after``. Args: - after: the module to insert after name: the name of the module to insert module: the moldule to insert + after: the module to insert after + before: the module to insert before """ + + if (before is None) is (after is None): + raise ValueError("Only one of before or after argument needs to be defined") + elif before is None: + insert_location = after + else: + insert_location = before + # This checks names, etc. self.add_module(name, module) # Now insert in the right place by overwriting names = list(self._modules.keys()) modules = list(self._modules.values()) - idx = names.index(after) - names.insert(idx + 1, name) - modules.insert(idx + 1, module) + idx = names.index(insert_location) + if before is None: + idx += 1 + names.insert(idx, name) + modules.insert(idx, module) + self._modules = OrderedDict(zip(names, modules)) + + module_list = list(self._modules.values()) + + # sanity check the compatibility + if idx > 0: + assert AtomicDataDict._irreps_compatible( + module_list[idx - 1].irreps_out, module.irreps_in + ) + if len(module_list) > idx: + assert AtomicDataDict._irreps_compatible( + module_list[idx + 1].irreps_in, module.irreps_out + ) + + # insert the new irreps_out to the later modules + for module_id, next_module in enumerate(module_list[idx + 1 :]): + next_module._add_independent_irreps(module.irreps_out) + + # update the final wrapper irreps_out + self.irreps_out = dict(module_list[-1].irreps_out) + return def insert_from_parameters( self, - after: str, shared_params: Mapping, name: str, builder: Callable, params: Dict[str, Any] = {}, + after: Optional[str] = None, + before: Optional[str] = None, ) -> None: r"""Build a module from parameters and insert it after ``after``. Args: - after: the name of the module to insert after shared_params (dict-like): shared parameters from which to pull when instantiating the module name (str): the name for the module builder (callable): a class or function to build a module params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params`` + after: the name of the module to insert after + before: the name of the module to insert before """ - idx = list(self._modules.keys()).index(after) + if (before is None) is (after is None): + raise ValueError("Only one of before or after argument needs to be defined") + elif before is None: + insert_location = after + else: + insert_location = before + idx = list(self._modules.keys()).index(insert_location) - 1 + if before is None: + idx += 1 instance, _ = instantiate( builder=builder, prefix=name, @@ -273,7 +344,7 @@ def insert_from_parameters( optional_args=params, all_args=shared_params, ) - self.insert(after, name, instance) + self.insert(after=after, before=before, name=name, module=instance) return # Copied from https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index 8370cf85..575144f2 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -3,7 +3,7 @@ import torch -from torch_scatter import scatter +from torch_runstats.scatter import scatter from e3nn import o3 from e3nn.nn import FullyConnectedNet @@ -25,7 +25,7 @@ def __init__( invariant_layers=1, invariant_neurons=8, avg_num_neighbors=None, - use_sc=False, + use_sc=True, nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp"}, ) -> None: """ diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index a946eb80..afd232f5 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -29,11 +29,11 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): scale_keys: List[str] shift_keys: List[str] - trainable_global_rescale_scale: bool - trainable_global_rescale_shift: bool + scale_trainble: bool + rescale_trainable: bool - _has_scale: bool - _has_shift: bool + has_scale: bool + has_shift: bool def __init__( self, @@ -42,8 +42,8 @@ def __init__( shift_keys: Union[Sequence[str], str] = [], scale_by=None, shift_by=None, - trainable_global_rescale_shift: bool = False, - trainable_global_rescale_scale: bool = False, + shift_trainable: bool = False, + scale_trainable: bool = False, irreps_in: dict = {}, ): super().__init__() @@ -75,33 +75,33 @@ def __init__( self.scale_keys = list(scale_keys) self.shift_keys = list(shift_keys) - self._has_scale = scale_by is not None - self.trainable_global_rescale_scale = trainable_global_rescale_scale - if self._has_scale: + self.has_scale = scale_by is not None + self.scale_trainble = scale_trainable + if self.has_scale: scale_by = torch.as_tensor(scale_by) - if self.trainable_global_rescale_scale: + if self.scale_trainble: self.scale_by = torch.nn.Parameter(scale_by) else: self.register_buffer("scale_by", scale_by) - elif self.trainable_global_rescale_scale: + elif self.scale_trainble: raise ValueError( - "Asked for a trainable_global_rescale_scale, but this RescaleOutput has no scaling (`scale_by = None`)" + "Asked for a scale_trainable, but this RescaleOutput has no scaling (`scale_by = None`)" ) else: # register dummy for TorchScript self.register_buffer("scale_by", torch.Tensor()) - self._has_shift = shift_by is not None - self.trainable_global_rescale_shift = trainable_global_rescale_shift - if self._has_shift: + self.has_shift = shift_by is not None + self.rescale_trainable = shift_trainable + if self.has_shift: shift_by = torch.as_tensor(shift_by) - if self.trainable_global_rescale_shift: + if self.rescale_trainable: self.shift_by = torch.nn.Parameter(shift_by) else: self.register_buffer("shift_by", shift_by) - elif self.trainable_global_rescale_shift: + elif self.rescale_trainable: raise ValueError( - "Asked for a trainable_global_rescale_shift, but this RescaleOutput has no shift (`shift_by = None`)" + "Asked for a shift_trainable, but this RescaleOutput has no shift (`shift_by = None`)" ) else: # register dummy for TorchScript @@ -126,10 +126,10 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: return data else: # Scale then shift - if self._has_scale: + if self.has_scale: for field in self.scale_keys: data[field] = data[field] * self.scale_by - if self._has_shift: + if self.has_shift: for field in self.shift_keys: data[field] = data[field] + self.shift_by return data @@ -154,11 +154,11 @@ def scale( if self.training and not force_process: return data else: - if self._has_scale: + if self.has_scale: for field in self.scale_keys: if field in data: data[field] = data[field] * self.scale_by - if self._has_shift: + if self.has_shift: for field in self.shift_keys: if field in data: data[field] = data[field] + self.shift_by @@ -183,11 +183,11 @@ def unscale( data = data.copy() if self.training or force_process: # To invert, -shift then divide by scale - if self._has_shift: + if self.has_shift: for field in self.shift_keys: if field in data: data[field] = data[field] - self.shift_by - if self._has_scale: + if self.has_scale: for field in self.scale_keys: if field in data: data[field] = data[field] / self.scale_by diff --git a/nequip/nn/cutoffs.py b/nequip/nn/cutoffs.py index cd1c44a8..99177323 100644 --- a/nequip/nn/cutoffs.py +++ b/nequip/nn/cutoffs.py @@ -1,9 +1,24 @@ import torch -from torch import nn -class PolynomialCutoff(nn.Module): - def __init__(self, r_max, p=6): +@torch.jit.script +def _poly_cutoff(x: torch.Tensor, factor: float) -> torch.Tensor: + p: float = 6.0 + x = x * factor + + out = 1.0 + out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p)) + out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0)) + out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0)) + + return out * (x < 1.0) + + +class PolynomialCutoff(torch.nn.Module): + _factor: float + p: float + + def __init__(self, r_max: float, p: float = 6): r"""Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 @@ -15,10 +30,13 @@ def __init__(self, r_max, p=6): p : int Power used in envelope function """ - super(PolynomialCutoff, self).__init__() - - self.register_buffer("p", torch.Tensor([p])) - self.register_buffer("r_max", torch.Tensor([r_max])) + super().__init__() + if p != 6: + raise NotImplementedError( + "p values other than 6 not currently supported for simplicity; if you need this please file an issue." + ) + self.p = p + self._factor = 1.0 / r_max def forward(self, x): """ @@ -26,13 +44,4 @@ def forward(self, x): x: torch.Tensor, input distance """ - envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) - * torch.pow(x / self.r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1.0) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2.0) - ) - - envelope *= (x < self.r_max).float() - return envelope + return _poly_cutoff(x, self._factor) diff --git a/nequip/nn/embedding/__init__.py b/nequip/nn/embedding/__init__.py index 103cc9e7..dfc9b710 100644 --- a/nequip/nn/embedding/__init__.py +++ b/nequip/nn/embedding/__init__.py @@ -1,2 +1,4 @@ from ._one_hot import OneHotAtomEncoding from ._edge import SphericalHarmonicEdgeAttrs, RadialBasisEdgeEncoding + +__all__ = [OneHotAtomEncoding, SphericalHarmonicEdgeAttrs, RadialBasisEdgeEncoding] diff --git a/nequip/nn/embedding/_one_hot.py b/nequip/nn/embedding/_one_hot.py index 50ad9d2e..373243d2 100644 --- a/nequip/nn/embedding/_one_hot.py +++ b/nequip/nn/embedding/_one_hot.py @@ -10,7 +10,7 @@ @compile_mode("script") class OneHotAtomEncoding(GraphModuleMixin, torch.nn.Module): - num_species: int + num_types: int set_features: bool # TODO: use torch.unique? @@ -18,60 +18,25 @@ class OneHotAtomEncoding(GraphModuleMixin, torch.nn.Module): # Docstrings def __init__( self, - allowed_species=None, - num_species: int = None, + num_types: int, set_features: bool = True, irreps_in=None, ): super().__init__() - if allowed_species is not None and num_species is not None: - raise ValueError("allowed_species and num_species cannot both be provided.") - - if allowed_species is not None: - num_species = len(allowed_species) - allowed_species = torch.as_tensor(allowed_species) - self.register_buffer("_min_Z", allowed_species.min()) - self.register_buffer("_max_Z", allowed_species.max()) - Z_to_index = torch.full( - (1 + self._max_Z - self._min_Z,), -1, dtype=torch.long - ) - Z_to_index[allowed_species - self._min_Z] = torch.arange(num_species) - self.register_buffer("_Z_to_index", Z_to_index) - self.num_species = num_species + self.num_types = num_types self.set_features = set_features - # Output irreps are num_species even (invariant) scalars - irreps_out = { - AtomicDataDict.NODE_ATTRS_KEY: Irreps([(self.num_species, (0, 1))]) - } + # Output irreps are num_types even (invariant) scalars + irreps_out = {AtomicDataDict.NODE_ATTRS_KEY: Irreps([(self.num_types, (0, 1))])} if self.set_features: irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = irreps_out[ AtomicDataDict.NODE_ATTRS_KEY ] self._init_irreps(irreps_in=irreps_in, irreps_out=irreps_out) - @torch.jit.export - def index_for_atomic_numbers(self, atomic_nums: torch.Tensor): - if atomic_nums.min() < self._min_Z or atomic_nums.max() > self._max_Z: - raise RuntimeError("Invalid atomic numbers for this OneHotEncoding") - - out = self._Z_to_index[atomic_nums - self._min_Z] - assert out.min() >= 0, "Invalid atomic numbers for this OneHotEncoding" - return out - def forward(self, data: AtomicDataDict.Type): - if AtomicDataDict.SPECIES_INDEX_KEY in data: - type_numbers = data[AtomicDataDict.SPECIES_INDEX_KEY] - elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: - type_numbers = self.index_for_atomic_numbers( - data[AtomicDataDict.ATOMIC_NUMBERS_KEY] - ) - data[AtomicDataDict.SPECIES_INDEX_KEY] = type_numbers - else: - raise ValueError( - "Nothing in this `data` to encode, need either species index or atomic numbers" - ) + type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY] one_hot = torch.nn.functional.one_hot( - type_numbers, num_classes=self.num_species + type_numbers, num_classes=self.num_types ).to(device=type_numbers.device, dtype=data[AtomicDataDict.POSITIONS_KEY].dtype) data[AtomicDataDict.NODE_ATTRS_KEY] = one_hot if self.set_features: diff --git a/nequip/nn/nonlinearities.py b/nequip/nn/nonlinearities.py index a8469f8f..7ddfba00 100644 --- a/nequip/nn/nonlinearities.py +++ b/nequip/nn/nonlinearities.py @@ -3,29 +3,6 @@ import math -class _ShiftedSoftPlus(torch.nn.Module): - """ - Shifted softplus as defined in SchNet, NeurIPS 2017. - - :param beta: value for the a more general softplus, default = 1 - :param threshold: values above are linear function, default = 20 - """ - - _log2: float - - def __init__(self, beta=1, threshold=20): - super().__init__() - self.softplus = torch.nn.Softplus(beta=beta, threshold=threshold) - self._log2 = math.log(2.0) - - def forward(self, x): - """ - Evaluate shifted softplus - - :param x: torch.Tensor, input - :return: torch.Tensor, ssp(x) - """ - return self.softplus(x) - self._log2 - - -ShiftedSoftPlus = _ShiftedSoftPlus() +@torch.jit.script +def ShiftedSoftPlus(x): + return torch.nn.functional.softplus(x) - math.log(2.0) diff --git a/nequip/nn/radial_basis.py b/nequip/nn/radial_basis.py index 21d04e06..b525679c 100644 --- a/nequip/nn/radial_basis.py +++ b/nequip/nn/radial_basis.py @@ -1,9 +1,47 @@ +from typing import Optional import math import torch from torch import nn +from e3nn.math import soft_one_hot_linspace +from e3nn.util.jit import compile_mode + + +@compile_mode("trace") +class e3nn_basis(nn.Module): + r_max: float + r_min: float + e3nn_basis_name: str + num_basis: int + + def __init__( + self, + r_max: float, + r_min: Optional[float] = None, + e3nn_basis_name: str = "gaussian", + num_basis: int = 8, + ): + super().__init__() + self.r_max = r_max + self.r_min = r_min if r_min is not None else 0.0 + self.e3nn_basis_name = e3nn_basis_name + self.num_basis = num_basis + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return soft_one_hot_linspace( + x, + start=self.r_min, + end=self.r_max, + number=self.num_basis, + basis=self.e3nn_basis_name, + cutoff=True, + ) + + def _make_tracing_inputs(self, n: int): + return [{"forward": (torch.randn(5, 1),)} for _ in range(n)] + class BesselBasis(nn.Module): r_max: float @@ -40,7 +78,7 @@ def __init__(self, r_max, num_basis=8, trainable=True): else: self.register_buffer("bessel_weights", bessel_weights) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate Bessel Basis for input x. @@ -52,3 +90,29 @@ def forward(self, x): numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.r_max) return self.prefactor * (numerator / x.unsqueeze(-1)) + + +# class GaussianBasis(nn.Module): +# r_max: float + +# def __init__(self, r_max, r_min=0.0, num_basis=8, trainable=True): +# super().__init__() + +# self.trainable = trainable +# self.num_basis = num_basis + +# self.r_max = float(r_max) +# self.r_min = float(r_min) + +# means = torch.linspace(self.r_min, self.r_max, self.num_basis) +# stds = torch.full(size=means.size, fill_value=means[1] - means[0]) +# if self.trainable: +# self.means = nn.Parameter(means) +# self.stds = nn.Parameter(stds) +# else: +# self.register_buffer("means", means) +# self.register_buffer("stds", stds) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# x = (x[..., None] - self.means) / self.stds +# x = x.square().mul(-0.5).exp() / self.stds # sqrt(2 * pi) diff --git a/nequip/scripts/benchmark.py b/nequip/scripts/benchmark.py new file mode 100644 index 00000000..0f1685d5 --- /dev/null +++ b/nequip/scripts/benchmark.py @@ -0,0 +1,180 @@ +import sys +import argparse +import textwrap +import tempfile +import contextlib +import itertools + +import torch +from torch.utils.benchmark import Timer, Measurement +from torch.utils.benchmark.utils.common import trim_sigfig, select_unit + +from e3nn.util.jit import script + +from nequip.utils import Config +from nequip.data import AtomicData, dataset_from_config +from nequip.model import model_from_config +from nequip.scripts.deploy import _compile_for_deploy +from nequip.scripts.train import _set_global_options, default_config + + +def main(args=None): + parser = argparse.ArgumentParser( + description=textwrap.dedent( + """Benchmark the approximate MD performance of a given model configuration / dataset pair.""" + ) + ) + parser.add_argument("config", help="configuration file") + parser.add_argument( + "--profile", + help="Profile instead of timing, creating and outputing a Chrome trace JSON to the given path.", + type=str, + default=None, + ) + parser.add_argument( + "--device", + help="Device to run the model on. If not provided, defaults to CUDA if available and CPU otherwise.", + type=str, + default=None, + ) + parser.add_argument( + "-n", + help="Number of trials.", + type=int, + default=30, + ) + parser.add_argument( + "--n-data", + help="Number of frames to use.", + type=int, + default=1, + ) + parser.add_argument( + "--timestep", + help="MD timestep for ns/day esimation, in fs. Defauts to 1fs.", + type=float, + default=1, + ) + + # TODO: option to profile + # TODO: option to show memory use + + # Parse the args + args = parser.parse_args(args=args) + + if args.device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(args.device) + + print(f"Using device: {device}") + + config = Config.from_file(args.config, defaults=default_config) + _set_global_options(config) + + # Load dataset to get something to benchmark on + print("Loading dataset... ") + # Currently, pytorch_geometric prints some status messages to stdout while loading the dataset + # TODO: fix may come soon: https://github.com/rusty1s/pytorch_geometric/pull/2950 + # Until it does, just redirect them. + with contextlib.redirect_stdout(sys.stderr): + dataset = dataset_from_config(config) + datas = [ + AtomicData.to_AtomicDataDict(dataset[i].to(device)) + for i in torch.randperm(len(dataset))[: args.n_data] + ] + n_atom: int = len(datas[0]["pos"]) + assert all(len(d["pos"]) == n_atom for d in datas) # TODO handle the general case + # TODO: show some stats about datas + + datas = itertools.cycle(datas) + + # Load model: + print("Loading model... ") + model = model_from_config(config, initialize=True, dataset=dataset) + print("Compile...") + # "Deploy" it + model.eval() + model = script(model) + + # OLD ---- OLD ---- OLD + # TODO!!: for now we just compile, but when + # https://github.com/pytorch/pytorch/issues/64957#issuecomment-918632252 + # is resolved, then should be deploying again + # print( + # "WARNING: this is currently not using deployed model, just scripted, because of PyTorch bugs" + # ) + # OLD ---- OLD ---- OLD + + model = _compile_for_deploy(model) # TODO make this an option + # save and reload to avoid bugs + with tempfile.NamedTemporaryFile() as f: + torch.jit.save(model, f.name) + model = torch.jit.load(f.name, map_location=device) + + # Make sure we're warm past compilation + warmup = config["_jit_bailout_depth"] + 4 # just to be safe... + + if args.profile is not None: + + def trace_handler(p): + p.export_chrome_trace(args.profile) + print(f"Wrote profiling trace to `{args.profile}`") + + print("Starting...") + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ] + + ([torch.profiler.ProfilerActivity.CUDA] if device.type == "cuda" else []), + schedule=torch.profiler.schedule( + wait=1, warmup=warmup, active=args.n, repeat=1 + ), + on_trace_ready=trace_handler, + ) as p: + for _ in range(1 + warmup + args.n): + model(next(datas).copy()) + p.step() + else: + print("Warmup...") + for _ in range(warmup): + model(next(datas).copy()) + + print("Starting...") + # just time + t = Timer( + stmt="model(next(datas).copy())", globals={"model": model, "datas": datas} + ) + perloop: Measurement = t.timeit(args.n) + + print(" -- Results --") + print( + f"PLEASE NOTE: these are speeds for the MODEL, evaluated on --n-data={args.n_data} configurations kept in memory." + ) + print( + " \\_ MD itself, memory copies, and other overhead will affect real-world performance." + ) + print() + trim_time = trim_sigfig(perloop.times[0], perloop.significant_figures) + time_unit, time_scale = select_unit(trim_time) + time_str = ("{:.%dg}" % perloop.significant_figures).format( + trim_time / time_scale + ) + print(f"The average call took {time_str}{time_unit}") + print( + "Assuming linear scaling — which is ALMOST NEVER true in practice, especially on GPU —" + ) + per_atom_time = trim_time / n_atom + time_unit_per, time_scale_per = select_unit(per_atom_time) + print( + f" \\_ this comes out to {per_atom_time/time_scale_per:g} {time_unit_per}/atom/call" + ) + ns_day = (86400.0 / trim_time) * args.timestep * 1e-6 + # day in s^ s/step^ ^ fs / step ^ ns / fs + print( + f"For this system, at a {args.timestep:.2f}fs timestep, this comes out to {ns_day:.2f} ns/day" + ) + + +if __name__ == "__main__": + main() diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 6c7a4e66..8bbea037 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -17,21 +17,47 @@ import torch +import ase.data + from e3nn.util.jit import script -import nequip -from nequip.nn import GraphModuleMixin +from nequip.train import Trainer CONFIG_KEY: Final[str] = "config" NEQUIP_VERSION_KEY: Final[str] = "nequip_version" +TORCH_VERSION_KEY: Final[str] = "torch_version" +E3NN_VERSION_KEY: Final[str] = "e3nn_version" R_MAX_KEY: Final[str] = "r_max" N_SPECIES_KEY: Final[str] = "n_species" +TYPE_NAMES_KEY: Final[str] = "type_names" +JIT_BAILOUT_KEY: Final[str] = "_jit_bailout_depth" +TF32_KEY: Final[str] = "allow_tf32" + +_ALL_METADATA_KEYS = [ + CONFIG_KEY, + NEQUIP_VERSION_KEY, + R_MAX_KEY, + N_SPECIES_KEY, + TYPE_NAMES_KEY, + JIT_BAILOUT_KEY, + TF32_KEY, +] + + +def _compile_for_deploy(model): + model.eval() -_ALL_METADATA_KEYS = [CONFIG_KEY, NEQUIP_VERSION_KEY, R_MAX_KEY, N_SPECIES_KEY] + if not isinstance(model, torch.jit.ScriptModule): + model = script(model) + + return model def load_deployed_model( - model_path: Union[pathlib.Path, str], device: Union[str, torch.device] = "cpu" + model_path: Union[pathlib.Path, str], + device: Union[str, torch.device] = "cpu", + freeze: bool = True, + set_global_options: Union[str, bool] = "warn", ) -> Tuple[torch.jit.ScriptModule, Dict[str, str]]: r"""Load a deployed model. @@ -43,6 +69,7 @@ def load_deployed_model( """ metadata = {k: "" for k in _ALL_METADATA_KEYS} try: + # TODO: use .to()? instead of map_location model = torch.jit.load(model_path, map_location=device, _extra_files=metadata) except RuntimeError as e: raise ValueError( @@ -53,19 +80,42 @@ def load_deployed_model( raise ValueError( f"{model_path} does not seem to be a deployed NequIP model file" ) - # Remove missing metadata - for k in metadata: - # TODO: some better semver based checking of versions here, or something - if metadata[k] == "": - warnings.warn( - f"Metadata key `{k}` wasn't present in the saved model; this may indicate compatability issues." - ) # Confirm its TorchScript assert isinstance(model, torch.jit.ScriptModule) # Make sure we're in eval mode model.eval() + # Freeze on load: + if freeze and hasattr(model, "training"): + # hasattr is how torch checks whether model is unfrozen + # only freeze if already unfrozen + model = torch.jit.freeze(model) # Everything we store right now is ASCII, so decode for printing metadata = {k: v.decode("ascii") for k, v in metadata.items()} + # Set up global settings: + assert set_global_options in (True, False, "warn") + if set_global_options: + # Set TF32 support + # See https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if torch.cuda.is_available() and metadata[TF32_KEY] != "": + allow_tf32 = bool(int(metadata[TF32_KEY])) + if torch.torch.backends.cuda.matmul.allow_tf32 is not allow_tf32: + # Update setting + if set_global_options == "warn": + warnings.warn( + "Loaded model had a different value for allow_tf32 than was currently set; changing the GLOBAL setting!" + ) + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + torch.backends.cudnn.allow_tf32 = allow_tf32 + + # JIT bailout + if metadata[JIT_BAILOUT_KEY] != "": + jit_bailout: int = int(metadata[JIT_BAILOUT_KEY]) + # no way to get current value, so assume we are overwriting it + if set_global_options == "warn": + warnings.warn( + "Loaded model had a different value for _jit_bailout_depth than was currently set; changing the GLOBAL setting!" + ) + torch._C._jit_set_bailout_depth(jit_bailout) return model, metadata @@ -73,7 +123,12 @@ def main(args=None): parser = argparse.ArgumentParser( description="Create and view information about deployed NequIP potentials." ) - subparsers = parser.add_subparsers(dest="command", required=True, title="commands") + # backward compat for 3.6 + if sys.version_info[1] > 6: + required = {"required": True} + else: + required = {} + subparsers = parser.add_subparsers(dest="command", title="commands", **required) info_parser = subparsers.add_parser( "info", help="Get information from a deployed model file" ) @@ -101,10 +156,11 @@ def main(args=None): logging.basicConfig(level=logging.INFO) if args.command == "info": - model, metadata = load_deployed_model(args.model_path) + model, metadata = load_deployed_model(args.model_path, set_global_options=False) del model config = metadata.pop(CONFIG_KEY) - logging.info(f"Loaded TorchScript model with metadata {metadata}") + metadata_str = "\n".join(" %s: %s" % e for e in metadata.items()) + logging.info(f"Loaded TorchScript model with metadata:\n{metadata_str}\n") logging.info("Model was built with config:") print(config) @@ -116,42 +172,42 @@ def main(args=None): f"{args.out_dir} is a directory, but a path to a file for the deployed model must be given" ) # -- load model -- - model_is_jit = False - model_path = args.train_dir / "best_model.pth" - try: - model = torch.jit.load(model_path, map_location=torch.device("cpu")) - model_is_jit = True - logging.info("Loaded TorchScript model") - except RuntimeError: - # ^ jit.load throws this when it can't find TorchScript files - model = torch.load(model_path, map_location=torch.device("cpu")) - if not isinstance(model, GraphModuleMixin): - raise TypeError( - "Model contained object that wasn't a NequIP model (nequip.nn.GraphModuleMixin)" - ) - logging.info("Loaded pickled model") - model = model.to(device=torch.device("cpu")) + model, _ = Trainer.load_model_from_training_session( + args.train_dir, model_name="best_model.pth", device="cpu" + ) # -- compile -- - if not model_is_jit: - model = script(model) - logging.info("Compiled model to TorchScript") - - model.eval() # just to be sure - - model = torch.jit.freeze(model) - logging.info("Froze TorchScript model") + model = _compile_for_deploy(model) + logging.info("Compiled & optimized model.") # load config - # TODO: walk module tree if config does not exist to find params? - config_str = (args.train_dir / "config_final.yaml").read_text() + config_str = (args.train_dir / "config.yaml").read_text() config = yaml.load(config_str, Loader=yaml.Loader) # Deploy - metadata: dict = {NEQUIP_VERSION_KEY: nequip.__version__} + metadata: dict = {} + for code in ["e3nn", "nequip", "torch"]: + metadata[code + "_version"] = config[code + "_version"] + metadata[R_MAX_KEY] = str(float(config["r_max"])) - metadata[N_SPECIES_KEY] = str(len(config["allowed_species"])) + if "allowed_species" in config: + # This is from before the atomic number updates + n_species = len(config["allowed_species"]) + type_names = { + type: ase.data.chemical_symbols[atomic_num] + for type, atomic_num in enumerate(config["allowed_species"]) + } + else: + # The new atomic number setup + n_species = str(config["num_types"]) + type_names = config["type_names"] + metadata[N_SPECIES_KEY] = str(n_species) + metadata[TYPE_NAMES_KEY] = " ".join(type_names) + + metadata[JIT_BAILOUT_KEY] = str(config["_jit_bailout_depth"]) + metadata[TF32_KEY] = str(int(config["allow_tf32"])) metadata[CONFIG_KEY] = config_str + metadata = {k: v.encode("ascii") for k, v in metadata.items()} torch.jit.save(model, args.out_file, _extra_files=metadata) else: diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 420bcc9e..67221af1 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -1,5 +1,6 @@ import sys import argparse +import logging import textwrap from pathlib import Path import contextlib @@ -9,23 +10,30 @@ import torch -from nequip.utils import Config, dataset_from_config -from nequip.data import AtomicData, Collater +from nequip.utils import Config +from nequip.data import AtomicData, Collater, dataset_from_config +from nequip.train import Trainer from nequip.scripts.deploy import load_deployed_model +from nequip.scripts.train import default_config, _set_global_options from nequip.utils import load_file, instantiate from nequip.train.loss import Loss from nequip.train.metrics import Metrics +from nequip.scripts.logger import set_up_script_logger -def main(args=None): +def main(args=None, running_as_script: bool = True): # in results dir, do: nequip-deploy build . deployed.pth parser = argparse.ArgumentParser( description=textwrap.dedent( """Compute the error of a model on a test set using various metrics. The model, metrics, dataset, etc. can specified individually, or a training session can be indicated with `--train-dir`. + In order of priority, the global settings (dtype, TensorFloat32, etc.) are taken from: + 1. The model config (for a training session) + 2. The dataset config (for a deployed model) + 3. The defaults - Prints only the final result in `name = num` format to stdout; all other information is printed to stderr. + Prints only the final result in `name = num` format to stdout; all other information is logging.debuged to stderr. WARNING: Please note that results of CUDA models are rarely exactly reproducible, and that even CPU models can be nondeterministic. """ @@ -45,13 +53,13 @@ def main(args=None): ) parser.add_argument( "--dataset-config", - help="A YAML config file specifying the dataset to load test data from. If omitted, `config_final.yaml` in `train_dir` will be used", + help="A YAML config file specifying the dataset to load test data from. If omitted, `config.yaml` in `train_dir` will be used", type=Path, default=None, ) parser.add_argument( "--metrics-config", - help="A YAML config file specifying the metrics to compute. If omitted, `config_final.yaml` in `train_dir` will be used. If the config does not specify `metrics_components`, the default is to print MAEs and RMSEs for all fields given in the loss function. If the literal string `None`, no metrics will be computed.", + help="A YAML config file specifying the metrics to compute. If omitted, `config.yaml` in `train_dir` will be used. If the config does not specify `metrics_components`, the default is to logging.debug MAEs and RMSEs for all fields given in the loss function. If the literal string `None`, no metrics will be computed.", type=str, default=None, ) @@ -63,7 +71,7 @@ def main(args=None): ) parser.add_argument( "--batch-size", - help="Batch size to use. Larger is usually faster on GPU.", + help="Batch size to use. Larger is usually faster on GPU. If you run out of memory, lower this.", type=int, default=50, ) @@ -79,8 +87,14 @@ def main(args=None): type=Path, default=None, ) + parser.add_argument( + "--log", + help="log file to store all the metrics and screen logging.debug", + type=Path, + default=None, + ) # Something has to be provided - # See https://stackoverflow.com/questions/22368458/how-to-make-argparse-print-usage-when-no-option-is-given-to-the-code + # See https://stackoverflow.com/questions/22368458/how-to-make-argparse-logging.debug-usage-when-no-option-is-given-to-the-code if len(sys.argv) == 1: parser.print_help() parser.exit() @@ -91,10 +105,10 @@ def main(args=None): dataset_is_from_training: bool = False if args.train_dir: if args.dataset_config is None: - args.dataset_config = args.train_dir / "config_final.yaml" + args.dataset_config = args.train_dir / "config.yaml" dataset_is_from_training = True if args.metrics_config is None: - args.metrics_config = args.train_dir / "config_final.yaml" + args.metrics_config = args.train_dir / "config.yaml" if args.model is None: args.model = args.train_dir / "best_model.pth" if args.test_indexes is None: @@ -129,56 +143,107 @@ def main(args=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) - print(f"Using device: {device}", file=sys.stderr) + + if running_as_script: + set_up_script_logger(args.log) + logger = logging.getLogger("nequip-evaluate") + logger.setLevel(logging.INFO) + + logger.info(f"Using device: {device}") if device.type == "cuda": - print( + logger.info( "WARNING: please note that models running on CUDA are usually nondeterministc and that this manifests in the final test errors; for a _more_ deterministic result, please use `--device cpu`", - file=sys.stderr, ) # Load model: - print("Loading model... ", file=sys.stderr, end="") + logger.info("Loading model... ") + model_from_training: bool = False try: - model, _ = load_deployed_model(args.model, device=device) - print("loaded deployed model.", file=sys.stderr) + model, _ = load_deployed_model( + args.model, + device=device, + set_global_options=True, # don't warn that setting + ) + logger.info("loaded deployed model.") except ValueError: # its not a deployed model - model = torch.load(args.model, map_location=device) + model, _ = Trainer.load_model_from_training_session( + traindir=args.model.parent, model_name=args.model.name + ) + model_from_training = True model = model.to(device) - print("loaded pickled Python model.", file=sys.stderr) + logger.info("loaded model from training session") + model.eval() # Load a config file - print( - f"Loading {'original training ' if dataset_is_from_training else ''}dataset...", - file=sys.stderr, + logger.info( + f"Loading {'original ' if dataset_is_from_training else ''}dataset...", ) config = Config.from_file(str(args.dataset_config)) + # set global options + if model_from_training: + # Use the model config, regardless of dataset config + global_config = args.model.parent / "config.yaml" + global_config = Config.from_file(str(global_config), defaults=default_config) + _set_global_options(global_config) + del global_config + else: + # the global settings for a deployed model are set by + # set_global_options in the call to load_deployed_model + # above + pass + + dataset_is_validation: bool = False # Currently, pytorch_geometric prints some status messages to stdout while loading the dataset # TODO: fix may come soon: https://github.com/rusty1s/pytorch_geometric/pull/2950 # Until it does, just redirect them. with contextlib.redirect_stdout(sys.stderr): - dataset = dataset_from_config(config) + try: + # Try to get validation dataset + dataset = dataset_from_config(config, prefix="validation_dataset") + dataset_is_validation = True + except KeyError: + # Get shared train + validation dataset + dataset = dataset_from_config(config) + logger.info( + f"Loaded {'validation_' if dataset_is_validation else ''}dataset specified in {args.dataset_config.name}.", + ) c = Collater.for_dataset(dataset, exclude_keys=[]) # Determine the test set # this makes no sense if a dataset is given seperately - if train_idcs is not None and dataset_is_from_training: + if ( + args.test_indexes is None + and train_idcs is not None + and dataset_is_from_training + ): # we know the train and val, get the rest all_idcs = set(range(len(dataset))) # set operations - test_idcs = list(all_idcs - train_idcs - val_idcs) - assert set(test_idcs).isdisjoint(train_idcs) + if dataset_is_validation: + test_idcs = list(all_idcs - val_idcs) + logger.info( + f"Using origial validation dataset minus validation set frames, yielding a test set size of {len(test_idcs)} frames.", + ) + else: + test_idcs = list(all_idcs - train_idcs - val_idcs) + assert set(test_idcs).isdisjoint(train_idcs) + logger.info( + f"Using origial training dataset minus training and validation frames, yielding a test set size of {len(test_idcs)} frames.", + ) + # No matter what it should be disjoint from validation: assert set(test_idcs).isdisjoint(val_idcs) - print( - f"Using training dataset minus training and validation frames, yielding a test set size of {len(test_idcs)} frames.", - file=sys.stderr, - ) if not do_metrics: - print( + logger.info( "WARNING: using the automatic test set ^^^ but not computing metrics, is this really what you wanted to do?", - file=sys.stderr, ) + elif args.test_indexes is None: + # Default to all frames + test_idcs = torch.arange(dataset.len()) + logger.info( + f"Using all frames from the specified test dataset, yielding a test set size of {len(test_idcs)} frames.", + ) else: # load from file test_idcs = load_file( @@ -187,9 +252,8 @@ def main(args=None): ), filename=str(args.test_indexes), ) - print( + logger.info( f"Using provided test set indexes, yielding a test set size of {len(test_idcs)} frames.", - file=sys.stderr, ) # Figure out what metrics we're actually computing @@ -224,7 +288,7 @@ def main(args=None): batch_i: int = 0 batch_size: int = args.batch_size - print("Starting...", file=sys.stderr) + logger.info("Starting...") context_stack = contextlib.ExitStack() with contextlib.ExitStack() as context_stack: # "None" checks if in a TTY and disables if not @@ -246,7 +310,7 @@ def main(args=None): while True: datas = [ - dataset.get(int(idex)) + dataset[int(idex)] for idex in test_idcs[batch_i * batch_size : (batch_i + 1) * batch_size] ] if len(datas) == 0: @@ -269,7 +333,7 @@ def main(args=None): metrics(out, batch) display_bar.set_description_str( " | ".join( - f"{k} = {v:4.2f}" + f"{k} = {v:4.4f}" for k, v in metrics.flatten_metrics( metrics.current_result() )[0].items() @@ -284,9 +348,8 @@ def main(args=None): display_bar.close() if do_metrics: - print(file=sys.stderr) - print("--- Final result: ---", file=sys.stderr) - print( + logger.info("\n--- Final result: ---") + logger.critical( "\n".join( f"{k:>20s} = {v:< 20f}" for k, v in metrics.flatten_metrics(metrics.current_result())[0].items() @@ -295,4 +358,4 @@ def main(args=None): if __name__ == "__main__": - main() + main(running_as_script=True) diff --git a/nequip/scripts/logger.py b/nequip/scripts/logger.py new file mode 100644 index 00000000..cfc93246 --- /dev/null +++ b/nequip/scripts/logger.py @@ -0,0 +1,19 @@ +import logging +import sys + + +def set_up_script_logger(logfile: str, verbose: str = "INFO"): + # Configure the root logger so stuff gets printed + root_logger = logging.getLogger() + root_logger.setLevel(logging.CRITICAL) + root_logger.handlers = [ + logging.StreamHandler(sys.stderr), + logging.StreamHandler(sys.stdout), + ] + level = getattr(logging, verbose.upper()) + root_logger.handlers[0].setLevel(level) + root_logger.handlers[1].setLevel(logging.CRITICAL) + if logfile is not None: + root_logger.addHandler(logging.FileHandler(logfile, mode="w")) + root_logger.handlers[-1].setLevel(level) + return root_logger diff --git a/nequip/scripts/requeue.py b/nequip/scripts/requeue.py deleted file mode 100644 index 02b55103..00000000 --- a/nequip/scripts/requeue.py +++ /dev/null @@ -1,58 +0,0 @@ -""" Start or automatically restart training. - -Arguments: config.yaml - -config.yaml: requeue=True, and workdir, root, run_name have to be unique. -""" -from os.path import isfile - -from nequip.scripts.train import fresh_start, parse_command_line -from nequip.scripts.restart import restart -from nequip.utils import Config - - -def main(args=None): - config = parse_command_line() - requeue(config) - - -def requeue(config): - - assert config.get( - "requeue", False - ), "This script only works for configs with `requeue` explicitly set to True. Be careful!!" - for key in ["workdir", "root", "run_name"]: - assert isinstance( - config[key], str - ), f"{key} has to be defined for requeue script" - - found_restart_file = isfile(config.workdir + "/trainer.pth") - config.restart = found_restart_file - config.append = found_restart_file - config.force_append = True - - # for fresh new train - if not found_restart_file: - config.run_time = 1 - fresh_start(config) - else: - new_config = Config( - dict(wandb_resume=True), - allow_list=[ - "run_name", - "run_time", - "run_id", - "restart", - "append", - "force_append", - "wandb_resume", - ], - ) - new_config.update(config) - restart(config.workdir + "/trainer.pth", new_config, mode="requeue") - - return - - -if __name__ == "__main__": - main() diff --git a/nequip/scripts/restart.py b/nequip/scripts/restart.py deleted file mode 100644 index 26f14e7c..00000000 --- a/nequip/scripts/restart.py +++ /dev/null @@ -1,93 +0,0 @@ -""" Restart previous training - -Arguments: file_name config.yaml(optional) - -file_name: trainer.pth from a previous training -config.yaml: any parameters that needs to be revised -""" -import logging -import argparse - -import torch - -from nequip.utils import Config, dataset_from_config, Output, load_file - - -def main(args=None): - file_name, config = parse_command_line(args) - restart(file_name, config, mode="update") - - -def parse_command_line(args=None): - parser = argparse.ArgumentParser( - description="Restart an existing NequIP training session." - ) - parser.add_argument("session", help="trainer.pth from a previous training") - parser.add_argument( - "--update-config", help="File containing any config paramters to update" - ) - args = parser.parse_args(args=args) - - if args.update_config: - config = Config.from_file(args.update_config) - else: - config = Config() - - config.append = config.get("append", True) - if config.append is None: - config.append = True - config.wandb_resume = config.get("wandb_resume", True) - - return args.session, config - - -def restart(file_name, config, mode="update"): - - # load the dictionary - dictionary = load_file( - supported_formats=dict(torch=["pt", "pth"]), - filename=file_name, - enforced_format="torch", - ) - - dictionary.update(config) - dictionary["run_time"] = 1 + dictionary.get("run_time", 0) - - config = Config(dictionary, exclude_keys=["state_dict", "progress"]) - - torch.set_default_dtype( - {"float32": torch.float32, "float64": torch.float64}[config.default_dtype] - ) - - if config.wandb: - from nequip.train.trainer_wandb import TrainerWandB - - # resume wandb run - if config.wandb_resume: - from nequip.utils.wandb import resume - - resume(config) - else: - from nequip.utils.wandb import init_n_update - - config = init_n_update(config) - - trainer = TrainerWandB.from_dict(dictionary) - else: - from nequip.train.trainer import Trainer - - trainer = Trainer.from_dict(dictionary) - - config.update(trainer.output.updated_dict()) - - dataset = dataset_from_config(config) - logging.info(f"Successfully reload the data set of type {dataset}...") - - trainer.set_dataset(dataset) - trainer.train() - - return - - -if __name__ == "__main__": - main() diff --git a/nequip/scripts/run_md.py b/nequip/scripts/run_md.py index c67e4553..97c78f65 100644 --- a/nequip/scripts/run_md.py +++ b/nequip/scripts/run_md.py @@ -11,10 +11,8 @@ from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.md.velocitydistribution import Stationary, ZeroRotation -import nequip -from nequip.dynamics.nequip_calculator import NequIPCalculator -from nequip.scripts.deploy import load_deployed_model -from nequip.dynamics.nosehoover import NoseHoover +from nequip.ase import NequIPCalculator +from nequip.ase import NoseHoover def save_to_xyz(atoms, logdir, prefix=""): diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index d064f1d3..175979de 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -1,43 +1,75 @@ """ Train a network.""" -from typing import Union, Callable import logging import argparse -import yaml # This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch. # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. import numpy as np # noqa: F401 +from os.path import isdir +from pathlib import Path + import torch import e3nn import e3nn.util.jit -from nequip.utils import Config, dataset_from_config -from nequip.data import AtomicDataDict -from nequip.nn import RescaleOutput +from nequip.model import model_from_config +from nequip.utils import Config +from nequip.data import dataset_from_config from nequip.utils.test import assert_AtomicData_equivariant, set_irreps_debug +from nequip.utils import load_file, dtype_from_name +from nequip.scripts.logger import set_up_script_logger default_config = dict( - requeue=False, + root="./", + run_name="NequIP", wandb=False, wandb_project="NequIP", - wandb_resume=False, compile_model=False, - model_builder="nequip.models.ForceModel", - model_initializers=[], + model_builders=[ + "EnergyModel", + "PerSpeciesRescale", + "ForceOutput", + "RescaleEnergyEtc", + ], dataset_statistics_stride=1, default_dtype="float32", - allow_tf32=False, + allow_tf32=False, # TODO: until we understand equivar issues verbose="INFO", model_debug_mode=False, equivariance_test=False, grad_anomaly_mode=False, + append=False, + _jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286 ) -def main(args=None): - fresh_start(parse_command_line(args)) +def main(args=None, running_as_script: bool = True): + + config = parse_command_line(args) + + if running_as_script: + set_up_script_logger(config.get("log", None), config.verbose) + + found_restart_file = isdir(f"{config.root}/{config.run_name}") + if found_restart_file and not config.append: + raise RuntimeError( + f"Training instance exists at {config.root}/{config.run_name}; " + "either set append to True or use a different root or runname" + ) + + # for fresh new train + if not found_restart_file: + trainer = fresh_start(config) + else: + trainer = restart(config) + + # Train + trainer.save() + trainer.train() + + return def parse_command_line(args=None): @@ -58,6 +90,12 @@ def parse_command_line(args=None): help="enable PyTorch autograd anomaly mode to debug NaN gradients. Do not use for production training!", action="store_true", ) + parser.add_argument( + "--log", + help="log file to store all the screen logging", + type=Path, + default=None, + ) args = parser.parse_args(args=args) config = Config.from_file(args.config, defaults=default_config) @@ -67,19 +105,8 @@ def parse_command_line(args=None): return config -def _load_callable(obj: Union[str, Callable]) -> Callable: - if callable(obj): - pass - elif isinstance(obj, str): - obj = yaml.load(f"!!python/name:{obj}", Loader=yaml.Loader) - else: - raise TypeError - assert callable(obj), f"{obj} isn't callable" - return obj - - -def fresh_start(config): - # = Set global state = +def _set_global_options(config): + """Configure global options of libraries like `torch` and `e3nn` based on `config`.""" # Set TF32 support # See https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if torch.cuda.is_available(): @@ -88,16 +115,23 @@ def fresh_start(config): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + # For avoiding 20 steps of painfully slow JIT recompilation + # See https://github.com/pytorch/pytorch/issues/52286 + torch._C._jit_set_bailout_depth(config["_jit_bailout_depth"]) + if config.model_debug_mode: set_irreps_debug(enabled=True) - torch.set_default_dtype( - {"float32": torch.float32, "float64": torch.float64}[config.default_dtype] - ) + torch.set_default_dtype(dtype_from_name(config.default_dtype)) if config.grad_anomaly_mode: torch.autograd.set_detect_anomaly(True) e3nn.set_optimization_defaults(**config.get("e3nn_optimization_defaults", {})) + +def fresh_start(config): + + _set_global_options(config) + # = Make the trainer = if config.wandb: import wandb # noqa: F401 @@ -114,133 +148,28 @@ def fresh_start(config): trainer = Trainer(model=None, **dict(config)) - output = trainer.output - config.update(output.updated_dict()) + # what is this + # to update wandb data? + config.update(trainer.params) # = Load the dataset = - dataset = dataset_from_config(config) + dataset = dataset_from_config(config, prefix="dataset") logging.info(f"Successfully loaded the data set of type {dataset}...") - - # = Train/test split = - trainer.set_dataset(dataset) - - # = Determine training type = - train_on = config.loss_coeffs - train_on = [train_on] if isinstance(train_on, str) else train_on - train_on = set(train_on) - if not train_on.issubset({"forces", "total_energy"}): - raise NotImplementedError( - f"Training on fields `{train_on}` besides forces and total energy not supported in the out-of-the-box training script yet; please use your own training script based on train.py." + try: + validation_dataset = dataset_from_config(config, prefix="validation_dataset") + logging.info( + f"Successfully loaded the validation data set of type {validation_dataset}..." ) - force_training = "forces" in train_on - logging.debug(f"Force training mode: {force_training}") - del train_on - - # = Get statistics of training dataset = - stats_fields = [ - AtomicDataDict.TOTAL_ENERGY_KEY, - AtomicDataDict.ATOMIC_NUMBERS_KEY, - ] - stats_modes = ["mean_std", "count"] - if force_training: - stats_fields.append(AtomicDataDict.FORCE_KEY) - stats_modes.append("rms") - stats = trainer.dataset_train.statistics( - fields=stats_fields, modes=stats_modes, stride=config.dataset_statistics_stride - ) - ( - (energies_mean, energies_std), - (allowed_species, Z_count), - ) = stats[:2] - if force_training: - # Scale by the force std instead - force_rms = stats[2][0] - del stats_modes - del stats_fields - - config.update(dict(allowed_species=allowed_species)) - - # = Build a model = - model_builder = _load_callable(config.model_builder) - core_model = model_builder(**dict(config)) - - # = Reinit if wanted = - with torch.no_grad(): - for initer in config.model_initializers: - initer = _load_callable(initer) - core_model.apply(initer) - - # = Determine shifts, scales = - # This is a bit awkward, but necessary for there to be a value - # in the config that signals "use dataset" - global_shift = config.get("global_rescale_shift", "dataset_energy_mean") - if global_shift == "dataset_energy_mean": - global_shift = energies_mean - elif ( - global_shift is None - or isinstance(global_shift, float) - or isinstance(global_shift, torch.Tensor) - ): - # valid values - pass - else: - raise ValueError(f"Invalid global shift `{global_shift}`") - - global_scale = config.get( - "global_rescale_scale", force_rms if force_training else energies_std - ) - if global_scale == "dataset_energy_std": - global_scale = energies_std - elif global_scale == "dataset_force_rms": - if not force_training: - raise ValueError( - "Cannot have global_scale = 'dataset_force_rms' without force training" - ) - global_scale = force_rms - elif ( - global_scale is None - or isinstance(global_scale, float) - or isinstance(global_scale, torch.Tensor) - ): - # valid values - pass - else: - raise ValueError(f"Invalid global scale `{global_scale}`") + except KeyError: + # It couldn't be found + validation_dataset = None - RESCALE_THRESHOLD = 1e-6 - if isinstance(global_scale, float) and global_scale < RESCALE_THRESHOLD: - raise ValueError( - f"Global energy scaling was very low: {global_scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None." - ) - # TODO: offer option to disable rescaling? - - logging.debug( - f"Initially outputs are scaled by: {global_scale}, eneriges are shifted by {global_shift}." - ) + # = Train/test split = + trainer.set_dataset(dataset, validation_dataset) - # == Build the model == - final_model = RescaleOutput( - model=core_model, - scale_keys=[AtomicDataDict.TOTAL_ENERGY_KEY] - + ( - [AtomicDataDict.FORCE_KEY] - if AtomicDataDict.FORCE_KEY in core_model.irreps_out - else [] - ) - + ( - [AtomicDataDict.PER_ATOM_ENERGY_KEY] - if AtomicDataDict.PER_ATOM_ENERGY_KEY in core_model.irreps_out - else [] - ), - scale_by=global_scale, - shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, - shift_by=global_shift, - trainable_global_rescale_shift=config.get( - "trainable_global_rescale_shift", False - ), - trainable_global_rescale_scale=config.get( - "trainable_global_rescale_scale", False - ), + # = Build model = + final_model = model_from_config( + config=config, initialize=True, dataset=trainer.dataset_train ) logging.info("Successfully built the network...") @@ -249,17 +178,12 @@ def fresh_start(config): final_model = e3nn.util.jit.script(final_model) logging.info("Successfully compiled model...") - # Record final config - with open(output.generate_file("config_final.yaml"), "w+") as fp: - yaml.dump(dict(config), fp) - # Equivar test if config.equivariance_test: - equivar_err = assert_AtomicData_equivariant(final_model, dataset.get(0)) - errstr = "\n".join( - f" parity_k={parity_k.item()}, did_translate={did_trans} -> max componentwise error={err.item()}" - for (parity_k, did_trans), err in equivar_err.items() - ) + from e3nn.util.test import format_equivariance_error + + equivar_err = assert_AtomicData_equivariant(final_model, dataset[0]) + errstr = format_equivariance_error(equivar_err) del equivar_err logging.info(f"Equivariance test passed; equivariance errors:\n{errstr}") del errstr @@ -267,11 +191,70 @@ def fresh_start(config): # Set the trainer trainer.model = final_model - # Train - trainer.train() + # Store any updated config information in the trainer + trainer.update_kwargs(config) - return + return trainer + + +def restart(config): + + # load the dictionary + restart_file = f"{config.root}/{config.run_name}/trainer.pth" + dictionary = load_file( + supported_formats=dict(torch=["pt", "pth"]), + filename=restart_file, + enforced_format="torch", + ) + + # compare dictionary to config and update stop condition related arguments + for k in config.keys(): + if config[k] != dictionary.get(k, ""): + if k == "max_epochs": + dictionary[k] = config[k] + logging.info(f'Update "{k}" to {dictionary[k]}') + elif k.startswith("early_stop"): + dictionary[k] = config[k] + logging.info(f'Update "{k}" to {dictionary[k]}') + elif isinstance(config[k], type(dictionary.get(k, ""))): + raise ValueError( + f'Key "{k}" is different in config and the result trainer.pth file. Please double check' + ) + + # recursive loop, if same type but different value + # raise error + + config = Config(dictionary, exclude_keys=["state_dict", "progress"]) + + # dtype, etc. + _set_global_options(config) + + if config.wandb: + from nequip.train.trainer_wandb import TrainerWandB + from nequip.utils.wandb import resume + + resume(config) + trainer = TrainerWandB.from_dict(dictionary) + else: + from nequip.train.trainer import Trainer + + trainer = Trainer.from_dict(dictionary) + + # = Load the dataset = + dataset = dataset_from_config(config, prefix="dataset") + logging.info(f"Successfully re-loaded the data set of type {dataset}...") + try: + validation_dataset = dataset_from_config(config, prefix="validation_dataset") + logging.info( + f"Successfully re-loaded the validation data set of type {validation_dataset}..." + ) + except KeyError: + # It couldn't be found + validation_dataset = None + trainer.set_dataset(dataset, validation_dataset) + + return trainer if __name__ == "__main__": - main() + main(running_as_script=True) diff --git a/nequip/train/__init__.py b/nequip/train/__init__.py index 7c3e9df4..7cff3661 100644 --- a/nequip/train/__init__.py +++ b/nequip/train/__init__.py @@ -1,3 +1,5 @@ from .loss import Loss from .metrics import Metrics from .trainer import Trainer + +__all__ = [Loss, Metrics, Trainer] diff --git a/nequip/train/_key.py b/nequip/train/_key.py index 4603eb02..f3582ebd 100644 --- a/nequip/train/_key.py +++ b/nequip/train/_key.py @@ -7,12 +7,13 @@ VALUE_KEY = "value" CONTRIB = "contrib" -VALIDATION = "Validation" -TRAIN = "Training" +VALIDATION = "validation" +TRAIN = "training" ABBREV = { AtomicDataDict.TOTAL_ENERGY_KEY: "e", AtomicDataDict.FORCE_KEY: "f", LOSS_KEY: "loss", VALIDATION: "val", + TRAIN: "train", } diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 97c8a024..4db16274 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -1,7 +1,8 @@ +import inspect import logging import torch.nn -from torch_scatter import scatter +from torch_runstats.scatter import scatter, scatter_mean from nequip.data import AtomicDataDict from nequip.utils import instantiate_from_cls_name @@ -19,9 +20,14 @@ class SimpleLoss: and reference tensor for its call functions, and outputs a vector with the same shape as pred/ref params (str): arguments needed to initialize the function above + + Return: + + if mean is True, return a scalar; else return the error matrix of each entry """ def __init__(self, func_name: str, params: dict = {}): + self.ignore_nan = params.get("ignore_nan", False) func, _ = instantiate_from_cls_name( torch.nn, class_name=func_name, @@ -30,6 +36,7 @@ def __init__(self, func_name: str, params: dict = {}): optional_args=params, all_args={}, ) + self.func_name = func_name self.func = func def __call__( @@ -37,28 +44,57 @@ def __call__( pred: dict, ref: dict, key: str, - atomic_weight_on: bool = False, mean: bool = True, ): - loss = self.func(pred[key], ref[key]) - weights_key = AtomicDataDict.WEIGHTS_KEY + key - if weights_key in ref and atomic_weight_on: - weights = ref[weights_key] - # TO DO + # zero the nan entries + has_nan = self.ignore_nan and torch.isnan(ref[key].mean()) + if has_nan: + not_nan = (ref[key] == ref[key]).int() + loss = self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan if mean: - return (loss * weights).mean() / weights.mean() + return loss.sum() / not_nan.sum() else: - raise NotImplementedError( - "metrics and running stat needs to be compatible with this" - ) - return loss * weights, weights + return loss else: + loss = self.func(pred[key], ref[key]) if mean: return loss.mean() else: return loss - return loss + +class PerAtomLoss(SimpleLoss): + def __call__( + self, + pred: dict, + ref: dict, + key: str, + mean: bool = True, + ): + # zero the nan entries + has_nan = self.ignore_nan and torch.isnan(ref[key].sum()) + N = torch.bincount(ref[AtomicDataDict.BATCH_KEY]) + N = N.reshape((-1, 1)) + if has_nan: + not_nan = (ref[key] == ref[key]).int() + loss = ( + self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan / N + ) + if self.func_name == "MSELoss": + loss = loss / N + if mean: + return loss.sum() / not_nan.sum() + else: + return loss + else: + loss = self.func(pred[key], ref[key]) + loss = loss / N + if self.func_name == "MSELoss": + loss = loss / N + if mean: + return loss.mean() + else: + return loss class PerSpeciesLoss(SimpleLoss): @@ -73,49 +109,70 @@ def __call__( pred: dict, ref: dict, key: str, - atomic_weight_on: bool = False, mean: bool = True, ): if not mean: - raise NotImplementedError("cannot handle this yet") + raise NotImplementedError("Cannot handle this yet") - per_atom_loss = self.func(pred[key], ref[key]) - per_atom_loss = per_atom_loss.mean(dim=-1, keepdim=True) + has_nan = self.ignore_nan and torch.isnan(ref[key].mean()) - # if there is atomic weights - weights_key = AtomicDataDict.WEIGHTS_KEY + key - if weights_key in ref and atomic_weight_on: - weights = ref[weights_key] - per_atom_loss = per_atom_loss * weights + if has_nan: + not_nan = (ref[key] == ref[key]).int() + per_atom_loss = ( + self.func(pred[key], torch.nan_to_num(ref[key], nan=0.0)) * not_nan + ) else: - atomic_weight_on = False + per_atom_loss = self.func(pred[key], ref[key]) + + reduce_dims = tuple(i + 1 for i in range(len(per_atom_loss.shape) - 1)) - species_index = pred[AtomicDataDict.SPECIES_INDEX_KEY] - _, inverse_species_index = torch.unique(species_index, return_inverse=True) + if has_nan: + if len(reduce_dims) > 0: + per_atom_loss = per_atom_loss.sum(dim=reduce_dims) + + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] + per_species_loss = scatter(per_atom_loss, spe_idx, dim=0) + + N = scatter(not_nan, spe_idx, dim=0) + N = N.sum(reduce_dims) + N = 1.0 / N + N_species = ((N == N).int()).sum() + + return (per_species_loss * N).sum() / N_species - if atomic_weight_on: - # TO DO - per_species_weight = scatter(weights, inverse_species_index, dim=0) - per_species_loss = scatter(per_atom_loss, inverse_species_index, dim=0) - return (per_species_loss / per_species_weight).mean() else: - return scatter( - per_atom_loss, inverse_species_index, reduce="mean", dim=0 - ).mean() + + if len(reduce_dims) > 0: + per_atom_loss = per_atom_loss.mean(dim=reduce_dims) + + # offset species index by 1 to use 0 for nan + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] + _, inverse_species_index = torch.unique(spe_idx, return_inverse=True) + + per_species_loss = scatter_mean(per_atom_loss, inverse_species_index, dim=0) + + return per_species_loss.mean() def find_loss_function(name: str, params): + """ + Search for loss functions in this module + + If the name starts with PerSpecies, return the PerSpeciesLoss instance + """ wrapper_list = dict( - PerSpecies=PerSpeciesLoss, + perspecies=PerSpeciesLoss, + peratom=PerAtomLoss, ) if isinstance(name, str): for key in wrapper_list: - if name.startswith(key): + if name.lower().startswith(key): logging.debug(f"create loss instance {wrapper_list[key]}") return wrapper_list[key](name[len(key) :], params) - + return SimpleLoss(name, params) + elif inspect.isclass(name): return SimpleLoss(name, params) elif callable(name): return name diff --git a/nequip/train/early_stopping.py b/nequip/train/early_stopping.py index c1aee9e5..3b21497c 100644 --- a/nequip/train/early_stopping.py +++ b/nequip/train/early_stopping.py @@ -1,6 +1,6 @@ from collections import OrderedDict from copy import deepcopy -from typing import Mapping, Optional, cast +from typing import Mapping class EarlyStopping: @@ -79,7 +79,7 @@ def __call__(self, metrics) -> None: self.counters[key] += 1 debug_args = f"EarlyStopping: {self.counters[key]} / {pat}" if self.counters[key] >= pat: - stop_args += " {key} has not reduced for {pat} epochs" + stop_args += f" {key} has not reduced for {pat} epochs" stop = True else: self.minimums[key] = value diff --git a/nequip/train/loss.py b/nequip/train/loss.py index cb1db763..e5cc8889 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -29,7 +29,6 @@ class Loss: 'force': (1.0, 'Weighted_L1Loss', param_dict)} ``` - If atomic_weight_on is True, all the loss function will be weighed by ref[AtomicDataDict.WEIGHTS_KEY+key] (if it exists) The loss function can be a loss class name that is exactly the same (case sensitive) to the ones defined in torch.nn. It can also be a user define class type that - takes "reduction=none" as init argument @@ -41,14 +40,13 @@ class Loss: def __init__( self, coeffs: Union[dict, str, List[str]], - atomic_weight_on: bool = False, coeff_schedule: str = "constant", ): - self.atomic_weight_on = atomic_weight_on self.coeff_schedule = coeff_schedule self.coeffs = {} self.funcs = {} + self.keys = [] mseloss = find_loss_function("MSELoss", {}) if isinstance(coeffs, str): @@ -97,6 +95,7 @@ def __init__( for key, coeff in self.coeffs.items(): self.coeffs[key] = torch.as_tensor(coeff, dtype=torch.get_default_dtype()) + self.keys += [key] def __call__(self, pred: dict, ref: dict): @@ -107,7 +106,6 @@ def __call__(self, pred: dict, ref: dict): pred=pred, ref=ref, key=key, - atomic_weight_on=self.atomic_weight_on, mean=True, ) contrib[key] = _loss @@ -117,23 +115,64 @@ def __call__(self, pred: dict, ref: dict): class LossStat: - def __init__(self, keys): - self.loss_stat = {"total": RunningStats(dim=tuple(), reduction=Reduction.MEAN)} + """ + The class that accumulate the loss function values over all batches + for each loss component. + + Args: + + keys (null): redundant argument + + """ + + def __init__(self, loss_instance=None): + self.loss_stat = { + "total": RunningStats( + dim=tuple(), reduction=Reduction.MEAN, ignore_nan=False + ) + } + self.ignore_nan = {} + if loss_instance is not None: + for key, func in loss_instance.funcs.items(): + self.ignore_nan[key] = ( + func.ignore_nan if hasattr(func, "ignore_nan") else False + ) def __call__(self, loss, loss_contrib): + """ + Args: + + loss (torch.Tensor): the value of the total loss function for the current batch + loss (Dict(torch.Tensor)): the dictionary which contain the loss components + """ + results = {} + results["loss"] = self.loss_stat["total"].accumulate_batch(loss).item() + + # go through each component for k, v in loss_contrib.items(): + + # initialize for the 1st batch if k not in self.loss_stat: - self.loss_stat[k] = RunningStats(dim=tuple(), reduction=Reduction.MEAN) + self.loss_stat[k] = RunningStats( + dim=tuple(), + reduction=Reduction.MEAN, + ignore_nan=self.ignore_nan.get(k, False), + ) device = v.get_device() self.loss_stat[k].to(device="cpu" if device == -1 else device) + results["loss_" + ABBREV.get(k, k)] = ( self.loss_stat[k].accumulate_batch(v).item() ) return results def reset(self): + """ + Reset all the counters to zero + """ + for v in self.loss_stat.values(): v.reset() diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index f22d24cb..820afd85 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -1,6 +1,9 @@ from copy import deepcopy +from hashlib import sha1 from typing import Union, Sequence, Tuple +import yaml + import torch from nequip.data import AtomicDataDict @@ -32,6 +35,8 @@ class Metrics: default: "L1Loss" PerSpecies: whether compute the estimation for each species or not + the keys are case-sensitive. + ``` components = ( @@ -51,38 +56,52 @@ def __init__( ): self.running_stats = {} - self.per_species = {} + self.params = {} self.funcs = {} self.kwargs = {} for component in components: key, reduction, params = Metrics.parse(component) - functional = params.pop("functional", "L1Loss") + params["PerSpecies"] = params.get("PerSpecies", False) + params["PerAtom"] = params.get("PerAtom", False) + + param_hash = Metrics.hash_component(component) + + functional = params.get("functional", "L1Loss") # default is to flatten the array - per_species = params.pop("PerSpecies", False) if key not in self.running_stats: self.running_stats[key] = {} - self.per_species[key] = {} self.funcs[key] = find_loss_function(functional, {}) self.kwargs[key] = {} + self.params[key] = {} # store for initialization kwargs = deepcopy(params) + kwargs.pop("functional", "L1Loss") + kwargs.pop("PerSpecies") + kwargs.pop("PerAtom") # by default, report a scalar that is mae and rmse over all component - self.kwargs[key][reduction] = dict( + self.kwargs[key][param_hash] = dict( reduction=metrics_to_reduction.get(reduction, reduction), ) - self.kwargs[key][reduction].update(kwargs) - - self.per_species[key][reduction] = per_species + self.kwargs[key][param_hash].update(kwargs) + self.params[key][param_hash] = (reduction, params) def init_runstat(self, params, error: torch.Tensor): + """ + Initialize Runstat Counter based on the shape of the error matrix + + Args: + params (dict): dictionary of additional arguments + error (torch.Tensor): error matrix + """ kwargs = deepcopy(params) + # automatically define the dimensionality if "dim" not in kwargs: kwargs["dim"] = error.shape[1:] @@ -94,6 +113,11 @@ def init_runstat(self, params, error: torch.Tensor): rs.to(device=error.device) return rs + @staticmethod + def hash_component(component): + buffer = yaml.dump(component).encode("ascii") + return sha1(buffer).hexdigest() + @staticmethod def parse(component): # parse the input list @@ -118,36 +142,49 @@ def parse(component): def __call__(self, pred: dict, ref: dict): metrics = {} + N = None for key, func in self.funcs.items(): + error = func( pred=pred, ref=ref, key=key, - atomic_weight_on=False, mean=False, ) - for reduction, kwargs in self.kwargs[key].items(): + for param_hash, kwargs in self.kwargs[key].items(): + + _, params = self.params[key][param_hash] + per_species = params["PerSpecies"] + per_atom = params["PerAtom"] # initialize the internal run_stat base on the error shape - if reduction not in self.running_stats[key]: - self.running_stats[key][reduction] = self.init_runstat( + if param_hash not in self.running_stats[key]: + self.running_stats[key][param_hash] = self.init_runstat( params=kwargs, error=error ) - stat = self.running_stats[key][reduction] + stat = self.running_stats[key][param_hash] params = {} - if self.per_species[key][reduction]: + if per_species: # TO DO, this needs OneHot component. will need to be decoupled - params = {"accumulate_by": pred[AtomicDataDict.SPECIES_INDEX_KEY]} + params = {"accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY]} + if per_atom: + if N is None: + N = torch.bincount(ref[AtomicDataDict.BATCH_KEY]).unsqueeze(-1) + error_N = error / N + else: + error_N = error - if stat.dim == () and not self.per_species[key][reduction]: - metrics[(key, reduction)] = stat.accumulate_batch( - error.flatten(), **params + if stat.dim == () and not per_species: + metrics[(key, param_hash)] = stat.accumulate_batch( + error_N.flatten(), **params ) else: - metrics[(key, reduction)] = stat.accumulate_batch(error, **params) + metrics[(key, param_hash)] = stat.accumulate_batch( + error_N, **params + ) return metrics @@ -169,36 +206,38 @@ def current_result(self): metrics[(key, reduction)] = stat.current_result() return metrics - def flatten_metrics(self, metrics, allowed_species=None): + def flatten_metrics(self, metrics, type_names=None): flat_dict = {} skip_keys = [] for k, value in metrics.items(): - key, reduction = k + key, param_hash = k + reduction, params = self.params[key][param_hash] + short_name = ABBREV.get(key, key) - item_name = f"{short_name}_{reduction}" + per_atom = params["PerAtom"] + suffix = "/N" if per_atom else "" + item_name = f"{short_name}{suffix}_{reduction}" - stat = self.running_stats[key][reduction] - per_species = self.per_species[key][reduction] + stat = self.running_stats[key][param_hash] + per_species = params["PerSpecies"] if per_species: - - element_names = ( - list(range(value.shape[0])) - if allowed_species is None - else list(allowed_species) - ) - if stat.output_dim == tuple(): + if type_names is None: + type_names = [i for i in range(len(value))] for id_ele, v in enumerate(value): - flat_dict[f"{element_names[id_ele]}_{item_name}"] = v.item() + if type_names is not None: + flat_dict[f"{type_names[id_ele]}_{item_name}"] = v.item() + else: + flat_dict[f"{id_ele}_{item_name}"] = v.item() flat_dict[f"all_{item_name}"] = value.mean().item() else: for id_ele, vec in enumerate(value): - ele = element_names[id_ele] + ele = type_names[id_ele] for idx, v in enumerate(vec): name = f"{ele}_{item_name}_{idx}" flat_dict[name] = v.item() diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index c801a0f8..e543d116 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -10,12 +10,11 @@ import sys import inspect import logging -import yaml from copy import deepcopy from os.path import isfile -from time import perf_counter, gmtime, strftime -from typing import Optional, Union - +from time import perf_counter +from typing import Callable, Optional, Union, Tuple, List +from pathlib import Path if sys.version_info[1] >= 7: import contextlib @@ -25,20 +24,23 @@ import numpy as np import e3nn -import torch_geometric import torch from torch_ema import ExponentialMovingAverage import nequip -from nequip.data import DataLoader, AtomicData, AtomicDataDict +from nequip.data import DataLoader, AtomicData, AtomicDataDict, AtomicDataset from nequip.utils import ( Output, + Config, instantiate_from_cls_name, instantiate, save_file, load_file, atomic_write, + dtype_from_name, ) +from nequip.utils.git import get_commit +from nequip.model import model_from_config from .loss import Loss, LossStat from .metrics import Metrics @@ -131,23 +133,14 @@ class Trainer: seed (int): random see number - run_name (str): run name. - root (str): the name of root dir to make work folders - timestr (optional, str): unique string to differentiate this trainer from others. - - restart (bool) : If true, the init_model function will not be callsed. Default: False - append (bool): If true, the preexisted workfolder and files will be overwritten. And log files will be appended - loss_coeffs (dict): dictionary to store coefficient and loss functions max_epochs (int): maximum number of epochs - lr_sched (optional): scheduler learning_rate (float): initial learning rate lr_scheduler_name (str): scheduler name lr_scheduler_kwargs (dict): parameters to initialize the scheduler - optim (): optimizer optimizer_name (str): name for optimizer optim_kwargs (dict): parameters to initialize the optimizer @@ -176,17 +169,14 @@ class Trainer: Additional Attributes: - init_params (list): list of parameters needed to reconstruct this instance - device : torch device - optim: optimizer - lr_sched: scheduler + init_keys (list): list of parameters needed to reconstruct this instance dl_train (DataLoader): training data dl_val (DataLoader): test data iepoch (int): # of epoches ran stop_arg (str): reason why the training stops batch_mae (float): the mae of the latest batch mae_dict (dict): all loss, mae of the latest validation - best_val_metrics (float): current best validation mae + best_metrics (float): current best validation mae best_epoch (float): current best epoch best_model_path (str): path to save the best model last_model_path (str): path to save the latest model @@ -213,31 +203,31 @@ class Trainer: ``` """ + stop_keys = ["max_epochs", "early_stopping", "early_stopping_kwargs"] + object_keys = ["lr_sched", "optim", "ema", "early_stopping_conds"] lr_scheduler_module = torch.optim.lr_scheduler optim_module = torch.optim def __init__( self, model, - run_name: Optional[str] = None, - root: Optional[str] = None, - timestr: Optional[str] = None, + model_builders: Optional[list] = [], + device: str = "cuda" if torch.cuda.is_available() else "cpu", seed: Optional[int] = None, - restart: bool = False, - append: bool = True, loss_coeffs: Union[dict, str] = AtomicDataDict.TOTAL_ENERGY_KEY, + train_on_keys: Optional[List[str]] = None, metrics_components: Optional[Union[dict, str]] = None, - metrics_key: str = ABBREV.get(LOSS_KEY, LOSS_KEY), - early_stopping: Optional[EarlyStopping] = None, + metrics_key: str = f"{VALIDATION}_" + ABBREV.get(LOSS_KEY, LOSS_KEY), + early_stopping_conds: Optional[EarlyStopping] = None, + early_stopping: Optional[Callable] = None, early_stopping_kwargs: Optional[dict] = None, max_epochs: int = 1000000, - lr_sched=None, learning_rate: float = 1e-2, lr_scheduler_name: str = "none", lr_scheduler_kwargs: Optional[dict] = None, - optim=None, optimizer_name: str = "Adam", optimizer_kwargs: Optional[dict] = None, + max_gradient_norm: float = float("inf"), use_ema: bool = False, ema_decay: float = 0.999, ema_use_num_updates=True, @@ -259,6 +249,7 @@ def __init__( log_epoch_freq: int = 1, save_checkpoint_freq: int = -1, save_ema_checkpoint_freq: int = -1, + report_init_validation: bool = False, verbose="INFO", **kwargs, ): @@ -266,45 +257,44 @@ def __init__( logging.debug("* Initialize Trainer") # store all init arguments - self.root = root self.model = model - self.optim = optim - self.lr_sched = lr_sched _local_kwargs = {} - for key in self.init_params: + for key in self.init_keys: setattr(self, key, locals()[key]) _local_kwargs[key] = locals()[key] - if self.use_ema: - self.ema = None + self.ema = None - output = Output.get_output(timestr, dict(**_local_kwargs, **kwargs)) + output = Output.get_output(dict(**_local_kwargs, **kwargs)) self.output = output - # timestr run_name root workdir logfile - for key, value in output.updated_dict().items(): - setattr(self, key, value) - - if self.logfile is None: - self.logfile = output.open_logfile("log", propagate=True) + self.logfile = output.open_logfile("log", propagate=True) self.epoch_log = output.open_logfile("metrics_epoch.csv", propagate=False) + self.init_epoch_log = output.open_logfile( + "metrics_initialization.csv", propagate=False + ) self.batch_log = { - TRAIN: output.open_logfile("metrics_batch_train.csv", propagate=False), - VALIDATION: output.open_logfile("metrics_batch_val.csv", propagate=False), + TRAIN: output.open_logfile( + f"metrics_batch_{ABBREV[TRAIN]}.csv", propagate=False + ), + VALIDATION: output.open_logfile( + f"metrics_batch_{ABBREV[VALIDATION]}.csv", propagate=False + ), } # add filenames if not defined self.best_model_path = output.generate_file("best_model.pth") self.last_model_path = output.generate_file("last_model.pth") self.trainer_save_path = output.generate_file("trainer.pth") + self.config_path = self.output.generate_file("config.yaml") - if not (seed is None or self.restart): + if seed is not None: torch.manual_seed(seed) np.random.seed(seed) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.logger.info(f"Torch device: {self.device}") + self.torch_device = torch.device(self.device) # sort out all the other parameters # for samplers, optimizer and scheduler @@ -313,36 +303,116 @@ def __init__( self.lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs) self.early_stopping_kwargs = deepcopy(early_stopping_kwargs) - # initialize the optimizer and scheduler, the params will be updated in the function + # initialize training states + self.best_metrics = float("inf") + self.best_epoch = 0 + self.iepoch = -1 if self.report_init_validation else 0 + + self.loss, _ = instantiate( + builder=Loss, + prefix="loss", + positional_args=dict(coeffs=self.loss_coeffs), + all_args=self.kwargs, + ) + self.loss_stat = LossStat(self.loss) + self.train_on_keys = self.loss.keys + if train_on_keys is not None: + assert set(train_on_keys) == set(self.train_on_keys) + self.init() - if not (restart and append): + def init_objects(self): + # initialize optimizer + self.optim, self.optimizer_kwargs = instantiate_from_cls_name( + module=torch.optim, + class_name=self.optimizer_name, + prefix="optimizer", + positional_args=dict(params=self.model.parameters(), lr=self.learning_rate), + all_args=self.kwargs, + optional_args=self.optimizer_kwargs, + ) + + self.max_gradient_norm = ( + float(self.max_gradient_norm) + if self.max_gradient_norm is not None + else float("inf") + ) - d = self.as_dict() - for key in list(d.keys()): - if not isinstance(d[key], (float, int, str, list, tuple)): - d[key] = repr(d[key]) + # initialize scheduler + assert ( + self.lr_scheduler_name + in ["CosineAnnealingWarmRestarts", "ReduceLROnPlateau", "none"] + ) or ( + (len(self.end_of_epoch_callbacks) + len(self.end_of_batch_callbacks)) > 0 + ), f"{self.lr_scheduler_name} cannot be used unless callback functions are defined" + self.lr_sched = None + self.lr_scheduler_kwargs = {} + if self.lr_scheduler_name != "none": + self.lr_sched, self.lr_scheduler_kwargs = instantiate_from_cls_name( + module=torch.optim.lr_scheduler, + class_name=self.lr_scheduler_name, + prefix="lr_scheduler", + positional_args=dict(optimizer=self.optim), + optional_args=self.lr_scheduler_kwargs, + all_args=self.kwargs, + ) + + # initialize early stopping conditions + key_mapping, kwargs = instantiate( + EarlyStopping, + prefix="early_stopping", + optional_args=self.early_stopping_kwargs, + all_args=self.kwargs, + return_args_only=True, + ) + n_args = 0 + for key, item in kwargs.items(): + # prepand VALIDATION string if k is not with + if isinstance(item, dict): + new_dict = {} + for k, v in item.items(): + if ( + k.lower().startswith(VALIDATION) + or k.lower().startswith(TRAIN) + or k.lower() in ["lr", "wall"] + ): + new_dict[k] = item[k] + else: + new_dict[f"{VALIDATION}_{k}"] = item[k] + kwargs[key] = new_dict + n_args += len(new_dict) + self.early_stopping_conds = EarlyStopping(**kwargs) if n_args > 0 else None - d["start_time"] = strftime("%a, %d %b %Y %H:%M:%S", gmtime()) + if self.use_ema and self.ema is None: + self.ema = ExponentialMovingAverage( + self.model.parameters(), + decay=self.ema_decay, + use_num_updates=self.ema_use_num_updates, + ) - self.log_dictionary(d, name="Initialization") + if hasattr(self.model, "irreps_out"): + for key in self.train_on_keys: + if key not in self.model.irreps_out: + raise RuntimeError( + "Loss function include fields that are not predicted by the model" + ) - logging.debug("! Done Initialize Trainer") + @property + def init_keys(self): + return [ + key + for key in list(inspect.signature(Trainer.__init__).parameters.keys()) + if key not in (["self", "kwargs", "model"] + Trainer.object_keys) + ] @property - def init_params(self): - d = inspect.signature(Trainer.__init__) - names = list(d.parameters.keys()) - for key in [ - "model", - "optim", - "lr_sched", - "self", - "kwargs", - ]: - if key in names: - names.remove(key) - return names + def params(self): + return self.as_dict(state_dict=False, training_progress=False, kwargs=False) + + def update_kwargs(self, config): + self.kwargs.update( + {key: value for key, value in config.items() if key not in self.init_keys} + ) @property def logger(self): @@ -352,7 +422,16 @@ def logger(self): def epoch_logger(self): return logging.getLogger(self.epoch_log) - def as_dict(self, state_dict: bool = False, training_progress: bool = False): + @property + def init_epoch_logger(self): + return logging.getLogger(self.init_epoch_log) + + def as_dict( + self, + state_dict: bool = False, + training_progress: bool = False, + kwargs: bool = True, + ): """convert instance to a dictionary Args: @@ -361,38 +440,30 @@ def as_dict(self, state_dict: bool = False, training_progress: bool = False): dictionary = {} - for key in self.init_params: + for key in self.init_keys: dictionary[key] = getattr(self, key, None) - dictionary.update(getattr(self, "kwargs", {})) + + if kwargs: + dictionary.update(getattr(self, "kwargs", {})) if state_dict: dictionary["state_dict"] = {} - dictionary["state_dict"]["optim"] = self.optim.state_dict() - if self.lr_sched is not None: - dictionary["state_dict"]["lr_sched"] = self.lr_sched.state_dict() + for key in Trainer.object_keys: + item = getattr(self, key, None) + if item is not None: + dictionary["state_dict"][key] = item.state_dict() dictionary["state_dict"]["rng_state"] = torch.get_rng_state() if torch.cuda.is_available(): dictionary["state_dict"]["cuda_rng_state"] = torch.cuda.get_rng_state( - device=self.device + device=self.torch_device ) - if self.use_ema: - dictionary["state_dict"]["ema_state"] = self.ema.state_dict() - if self.early_stopping is not None: - dictionary["state_dict"][ - "early_stopping" - ] = self.early_stopping.state_dict() - - if hasattr(self.model, "save") and not issubclass( - type(self.model), torch.jit.ScriptModule - ): - dictionary["model_class"] = type(self.model) if training_progress: dictionary["progress"] = {} for key in ["iepoch", "best_epoch"]: dictionary["progress"][key] = self.__dict__.get(key, -1) - dictionary["progress"]["best_val_metrics"] = self.__dict__.get( - "best_val_metrics", float("inf") + dictionary["progress"]["best_metrics"] = self.__dict__.get( + "best_metrics", float("inf") ) dictionary["progress"]["stop_arg"] = self.__dict__.get("stop_arg", None) @@ -400,13 +471,25 @@ def as_dict(self, state_dict: bool = False, training_progress: bool = False): dictionary["progress"]["best_model_path"] = self.best_model_path dictionary["progress"]["last_model_path"] = self.last_model_path dictionary["progress"]["trainer_save_path"] = self.trainer_save_path + if hasattr(self, "config_save_path"): + dictionary["progress"]["config_save_path"] = self.config_save_path - for code in [e3nn, nequip, torch, torch_geometric]: + for code in [e3nn, nequip, torch]: dictionary[f"{code.__name__}_version"] = code.__version__ + for code in ["e3nn", "nequip"]: + dictionary[f"{code}_commit"] = get_commit(code) return dictionary - def save(self, filename, format=None): + def save_config(self) -> None: + save_file( + item=self.as_dict(state_dict=False, training_progress=False), + supported_formats=dict(yaml=["yaml"]), + filename=self.config_path, + enforced_format=None, + ) + + def save(self, filename: Optional[str] = None, format=None): """save the file as filename Args: @@ -415,6 +498,9 @@ def save(self, filename, format=None): format (str): format of the file. yaml and json format will not save the weights. """ + if filename is None: + filename = self.trainer_save_path + logger = self.logger state_dict = ( @@ -433,6 +519,7 @@ def save(self, filename, format=None): ) logger.debug(f"Saved trainer to {filename}") + self.save_config() self.save_model(self.last_model_path) logger.debug(f"Saved last model to to {self.last_model_path}") @@ -469,7 +556,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): d = deepcopy(dictionary) - for code in [e3nn, nequip, torch, torch_geometric]: + for code in [e3nn, nequip, torch]: version = d.get(f"{code.__name__}_version", None) if version is not None and version != code.__version__: logging.warning( @@ -479,16 +566,11 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): ) # update the restart and append option - d["restart"] = True if append is not None: d["append"] = append - # update the file and folder name - output = Output.from_config(d) - d.update(output.updated_dict()) - model = None - iepoch = 0 + iepoch = -1 if "model" in d: model = d.pop("model") elif "progress" in d: @@ -497,18 +579,14 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): # load the model from file iepoch = progress["iepoch"] if isfile(progress["last_model_path"]): - load_path = progress["last_model_path"] + load_path = Path(progress["last_model_path"]) iepoch = progress["iepoch"] else: raise AttributeError("model weights & bias are not saved") - if "model_class" in d: - model = d["model_class"].load(load_path) - else: - if dictionary.get("compile_model", False): - model = torch.jit.load(load_path) - else: - model = torch.load(load_path) + model, _ = Trainer.load_model_from_training_session( + traindir=load_path.parent, model_name=load_path.name + ) logging.debug(f"Reload the model from {load_path}") d.pop("progress") @@ -519,27 +597,22 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): if state_dict is not None and trainer.model is not None: logging.debug("Reload optimizer and scheduler states") - trainer.optim.load_state_dict(state_dict["optim"]) - - if trainer.lr_sched is not None: - trainer.lr_sched.load_state_dict(state_dict["lr_sched"]) - - if trainer.early_stopping is not None: - trainer.early_stopping.load_state_dict(state_dict["early_stopping"]) + for key in Trainer.object_keys: + item = getattr(trainer, key, None) + if item is not None: + item.load_state_dict(state_dict[key]) + trainer._initialized = True torch.set_rng_state(state_dict["rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) - if trainer.use_ema: - trainer.ema.load_state_dict(state_dict["ema_state"]) - if "progress" in d: - trainer.best_val_metrics = progress["best_val_metrics"] + trainer.best_metrics = progress["best_metrics"] trainer.best_epoch = progress["best_epoch"] stop_arg = progress.pop("stop_arg", None) else: - trainer.best_val_metrics = float("inf") + trainer.best_metrics = float("inf") trainer.best_epoch = 0 stop_arg = None trainer.iepoch = iepoch @@ -553,93 +626,51 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): return trainer - def init(self): - """ initialize optimizer """ - if self.model is None: - return + @staticmethod + def load_model_from_training_session( + traindir, model_name="best_model.pth", device="cpu" + ) -> Tuple[torch.nn.Module, Config]: + traindir = str(traindir) + model_name = str(model_name) - self.model.to(self.device) + config = Config.from_file(traindir + "/config.yaml") - if self.optim is None: - self.optim, self.optimizer_kwargs = instantiate_from_cls_name( - module=torch.optim, - class_name=self.optimizer_name, - prefix="optimizer", - positional_args=dict( - params=self.model.parameters(), lr=self.learning_rate - ), - all_args=self.kwargs, - optional_args=self.optimizer_kwargs, + if config.get("compile_model", False): + model = torch.jit.load(traindir + "/" + model_name, map_location=device) + else: + model = model_from_config( + config=config, + initialize=False, ) + if model is not None: + # TODO: this is not exactly equivalent to building with + # this set as default dtype... does it matter? + model.to( + device=torch.device(device), + dtype=dtype_from_name(config.default_dtype), + ) + model_state_dict = torch.load( + traindir + "/" + model_name, map_location=device + ) + model.load_state_dict(model_state_dict) + return model, config - if self.use_ema and self.ema is None: - self.ema = ExponentialMovingAverage( - self.model.parameters(), - decay=self.ema_decay, - use_num_updates=self.ema_use_num_updates, - ) + def init(self): + """initialize optimizer""" + if self.model is None: + return - if self.lr_sched is None: - assert ( - self.lr_scheduler_name - in ["CosineAnnealingWarmRestarts", "ReduceLROnPlateau", "none"] - ) or ( - (len(self.end_of_epoch_callbacks) + len(self.end_of_batch_callbacks)) - > 0 - ), f"{self.lr_scheduler_name} cannot be used unless callback functions are defined" - self.lr_sched = None - self.lr_scheduler_kwargs = {} - if self.lr_scheduler_name != "none": - self.lr_sched, self.lr_scheduler_kwargs = instantiate_from_cls_name( - module=torch.optim.lr_scheduler, - class_name=self.lr_scheduler_name, - prefix="lr_scheduler", - positional_args=dict(optimizer=self.optim), - optional_args=self.lr_scheduler_kwargs, - all_args=self.kwargs, - ) + self.model.to(self.torch_device) + self.init_objects() - self.loss, _ = instantiate( - builder=Loss, - prefix="loss", - positional_args=dict(coeffs=self.loss_coeffs), - all_args=self.kwargs, - ) - self.loss_stat = LossStat(keys=list(self.loss.funcs.keys())) self._initialized = True - if self.early_stopping is None: - key_mapping, kwargs = instantiate( - EarlyStopping, - prefix="early_stopping", - optional_args=self.early_stopping_kwargs, - all_args=self.kwargs, - return_args_only=True, - ) - n_args = 0 - for key, item in kwargs.items(): - # prepand VALIDATION string if k is not with - if isinstance(item, dict): - new_dict = {} - for k, v in item.items(): - if ( - k.startswith(VALIDATION) - or k.startswith(TRAIN) - or k in ["LR", "wall"] - ): - new_dict[k] = item[k] - else: - new_dict[f"{VALIDATION}_{k}"] = item[k] - kwargs[key] = new_dict - n_args += len(new_dict) - self.early_stopping = EarlyStopping(**kwargs) if n_args > 0 else None - def init_metrics(self): if self.metrics_components is None: self.metrics_components = [] for key, func in self.loss.funcs.items(): params = { - "PerSpecies": type(func).__name__.startswith("PerSpecies"), + "PerSpecies": type(func).__name__.lower().startswith("perspecies"), } self.metrics_components.append((key, "mae", params)) self.metrics_components.append((key, "rmse", params)) @@ -652,28 +683,25 @@ def init_metrics(self): ) if not ( - self.metrics_key.startswith(VALIDATION) - or self.metrics_key.startswith(TRAIN) + self.metrics_key.lower().startswith(VALIDATION) + or self.metrics_key.lower().startswith(TRAIN) ): - self.metrics_key = f"{VALIDATION}_{self.metrics_key}" - - def init_model(self): - logger = self.logger - logger.info( - "Number of weights: {}".format( - sum(p.numel() for p in self.model.parameters()) + raise RuntimeError( + f"metrics_key should start with either {VALIDATION} or {TRAIN}" ) - ) def train(self): + """Training""" if getattr(self, "dl_train", None) is None: raise RuntimeError("You must call `set_dataset()` before calling `train()`") if not self._initialized: self.init() - - if not self.restart: - self.init_model() + self.logger.info( + "Number of weights: {}".format( + sum(p.numel() for p in self.model.parameters()) + ) + ) for callback in self.init_callbacks: callback(self) @@ -681,10 +709,8 @@ def train(self): self.init_log() self.wall = perf_counter() - if not self.restart: - self.best_val_metrics = float("inf") - self.best_epoch = 0 - self.iepoch = 0 + if self.iepoch == -1: + self.save() self.init_metrics() @@ -698,7 +724,7 @@ def train(self): self.final_log() - self.save(self.trainer_save_path) + self.save() def batch_step(self, data, validation=False): # no need to have gradients from old steps taking up memory @@ -710,7 +736,7 @@ def batch_step(self, data, validation=False): self.model.train() # Do any target rescaling - data = data.to(self.device) + data = data.to(self.torch_device) data = AtomicData.to_AtomicDataDict(data) if hasattr(self.model, "unscale"): @@ -734,10 +760,19 @@ def batch_step(self, data, validation=False): # Note that either way all normalization was handled internally by RescaleOutput if not validation: + # Actually do an optimization step, since we're training: loss, loss_contrib = self.loss(pred=out, ref=data_unscaled) # see https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-parameter-grad-none-instead-of-model-zero-grad-or-optimizer-zero-grad self.optim.zero_grad(set_to_none=True) loss.backward() + + # See https://stackoverflow.com/a/56069467 + # Has to happen after .backward() so there are grads to clip + if self.max_gradient_norm < float("inf"): + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.max_gradient_norm + ) + self.optim.step() if self.use_ema: @@ -770,10 +805,12 @@ def batch_step(self, data, validation=False): @property def stop_cond(self): - """ kill the training early """ + """kill the training early""" - if self.early_stopping is not None and hasattr(self, "mae_dict"): - early_stop, early_stop_args, debug_args = self.early_stopping(self.mae_dict) + if self.early_stopping_conds is not None and hasattr(self, "mae_dict"): + early_stop, early_stop_args, debug_args = self.early_stopping_conds( + self.mae_dict + ) if debug_args is not None: self.logger.debug(debug_args) if early_stop: @@ -788,14 +825,14 @@ def stop_cond(self): def reset_metrics(self): self.loss_stat.reset() - self.loss_stat.to(self.device) + self.loss_stat.to(self.torch_device) self.metrics.reset() - self.metrics.to(self.device) + self.metrics.to(self.torch_device) def epoch_step(self): datasets = [self.dl_train, self.dl_val] - categories = [TRAIN, VALIDATION] + categories = [TRAIN, VALIDATION] if self.iepoch >= 0 else [VALIDATION] self.metrics_dict = {} self.loss_dict = {} @@ -833,22 +870,13 @@ def epoch_step(self): for callback in self.end_of_epoch_callbacks: callback(self) - def log_dictionary(self, dictionary: dict, name: str = ""): - """ - dump the keys and values of a dictionary - """ - - logger = self.logger - logger.info(f"* {name}") - logger.info(yaml.dump(dictionary)) - def end_of_batch_log(self, batch_type: str): """ store all the loss/mae of each batch """ mat_str = f"{self.iepoch+1:5d}, {self.ibatch+1:5d}" - log_str = f"{self.iepoch+1:5d} {self.ibatch+1:5d}" + log_str = f" {self.iepoch+1:5d} {self.ibatch+1:5d}" header = "epoch, batch" log_header = "# Epoch batch" @@ -858,12 +886,13 @@ def end_of_batch_log(self, batch_type: str): mat_str += f", {value:16.5g}" header += f", {name}" log_str += f" {value:12.3g}" - log_header += f" {name:>12s}" + log_header += f" {name:>12.12}" # append details from metrics metrics, skip_keys = self.metrics.flatten_metrics( metrics=self.batch_metrics, - allowed_species=self.model.config.get("allowed_species", None) + # TO DO, how about chemical to symbol + type_names=self.model.config.get("type_names") if hasattr(self.model, "config") else None, ) @@ -873,17 +902,19 @@ def end_of_batch_log(self, batch_type: str): header += f", {key}" if key not in skip_keys: log_str += f" {value:12.3g}" - log_header += f" {key:>12s}" + log_header += f" {key:>12.12}" batch_logger = logging.getLogger(self.batch_log[batch_type]) - if not self.batch_header_print[batch_type]: - self.batch_header_print[batch_type] = True - batch_logger.info(header) if self.ibatch == 0: self.logger.info("") self.logger.info(f"{batch_type}") self.logger.info(log_header) + init_step = -1 if self.report_init_validation else 0 + if (self.iepoch == init_step and batch_type == VALIDATION) or ( + self.iepoch == 0 and batch_type == TRAIN + ): + batch_logger.info(header) batch_logger.info(mat_str) if (self.ibatch + 1) % self.log_batch_freq == 0 or ( @@ -896,19 +927,19 @@ def end_of_epoch_save(self): save model and trainer details """ - val_metrics = self.mae_dict[self.metrics_key] - if val_metrics < self.best_val_metrics: - self.best_val_metrics = val_metrics + current_metrics = self.mae_dict[self.metrics_key] + if current_metrics < self.best_metrics: + self.best_metrics = current_metrics self.best_epoch = self.iepoch self.save_ema_model(self.best_model_path) self.logger.info( - f"! Best model {self.best_epoch:8d} {self.best_val_metrics:8.3f}" + f"! Best model {self.best_epoch:8d} {self.best_metrics:8.3f}" ) if (self.iepoch + 1) % self.log_epoch_freq == 0: - self.save(self.trainer_save_path) + self.save() if ( self.save_checkpoint_freq > 0 @@ -939,25 +970,22 @@ def save_ema_model(self, path): self.save_model(path) def save_model(self, path): - + self.save_config() with atomic_write(path) as write_to: - if hasattr(self.model, "save"): - self.model.save(write_to) + if isinstance(self.model, torch.jit.ScriptModule): + torch.jit.save(self.model, write_to) else: - torch.save(self.model, write_to) + torch.save(self.model.state_dict(), write_to) def init_log(self): - - if self.restart: + if self.iepoch > 0: self.logger.info("! Restarting training ...") else: self.logger.info("! Starting training ...") - self.epoch_header_print = False - self.batch_header_print = {TRAIN: False, VALIDATION: False} def final_log(self): - self.logger.info(f"! Stop training for eaching {self.stop_arg}") + self.logger.info(f"! Stop training: {self.stop_arg}") wall = perf_counter() - self.wall self.logger.info(f"Wall time: {wall}") @@ -976,7 +1004,7 @@ def end_of_epoch_log(self): header = "epoch, wall, LR" - categories = [TRAIN, VALIDATION] + categories = [TRAIN, VALIDATION] if self.iepoch > 0 else [VALIDATION] log_header = {} log_str = {} @@ -991,7 +1019,7 @@ def end_of_epoch_log(self): met, skip_keys = self.metrics.flatten_metrics( metrics=self.metrics_dict[category], - allowed_species=self.model.config.get("allowed_species", None) + type_names=self.model.config.get("type_names") if hasattr(self.model, "config") else None, ) @@ -1001,7 +1029,7 @@ def end_of_epoch_log(self): mat_str += f", {value:16.5g}" header += f", {category}_{key}" log_str[category] += f" {value:12.3g}" - log_header[category] += f" {key:>12s}" + log_header[category] += f" {key:>12.12}" self.mae_dict[f"{category}_{key}"] = value # append details from metrics @@ -1010,72 +1038,118 @@ def end_of_epoch_log(self): header += f", {category}_{key}" if key not in skip_keys: log_str[category] += f" {value:12.3g}" - log_header[category] += f" {key:>12s}" + log_header[category] += f" {key:>12.12}" self.mae_dict[f"{category}_{key}"] = value - if not self.epoch_header_print: + if self.iepoch == 0: + self.init_epoch_logger.info(header) + self.init_epoch_logger.info(mat_str) + elif self.iepoch == 1: self.epoch_logger.info(header) - self.epoch_header_print = True - self.epoch_logger.info(mat_str) - self.logger.info("\n\n Train " + log_header[TRAIN]) - self.logger.info("! Train " + log_str[TRAIN]) - self.logger.info(" Validation " + log_header[VALIDATION]) - self.logger.info("! Validation " + log_str[VALIDATION]) + if self.iepoch > 0: + self.epoch_logger.info(mat_str) + + if self.iepoch > 0: + self.logger.info("\n\n Train " + log_header[TRAIN]) + self.logger.info("! Train " + log_str[TRAIN]) + self.logger.info("! Validation " + log_str[VALIDATION]) + else: + self.logger.info("\n\n Initialization " + log_header[VALIDATION]) + self.logger.info("! Initial Validation " + log_str[VALIDATION]) + + wall = perf_counter() - self.wall + self.logger.info(f"Wall time: {wall}") def __del__(self): - if not self.append: + if not self._initialized: + return - logger = self.logger - for hdl in logger.handlers: - hdl.flush() - hdl.close() - logger.handlers = [] + logger = self.logger + for hdl in logger.handlers: + hdl.flush() + hdl.close() + logger.handlers = [] - for i in range(len(logger.handlers)): - logger.handlers.pop() + for i in range(len(logger.handlers)): + logger.handlers.pop() - def set_dataset(self, dataset): + def set_dataset( + self, + dataset: AtomicDataset, + validation_dataset: Optional[AtomicDataset] = None, + ) -> None: + """Set the dataset(s) used by this trainer. - if self.train_idcs is None or self.val_idcs is None: + Training and validation datasets will be sampled from + them in accordance with the trainer's parameters. - total_n = len(dataset) + If only one dataset is provided, the train and validation + datasets will both be sampled from it. Otherwise, if + `validation_dataset` is provided, it will be used. + """ - if (self.n_train + self.n_val) > total_n: - raise ValueError( - "too little data for training and validation. please reduce n_train and n_val" - ) + if self.train_idcs is None or self.val_idcs is None: + if validation_dataset is None: + # Sample both from `dataset`: + total_n = len(dataset) + if (self.n_train + self.n_val) > total_n: + raise ValueError( + "too little data for training and validation. please reduce n_train and n_val" + ) + + if self.train_val_split == "random": + idcs = torch.randperm(total_n) + elif self.train_val_split == "sequential": + idcs = torch.arange(total_n) + else: + raise NotImplementedError( + f"splitting mode {self.train_val_split} not implemented" + ) - if self.train_val_split == "random": - idcs = torch.randperm(total_n) - elif self.train_val_split == "sequential": - idcs = torch.arange(total_n) + self.train_idcs = idcs[: self.n_train] + self.val_idcs = idcs[self.n_train : self.n_train + self.n_val] else: - raise NotImplementedError( - f"splitting mode {self.train_val_split} not implemented" - ) + if self.n_train > len(dataset): + raise ValueError("Not enough data in dataset for requested n_train") + if self.n_val > len(validation_dataset): + raise ValueError("Not enough data in dataset for requested n_train") + if self.train_val_split == "random": + self.train_idcs = torch.randperm(len(dataset))[: self.n_train] + self.val_idcs = torch.randperm(len(validation_dataset))[ + : self.n_val + ] + elif self.train_val_split == "sequential": + self.train_idcs = torch.arange(self.n_train) + self.val_idcs = torch.arange(self.n_val) + else: + raise NotImplementedError( + f"splitting mode {self.train_val_split} not implemented" + ) - self.train_idcs = idcs[: self.n_train] - self.val_idcs = idcs[self.n_train : self.n_train + self.n_val] + if validation_dataset is None: + validation_dataset = dataset # torch_geometric datasets inherantly support subsets using `index_select` self.dataset_train = dataset.index_select(self.train_idcs) - self.dataset_val = dataset.index_select(self.val_idcs) + self.dataset_val = validation_dataset.index_select(self.val_idcs) # based on recommendations from # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#enable-async-data-loading-and-augmentation - if self.dataloader_num_workers != 0: - # some issues with timeouts need to be investigated - raise NotImplementedError dl_kwargs = dict( batch_size=self.batch_size, shuffle=self.shuffle, exclude_keys=self.exclude_keys, - # num_workers=self.dataloader_num_workers, - # persistent_workers=(self.max_epochs > 1), - pin_memory=(self.device != torch.device("cpu")), - # timeout=10, # just so you don't get stuck + num_workers=self.dataloader_num_workers, + # keep stuff around in memory + persistent_workers=( + self.dataloader_num_workers > 0 and self.max_epochs > 1 + ), + # PyTorch recommends this for GPU since it makes copies much faster + pin_memory=(self.torch_device != torch.device("cpu")), + # avoid getting stuck + timeout=(10 if self.dataloader_num_workers > 0 else 0), ) self.dl_train = DataLoader(dataset=self.dataset_train, **dl_kwargs) self.dl_val = DataLoader(dataset=self.dataset_val, **dl_kwargs) diff --git a/nequip/train/trainer_wandb.py b/nequip/train/trainer_wandb.py index 925ba7e1..2c62493c 100644 --- a/nequip/train/trainer_wandb.py +++ b/nequip/train/trainer_wandb.py @@ -18,7 +18,6 @@ class TrainerWandB(Trainer): def __init__(self, **kwargs): Trainer.__init__(self, **kwargs) - wandb.config.update(self.output.updated_dict(), allow_val_change=True) def end_of_epoch_log(self): Trainer.end_of_epoch_log(self) @@ -28,6 +27,5 @@ def init_model(self): Trainer.init_model(self) - # TO DO, this will trigger pickel failure - # we may need to go back to state_dict method for saving + # TODO: test and re-enable this # wandb.watch(self.model) diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index d9234f69..16ad1ee6 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -1,4 +1,23 @@ -from .auto_init import instantiate_from_cls_name, instantiate, dataset_from_config +from .auto_init import ( + instantiate_from_cls_name, + instantiate, + get_w_prefix, +) from .savenload import save_file, load_file, atomic_write from .config import Config from .output import Output +from .modules import find_first_of_type +from .misc import dtype_from_name + +__all__ = [ + instantiate_from_cls_name, + instantiate, + get_w_prefix, + save_file, + load_file, + atomic_write, + Config, + Output, + find_first_of_type, + dtype_from_name, +] diff --git a/nequip/utils/auto_init.py b/nequip/utils/auto_init.py index a913dd1e..8a9a9917 100644 --- a/nequip/utils/auto_init.py +++ b/nequip/utils/auto_init.py @@ -1,65 +1,10 @@ +from typing import Optional, Union, List import inspect import logging -from importlib import import_module -from typing import Optional, Union, List - -from nequip import data, datasets from .config import Config -def dataset_from_config(config): - """initialize database based on a config instance - - It needs dataset type name (case insensitive), - and all the parameters needed in the constructor. - - Examples see tests/data/test_dataset.py TestFromConfig - and tests/datasets/test_simplest.py - """ - - if inspect.isclass(config.dataset): - # user define class - class_name = config.dataset - else: - try: - module_name = ".".join(config.dataset.split(".")[:-1]) - class_name = ".".join(config.dataset.split(".")[-1:]) - class_name = getattr(import_module(module_name), class_name) - except Exception as e: - # ^ TODO: don't catch all Exception - # default class defined in nequip.data or nequip.dataset - dataset_name = config.dataset.lower() - - class_name = None - for k, v in inspect.getmembers(data, inspect.isclass) + inspect.getmembers( - datasets, inspect.isclass - ): - if k.endswith("Dataset"): - if k.lower() == dataset_name: - class_name = v - if k[:-7].lower() == dataset_name: - class_name = v - elif k.lower() == dataset_name: - class_name = v - - if class_name is None: - raise NameError(f"dataset {dataset_name} does not exists") - - # if dataset r_max is not found, use the universal r_max - if "dataset_extra_fixed_fields" not in config: - config.dataset_extra_fixed_fields = {} - if "extra_fixed_fields" in config: - config.dataset_extra_fixed_fields.update(config.extra_fixed_fields) - - if "r_max" in config and "r_max" not in config.dataset_extra_fixed_fields: - config.dataset_extra_fixed_fields["r_max"] = config.r_max - - instance, _ = instantiate(class_name, prefix="dataset", optional_args=config) - - return instance - - def instantiate_from_cls_name( module, class_name: str, @@ -144,8 +89,10 @@ def instantiate( prefix_list = [builder.__name__] if inspect.isclass(builder) else [] if isinstance(prefix, str): prefix_list += [prefix] - else: + elif isinstance(prefix, list): prefix_list += prefix + else: + raise ValueError(f"prefix has the wrong type {type(prefix)}") # detect the input parameters needed from params config = Config.from_class(builder, remove_kwargs=remove_kwargs) @@ -281,6 +228,75 @@ def instantiate( logging.debug(f"... optional_args = {final_optional_args},") logging.debug(f"... positional_args = {positional_args})") - instance = builder(**positional_args, **final_optional_args) + try: + instance = builder(**positional_args, **final_optional_args) + except Exception as e: + raise RuntimeError( + f"Failed to build object with prefix `{prefix}` using builder `{builder.__name__}`" + ) from e return instance, final_optional_args + + +def get_w_prefix( + key: List[str], + *kwargs, + arg_dicts: List[dict] = [], + prefix: Optional[Union[str, List[str]]] = [], +): + """ + act as the get function and try to search for the value key from arg_dicts + """ + + # detect the input parameters needed from params + config = Config(config={}, allow_list=[key]) + + # sort out all possible prefixes + if isinstance(prefix, str): + prefix_list = [prefix] + elif isinstance(prefix, list): + prefix_list = prefix + else: + raise ValueError(f"prefix is with a wrong type {type(prefix)}") + + if not isinstance(arg_dicts, list): + arg_dicts = [arg_dicts] + + # extract all the parameters that has the pattern prefix_variable + # debug container to record all the variable name transformation + key_mapping = {} + for idx, arg_dict in enumerate(arg_dicts[::-1]): + # fetch paratemeters that directly match the name + _keys = config.update(arg_dict) + key_mapping[idx] = {k: k for k in _keys} + # fetch paratemeters that match prefix + "_" + name + for idx, prefix_str in enumerate(prefix_list): + _keys = config.update_w_prefix( + arg_dict, + prefix=prefix_str, + ) + key_mapping[idx].update(_keys) + + # for logging only, remove the overlapped keys + num_dicts = len(arg_dicts) + if num_dicts > 1: + for id_dict in range(num_dicts - 1): + higher_priority_keys = [] + for id_higher in range(id_dict + 1, num_dicts): + higher_priority_keys += list(key_mapping[id_higher].keys()) + key_mapping[id_dict] = { + k: v + for k, v in key_mapping[id_dict].items() + if k not in higher_priority_keys + } + + # debug info + logging.debug(f"search for {key} with prefix {prefix}") + for t in key_mapping: + for k, v in key_mapping[t].items(): + string = f" {str(t):>10.10}_args : {k:>50s}" + if k != v: + string += f" <- {v:>50s}" + logging.debug(string) + + return config.get(key, *kwargs) diff --git a/nequip/utils/batch_ops.py b/nequip/utils/batch_ops.py index 190e17ff..6740661e 100644 --- a/nequip/utils/batch_ops.py +++ b/nequip/utils/batch_ops.py @@ -2,8 +2,6 @@ import torch -from torch_scatter import scatter - def bincount( input: torch.Tensor, batch: Optional[torch.Tensor] = None, minlength: int = 0 diff --git a/nequip/utils/config.py b/nequip/utils/config.py index 7ea74977..72b896a1 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -242,7 +242,7 @@ def update_locked(self, d, user=None): pass def save(self, filename: str, format: Optional[str] = None): - """ Print config to file. """ + """Print config to file.""" supported_formats = {"yaml": ("yml", "yaml"), "json": "json"} return save_file( @@ -254,7 +254,7 @@ def save(self, filename: str, format: Optional[str] = None): @staticmethod def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): - """ Load arguments from file """ + """Load arguments from file""" supported_formats = {"yaml": ("yml", "yaml"), "json": "json"} dictionary = load_file( diff --git a/nequip/utils/git.py b/nequip/utils/git.py new file mode 100644 index 00000000..168f8921 --- /dev/null +++ b/nequip/utils/git.py @@ -0,0 +1,21 @@ +import subprocess +from pathlib import Path +from importlib import import_module + + +def get_commit(module: str): + + module = import_module(module) + path = str(Path(module.__file__).parents[0] / "..") + + retcode = subprocess.run( + "git show --oneline -s".split(), + cwd=path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + if retcode.returncode == 0: + return retcode.stdout.decode().splitlines()[0].split()[0] + else: + return "NaN" diff --git a/nequip/utils/initialization.py b/nequip/utils/initialization.py deleted file mode 100644 index 70ef1fe9..00000000 --- a/nequip/utils/initialization.py +++ /dev/null @@ -1,103 +0,0 @@ -import math - -import torch - -import e3nn.o3 -import e3nn.nn - - -# == Uniform init == -def unit_uniform_init_(t: torch.Tensor): - """Uniform initialization with = 1""" - t.uniform_(-math.sqrt(3), math.sqrt(3)) - - -def uniform_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s with ``unit_uniform_init_`` - - No need to do torch.nn.Linear, which is uniform by default. - """ - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - unit_uniform_init_(layer.weight) - - -def uniform_initialize_equivariant_linears(mod: torch.nn.Module): - """Initialize ``e3nn.o3.Linear``s that have internal weights with ``unit_uniform_init_``""" - if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: - unit_uniform_init_(mod.weight) - - -def uniform_initialize_tp_internal_weights(mod: torch.nn.Module): - """Initialize ``e3nn.o3.TensorProduct``s that have internal weights with ``unit_uniform_init_``""" - if isinstance(mod, e3nn.o3.TensorProduct) and mod.internal_weights: - unit_uniform_init_(mod.weight) - - -# == Xavier == -def xavier_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with Xavier uniform initialization""" - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - # in FC: - # h_in, _h_out = W.shape - # W = W / h_in**0.5 - torch.nn.init.xavier_uniform_( - layer.weight, gain=layer.weight.shape[0] ** 0.5 - ) - elif isinstance(mod, torch.nn.Linear): - torch.nn.init.xavier_uniform_(mod.weight) - - -# == Orthogonal == -# TODO: does this normalization make any sense -def unit_orthogonal_init_(t: torch.Tensor): - """Orthogonal init with = 1""" - assert t.ndim == 2 - torch.nn.init.orthogonal_(t, gain=math.sqrt(max(t.shape))) - - -def unit_orthogonal_initialize_equivariant_linears(mod: torch.nn.Module): - """Initialize ``e3nn.o3.Linear``s that have internal weights with ``unit_orthogonal_init_``""" - if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: - for w in mod.weight_views(): - unit_orthogonal_init_(w) - - -def unit_orthogonal_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with ``unit_orthogonal_init_``""" - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - unit_orthogonal_init_(layer.weight) - elif isinstance(mod, torch.nn.Linear): - unit_orthogonal_init_(mod.weight) - - -def unit_orthogonal_initialize_e3nn_fcs(mod: torch.nn.Module): - """Initialize only ``e3nn.nn.FullyConnectedNet``s with ``unit_orthogonal_init_``""" - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - unit_orthogonal_init_(layer.weight) - - -def orthogonal_initialize_equivariant_linears(mod: torch.nn.Module): - """Initialize ``e3nn.o3.Linear``s that have internal weights with ``torch.nn.init.orthogonal_``""" - if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: - for w in mod.weight_views(): - torch.nn.init.orthogonal_(w) - - -def orthogonal_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with ``torch.nn.init.orthogonal_``""" - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - torch.nn.init.orthogonal_(layer.weight) - elif isinstance(mod, torch.nn.Linear): - torch.nn.init.orthogonal_(mod.weight) - - -def orthogonal_initialize_e3nn_fcs(mod: torch.nn.Module): - """Initialize only ``e3nn.nn.FullyConnectedNet``s with ``torch.nn.init.orthogonal_``""" - if isinstance(mod, e3nn.nn.FullyConnectedNet): - for layer in mod: - torch.nn.init.orthogonal_(layer.weight) diff --git a/nequip/utils/misc.py b/nequip/utils/misc.py new file mode 100644 index 00000000..4beba97b --- /dev/null +++ b/nequip/utils/misc.py @@ -0,0 +1,5 @@ +import torch + + +def dtype_from_name(name: str) -> torch.dtype: + return {"float32": torch.float32, "float64": torch.float64}[name] diff --git a/nequip/utils/modules.py b/nequip/utils/modules.py new file mode 100644 index 00000000..07e1383a --- /dev/null +++ b/nequip/utils/modules.py @@ -0,0 +1,15 @@ +from typing import Optional + +import torch + + +def find_first_of_type(m: torch.nn.Module, kls) -> Optional[torch.nn.Module]: + """Find the first module of a given type in a module tree.""" + if isinstance(m, kls): + return m + else: + for child in m.children(): + tmp = find_first_of_type(child, kls) + if tmp is not None: + return tmp + return None diff --git a/nequip/utils/output.py b/nequip/utils/output.py index 94ad13b6..a8dbf760 100644 --- a/nequip/utils/output.py +++ b/nequip/utils/output.py @@ -1,4 +1,3 @@ -import datetime import inspect import logging import sys @@ -6,7 +5,6 @@ from logging import FileHandler, StreamHandler from os import makedirs from os.path import abspath, relpath, isfile, isdir -from time import time from typing import Optional from .config import Config @@ -18,29 +16,20 @@ class Output: Args: run_name: unique name of the simulation root: the base folder where the processed data will be stored - workdir: the path where all log files will be stored. will be updated to root/{run_name}_{timestr} if the folder already exists. - timestr (optional): unique id to generate work folder and store the output instance. default is time stamp if not defined. logfile (optional): if define, an additional logger (from the root one) will be defined and write to the file - restart (optional): if True, the append flag will be used. append (optional): if True, the workdir and files can be append screen (optional): if True, root logger print to screen verbose (optional): same as Logging verbose level """ - instances = {} - def __init__( self, - run_name: Optional[str] = None, - root: Optional[str] = None, - timestr: Optional[str] = None, - workdir: Optional[str] = None, + root: str, + run_name: str, logfile: Optional[str] = None, - restart: bool = False, append: bool = False, screen: bool = False, verbose: str = "info", - force_append: bool = False, ): # add screen output to the universal logger @@ -57,43 +46,26 @@ def __init__( for handler in logger.handlers: handler.setFormatter(fmt=formatter) - self.restart = restart self.append = append - self.force_append = force_append self.screen = screen self.verbose = verbose # open root folder for storing # if folder exists and not append, the folder name and filename will be updated - if ((not force_append) and (restart and not append)) or timestr is None: - timestr = datetime.datetime.fromtimestamp(time()).strftime( - "%Y-%m-%d_%H:%M:%S:%f" - ) - if not force_append: - root = set_if_none(root, f".") - run_name = set_if_none(run_name, f"NequIP") - workdir = set_if_none(workdir, f"{root}/{run_name}") + self.root = set_if_none(root, ".") + self.run_name = run_name + self.workdir = f"{self.root}/{self.run_name}" assert "/" not in run_name # if folder exists in a non-append-mode or a fresh run # rename the work folder based on run name - if ( - isdir(workdir) - and (((restart and not append) or (not restart))) - and not force_append - ): - logging.debug(f" ...renaming workdir from {workdir} to") - - workdir = f"{root}/{run_name}_{timestr}" - logging.debug(f" ...{workdir}") - - makedirs(workdir, exist_ok=True) + if isdir(self.workdir) and not append: + raise RuntimeError( + f"project {self.run_name} already exist under {self.root}" + ) - self.timestr = timestr - self.run_name = run_name - self.root = root - self.workdir = workdir + makedirs(self.workdir, exist_ok=True) self.logfile = logfile if logfile is not None: @@ -102,17 +74,6 @@ def __init__( ) logging.debug(f" ...logfile {self.logfile} to") - Output.instances[self.timestr] = self - - def updated_dict(self): - return dict( - timestr=self.timestr, - run_name=self.run_name, - root=self.root, - workdir=self.workdir, - logfile=self.logfile, - ) - def generate_file(self, file_name: str): """ only works with relative path. open a file @@ -122,7 +83,7 @@ def generate_file(self, file_name: str): raise ValueError("filename should be a relative path file name") file_name = f"{self.workdir}/{file_name}" - if isfile(file_name) and not (self.restart and self.append): + if isfile(file_name) and not self.append: raise RuntimeError( f"Tried to create file `{file_name}` but it already exists and either (1) append is disabled or (2) this run is not a restart" ) @@ -182,22 +143,15 @@ def as_dict(self): } @classmethod - def get_output(cls, timestr: str, kwargs: dict = {}): - if len(kwargs) == 0: - return cls.instances.get(timestr, cls(root="./")) - else: - if "timestr" in kwargs: - timestr = kwargs.get("timestr", "./") - if timestr in cls.instances: - return cls.instances[timestr] - - d = inspect.signature(cls.__init__) - _kwargs = { - key: kwargs.get(key, None) - for key in list(d.parameters.keys()) - if key not in ["self", "kwargs"] - } - return cls(**_kwargs) + def get_output(cls, kwargs: dict = {}): + + d = inspect.signature(cls.__init__) + _kwargs = { + key: kwargs.get(key, None) + for key in list(d.parameters.keys()) + if key not in ["self", "kwargs"] + } + return cls(**_kwargs) @classmethod def from_config(cls, config): diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py new file mode 100644 index 00000000..930af115 --- /dev/null +++ b/nequip/utils/regressor.py @@ -0,0 +1,181 @@ +import logging +import torch +import numpy as np +from typing import Optional +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import DotProduct, Kernel, Hyperparameter + + +def solver(X, y, regressor: Optional[str] = "NormalizedGaussianProcess", **kwargs): + if regressor == "GaussianProcess": + return gp(X, y, **kwargs) + elif regressor == "NormalizedGaussianProcess": + return normalized_gp(X, y, **kwargs) + else: + raise NotImplementedError(f"{regressor} is not implemented") + + +def normalized_gp(X, y, **kwargs): + feature_rms = 1.0 / np.sqrt(np.average(X ** 2, axis=0)) + feature_rms = np.nan_to_num(feature_rms, 1) + y_mean = torch.sum(y) / torch.sum(X) + mean, std = base_gp( + X, + y - (torch.sum(X, axis=1) * y_mean).reshape(y.shape), + NormalizedDotProduct, + {"diagonal_elements": feature_rms}, + **kwargs, + ) + return mean + y_mean, std + + +def gp(X, y, **kwargs): + return base_gp( + X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, **kwargs + ) + + +def base_gp( + X, + y, + kernel, + kernel_kwargs, + alpha: Optional[float] = 0.1, + max_iteration: int = 20, + stride: Optional[int] = None, +): + + if len(y.shape) == 1: + y = y.reshape([-1, 1]) + + if stride is not None: + X = X[::stride] + y = y[::stride] + + not_fit = True + iteration = 0 + mean = None + std = None + while not_fit: + logging.debug(f"GP fitting iteration {iteration} {alpha}") + try: + _kernel = kernel(**kernel_kwargs) + gpr = GaussianProcessRegressor(kernel=_kernel, random_state=0, alpha=alpha) + gpr = gpr.fit(X, y) + + vec = torch.diag(torch.ones(X.shape[1])) + mean, std = gpr.predict(vec, return_std=True) + + mean = torch.as_tensor(mean, dtype=torch.get_default_dtype()).reshape([-1]) + # ignore all the off-diagonal terms + std = torch.as_tensor(std, dtype=torch.get_default_dtype()).reshape([-1]) + likelihood = gpr.log_marginal_likelihood() + + res = torch.sqrt( + torch.square(torch.matmul(X, mean.reshape([-1, 1])) - y).mean() + ) + + logging.debug( + f"GP fitting: alpha {alpha}:\n" + f" residue {res}\n" + f" mean {mean} std {std}\n" + f" log marginal likelihood {likelihood}" + ) + not_fit = False + + except Exception as e: + logging.info(f"GP fitting failed for alpha={alpha} and {e.args}") + if alpha == 0 or alpha is None: + logging.info("try a non-zero alpha") + not_fit = False + raise ValueError( + f"Please set the {alpha} to non-zero value. \n" + "The dataset energy is rank deficient to be solved with GP" + ) + else: + alpha = alpha * 2 + iteration += 1 + logging.debug(f" increase alpha to {alpha}") + + if iteration >= max_iteration or not_fit is False: + raise ValueError( + "Please set the per species shift and scale to zeros and ones. \n" + "The dataset energy is to diverge to be solved with GP" + ) + + return mean, std + + +class NormalizedDotProduct(Kernel): + r"""Dot-Product kernel. + .. math:: + k(x_i, x_j) = x_i \cdot A \cdot x_j + """ + + def __init__(self, diagonal_elements): + # TO DO: check shape + self.diagonal_elements = diagonal_elements + self.A = np.diag(diagonal_elements) + + def __call__(self, X, Y=None, eval_gradient=False): + """Return the kernel k(X, Y) and optionally its gradient. + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + Left argument of the returned kernel k(X, Y) + Y : ndarray of shape (n_samples_Y, n_features), default=None + Right argument of the returned kernel k(X, Y). If None, k(X, X) + if evaluated instead. + eval_gradient : bool, default=False + Determines whether the gradient with respect to the log of + the kernel hyperparameter is computed. + Only supported when Y is None. + Returns + ------- + K : ndarray of shape (n_samples_X, n_samples_Y) + Kernel k(X, Y) + K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims),\ + optional + The gradient of the kernel k(X, X) with respect to the log of the + hyperparameter of the kernel. Only returned when `eval_gradient` + is True. + """ + X = np.atleast_2d(X) + if Y is None: + K = (X.dot(self.A)).dot(X.T) + else: + if eval_gradient: + raise ValueError("Gradient can only be evaluated when Y is None.") + K = (X.dot(self.A)).dot(Y.T) + + if eval_gradient: + return K, np.empty((X.shape[0], X.shape[0], 0)) + else: + return K + + def diag(self, X): + """Returns the diagonal of the kernel k(X, X). + The result of this method is identical to np.diag(self(X)); however, + it can be evaluated more efficiently since only the diagonal is + evaluated. + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + Left argument of the returned kernel k(X, Y). + Returns + ------- + K_diag : ndarray of shape (n_samples_X,) + Diagonal of kernel k(X, X). + """ + return np.einsum("ij,ij,jj->i", X, X, self.A) + + def __repr__(self): + return "" + + def is_stationary(self): + """Returns whether the kernel is stationary.""" + return False + + @property + def hyperparameter_diagonal_elements(self): + return Hyperparameter("diagonal_elements", "numeric", "fixed") diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 8c87a853..0980fef2 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -94,7 +94,8 @@ def load_file(supported_formats: dict, filename: str, enforced_format: str = Non format = enforced_format if not isfile(filename): - raise OSError(f"file {filename} is not found") + abs_path = str(Path(filename).resolve()) + raise OSError(f"file {filename} at {abs_path} is not found") if format == "json": import json diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 0085814b..da2fe4ef 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -1,11 +1,16 @@ -from typing import Union, Sequence, Set +from typing import Union import torch from e3nn import o3 from e3nn.util.test import assert_equivariant from nequip.nn import GraphModuleMixin -from nequip.data import AtomicData, AtomicDataDict +from nequip.data import ( + AtomicData, + AtomicDataDict, + _NODE_FIELDS, + _EDGE_FIELDS, +) PERMUTATION_FLOAT_TOLERANCE = {torch.float32: 1e-5, torch.float64: 1e-10} @@ -18,84 +23,18 @@ def _inverse_permutation(perm): return inv -_DEFAULT_NODE_PERMUTE_FIELDS: Set[str] = { - AtomicDataDict.POSITIONS_KEY, - AtomicDataDict.WEIGHTS_KEY, - AtomicDataDict.NODE_FEATURES_KEY, - AtomicDataDict.NODE_ATTRS_KEY, - AtomicDataDict.ATOMIC_NUMBERS_KEY, - AtomicDataDict.SPECIES_INDEX_KEY, - AtomicDataDict.FORCE_KEY, - AtomicDataDict.PER_ATOM_ENERGY_KEY, - AtomicDataDict.BATCH_KEY, -} -_DEFAULT_EDGE_PERMUTE_FIELDS: Set[str] = { - AtomicDataDict.EDGE_CELL_SHIFT_KEY, - AtomicDataDict.EDGE_VECTORS_KEY, - AtomicDataDict.EDGE_LENGTH_KEY, - AtomicDataDict.EDGE_ATTRS_KEY, - AtomicDataDict.EDGE_EMBEDDING_KEY, -} -_NODE_PERMUTE_FIELDS: Set[str] = set(_DEFAULT_NODE_PERMUTE_FIELDS) -_EDGE_PERMUTE_FIELDS: Set[str] = set(_DEFAULT_EDGE_PERMUTE_FIELDS) - - -def register_fields( - node_permute_fields: Sequence[str] = [], edge_permute_fields: Sequence[str] = [] -) -> None: - r"""Register a field as having specific properties for testing purposes. - - See ``assert_permutation_equivariant``. - - Args: - node_permute_fields: fields that are equivariant to node permutations. - edge_permute_fields: fields that are equivariant to edge permutations. - """ - node_permute_fields: set = set(node_permute_fields) - edge_permute_fields: set = set(edge_permute_fields) - assert node_permute_fields.isdisjoint( - edge_permute_fields - ), "Fields cannot be both node and edge equivariant" - assert (_NODE_PERMUTE_FIELDS.union(_EDGE_PERMUTE_FIELDS)).isdisjoint( - node_permute_fields.union(edge_permute_fields) - ), "Cannot reregister a field that has already been registered" - _NODE_PERMUTE_FIELDS.update(node_permute_fields) - _EDGE_PERMUTE_FIELDS.update(edge_permute_fields) - - -def deregister_fields(*fields: Sequence[str]) -> None: - r"""Deregister a field registered with ``register_fields``. - - Silently ignores fields that were never registered to begin with. - - Args: - *fields: fields to deregister. - """ - for f in fields: - assert f not in _DEFAULT_EDGE_PERMUTE_FIELDS, "Cannot deregister built-in field" - assert f not in _DEFAULT_NODE_PERMUTE_FIELDS, "Cannot deregister built-in field" - _NODE_PERMUTE_FIELDS.discard(f) - _EDGE_PERMUTE_FIELDS.discard(f) - - def assert_permutation_equivariant( - func: GraphModuleMixin, - data_in: AtomicDataDict.Type, - extra_node_permute_fields: Sequence[str] = [], - extra_edge_permute_fields: Sequence[str] = [], + func: GraphModuleMixin, data_in: AtomicDataDict.Type ): r"""Test the permutation equivariance of ``func``. - Standard fields are assumed to be equivariant to node or edge permutations according to their standard interpretions; all other fields are assumed to be invariant to all permutations. Non-standard fields can be registered as node/edge permutation equivariant using ``register_fields``, or can be provided directly in the - ``extra_node_permute_fields`` and ``extra_edge_permute_fields`` arguments. + Standard fields are assumed to be equivariant to node or edge permutations according to their standard interpretions; all other fields are assumed to be invariant to all permutations. Non-standard fields can be registered as node/edge permutation equivariant using ``register_fields``. Raises ``AssertionError`` if issues are found. Args: func: the module or model to test data_in: the example input data to test with - extra_node_permute_fields: names of non-standard fields that should be equivariant to permutations of the *node* ordering - extra_edge_permute_fields: names of non-standard fields that should be equivariant to permutations of the *edge* ordering """ # Prevent pytest from showing this function in the traceback # __tracebackhide__ = True @@ -105,20 +44,8 @@ def assert_permutation_equivariant( data_in = data_in.copy() device = data_in[AtomicDataDict.POSITIONS_KEY].device - # instead of doing fragile shape checks, just do a list of fields that permute - extra_node_permute_fields: Set[str] = set(extra_node_permute_fields) - extra_edge_permute_fields: Set[str] = set(extra_edge_permute_fields) - assert extra_edge_permute_fields.isdisjoint( - extra_node_permute_fields - ), "A field cannot be both edge and node permutation equivariant" - assert _EDGE_PERMUTE_FIELDS.isdisjoint( - extra_node_permute_fields - ), "Some member of extra_node_permute_fields is registered as an edge permutation equivariant" - assert _NODE_PERMUTE_FIELDS.isdisjoint( - extra_edge_permute_fields - ), "Some member of extra_edge_permute_fields is registered as an node permutation equivariant" - node_permute_fields = _NODE_PERMUTE_FIELDS.union(extra_node_permute_fields) - edge_permute_fields = _EDGE_PERMUTE_FIELDS.union(extra_edge_permute_fields) + node_permute_fields = _NODE_FIELDS + edge_permute_fields = _EDGE_FIELDS # Make permutations and make sure they are not identities n_node: int = len(data_in[AtomicDataDict.POSITIONS_KEY]) @@ -193,8 +120,6 @@ def assert_permutation_equivariant( def assert_AtomicData_equivariant( func: GraphModuleMixin, data_in: Union[AtomicData, AtomicDataDict.Type], - extra_node_permute_fields: Sequence[str] = [], - extra_edge_permute_fields: Sequence[str] = [], **kwargs, ): r"""Test the rotation, translation, parity, and permutation equivariance of ``func``. @@ -207,8 +132,6 @@ def assert_AtomicData_equivariant( Args: func: the module or model to test data_in: the example input data to test with - extra_node_permute_fields: see ``assert_permutation_equivariant`` - extra_edge_permute_fields: see ``assert_permutation_equivariant`` **kwargs: passed to ``e3nn.util.test.assert_equivariant`` Returns: @@ -224,8 +147,6 @@ def assert_AtomicData_equivariant( assert_permutation_equivariant( func, data_in, - extra_node_permute_fields=extra_node_permute_fields, - extra_edge_permute_fields=extra_edge_permute_fields, ) # == Test rotation, parity, and translation using e3nn == @@ -298,7 +219,7 @@ def set_irreps_debug(enabled: bool = False) -> None: pass import torch.nn.modules - from torch_geometric.data import Data + from nequip.utils.torch_geometric import Data def pre_hook(mod: GraphModuleMixin, inp): # __tracebackhide__ = True diff --git a/nequip/utils/torch_geometric/README.md b/nequip/utils/torch_geometric/README.md new file mode 100644 index 00000000..353d1901 --- /dev/null +++ b/nequip/utils/torch_geometric/README.md @@ -0,0 +1,10 @@ +# Trimmed-down `pytorch_geometric` + +NequIP uses the data format and code of the excellent [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. We use, however, only a very limited subset of that library: the most basic graph data structures. + +To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is neccessary for our code. + +We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch. + + [1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric + [2] https://arxiv.org/abs/1903.02428 \ No newline at end of file diff --git a/nequip/utils/torch_geometric/__init__.py b/nequip/utils/torch_geometric/__init__.py new file mode 100644 index 00000000..818ea494 --- /dev/null +++ b/nequip/utils/torch_geometric/__init__.py @@ -0,0 +1,5 @@ +from .batch import Batch +from .data import Data +from .dataset import Dataset + +__all__ = ["Batch", "Data", "Dataset"] diff --git a/nequip/utils/torch_geometric/batch.py b/nequip/utils/torch_geometric/batch.py new file mode 100644 index 00000000..301e6d8c --- /dev/null +++ b/nequip/utils/torch_geometric/batch.py @@ -0,0 +1,258 @@ +from typing import List + +from collections.abc import Sequence + +import torch +import numpy as np +from torch import Tensor + +from .data import Data +from .dataset import IndexType + + +class Batch(Data): + r"""A plain old python object modeling a batch of graphs as one big + (disconnected) graph. With :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + In addition, single graphs can be reconstructed via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + """ + + def __init__(self, batch=None, ptr=None, **kwargs): + super(Batch, self).__init__(**kwargs) + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + self.batch = batch + self.ptr = ptr + self.__data_class__ = Data + self.__slices__ = None + self.__cumsum__ = None + self.__cat_dims__ = None + self.__num_nodes_list__ = None + self.__num_graphs__ = None + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert "batch" not in keys and "ptr" not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != "__" and key[-2:] != "__": + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ["batch"]: + batch[key] = [] + batch["ptr"] = [0] + + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + cat_dims = {} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f"{key}_{j}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + else: + tmp = f"{key}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + + if hasattr(data, "__num_nodes__"): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long, device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + # if torch_geometric.is_debug_enabled(): + # batch.debug() + + return batch.contiguous() + + def get_example(self, idx: int) -> Data: + r"""Reconstructs the :class:`torch_geometric.data.Data` object at index + :obj:`idx` from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + + if self.__slices__ is None: + raise RuntimeError( + ( + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." + ) + ) + + data = self.__data_class__() + idx = self.num_graphs + idx if idx < 0 else idx + + for key in self.__slices__.keys(): + item = self[key] + if self.__cat_dims__[key] is None: + # The item was concatenated along a new batch dimension, + # so just index in that dimension: + item = item[idx] + else: + # Narrow the item based on the values in `__slices__`. + if isinstance(item, Tensor): + dim = self.__cat_dims__[key] + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item.narrow(dim, start, end - start) + else: + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item[start:end] + item = item[0] if len(item) == 1 else item + + # Decrease its value by `cumsum` value: + cum = self.__cumsum__[key][idx] + if isinstance(item, Tensor): + if not isinstance(cum, int) or cum != 0: + item = item - cum + elif isinstance(item, (int, float)): + item = item - cum + + data[key] = item + + if self.__num_nodes_list__[idx] is not None: + data.num_nodes = self.__num_nodes_list__[idx] + + return data + + def index_select(self, idx: IndexType) -> List[Data]: + if isinstance(idx, slice): + idx = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + idx = idx.flatten().tolist() + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + idx = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + pass + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in idx] + + def __getitem__(self, idx): + if isinstance(idx, str): + return super(Batch, self).__getitem__(idx) + elif isinstance(idx, (int, np.integer)): + return self.get_example(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Data]: + r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects + from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if self.__num_graphs__ is not None: + return self.__num_graphs__ + elif self.ptr is not None: + return self.ptr.numel() - 1 + elif self.batch is not None: + return int(self.batch.max()) + 1 + else: + raise ValueError diff --git a/nequip/utils/torch_geometric/data.py b/nequip/utils/torch_geometric/data.py new file mode 100644 index 00000000..3b737f49 --- /dev/null +++ b/nequip/utils/torch_geometric/data.py @@ -0,0 +1,441 @@ +import re +import copy +import collections + +import torch + +# from ..utils.num_nodes import maybe_num_nodes + +__num_nodes_warn_msg__ = ( + "The number of nodes in your data object can only be inferred by its {} " + "indices, and hence may result in unexpected batch-wise behavior, e.g., " + "in case there exists isolated nodes. Please consider explicitly setting " + "the number of nodes for this data object by assigning it to " + "data.num_nodes." +) + + +def size_repr(key, item, indent=0): + indent_str = " " * indent + if torch.is_tensor(item) and item.dim() == 0: + out = item.item() + elif torch.is_tensor(item): + out = str(list(item.size())) + elif isinstance(item, list) or isinstance(item, tuple): + out = str([len(item)]) + elif isinstance(item, dict): + lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] + out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" + elif isinstance(item, str): + out = f'"{item}"' + else: + out = str(item) + + return f"{indent_str}{key}={out}" + + +class Data(object): + r"""A plain old python object modeling a single graph with various + (optional) attributes: + + Args: + x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, + num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph or node targets with arbitrary shape. + (default: :obj:`None`) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + normal (Tensor, optional): Normal vector matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + face (LongTensor, optional): Face adjacency matrix with shape + :obj:`[3, num_faces]`. (default: :obj:`None`) + + The data object is not restricted to these attributes and can be extented + by any other additional data. + + Example:: + + data = Data(x=x, edge_index=edge_index) + data.train_idx = torch.tensor([...], dtype=torch.long) + data.test_mask = torch.tensor([...], dtype=torch.bool) + """ + + def __init__( + self, + x=None, + edge_index=None, + edge_attr=None, + y=None, + pos=None, + normal=None, + face=None, + **kwargs, + ): + self.x = x + self.edge_index = edge_index + self.edge_attr = edge_attr + self.y = y + self.pos = pos + self.normal = normal + self.face = face + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + if edge_index is not None and edge_index.dtype != torch.long: + raise ValueError( + ( + f"Argument `edge_index` needs to be of type `torch.long` but " + f"found type `{edge_index.dtype}`." + ) + ) + + if face is not None and face.dtype != torch.long: + raise ValueError( + ( + f"Argument `face` needs to be of type `torch.long` but found " + f"type `{face.dtype}`." + ) + ) + + @classmethod + def from_dict(cls, dictionary): + r"""Creates a data object from a python dictionary.""" + data = cls() + + for key, item in dictionary.items(): + data[key] = item + + return data + + def to_dict(self): + return {key: item for key, item in self} + + def to_namedtuple(self): + keys = self.keys + DataTuple = collections.namedtuple("DataTuple", keys) + return DataTuple(*[self[key] for key in keys]) + + def __getitem__(self, key): + r"""Gets the data of the attribute :obj:`key`.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Sets the attribute :obj:`key` to :obj:`value`.""" + setattr(self, key, value) + + def __delitem__(self, key): + r"""Delete the data of the attribute :obj:`key`.""" + return delattr(self, key) + + @property + def keys(self): + r"""Returns all names of graph attributes.""" + keys = [key for key in self.__dict__.keys() if self[key] is not None] + keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] + return keys + + def __len__(self): + r"""Returns the number of all present attributes.""" + return len(self.keys) + + def __contains__(self, key): + r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the + data.""" + return key in self.keys + + def __iter__(self): + r"""Iterates over all present attributes in the data, yielding their + attribute names and content.""" + for key in sorted(self.keys): + yield key, self[key] + + def __call__(self, *keys): + r"""Iterates over all attributes :obj:`*keys` in the data, yielding + their attribute names and content. + If :obj:`*keys` is not given this method will iterative over all + present attributes.""" + for key in sorted(self.keys) if not keys else keys: + if key in self: + yield key, self[key] + + def __cat_dim__(self, key, value): + r"""Returns the dimension for which :obj:`value` of attribute + :obj:`key` will get concatenated when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + if bool(re.search("(index|face)", key)): + return -1 + return 0 + + def __inc__(self, key, value): + r"""Returns the incremental count to cumulatively increase the value + of the next attribute of :obj:`key` when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + # Only `*index*` and `*face*` attributes should be cumulatively summed + # up when creating batches. + return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + + @property + def num_nodes(self): + r"""Returns or sets the number of nodes in the graph. + + .. note:: + The number of nodes in your data object is typically automatically + inferred, *e.g.*, when node features :obj:`x` are present. + In some cases however, a graph may only be given by its edge + indices :obj:`edge_index`. + PyTorch Geometric then *guesses* the number of nodes + according to :obj:`edge_index.max().item() + 1`, but in case there + exists isolated nodes, this number has not to be correct and can + therefore result in unexpected batch-wise behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + if hasattr(self, "__num_nodes__"): + return self.__num_nodes__ + for key, item in self("x", "pos", "normal", "batch"): + return item.size(self.__cat_dim__(key, item)) + if hasattr(self, "adj"): + return self.adj.size(0) + if hasattr(self, "adj_t"): + return self.adj_t.size(1) + # if self.face is not None: + # logging.warning(__num_nodes_warn_msg__.format("face")) + # return maybe_num_nodes(self.face) + # if self.edge_index is not None: + # logging.warning(__num_nodes_warn_msg__.format("edge")) + # return maybe_num_nodes(self.edge_index) + return None + + @num_nodes.setter + def num_nodes(self, num_nodes): + self.__num_nodes__ = num_nodes + + @property + def num_edges(self): + """ + Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + for key, item in self("edge_index", "edge_attr"): + return item.size(self.__cat_dim__(key, item)) + for key, item in self("adj", "adj_t"): + return item.nnz() + return None + + @property + def num_faces(self): + r"""Returns the number of faces in the mesh.""" + if self.face is not None: + return self.face.size(self.__cat_dim__("face", self.face)) + return None + + @property + def num_node_features(self): + r"""Returns the number of features per node in the graph.""" + if self.x is None: + return 0 + return 1 if self.x.dim() == 1 else self.x.size(1) + + @property + def num_features(self): + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self): + r"""Returns the number of features per edge in the graph.""" + if self.edge_attr is None: + return 0 + return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + + def __apply__(self, item, func): + if torch.is_tensor(item): + return func(item) + elif isinstance(item, (tuple, list)): + return [self.__apply__(v, func) for v in item] + elif isinstance(item, dict): + return {k: self.__apply__(v, func) for k, v in item.items()} + else: + return item + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + self[key] = self.__apply__(item, func) + return self + + def contiguous(self, *keys): + r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. + If :obj:`*keys` is not given, all present attributes are ensured to + have a contiguous memory layout.""" + return self.apply(lambda x: x.contiguous(), *keys) + + def to(self, device, *keys, **kwargs): + r"""Performs tensor dtype and/or device conversion to all attributes + :obj:`*keys`. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.to(device, **kwargs), *keys) + + def cpu(self, *keys): + r"""Copies all attributes :obj:`*keys` to CPU memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.cpu(), *keys) + + def cuda(self, device=None, non_blocking=False, *keys): + r"""Copies all attributes :obj:`*keys` to CUDA memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply( + lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys + ) + + def clone(self): + r"""Performs a deep-copy of the data object.""" + return self.__class__.from_dict( + { + k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) + for k, v in self.__dict__.items() + } + ) + + def pin_memory(self, *keys): + r"""Copies all attributes :obj:`*keys` to pinned memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.pin_memory(), *keys) + + def debug(self): + if self.edge_index is not None: + if self.edge_index.dtype != torch.long: + raise RuntimeError( + ( + "Expected edge indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.edge_index.dtype) + ) + + if self.face is not None: + if self.face.dtype != torch.long: + raise RuntimeError( + ( + "Expected face indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.face.dtype) + ) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: + raise RuntimeError( + ( + "Edge indices should have shape [2, num_edges] but found" + " shape {}" + ).format(self.edge_index.size()) + ) + + if self.edge_index is not None and self.num_nodes is not None: + if self.edge_index.numel() > 0: + min_index = self.edge_index.min() + max_index = self.edge_index.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Edge indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.face is not None: + if self.face.dim() != 2 or self.face.size(0) != 3: + raise RuntimeError( + ( + "Face indices should have shape [3, num_faces] but found" + " shape {}" + ).format(self.face.size()) + ) + + if self.face is not None and self.num_nodes is not None: + if self.face.numel() > 0: + min_index = self.face.min() + max_index = self.face.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Face indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.edge_index is not None and self.edge_attr is not None: + if self.edge_index.size(1) != self.edge_attr.size(0): + raise RuntimeError( + ( + "Edge indices and edge attributes hold a differing " + "number of edges, found {} and {}" + ).format(self.edge_index.size(), self.edge_attr.size()) + ) + + if self.x is not None and self.num_nodes is not None: + if self.x.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node features should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.x.size(0)) + ) + + if self.pos is not None and self.num_nodes is not None: + if self.pos.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node positions should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.pos.size(0)) + ) + + if self.normal is not None and self.num_nodes is not None: + if self.normal.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node normals should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.normal.size(0)) + ) + + def __repr__(self): + cls = str(self.__class__.__name__) + has_dict = any([isinstance(item, dict) for _, item in self]) + + if not has_dict: + info = [size_repr(key, item) for key, item in self] + return "{}({})".format(cls, ", ".join(info)) + else: + info = [size_repr(key, item, indent=2) for key, item in self] + return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/nequip/utils/torch_geometric/dataset.py b/nequip/utils/torch_geometric/dataset.py new file mode 100644 index 00000000..1a3401f3 --- /dev/null +++ b/nequip/utils/torch_geometric/dataset.py @@ -0,0 +1,282 @@ +from typing import List, Optional, Callable, Union, Any, Tuple + +import re +import copy +import warnings +import numpy as np +import os.path as osp +from collections.abc import Sequence + +import torch.utils.data +from torch import Tensor + +from .data import Data +from .utils import makedirs + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class Dataset(torch.utils.data.Dataset): + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (string, optional): Root directory where the dataset should be + saved. (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.raw_dir` folder in + order to skip the download.""" + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + raise NotImplementedError + + def download(self): + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self): + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + raise NotImplementedError + + def get(self, idx: int) -> Data: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + super().__init__() + + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self._indices: Optional[Sequence] = None + + if "download" in self.__class__.__dict__.keys(): + self._download() + + if "process" in self.__class__.__dict__.keys(): + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return osp.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + @property + def raw_paths(self) -> List[str]: + r"""The filepaths to find in order to skip the download.""" + files = to_list(self.raw_file_names) + return [osp.join(self.raw_dir, f) for f in files] + + @property + def processed_paths(self) -> List[str]: + r"""The filepaths to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + files = to_list(self.processed_file_names) + return [osp.join(self.processed_dir, f) for f in files] + + def _download(self): + if files_exist(self.raw_paths): # pragma: no cover + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + f = osp.join(self.processed_dir, "pre_transform.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"sure to delete '{self.processed_dir}' first" + ) + + f = osp.join(self.processed_dir, "pre_filter.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in the " + "pre-processed version of this dataset. If you want to make " + "use of another pre-fitering technique, make sure to delete " + "'{self.processed_dir}' first" + ) + + if files_exist(self.processed_paths): # pragma: no cover + return + + print("Processing...") + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, "pre_transform.pt") + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, "pre_filter.pt") + torch.save(_repr(self.pre_filter), path) + + print("Done!") + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Data]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy + :obj:`np.array`, will return a subset of the dataset at the specified + indices.""" + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def index_select(self, idx: IndexType) -> "Dataset": + indices = self.indices() + + if isinstance(idx, slice): + indices = indices[idx] + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False) + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will return + the random permutation used to shuffle the dataset in addition. + (default: :obj:`False`) + """ + perm = torch.randperm(len(self)) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def files_exist(files: List[str]) -> bool: + # NOTE: We return `False` in case `files` is empty, leading to a + # re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/nequip/utils/torch_geometric/utils.py b/nequip/utils/torch_geometric/utils.py new file mode 100644 index 00000000..c56ab0fe --- /dev/null +++ b/nequip/utils/torch_geometric/utils.py @@ -0,0 +1,54 @@ +import ssl +import os +import os.path as osp +import urllib +import zipfile + + +def makedirs(dir): + os.makedirs(dir, exist_ok=True) + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2].split("?")[0] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + with zipfile.ZipFile(path, "r") as f: + f.extractall(folder) diff --git a/nequip/utils/wandb.py b/nequip/utils/wandb.py index 8926097e..2391a9f4 100644 --- a/nequip/utils/wandb.py +++ b/nequip/utils/wandb.py @@ -1,4 +1,3 @@ -import os import wandb import logging from wandb.util import json_friendly_val diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..e188ebe4 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 127 +select = E,F,W,C +ignore = E226,E501,E741,E743,C901,W503,E203 +exclude = .eggs,*.egg,build,dist,docs diff --git a/setup.py b/setup.py index 0322ebf4..d01cf27e 100644 --- a/setup.py +++ b/setup.py @@ -20,9 +20,8 @@ # make the scripts available as command line scripts "console_scripts": [ "nequip-train = nequip.scripts.train:main", - "nequip-restart = nequip.scripts.restart:main", - "nequip-requeue = nequip.scripts.requeue:main", "nequip-evaluate = nequip.scripts.evaluate:main", + "nequip-benchmark = nequip.scripts.benchmark:main", "nequip-deploy = nequip.scripts.deploy:main", ] }, @@ -30,13 +29,14 @@ "numpy", "ase", "tqdm", - "torch>=1.8", - "torch_geometric==1.7.2", - "e3nn==0.3.5", + "torch>=1.8,<1.11", # torch.fx added in 1.8 + "e3nn>=0.3.5,<0.5.0", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext - "typing_extensions;python_version<'3.8'", - "torch-runstats", + "typing_extensions;python_version<'3.8'", # backport of Final + "torch-runstats>=0.2.0", + "torch-ema>=0.3.0", + "scikit_learn", # for GaussianProcess for per-species statistics ], zip_safe=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index 4a03bcb3..0222a8a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import numpy as np import pathlib import pytest @@ -10,10 +10,11 @@ from ase.io import write import torch -from torch_geometric.data import Batch from nequip.utils.test import set_irreps_debug from nequip.data import AtomicData, ASEDataset +from nequip.data.transforms import TypeMapper +from nequip.utils.torch_geometric import Batch # For good practice, we *should* do this: # See https://docs.pytest.org/en/stable/fixture.html#using-fixtures-from-other-projects @@ -54,7 +55,15 @@ def temp_data(float_tolerance): @pytest.fixture(scope="session") -def CH3CHO(float_tolerance) -> AtomicData: +def CH3CHO(CH3CHO_no_typemap) -> Tuple[Atoms, AtomicData]: + atoms, data = CH3CHO_no_typemap + tm = TypeMapper(chemical_symbol_to_type={"C": 0, "O": 1, "H": 2}) + data = tm(data) + return atoms, data + + +@pytest.fixture(scope="session") +def CH3CHO_no_typemap(float_tolerance) -> Tuple[Atoms, AtomicData]: atoms = molecule("CH3CHO") data = AtomicData.from_ase(atoms, r_max=2.0) return atoms, data @@ -87,13 +96,14 @@ def nequip_dataset(molecules, temp_data, float_tolerance): root=temp_data, extra_fixed_fields={"r_max": 3.0}, ase_args=dict(format="extxyz"), + type_mapper=TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}), ) yield a @pytest.fixture(scope="session") def atomic_batch(nequip_dataset): - return Batch.from_data_list([nequip_dataset.data[0], nequip_dataset.data[1]]) + return Batch.from_data_list([nequip_dataset[0], nequip_dataset[1]]) # Use debug mode diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py deleted file mode 100644 index c75b5d82..00000000 --- a/tests/data/test_dataset.py +++ /dev/null @@ -1,194 +0,0 @@ -import numpy as np -import pytest -import tempfile -import torch - -from os.path import isdir - -from ase.io import write - -from nequip.data import ( - AtomicDataDict, - AtomicInMemoryDataset, - NpzDataset, - ASEDataset, -) -from nequip.utils import dataset_from_config, Config - - -@pytest.fixture(scope="module") -def ase_file(molecules): - with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: - for atoms in molecules: - write(fp.name, atoms, format="extxyz", append=True) - yield fp.name - - -@pytest.fixture(scope="session") -def npz(): - natoms = 3 - nframes = 4 - yield dict( - positions=np.random.random((nframes, natoms, 3)), - force=np.random.random((nframes, natoms, 3)), - energy=np.random.random(nframes), - Z=np.random.randint(1, 8, size=(nframes, natoms)), - ) - - -@pytest.fixture(scope="session") -def npz_data(npz): - with tempfile.NamedTemporaryFile(suffix=".npz") as path: - np.savez(path.name, **npz) - yield path.name - - -@pytest.fixture(scope="session") -def npz_dataset(npz_data, temp_data): - a = NpzDataset( - file_name=npz_data, - root=temp_data + "/test_dataset", - extra_fixed_fields={"r_max": 3}, - ) - yield a - - -@pytest.fixture(scope="function") -def root(): - with tempfile.TemporaryDirectory(prefix="datasetroot") as path: - yield path - - -class TestInit: - def test_init(self): - with pytest.raises(NotImplementedError) as excinfo: - a = AtomicInMemoryDataset(root=None) - assert str(excinfo.value) == "" - - def test_npz(self, npz_data, root): - g = NpzDataset(file_name=npz_data, root=root, extra_fixed_fields={"r_max": 3.0}) - assert isdir(g.root) - assert isdir(f"{g.root}/processed") - - def test_ase(self, ase_file, root): - a = ASEDataset( - file_name=ase_file, - root=root, - extra_fixed_fields={"r_max": 3.0}, - ase_args=dict(format="extxyz"), - ) - assert isdir(a.root) - assert isdir(f"{a.root}/processed") - - -class TestStatistics: - @pytest.mark.xfail( - reason="Current subset hack doesn't support statistics of non-per-node callable" - ) - def test_callable(self, npz_dataset, npz): - # Get componentwise statistics - ((f_mean, f_std),) = npz_dataset.statistics( - [lambda d: torch.flatten(d[AtomicDataDict.FORCE_KEY])] - ) - n_ex, n_at, _ = npz["force"].shape - f_raveled = npz["force"].reshape((n_ex * n_at * 3,)) - assert np.allclose(np.mean(f_raveled), f_mean) - # By default we follow torch convention of defaulting to the unbiased std - assert np.allclose(np.std(f_raveled, ddof=1), f_std) - - def test_statistics(self, npz_dataset, npz): - - (eng_mean, eng_std), (Z_unique, Z_count) = npz_dataset.statistics( - [AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY] - ) - - eng = npz["energy"] - assert np.allclose(eng_mean, np.mean(eng)) - # By default we follow torch convention of defaulting to the unbiased std - assert np.allclose(eng_std, np.std(eng, ddof=1)) - - if isinstance(Z_count, torch.Tensor): - Z_count = Z_count.numpy() - Z_unique = Z_unique.numpy() - - uniq, count = np.unique(npz["Z"].ravel(), return_counts=True) - assert np.all(Z_unique == uniq) - assert np.all(Z_count == count) - - def test_with_subset(self, npz_dataset, npz): - - dataset = npz_dataset.index_select([0]) - - ((Z_unique, Z_count), (force_rms,)) = dataset.statistics( - [AtomicDataDict.ATOMIC_NUMBERS_KEY, AtomicDataDict.FORCE_KEY], - modes=["count", "rms"], - ) - print("npz", npz["Z"]) - - uniq, count = np.unique(npz["Z"][0].ravel(), return_counts=True) - assert np.all(Z_unique.numpy() == uniq) - assert np.all(Z_count.numpy() == count) - - assert np.allclose( - force_rms.numpy(), np.sqrt(np.mean(np.square(npz["force"][0]))) - ) - - -class TestReload: - @pytest.mark.parametrize("change_rmax", [0, 1]) - def test_reload(self, npz_dataset, npz_data, change_rmax): - r_max = npz_dataset.extra_fixed_fields["r_max"] + change_rmax - a = NpzDataset( - file_name=npz_data, - root=npz_dataset.root, - extra_fixed_fields={"r_max": r_max}, - ) - print(a.processed_file_names[0]) - print(npz_dataset.processed_file_names[0]) - assert (a.processed_file_names[0] == npz_dataset.processed_file_names[0]) == ( - change_rmax == 0 - ) - - -class TestFromConfig: - @pytest.mark.parametrize( - "args", - [ - dict(extra_fixed_fields={"r_max": 3.0}), - dict(dataset_extra_fixed_fields={"r_max": 3.0}), - dict(r_max=3.0), - dict(r_max=3.0, extra_fixed_fields={}), - ], - ) - def test_npz(self, npz_data, root, args): - config = Config(dict(dataset="npz", file_name=npz_data, root=root, **args)) - g = dataset_from_config(config) - assert g.fixed_fields["r_max"] == 3 - assert isdir(g.root) - assert isdir(f"{g.root}/processed") - - def test_ase(self, ase_file, root): - config = Config( - dict( - dataset="ASEDataset", - file_name=ase_file, - root=root, - extra_fixed_fields={"r_max": 3.0}, - ase_args=dict(format="extxyz"), - ) - ) - a = dataset_from_config(config) - assert isdir(a.root) - assert isdir(f"{a.root}/processed") - - -class TestFromList: - def test_from_atoms(self, molecules): - dataset = ASEDataset.from_atoms_list( - molecules, extra_fixed_fields={"r_max": 4.5} - ) - assert len(dataset) == len(molecules) - for i, mol in enumerate(molecules): - assert np.array_equal( - mol.get_atomic_numbers(), dataset.get(i).to_ase().get_atomic_numbers() - ) diff --git a/tests/datasets/test_simplesets.py b/tests/datasets/test_simplesets.py deleted file mode 100644 index d345b0c0..00000000 --- a/tests/datasets/test_simplesets.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest - -from os.path import isdir -from shutil import rmtree - -from nequip.utils.config import Config -from nequip.utils.auto_init import dataset_from_config - - -include_frames = [0, 1] - - -@pytest.mark.parametrize("name", ["aspirin"]) -def test_simple(name, temp_data, BENCHMARK_ROOT): - - config = Config( - dict( - dataset=name, - root=f"{temp_data}/{name}", - extra_fixed_fields={"r_max": 3}, - include_frames=include_frames, - ) - ) - - if name == "aspirin": - config.dataset_file_name = str(BENCHMARK_ROOT / "aspirin_ccsd-train.npz") - - a = dataset_from_config(config) - print(a.data) - print(a.fixed_fields) - assert isdir(config.root) - assert isdir(f"{config.root}/processed") - assert len(a.data.edge_index) == len(include_frames) - rmtree(f"{config.root}/processed") diff --git a/tests/scripts/test_deploy.py b/tests/integration/test_deploy.py similarity index 53% rename from tests/scripts/test_deploy.py rename to tests/integration/test_deploy.py index c0b32124..e132e413 100644 --- a/tests/scripts/test_deploy.py +++ b/tests/integration/test_deploy.py @@ -3,6 +3,7 @@ import pathlib import yaml import subprocess +import sys import numpy as np import torch @@ -10,10 +11,14 @@ import nequip from nequip.data import AtomicDataDict, AtomicData from nequip.scripts import deploy +from nequip.train import Trainer +from nequip.ase import NequIPCalculator -def test_deploy(nequip_dataset, BENCHMARK_ROOT): - +@pytest.mark.parametrize( + "device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +) +def test_deploy(nequip_dataset, BENCHMARK_ROOT, device): dtype = str(torch.get_default_dtype())[len("torch.") :] # if torch.cuda.is_available(): @@ -25,8 +30,9 @@ def test_deploy(nequip_dataset, BENCHMARK_ROOT): with tempfile.TemporaryDirectory() as tmpdir: # Save time run_name = "test_deploy" + dtype + root = "./" true_config["run_name"] = run_name - true_config["root"] = tmpdir + true_config["root"] = root true_config["dataset_file_name"] = str( BENCHMARK_ROOT / "aspirin_ccsd-train.npz" ) @@ -34,55 +40,78 @@ def test_deploy(nequip_dataset, BENCHMARK_ROOT): true_config["max_epochs"] = 1 true_config["n_train"] = 1 true_config["n_val"] = 1 - config_path = tmpdir + "/conf.yaml" - with open(config_path, "w+") as fp: + config_path = "conf.yaml" + with open(f"{tmpdir}/{config_path}", "w+") as fp: yaml.dump(true_config, fp) # Train model retcode = subprocess.run(["nequip-train", str(config_path)], cwd=tmpdir) retcode.check_returncode() # Deploy - deployed_path = tmpdir / pathlib.Path(f"deployed_{dtype}.pth") + deployed_path = pathlib.Path(f"deployed_{dtype}.pth") retcode = subprocess.run( - ["nequip-deploy", "build", f"{tmpdir}/{run_name}/", str(deployed_path)], + ["nequip-deploy", "build", f"{root}/{run_name}/", str(deployed_path)], cwd=tmpdir, ) retcode.check_returncode() + deployed_path = tmpdir / deployed_path assert deployed_path.is_file(), "Deploy didn't create file" # now test predictions the same - best_mod = torch.load(f"{tmpdir}/{run_name}/best_model.pth") - device = next(best_mod.parameters()).device - data = AtomicData.to_AtomicDataDict(nequip_dataset.get(0).to(device)) + best_mod, _ = Trainer.load_model_from_training_session( + traindir=f"{tmpdir}/{root}/{run_name}/", + model_name="best_model.pth", + device=device, + ) + best_mod.eval() + + data = AtomicData.to_AtomicDataDict(nequip_dataset[0].to(device)) # Needed because of debug mode: data[AtomicDataDict.TOTAL_ENERGY_KEY] = data[ AtomicDataDict.TOTAL_ENERGY_KEY ].unsqueeze(0) - train_pred = best_mod(data)[AtomicDataDict.TOTAL_ENERGY_KEY] + train_pred = best_mod(data)[AtomicDataDict.TOTAL_ENERGY_KEY].to("cpu") # load model and check that metadata saved - metadata = { - deploy.NEQUIP_VERSION_KEY: "", - deploy.R_MAX_KEY: "", - } - deploy_mod = torch.jit.load( - deployed_path, _extra_files=metadata, map_location="cpu" + # TODO: use both CPU and CUDA to load? + deploy_mod, metadata = deploy.load_deployed_model( + deployed_path, + device="cpu", + set_global_options=False, # don't need this corrupting test environment ) # Everything we store right now is ASCII, so decode for printing - metadata = {k: v.decode("ascii") for k, v in metadata.items()} assert metadata[deploy.NEQUIP_VERSION_KEY] == nequip.__version__ assert np.allclose(float(metadata[deploy.R_MAX_KEY]), true_config["r_max"]) + assert len(metadata[deploy.TYPE_NAMES_KEY].split(" ")) == 3 # C, H, O - data = AtomicData.to_AtomicDataDict(nequip_dataset.get(0).to("cpu")) + data_idx = 0 + data = AtomicData.to_AtomicDataDict(nequip_dataset[data_idx].to("cpu")) deploy_pred = deploy_mod(data)[AtomicDataDict.TOTAL_ENERGY_KEY] - assert torch.allclose(train_pred.to("cpu"), deploy_pred, atol=1e-7) + assert torch.allclose(train_pred, deploy_pred, atol=1e-7) # now test info + # hack for old version + if sys.version_info[1] > 6: + text = {"text": True} + else: + text = {} retcode = subprocess.run( ["nequip-deploy", "info", str(deployed_path)], - text=True, stdout=subprocess.PIPE, + **text, ) retcode.check_returncode() # Try to load extract config config = yaml.load(retcode.stdout, Loader=yaml.Loader) del config + + # Finally, try to load in ASE + calc = NequIPCalculator.from_deployed_model( + deployed_path, + device="cpu", + species_to_type_name={s: s for s in ("C", "H", "O")}, + ) + # use .get() so it's not transformed + atoms = nequip_dataset.get(data_idx).to_ase() + atoms.calc = calc + ase_forces = atoms.get_potential_energy() + assert torch.allclose(train_pred, torch.as_tensor(ase_forces), atol=1e-7) diff --git a/tests/scripts/test_evaluate.py b/tests/integration/test_evaluate.py similarity index 77% rename from tests/scripts/test_evaluate.py rename to tests/integration/test_evaluate.py index 472de3aa..be8f65e5 100644 --- a/tests/scripts/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -8,6 +8,8 @@ import shutil import numpy as np +import ase.io + import torch from nequip.data import AtomicDataDict @@ -43,13 +45,13 @@ def training_session(request, BENCHMARK_ROOT, conffile): # Save time run_name = "test_train_" + dtype true_config["run_name"] = run_name - true_config["root"] = tmpdir + true_config["root"] = "./" true_config["dataset_file_name"] = str( BENCHMARK_ROOT / "aspirin_ccsd-train.npz" ) true_config["default_dtype"] = dtype true_config["max_epochs"] = 2 - true_config["model_builder"] = builder + true_config["model_builders"] = [builder] # to be a true identity, we can't have rescaling true_config["global_rescale_shift"] = None @@ -64,9 +66,7 @@ def training_session(request, BENCHMARK_ROOT, conffile): env["PYTHONPATH"] = ":".join( [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") ) - retcode = subprocess.run( - ["nequip-train", str(config_path)], cwd=tmpdir, env=env - ) + retcode = subprocess.run(["nequip-train", "conf.yaml"], cwd=tmpdir, env=env) retcode.check_returncode() yield builder, true_config, tmpdir, env @@ -75,11 +75,16 @@ def training_session(request, BENCHMARK_ROOT, conffile): @pytest.mark.parametrize("do_test_idcs", [True, False]) @pytest.mark.parametrize("do_metrics", [True, False]) def test_metrics(training_session, do_test_idcs, do_metrics): + builder, true_config, tmpdir, env = training_session # == Run test error == outdir = f"{true_config['root']}/{true_config['run_name']}/" - default_params = {"train-dir": outdir, "output": tmpdir + "/out.xyz"} + default_params = { + "train-dir": outdir, + "output": "out.xyz", + "log": "out.log", + } def runit(params: dict): tmp = default_params.copy() @@ -90,11 +95,12 @@ def runit(params: dict): ["nequip-evaluate"] + sum( (["--" + k, str(v)] for k, v in params.items() if v is not None), - start=[], + [], ), cwd=tmpdir, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) retcode.check_returncode() @@ -113,8 +119,8 @@ def runit(params: dict): # The Aspirin dataset is 1000 frames long # Pick some arbitrary number of frames test_idcs_arr = torch.randperm(1000)[:257] - test_idcs = tmpdir + "/some-test-idcs.pth" - torch.save(test_idcs_arr, test_idcs) + test_idcs = "some-test-idcs.pth" + torch.save(test_idcs_arr, f"{tmpdir}/{test_idcs}") else: test_idcs = None # ignore and use default default_params["test-indexes"] = test_idcs @@ -122,10 +128,10 @@ def runit(params: dict): # Metrics if do_metrics: # Write an explicit metrics file - metrics_yaml = tmpdir + "/my-metrics.yaml" - with open(metrics_yaml, "w") as f: + metrics_yaml = "my-metrics.yaml" + with open(f"{tmpdir}/{metrics_yaml}", "w") as f: # Write out a fancier metrics file - # We don't use PerSpecies here since the simple models don't fill SPECIES_INDEX right now + # We don't use PerSpecies here since the simple models don't fill ATOM_TYPE_KEY right now # ^ TODO! f.write( textwrap.dedent( @@ -148,6 +154,8 @@ def runit(params: dict): metrics = runit({"train-dir": outdir, "batch-size": 200, "device": "cpu"}) # move out.xyz to out-orig.xyz shutil.move(tmpdir + "/out.xyz", tmpdir + "/out-orig.xyz") + # Load it + orig_atoms = ase.io.read(tmpdir + "/out-orig.xyz", index=":", format="extxyz") assert set(metrics.keys()) == expect_metrics @@ -161,25 +169,26 @@ def runit(params: dict): # Check insensitive to batch size for batch_size in (13, 1000): metrics2 = runit( - {"train-dir": outdir, "batch-size": batch_size, "device": "cpu"} + { + "train-dir": outdir, + "batch-size": batch_size, + "device": "cpu", + "output": f"{batch_size}.xyz", + "log": f"{batch_size}.log", + } ) for k, v in metrics.items(): assert np.all(np.abs(v - metrics2[k]) < 1e-5) - # Diff the output XYZ, which shouldn't change at all - # Use `cmp`, which is UNIX standard, to make efficient - # See https://stackoverflow.com/questions/12900538/fastest-way-to-tell-if-two-files-have-the-same-contents-in-unix-linux - cmp_retval = subprocess.run( - ["cmp", "--silent", tmpdir + "/out-orig.xyz", tmpdir + "/out.xyz"] - ) - if cmp_retval.returncode == 0: - # same - pass - if cmp_retval.returncode == 1: - raise AssertionError( - f"Changing batch size to {batch_size} changed out.xyz!" + + # Check the output XYZ + batch_atoms = ase.io.read(tmpdir + "/out-orig.xyz", index=":", format="extxyz") + for origframe, newframe in zip(orig_atoms, batch_atoms): + assert np.allclose(origframe.get_positions(), newframe.get_positions()) + assert np.array_equal( + origframe.get_atomic_numbers(), newframe.get_atomic_numbers() ) - else: - cmp_retval.check_returncode() # error out for subprocess problem + assert np.array_equal(origframe.get_pbc(), newframe.get_pbc()) + assert np.array_equal(origframe.get_cell(), newframe.get_cell()) # Check GPU if torch.cuda.is_available(): diff --git a/tests/scripts/test_train.py b/tests/integration/test_train.py similarity index 69% rename from tests/scripts/test_train.py rename to tests/integration/test_train.py index c6d798b1..72d7ecb1 100644 --- a/tests/scripts/test_train.py +++ b/tests/integration/test_train.py @@ -9,7 +9,7 @@ import torch from nequip.data import AtomicDataDict -from nequip.nn import GraphModuleMixin, RescaleOutput +from nequip.nn import GraphModuleMixin class IdentityModel(GraphModuleMixin, torch.nn.Module): @@ -19,7 +19,7 @@ def __init__(self, **kwargs): irreps_in={ AtomicDataDict.TOTAL_ENERGY_KEY: "0e", AtomicDataDict.FORCE_KEY: "1o", - } + }, ) self.one = torch.nn.Parameter(torch.as_tensor(1.0)) @@ -38,7 +38,7 @@ def __init__(self, **kwargs): irreps_in={ AtomicDataDict.TOTAL_ENERGY_KEY: "0e", AtomicDataDict.FORCE_KEY: "1o", - } + }, ) # to keep the optimizer happy: self.dummy = torch.nn.Parameter(torch.zeros(1)) @@ -61,7 +61,7 @@ def __init__(self, **kwargs): irreps_in={ AtomicDataDict.TOTAL_ENERGY_KEY: "0e", AtomicDataDict.FORCE_KEY: "1o", - } + }, ) # By using a big factor, we keep it in a nice descending part # of the optimization without too much oscilation in loss at @@ -77,16 +77,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: @pytest.mark.parametrize( - "conffile,field", + "conffile", [ - ("minimal.yaml", AtomicDataDict.FORCE_KEY), - ("minimal_eng.yaml", AtomicDataDict.TOTAL_ENERGY_KEY), + "minimal.yaml", + "minimal_eng.yaml", ], ) @pytest.mark.parametrize( "builder", [IdentityModel, ConstFactorModel, LearningFactorModel] ) -def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): +def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, builder): dtype = str(torch.get_default_dtype())[len("torch.") :] @@ -97,21 +97,19 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): path_to_this_file = pathlib.Path(__file__) config_path = path_to_this_file.parents[2] / f"configs/{conffile}" true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + with tempfile.TemporaryDirectory() as tmpdir: # Save time run_name = "test_train_" + dtype true_config["run_name"] = run_name - true_config["root"] = tmpdir + true_config["root"] = "./" true_config["dataset_file_name"] = str( BENCHMARK_ROOT / "aspirin_ccsd-train.npz" ) true_config["default_dtype"] = dtype true_config["max_epochs"] = 2 - true_config["model_builder"] = builder - - # to be a true identity, we can't have rescaling - true_config["global_rescale_shift"] = None - true_config["global_rescale_scale"] = None + # We just don't add rescaling: + true_config["model_builders"] = [builder] config_path = tmpdir + "/conf.yaml" with open(config_path, "w+") as fp: @@ -122,16 +120,22 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): env["PYTHONPATH"] = ":".join( [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") ) + retcode = subprocess.run( - ["nequip-train", str(config_path)], cwd=tmpdir, env=env + ["nequip-train", "conf.yaml"], + cwd=tmpdir, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) retcode.check_returncode() # == Load metrics == - outdir = f"{true_config['root']}/{true_config['run_name']}/" + outdir = f"{tmpdir}/{true_config['root']}/{run_name}/" if builder == IdentityModel or builder == LearningFactorModel: for which in ("train", "val"): + dat = np.genfromtxt( f"{outdir}/metrics_batch_{which}.csv", delimiter=",", @@ -161,10 +165,11 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): for field in dat.dtype.names: if field == "epoch" or field == "wall" or field == "LR": continue + # Everything else should be a loss or a metric if builder == IdentityModel: assert np.allclose( - dat[field], 0.0 + dat[field][1:], 0.0 ), f"Loss/metric `{field}` wasn't all equal to zero for epoch" elif builder == ConstFactorModel: # otherwise just check its constant. @@ -179,14 +184,78 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): # == Check model == model = torch.load(outdir + "/last_model.pth") - assert isinstance( - model, RescaleOutput - ) # make sure trainer and this test aren't out of sync if builder == IdentityModel: - one = model.model.one + one = model["one"] # Since the loss is always zero, even though the constant # 1 was trainable, it shouldn't have changed assert torch.allclose( one, torch.ones(1, device=one.device, dtype=one.dtype) ) + + +@pytest.mark.parametrize( + "conffile", + [ + "minimal.yaml", + "minimal_eng.yaml", + ], +) +def test_requeue(nequip_dataset, BENCHMARK_ROOT, conffile): + + builder = IdentityModel + dtype = str(torch.get_default_dtype())[len("torch.") :] + + # if torch.cuda.is_available(): + # # TODO: is this true? + # pytest.skip("CUDA and subprocesses have issues") + + path_to_this_file = pathlib.Path(__file__) + config_path = path_to_this_file.parents[2] / f"configs/{conffile}" + true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + + with tempfile.TemporaryDirectory() as tmpdir: + + run_name = "test_requeue_" + dtype + true_config["run_name"] = run_name + true_config["append"] = True + true_config["root"] = "./" + true_config["dataset_file_name"] = str( + BENCHMARK_ROOT / "aspirin_ccsd-train.npz" + ) + true_config["default_dtype"] = dtype + # We just don't add rescaling: + true_config["model_builders"] = [builder] + + for irun in range(3): + + true_config["max_epochs"] = 2 * (irun + 1) + config_path = tmpdir + "/conf.yaml" + with open(config_path, "w+") as fp: + yaml.dump(true_config, fp) + + # == Train model == + env = dict(os.environ) + # make this script available so model builders can be loaded + env["PYTHONPATH"] = ":".join( + [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") + ) + + retcode = subprocess.run( + ["nequip-train", "conf.yaml"], + cwd=tmpdir, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + retcode.check_returncode() + + # == Load metrics == + dat = np.genfromtxt( + f"{tmpdir}/{run_name}/metrics_epoch.csv", + delimiter=",", + names=True, + dtype=None, + ) + + assert len(dat["epoch"]) == true_config["max_epochs"] diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/trainer/test_loss.py b/tests/trainer/test_loss.py deleted file mode 100644 index 9b0322cf..00000000 --- a/tests/trainer/test_loss.py +++ /dev/null @@ -1,137 +0,0 @@ -import pytest -import torch -from nequip.data import AtomicDataDict -from nequip.train import Loss - -# all the config to test init -# only the last one will be used to test the loss and mae -dicts = ( - {AtomicDataDict.TOTAL_ENERGY_KEY: (3.0, "MSELoss")}, - {AtomicDataDict.TOTAL_ENERGY_KEY: "MSELoss"}, - [AtomicDataDict.FORCE_KEY, AtomicDataDict.TOTAL_ENERGY_KEY], - {AtomicDataDict.FORCE_KEY: (1.0, "PerSpeciesMSELoss")}, - {AtomicDataDict.FORCE_KEY: (1.0), "k": (1.0, torch.nn.L1Loss())}, - AtomicDataDict.TOTAL_ENERGY_KEY, - { - AtomicDataDict.TOTAL_ENERGY_KEY: (3.0, "L1Loss"), - AtomicDataDict.FORCE_KEY: (1.0), - "k": 1.0, - }, -) - - -class TestInit: - @pytest.mark.parametrize("loss", dicts, indirect=True) - def test_init(self, loss): - - assert len(loss.funcs) == len(loss.coeffs) - for key, value in loss.coeffs.items(): - assert isinstance(value, torch.Tensor) - - -class TestLoss: - @pytest.mark.parametrize("loss", dicts[-2:], indirect=True) - def test_loss(self, loss, data): - - pred, ref = data - - loss_value = loss(pred, ref) - - loss_value, contrib = loss_value - assert len(contrib) > 0 - assert isinstance(contrib, dict) - for key, value in contrib.items(): - assert isinstance(value, torch.Tensor) - - assert isinstance(loss_value, torch.Tensor) - - -class TestWeight: - def test_loss(self, data): - - pred, ref = data - - loss = Loss(coeffs=dicts[-1], atomic_weight_on=False) - w_loss = Loss(coeffs=dicts[-1], atomic_weight_on=True) - - w_l, w_contb = w_loss(pred, ref) - l, contb = loss(pred, ref) - - assert isinstance(w_l, torch.Tensor) - assert not torch.isclose(w_l, l) - assert torch.isclose( - w_contb[AtomicDataDict.FORCE_KEY], contb[AtomicDataDict.FORCE_KEY] - ) - - def test_per_specie(self, data): - - pred, ref = data - - config = {AtomicDataDict.FORCE_KEY: (1.0, "PerSpeciesMSELoss")} - loss = Loss(coeffs=config, atomic_weight_on=False) - w_loss = Loss(coeffs=config, atomic_weight_on=True) - - w_l, w_contb = w_loss(pred, ref) - l, contb = loss(pred, ref) - - # first half data are specie 1 - # loss_ref_1 = torch.square(pred[AtomicDataDict.FORCE_KEY][:5] - ref[AtomicDataDict.FORCE_KEY][:5]).mean() - # loss_ref_0 = torch.square(pred[AtomicDataDict.FORCE_KEY][5:] - ref[AtomicDataDict.FORCE_KEY][5:]).mean() - - # since atomic weights are all the same value, - # the two loss should have the same result - assert isinstance(w_l, torch.Tensor) - print(w_l) - print(l) - assert torch.isclose(w_l, l) - - for c in [w_contb, contb]: - for key, value in c.items(): - assert key in [AtomicDataDict.FORCE_KEY] - - assert torch.allclose( - w_contb[AtomicDataDict.FORCE_KEY], contb[AtomicDataDict.FORCE_KEY] - ) - # assert torch.isclose(w_contb[1][AtomicDataDict.FORCE_KEY], loss_ref_1) - # assert torch.isclose(w_contb[0][AtomicDataDict.FORCE_KEY], loss_ref_0) - - -@pytest.fixture(scope="class") -def loss(request): - """""" - d = request.param - instance = Loss(coeffs=d, atomic_weight_on=False) - yield instance - - -@pytest.fixture(scope="class") -def w_loss(): - """""" - instance = Loss(coeffs=dicts[-1], atomic_weight_on=True) - yield instance - - -@pytest.fixture(scope="module") -def data(float_tolerance): - """""" - pred = { - AtomicDataDict.FORCE_KEY: torch.rand(10, 3), - AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), - "k": torch.rand((2, 1)), - AtomicDataDict.SPECIES_INDEX_KEY: torch.as_tensor( - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0] - ), - } - ref = { - AtomicDataDict.FORCE_KEY: torch.rand(10, 3), - AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), - "k": torch.rand((2, 1)), - AtomicDataDict.SPECIES_INDEX_KEY: torch.as_tensor( - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0] - ), - } - ref[AtomicDataDict.WEIGHTS_KEY + AtomicDataDict.FORCE_KEY] = 2 * torch.ones((10, 1)) - ref[AtomicDataDict.WEIGHTS_KEY + AtomicDataDict.TOTAL_ENERGY_KEY] = torch.rand( - (2, 1) - ) - yield pred, ref diff --git a/tests/data/test_AtomicData.py b/tests/unit/data/test_AtomicData.py similarity index 97% rename from tests/data/test_AtomicData.py rename to tests/unit/data/test_AtomicData.py index 4af9cd27..f5d6dc27 100644 --- a/tests/data/test_AtomicData.py +++ b/tests/unit/data/test_AtomicData.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch_geometric.data import Batch +from nequip.utils.torch_geometric import Batch import ase.build import ase.geometry @@ -21,8 +21,8 @@ def test_from_ase(CuFcc): assert data[key].shape == (len(atoms), 3) # 4 species in this atoms -def test_to_ase(CH3CHO): - atoms, data = CH3CHO +def test_to_ase(CH3CHO_no_typemap): + atoms, data = CH3CHO_no_typemap to_ase_atoms = data.to_ase() assert np.allclose(atoms.get_positions(), to_ase_atoms.get_positions()) assert np.array_equal(atoms.get_atomic_numbers(), to_ase_atoms.get_atomic_numbers()) @@ -39,7 +39,7 @@ def test_to_ase_batches(atomic_batch): assert np.allclose(atoms.get_positions(), atomic_data.pos[mask]) assert atoms.get_atomic_numbers().shape == (len(atoms),) assert np.array_equal( - atoms.get_atomic_numbers(), atomic_data.atomic_numbers[mask] + atoms.get_atomic_numbers(), atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask] ) assert np.array_equal(atoms.get_cell(), atomic_data.cell[batch_idx]) assert np.array_equal(atoms.get_pbc(), atomic_data.pbc[batch_idx]) diff --git a/tests/data/test_dataloader.py b/tests/unit/data/test_dataloader.py similarity index 90% rename from tests/data/test_dataloader.py rename to tests/unit/data/test_dataloader.py index 871ecf96..5fbeeb93 100644 --- a/tests/data/test_dataloader.py +++ b/tests/unit/data/test_dataloader.py @@ -9,12 +9,10 @@ class TestInit: def test_init(self, npz_dataset): - dl = DataLoader( - npz_dataset, batch_size=2, shuffle=True, exclude_keys=["energy"] - ) + DataLoader(npz_dataset, batch_size=2, shuffle=True, exclude_keys=["energy"]) def test_subset(self, npz_dataset): - subset = npz_dataset[[1, 3]] + npz_dataset[[1, 3]] class TestLoop: @@ -24,7 +22,7 @@ def test_whole(self, dloader): print(batch) def test_non_divisor(self, npz_dataset): - dataset = [npz_dataset.get(i) for i in range(7)] # make it odd length + dataset = [npz_dataset[i] for i in range(7)] # make it odd length dl = DataLoader(dataset, batch_size=2, shuffle=True, exclude_keys=["energy"]) dl_iter = iter(dl) for _ in range(3): diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py new file mode 100644 index 00000000..98aa635c --- /dev/null +++ b/tests/unit/data/test_dataset.py @@ -0,0 +1,397 @@ +import numpy as np +import pytest +import tempfile +import torch + +from os.path import isdir, isfile + +from ase.data import chemical_symbols +from ase.io import write + +from nequip.data import ( + AtomicDataDict, + AtomicInMemoryDataset, + NpzDataset, + ASEDataset, + dataset_from_config, +) +from nequip.data.transforms import TypeMapper +from nequip.utils import Config + + +@pytest.fixture(scope="module") +def ase_file(molecules): + with tempfile.NamedTemporaryFile(suffix=".xyz") as fp: + for atoms in molecules: + write(fp.name, atoms, format="extxyz", append=True) + yield fp.name + + +MAX_ATOMIC_NUMBER: int = 5 +NATOMS = 3 + + +@pytest.fixture(scope="function") +def npz(): + natoms = NATOMS + nframes = 8 + yield dict( + positions=np.random.random((nframes, natoms, 3)), + force=np.random.random((nframes, natoms, 3)), + energy=np.random.random(nframes) * -600, + Z=np.random.randint(1, MAX_ATOMIC_NUMBER, size=(nframes, natoms)), + ) + + +@pytest.fixture(scope="function") +def npz_data(npz): + with tempfile.NamedTemporaryFile(suffix=".npz") as path: + np.savez(path.name, **npz) + yield path.name + + +@pytest.fixture(scope="function") +def npz_dataset(npz_data, temp_data): + a = NpzDataset( + file_name=npz_data, + root=temp_data + "/test_dataset", + extra_fixed_fields={"r_max": 3}, + ) + yield a + + +@pytest.fixture(scope="function") +def root(): + with tempfile.TemporaryDirectory(prefix="datasetroot") as path: + yield path + + +class TestInit: + def test_init(self): + with pytest.raises(NotImplementedError) as excinfo: + AtomicInMemoryDataset(root=None) + assert str(excinfo.value) == "" + + def test_npz(self, npz_data, root): + g = NpzDataset(file_name=npz_data, root=root, extra_fixed_fields={"r_max": 3.0}) + assert isdir(g.root) + assert isdir(g.processed_dir) + assert isfile(g.processed_dir + "/data.pth") + + def test_ase(self, ase_file, root): + a = ASEDataset( + file_name=ase_file, + root=root, + extra_fixed_fields={"r_max": 3.0}, + ase_args=dict(format="extxyz"), + ) + assert isdir(a.root) + assert isdir(a.processed_dir) + assert isfile(a.processed_dir + "/data.pth") + + +class TestStatistics: + @pytest.mark.xfail( + reason="Current subset hack doesn't support statistics of non-per-node callable" + ) + def test_callable(self, npz_dataset, npz): + # Get componentwise statistics + ((f_mean, f_std),) = npz_dataset.statistics( + [lambda d: torch.flatten(d[AtomicDataDict.FORCE_KEY])] + ) + n_ex, n_at, _ = npz["force"].shape + f_raveled = npz["force"].reshape((n_ex * n_at * 3,)) + assert np.allclose(np.mean(f_raveled), f_mean) + # By default we follow torch convention of defaulting to the unbiased std + assert np.allclose(np.std(f_raveled, ddof=1), f_std) + + def test_statistics(self, npz_dataset, npz): + + (eng_mean, eng_std), (Z_unique, Z_count) = npz_dataset.statistics( + fields=[AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY], + modes=["mean_std", "count"], + ) + + eng = npz["energy"] + assert np.allclose(eng_mean, np.mean(eng)) + # By default we follow torch convention of defaulting to the unbiased std + assert np.allclose(eng_std, np.std(eng, ddof=1)) + + if isinstance(Z_count, torch.Tensor): + Z_count = Z_count.numpy() + Z_unique = Z_unique.numpy() + + uniq, count = np.unique(npz["Z"].ravel(), return_counts=True) + assert np.all(Z_unique == uniq) + assert np.all(Z_count == count) + + def test_with_subset(self, npz_dataset, npz): + + dataset = npz_dataset.index_select([0]) + + ((Z_unique, Z_count), (force_rms,)) = dataset.statistics( + [AtomicDataDict.ATOMIC_NUMBERS_KEY, AtomicDataDict.FORCE_KEY], + modes=["count", "rms"], + ) + + uniq, count = np.unique(npz["Z"][0].ravel(), return_counts=True) + assert np.all(Z_unique.numpy() == uniq) + assert np.all(Z_count.numpy() == count) + + assert np.allclose( + force_rms.numpy(), np.sqrt(np.mean(np.square(npz["force"][0]))) + ) + + def test_atom_types(self, npz_dataset): + ((avg_num_neigh, _),) = npz_dataset.statistics( + fields=[ + lambda data: ( + torch.unique( + data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True + )[1], + "node", + ) + ], + modes=["mean_std"], + ) + # They are all homogenous in this dataset: + assert ( + avg_num_neigh + == torch.bincount(npz_dataset[0][AtomicDataDict.EDGE_INDEX_KEY][0])[0] + ) + + def test_edgewise_stats(self, npz_dataset): + ((avg_edge_length, std_edge_len),) = npz_dataset.statistics( + fields=[ + lambda data: ( + ( + data[AtomicDataDict.POSITIONS_KEY][ + data[AtomicDataDict.EDGE_INDEX_KEY][1] + ] + - data[AtomicDataDict.POSITIONS_KEY][ + data[AtomicDataDict.EDGE_INDEX_KEY][0] + ] + ).norm(dim=-1), + "edge", + ) + ], + modes=["mean_std"], + ) + # TODO: check correct + + +class TestPerSpeciesStatistics: + @pytest.mark.parametrize("fixed_field", [True, False]) + @pytest.mark.parametrize("mode", ["mean_std", "rms"]) + def test_per_node_field(self, npz_dataset, fixed_field, mode): + # set up the transformer + npz_dataset = set_up_transformer(npz_dataset, not fixed_field, fixed_field) + + (result,) = npz_dataset.statistics( + [AtomicDataDict.BATCH_KEY], + modes=[f"per_species_{mode}"], + ) + print(result) + + @pytest.mark.parametrize("alpha", [1e-10, 1e-6, 0.1, 0.5, 1]) + @pytest.mark.parametrize("fixed_field", [True, False]) + @pytest.mark.parametrize("full_rank", [True, False]) + @pytest.mark.parametrize( + "regressor", ["NormalizedGaussianProcess", "GaussianProcess"] + ) + def test_per_graph_field( + self, npz_dataset, alpha, fixed_field, full_rank, regressor + ): + + npz_dataset = set_up_transformer(npz_dataset, full_rank, fixed_field) + if npz_dataset is None: + return + + # get species count per graph + Ns = [] + for i in range(npz_dataset.len()): + Ns.append(torch.bincount(npz_dataset[i][AtomicDataDict.ATOM_TYPE_KEY])) + n_spec = max(len(e) for e in Ns) + N = torch.zeros(len(Ns), n_spec) + for i in range(len(Ns)): + N[i, : len(Ns[i])] = Ns[i] + del n_spec + del Ns + + if alpha == 1e-10: + ref_mean, ref_std, E = generate_E(N, 100, 0.0) + else: + ref_mean, ref_std, E = generate_E(N, 100, 0.5) + + npz_dataset.data[AtomicDataDict.TOTAL_ENERGY_KEY] = E + + ref_res2 = torch.square( + torch.matmul(N, ref_mean.reshape([-1, 1])) - E.reshape([-1, 1]) + ).sum() + + ((mean, std),) = npz_dataset.statistics( + [AtomicDataDict.TOTAL_ENERGY_KEY], + modes=["per_species_mean_std"], + kwargs={ + AtomicDataDict.TOTAL_ENERGY_KEY + + "per_species_mean_std": { + "alpha": alpha, + "regressor": regressor, + "stride": 1, + } + }, + ) + + res = torch.matmul(N, mean.reshape([-1, 1])) - E.reshape([-1, 1]) + res2 = torch.sum(torch.square(res)) + print("residue", alpha, res2 - ref_res2) + print("mean", mean, ref_mean) + print("diff in mean", mean - ref_mean) + print("std", std, ref_std) + + if alpha == 1e-10 and full_rank: + assert torch.allclose(mean, ref_mean, rtol=1e-1) + assert torch.allclose(std, torch.zeros_like(ref_mean), atol=1e-2) + # else: + # assert res2 > ref_res2 + + +class TestReload: + @pytest.mark.parametrize("change_rmax", [0, 1]) + @pytest.mark.parametrize("give_url", [True, False]) + @pytest.mark.parametrize("change_key_map", [True, False]) + def test_reload(self, npz_dataset, npz_data, change_rmax, give_url, change_key_map): + r_max = npz_dataset.extra_fixed_fields["r_max"] + change_rmax + keymap = npz_dataset.key_mapping.copy() # the default one + if change_key_map: + keymap["x1"] = "x2" + a = NpzDataset( + file_name=npz_data, + root=npz_dataset.root, + extra_fixed_fields={"r_max": r_max}, + key_mapping=keymap, + **({"url": "example.com/data.dat"} if give_url else {}), + ) + print(a.processed_file_names[0]) + print(npz_dataset.processed_file_names[0]) + assert (a.processed_dir == npz_dataset.processed_dir) == ( + (change_rmax == 0) and (not give_url) and (not change_key_map) + ) + + +class TestFromConfig: + @pytest.mark.parametrize( + "args", + [ + dict(extra_fixed_fields={"r_max": 3.0}), + dict(dataset_extra_fixed_fields={"r_max": 3.0}), + dict(r_max=3.0), + dict(r_max=3.0, extra_fixed_fields={}), + ], + ) + def test_npz(self, npz_data, root, args): + config = Config( + dict( + dataset="npz", + file_name=npz_data, + root=root, + chemical_symbol_to_type={ + chemical_symbols[an]: an - 1 for an in range(1, MAX_ATOMIC_NUMBER) + }, + **args, + ) + ) + g = dataset_from_config(config) + assert g.fixed_fields["r_max"] == 3 + assert isdir(g.root) + assert isdir(g.processed_dir) + assert isfile(g.processed_dir + "/data.pth") + + @pytest.mark.parametrize("prefix", ["dataset", "thingy"]) + def test_ase(self, ase_file, root, prefix): + config = Config( + dict( + file_name=ase_file, + root=root, + extra_fixed_fields={"r_max": 3.0}, + ase_args=dict(format="extxyz"), + chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}, + ) + ) + config[prefix] = "ASEDataset" + a = dataset_from_config(config, prefix=prefix) + assert isdir(a.root) + assert isdir(a.processed_dir) + assert isfile(a.processed_dir + "/data.pth") + + # Test reload + # Change some random ASE specific parameter + # See https://wiki.fysik.dtu.dk/ase/ase/io/io.html + config["ase_args"]["do_not_split_by_at_sign"] = True + b = dataset_from_config(config, prefix=prefix) + assert isdir(b.processed_dir) + assert isfile(b.processed_dir + "/data.pth") + assert a.processed_dir != b.processed_dir + + +class TestFromList: + def test_from_atoms(self, molecules): + dataset = ASEDataset.from_atoms_list( + molecules, extra_fixed_fields={"r_max": 4.5} + ) + assert len(dataset) == len(molecules) + for i, mol in enumerate(molecules): + assert np.array_equal( + mol.get_atomic_numbers(), dataset[i].to_ase().get_atomic_numbers() + ) + + +def generate_E(N, mean, std): + torch.manual_seed(0) + ref_mean = torch.rand((N.shape[1])) * mean + t_mean = torch.ones((N.shape[0], 1)) * ref_mean.reshape([1, -1]) + ref_std = torch.rand((N.shape[1])) * std + t_std = torch.ones((N.shape[0], 1)) * ref_std.reshape([1, -1]) + E = torch.normal(t_mean, t_std) + return ref_mean, ref_std, (N * E).sum(axis=-1) + + +def set_up_transformer(npz_dataset, full_rank, fixed_field): + + if full_rank: + + if fixed_field: + return + + unique = torch.unique(npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY]) + npz_dataset.transform = TypeMapper( + chemical_symbol_to_type={ + chemical_symbols[n]: i for i, n in enumerate(unique) + } + ) + else: + ntype = 2 + + # let all atoms to be the same type distribution + num_nodes = npz_dataset.data[AtomicDataDict.BATCH_KEY].shape[0] + if fixed_field: + del npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY] + del npz_dataset.data.__slices__[ + AtomicDataDict.ATOMIC_NUMBERS_KEY + ] # remove batch metadata for the key + new_n = torch.ones(NATOMS, dtype=torch.int64) + new_n[0] += ntype + npz_dataset.fixed_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = new_n + else: + npz_dataset.fixed_fields.pop(AtomicDataDict.ATOMIC_NUMBERS_KEY, None) + new_n = torch.ones(num_nodes, dtype=torch.int64) + new_n[::NATOMS] += ntype + npz_dataset.data[AtomicDataDict.ATOMIC_NUMBERS_KEY] = new_n + + # set up the transformer + npz_dataset.transform = TypeMapper( + chemical_symbol_to_type={ + chemical_symbols[n]: i for i, n in enumerate([1, ntype + 1]) + } + ) + return npz_dataset diff --git a/tests/model/test_eng_force.py b/tests/unit/model/test_eng_force.py similarity index 83% rename from tests/model/test_eng_force.py rename to tests/unit/model/test_eng_force.py index a8db666e..05fca69a 100644 --- a/tests/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -3,26 +3,27 @@ import logging import tempfile import torch -from os.path import isfile import numpy as np from e3nn import o3 from e3nn.util.jit import script -from nequip.data import AtomicDataDict, AtomicData -from nequip.models import EnergyModel, ForceModel +from nequip.data import AtomicDataDict, AtomicData, Collater +from nequip.data.transforms import TypeMapper +from nequip.model import model_from_config, uniform_initialize_FCs from nequip.nn import GraphModuleMixin, AtomwiseLinear -from nequip.utils.initialization import uniform_initialize_equivariant_linears from nequip.utils.test import assert_AtomicData_equivariant logging.basicConfig(level=logging.DEBUG) -ALLOWED_SPECIES = [1, 6, 8] +COMMON_CONFIG = { + "num_types": 3, + "types_names": ["H", "C", "O"], +} r_max = 3 minimal_config1 = dict( - allowed_species=ALLOWED_SPECIES, irreps_edge_sh="0e + 1o", r_max=4, feature_irreps_hidden="4x0e + 4x1o", @@ -31,17 +32,17 @@ num_basis=8, PolynomialCutoff_p=6, nonlinearity_type="norm", + **COMMON_CONFIG ) minimal_config2 = dict( - allowed_species=ALLOWED_SPECIES, irreps_edge_sh="0e + 1o", r_max=4, chemical_embedding_irreps_out="8x0e + 8x0o + 8x1e + 8x1o", irreps_mid_output_block="2x0e", feature_irreps_hidden="4x0e + 4x1o", + **COMMON_CONFIG ) minimal_config3 = dict( - allowed_species=ALLOWED_SPECIES, irreps_edge_sh="0e + 1o", r_max=4, feature_irreps_hidden="4x0e + 4x1o", @@ -50,9 +51,9 @@ num_basis=8, PolynomialCutoff_p=6, nonlinearity_type="gate", + **COMMON_CONFIG ) minimal_config4 = dict( - allowed_species=ALLOWED_SPECIES, irreps_edge_sh="0e + 1o + 2e", r_max=4, feature_irreps_hidden="2x0e + 2x1o + 2x2e", @@ -64,6 +65,7 @@ # test custom nonlinearities nonlinearity_scalars={"e": "silu", "o": "tanh"}, nonlinearity_gates={"e": "silu", "o": "abs"}, + **COMMON_CONFIG ) @@ -77,15 +79,23 @@ def config(request): @pytest.fixture( params=[ - (ForceModel, AtomicDataDict.FORCE_KEY), - (EnergyModel, AtomicDataDict.TOTAL_ENERGY_KEY), + ( + ["EnergyModel", "ForceOutput"], + AtomicDataDict.FORCE_KEY, + ), + ( + ["EnergyModel"], + AtomicDataDict.TOTAL_ENERGY_KEY, + ), ] ) def model(request, config): torch.manual_seed(0) np.random.seed(0) builder, out_field = request.param - return builder(**config), out_field + config = config.copy() + config["model_builders"] = builder + return model_from_config(config), out_field @pytest.fixture( @@ -116,8 +126,7 @@ def test_weight_init(self, model, atomic_batch, device): out_orig = instance(data)[out_field] - with torch.no_grad(): - instance.apply(uniform_initialize_equivariant_linears) + instance = uniform_initialize_FCs(instance, initialize=True) out_unif = instance(data)[out_field] assert not torch.allclose(out_orig, out_unif) @@ -129,7 +138,9 @@ def test_jit(self, model, atomic_batch, device): model_script = script(instance) assert torch.allclose( - instance(data)[out_field], model_script(data)[out_field], atol=1e-7 + instance(data)[out_field], + model_script(data)[out_field], + atol=1e-6, ) # - Try saving, loading in another process, and running - @@ -156,7 +167,9 @@ def test_jit(self, model, atomic_batch, device): ) def test_submods(self): - model = EnergyModel(**minimal_config2) + config = minimal_config2.copy() + config["model_builders"] = ["EnergyModel"] + model = model_from_config(config=config, initialize=True) assert isinstance(model.chemical_embedding, AtomwiseLinear) true_irreps = o3.Irreps(minimal_config2["chemical_embedding_irreps_out"]) assert ( @@ -177,18 +190,15 @@ def test_forward(self, model, atomic_batch, device): assert out_field in output def test_saveload(self, model): - with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: - instance, _ = model - torch.save(instance, tmp.name) - assert isfile(tmp.name) - - new_model = torch.load(tmp.name) - assert isinstance(new_model, type(instance)) + # TO DO, test load/save state_dict + pass class TestGradient: def test_numeric_gradient(self, config, atomic_batch, device, float_tolerance): - model = ForceModel(**config) + config = config.copy() + config["model_builders"] = ["EnergyModel", "ForceOutput"] + model = model_from_config(config=config, initialize=True) model.to(device) data = atomic_batch.to(device) output = model(AtomicData.to_AtomicDataDict(data)) @@ -218,9 +228,12 @@ def test_numeric_gradient(self, config, atomic_batch, device, float_tolerance): class TestAutoGradient: def test_cross_frame_grad(self, config, nequip_dataset): - batch = nequip_dataset.data + c = Collater.for_dataset(nequip_dataset) + batch = c([nequip_dataset[i] for i in range(len(nequip_dataset))]) device = "cpu" - energy_model = EnergyModel(**config) + config = config.copy() + config["model_builders"] = ["EnergyModel"] + energy_model = model_from_config(config=config, initialize=True) energy_model.to(device) data = AtomicData.to_AtomicDataDict(batch.to(device)) data[AtomicDataDict.POSITIONS_KEY].requires_grad = True @@ -251,7 +264,7 @@ def test_forward(self, model, atomic_batch, device): class TestCutoff: def test_large_separation(self, model, config, molecules): - atol = {torch.float32: 1e-6, torch.float64: 1e-10}[torch.get_default_dtype()] + atol = {torch.float32: 1e-4, torch.float64: 1e-10}[torch.get_default_dtype()] instance, _ = model r_max = config["r_max"] atoms1 = molecules[0].copy() @@ -260,9 +273,10 @@ def test_large_separation(self, model, config, molecules): atoms2.positions += 40.0 + np.random.randn(3) atoms_both = atoms1.copy() atoms_both.extend(atoms2) - data1 = AtomicData.from_ase(atoms1, r_max=r_max) - data2 = AtomicData.from_ase(atoms2, r_max=r_max) - data_both = AtomicData.from_ase(atoms_both, r_max=r_max) + tm = TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}) + data1 = tm(AtomicData.from_ase(atoms1, r_max=r_max)) + data2 = tm(AtomicData.from_ase(atoms2, r_max=r_max)) + data_both = tm(AtomicData.from_ase(atoms_both, r_max=r_max)) assert ( data_both[AtomicDataDict.EDGE_INDEX_KEY].shape[1] == data1[AtomicDataDict.EDGE_INDEX_KEY].shape[1] @@ -284,7 +298,7 @@ def test_large_separation(self, model, config, molecules): atoms3 = atoms2.copy() atoms3.positions += np.random.randn(3) atoms_both2.extend(atoms3) - data_both2 = AtomicData.from_ase(atoms_both2, r_max=r_max) + data_both2 = tm(AtomicData.from_ase(atoms_both2, r_max=r_max)) out_both2 = instance(AtomicData.to_AtomicDataDict(data_both2)) assert torch.allclose( out_both2[AtomicDataDict.TOTAL_ENERGY_KEY], @@ -298,12 +312,14 @@ def test_large_separation(self, model, config, molecules): ) def test_embedding_cutoff(self, config): - instance = EnergyModel(**config) + config = config.copy() + config["model_builders"] = ["EnergyModel"] + instance = model_from_config(config=config, initialize=True) r_max = config["r_max"] # make a synthetic three atom example data = AtomicData( - atomic_numbers=np.random.choice(ALLOWED_SPECIES, size=3), + atom_types=np.random.choice([0, 1, 2], size=3), pos=np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), edge_index=np.array([[0, 1, 0, 2], [1, 0, 2, 0]]), ) diff --git a/tests/nn/test_atomic.py b/tests/unit/nn/test_atomic.py similarity index 72% rename from tests/nn/test_atomic.py rename to tests/unit/nn/test_atomic.py index 16c593cc..3b00f7ae 100644 --- a/tests/nn/test_atomic.py +++ b/tests/unit/nn/test_atomic.py @@ -11,7 +11,7 @@ from nequip.nn.embedding import ( OneHotAtomEncoding, ) -from torch_geometric.data import Batch +from nequip.utils.torch_geometric import Batch @pytest.fixture(scope="class", params=[0, 1, 2]) @@ -19,7 +19,11 @@ def model(float_tolerance, request): zero_species = request.param shifts = [3, 5, 7] shifts[zero_species] = 0 - params = dict(allowed_species=[1, 6, 8], total_shift=1.0, shifts=shifts) + params = dict( + num_types=3, + total_shift=1.0, + shifts=shifts, + ) return SequentialGraphNetwork.from_parameters( shared_params=params, layers={ @@ -30,7 +34,13 @@ def model(float_tolerance, request): ), "shift": ( PerSpeciesScaleShift, - dict(field="e", out_field="shifted"), + dict( + field="e", + out_field="shifted", + scales=1.0, + shifts=0.0, + arguments_in_dataset_units=False, + ), ), "sum": ( AtomwiseReduce, @@ -44,13 +54,11 @@ def model(float_tolerance, request): def batches(float_tolerance, nequip_dataset): b = [] for idx in [[0], [1], [0, 1]]: - b += [ - AtomicData.to_AtomicDataDict(Batch.from_data_list(nequip_dataset.data[idx])) - ] + b += [AtomicData.to_AtomicDataDict(Batch.from_data_list(nequip_dataset[idx]))] return b -def test_per_specie_shift(nequip_dataset, batches, model): +def test_per_species_shift(nequip_dataset, batches, model): batch1, batch2, batch12 = batches result1 = model(batch1) result2 = model(batch2) diff --git a/tests/nn/test_embed.py b/tests/unit/nn/test_embed.py similarity index 89% rename from tests/nn/test_embed.py rename to tests/unit/nn/test_embed.py index a137e9d4..04fd5766 100644 --- a/tests/nn/test_embed.py +++ b/tests/unit/nn/test_embed.py @@ -1,8 +1,5 @@ -import torch - from e3nn.util.test import assert_auto_jitable -from nequip.data import AtomicDataDict from nequip.utils.test import assert_AtomicData_equivariant from nequip.nn.radial_basis import BesselBasis from nequip.nn.cutoffs import PolynomialCutoff @@ -16,7 +13,7 @@ def test_onehot(CH3CHO): _, data = CH3CHO oh = OneHotAtomEncoding( - allowed_species=torch.unique(data[AtomicDataDict.ATOMIC_NUMBERS_KEY]), + num_types=3, irreps_in=data.irreps, ) assert_auto_jitable(oh) diff --git a/tests/nn/test_rescale.py b/tests/unit/nn/test_rescale.py similarity index 81% rename from tests/nn/test_rescale.py rename to tests/unit/nn/test_rescale.py index d888e16c..22f4652c 100644 --- a/tests/nn/test_rescale.py +++ b/tests/unit/nn/test_rescale.py @@ -1,6 +1,12 @@ import pytest -import contextlib +import sys + +if sys.version_info[1] >= 7: + import contextlib +else: + # has backport of nullcontext + import contextlib2 as contextlib import torch @@ -14,26 +20,26 @@ @pytest.mark.parametrize("scale_by", [0.77, 1.0, None]) @pytest.mark.parametrize("shift_by", [0.0, 0.4443, None]) -@pytest.mark.parametrize("trainable_global_rescale_scale", [True, False]) -@pytest.mark.parametrize("trainable_global_rescale_shift", [True, False]) +@pytest.mark.parametrize("scale_trainable", [True, False]) +@pytest.mark.parametrize("shift_trainable", [True, False]) def test_rescale( CH3CHO, scale_by, shift_by, - trainable_global_rescale_scale, - trainable_global_rescale_shift, + scale_trainable, + shift_trainable, ): _, data = CH3CHO oh = OneHotAtomEncoding( - allowed_species=torch.unique(data[AtomicDataDict.ATOMIC_NUMBERS_KEY]), + num_types=3, irreps_in=data.irreps, ) # some combinations are illegal and should raise build_with = contextlib.nullcontext() - if scale_by is None and trainable_global_rescale_scale: + if scale_by is None and scale_trainable: build_with = pytest.raises(ValueError) - elif shift_by is None and trainable_global_rescale_shift: + elif shift_by is None and shift_trainable: build_with = pytest.raises(ValueError) rescale = None @@ -44,8 +50,8 @@ def test_rescale( shift_keys=AtomicDataDict.NODE_ATTRS_KEY, scale_by=scale_by, shift_by=shift_by, - trainable_global_rescale_scale=trainable_global_rescale_scale, - trainable_global_rescale_shift=trainable_global_rescale_shift, + scale_trainable=scale_trainable, + shift_trainable=shift_trainable, ) if rescale is None: diff --git a/tests/nn/test_sequential.py b/tests/unit/nn/test_sequential.py similarity index 76% rename from tests/nn/test_sequential.py rename to tests/unit/nn/test_sequential.py index a081ff34..4b5f257b 100644 --- a/tests/nn/test_sequential.py +++ b/tests/unit/nn/test_sequential.py @@ -1,3 +1,4 @@ +import pytest import torch from nequip.data import AtomicDataDict @@ -7,21 +8,21 @@ def test_basic(): sgn = SequentialGraphNetwork.from_parameters( - shared_params={"num_species": 3}, + shared_params={"num_types": 3}, layers={"one_hot": OneHotAtomEncoding, "linear": AtomwiseLinear}, ) sgn( { AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), - AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + AtomicDataDict.ATOM_TYPE_KEY: torch.LongTensor([0, 0, 1, 2, 0]), } ) def test_append(): sgn = SequentialGraphNetwork.from_parameters( - shared_params={"num_species": 3}, layers={"one_hot": OneHotAtomEncoding} + shared_params={"num_types": 3}, layers={"one_hot": OneHotAtomEncoding} ) sgn.append_from_parameters( shared_params={"out_field": AtomicDataDict.NODE_FEATURES_KEY}, @@ -34,19 +35,21 @@ def test_append(): { AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), - AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + AtomicDataDict.ATOM_TYPE_KEY: torch.LongTensor([0, 0, 1, 2, 0]), } ) assert out["thing"].shape == out[AtomicDataDict.NODE_FEATURES_KEY].shape -def test_insert(): +@pytest.mark.parametrize("mode", {"before", "after"}) +def test_insert(mode): sgn = SequentialGraphNetwork.from_parameters( - shared_params={"num_species": 3}, + shared_params={"num_types": 3}, layers={"one_hot": OneHotAtomEncoding, "lin2": AtomwiseLinear}, ) + keys = {"before": "lin2", "after": "one_hot"} sgn.insert_from_parameters( - after="one_hot", + **{mode: keys[mode]}, shared_params={"out_field": "thing"}, name="lin1", builder=AtomwiseLinear, @@ -61,7 +64,7 @@ def test_insert(): { AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), - AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + AtomicDataDict.ATOM_TYPE_KEY: torch.LongTensor([0, 0, 1, 2, 0]), } ) assert AtomicDataDict.NODE_FEATURES_KEY in out diff --git a/tests/nn/test_utils.py b/tests/unit/nn/test_utils.py similarity index 87% rename from tests/nn/test_utils.py rename to tests/unit/nn/test_utils.py index ee645924..8f62e868 100644 --- a/tests/nn/test_utils.py +++ b/tests/unit/nn/test_utils.py @@ -7,7 +7,7 @@ def test_basic(): sgn = SequentialGraphNetwork.from_parameters( - shared_params={"num_species": 4}, + shared_params={"num_types": 4}, layers={ "one_hot": OneHotAtomEncoding, "save": ( @@ -21,7 +21,7 @@ def test_basic(): { AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), - AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + AtomicDataDict.ATOM_TYPE_KEY: torch.LongTensor([0, 0, 1, 2, 0]), } ) saved = out["saved"] diff --git a/nequip/dynamics/__init__.py b/tests/unit/trainer/__init__.py similarity index 100% rename from nequip/dynamics/__init__.py rename to tests/unit/trainer/__init__.py diff --git a/tests/trainer/test_early_stopping.py b/tests/unit/trainer/test_early_stopping.py similarity index 100% rename from tests/trainer/test_early_stopping.py rename to tests/unit/trainer/test_early_stopping.py diff --git a/tests/unit/trainer/test_loss.py b/tests/unit/trainer/test_loss.py new file mode 100644 index 00000000..101d28fb --- /dev/null +++ b/tests/unit/trainer/test_loss.py @@ -0,0 +1,223 @@ +import pytest +import torch +from nequip.data import AtomicDataDict +from nequip.train import Loss + +# all the config to test init +# only the last one will be used to test the loss and mae +dicts = ( + {AtomicDataDict.TOTAL_ENERGY_KEY: (3.0, "MSELoss")}, + {AtomicDataDict.TOTAL_ENERGY_KEY: "MSELoss"}, + [AtomicDataDict.FORCE_KEY, AtomicDataDict.TOTAL_ENERGY_KEY], + {AtomicDataDict.FORCE_KEY: (1.0, "PerSpeciesMSELoss")}, + {AtomicDataDict.FORCE_KEY: (1.0), "k": (1.0, torch.nn.L1Loss())}, + AtomicDataDict.TOTAL_ENERGY_KEY, + { + AtomicDataDict.TOTAL_ENERGY_KEY: (3.0, "L1Loss"), + AtomicDataDict.FORCE_KEY: (1.0), + "k": 1.0, + }, +) +nan_dict = { + AtomicDataDict.TOTAL_ENERGY_KEY: (3.0, "L1Loss", {"ignore_nan": True}), + AtomicDataDict.FORCE_KEY: (1.0, "MSELoss", {"ignore_nan": True}), + "k": 1.0, +} + + +class TestInit: + @pytest.mark.parametrize("loss", dicts, indirect=True) + def test_init(self, loss): + + assert len(loss.funcs) == len(loss.coeffs) + for key, value in loss.coeffs.items(): + assert isinstance(value, torch.Tensor) + + +class TestLoss: + @pytest.mark.parametrize("loss", dicts[-2:], indirect=True) + def test_loss(self, loss, data): + + pred, ref = data + + loss_value = loss(pred, ref) + + loss_value, contrib = loss_value + assert len(contrib) > 0 + assert isinstance(contrib, dict) + for key, value in contrib.items(): + assert isinstance(value, torch.Tensor) + + assert isinstance(loss_value, torch.Tensor) + + def test_per_species(self, data): + + pred, ref = data + + config = {AtomicDataDict.FORCE_KEY: (1.0, "PerSpeciesMSELoss")} + loss = Loss(coeffs=config) + + l, contb = loss(pred, ref) + + # first half data are specie 1 + loss_ref_1 = torch.square( + pred[AtomicDataDict.FORCE_KEY][:5] - ref[AtomicDataDict.FORCE_KEY][:5] + ).mean() + loss_ref_0 = torch.square( + pred[AtomicDataDict.FORCE_KEY][5:] - ref[AtomicDataDict.FORCE_KEY][5:] + ).mean() + + assert torch.isclose( + contb[AtomicDataDict.FORCE_KEY], (loss_ref_0 + loss_ref_1) / 2.0 + ) + + def test_per_atom(self, data): + + pred, ref = data + + config = {AtomicDataDict.TOTAL_ENERGY_KEY: (1.0, "PerAtomMSELoss")} + loss = Loss(coeffs=config) + + l, contb = loss(pred, ref) + + # first graph + loss_ref_1 = torch.square( + ( + pred[AtomicDataDict.TOTAL_ENERGY_KEY][0] + - ref[AtomicDataDict.TOTAL_ENERGY_KEY][0] + ) + / 3.0 + ) + # second graph + loss_ref_2 = torch.square( + ( + pred[AtomicDataDict.TOTAL_ENERGY_KEY][1] + - ref[AtomicDataDict.TOTAL_ENERGY_KEY][1] + ) + / 7.0 + ) + loss_ref = (loss_ref_1 + loss_ref_2) / 2.0 + + assert torch.isclose(l, loss_ref) + + +class TestNaN: + def test_loss(self, data_w_NaN): + + pred, ref, wo_nan_pred, wo_nan_ref = data_w_NaN + + loss = Loss(coeffs=nan_dict) + l, contb = loss(pred, ref) + l_wo_nan, contb_wo_nan = loss(wo_nan_pred, wo_nan_ref) + + assert torch.isclose(l_wo_nan, l) + for k in contb: + assert torch.isclose(contb_wo_nan[k], contb[k]) + + def test_per_species(self, data_w_NaN): + + pred, ref, wo_nan_pred, wo_nan_ref = data_w_NaN + + config = { + AtomicDataDict.FORCE_KEY: (1.0, "PerSpeciesMSELoss", {"ignore_nan": True}) + } + loss = Loss(coeffs=config) + + l, contb = loss(pred, ref) + l_wo_nan, contb_wo_nan = loss(wo_nan_pred, wo_nan_ref) + + assert torch.isclose(l_wo_nan, l) + for k in contb: + assert torch.isclose(contb_wo_nan[k], contb[k]) + + def test_per_atom(self, data_w_NaN): + + pred, ref, wo_nan_pred, wo_nan_ref = data_w_NaN + + config = { + AtomicDataDict.TOTAL_ENERGY_KEY: ( + 1.0, + "PerAtomMSELoss", + {"ignore_nan": True}, + ) + } + loss = Loss(coeffs=config) + l_wo_nan, contb_wo_nan = loss(wo_nan_pred, wo_nan_ref) + + l, contb = loss(pred, ref) + + assert torch.isclose(l_wo_nan, l) + for k in contb: + assert torch.isclose(contb_wo_nan[k], contb[k]) + + # first half data are specie 1 + loss_ref = torch.square( + ( + pred[AtomicDataDict.TOTAL_ENERGY_KEY][0] + - ref[AtomicDataDict.TOTAL_ENERGY_KEY][0] + ) + / 3.0 + ) + + assert torch.isclose(l, loss_ref) + + +@pytest.fixture(scope="class") +def loss(request): + """""" + d = request.param + instance = Loss(coeffs=d) + yield instance + + +@pytest.fixture(scope="module") +def data(float_tolerance): + """""" + pred = { + AtomicDataDict.FORCE_KEY: torch.rand(10, 3), + AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), + "k": torch.rand((2, 1)), + AtomicDataDict.ATOM_TYPE_KEY: torch.as_tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + } + ref = { + AtomicDataDict.BATCH_KEY: torch.tensor( + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int + ), + AtomicDataDict.FORCE_KEY: torch.rand(10, 3), + AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), + "k": torch.rand((2, 1)), + AtomicDataDict.ATOM_TYPE_KEY: torch.as_tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + } + yield pred, ref + + +@pytest.fixture(scope="module") +def data_w_NaN(float_tolerance, data): + """""" + _pred, _ref = data + + pred = {k: torch.clone(v) for k, v in _pred.items()} + ref = {k: torch.clone(v) for k, v in _ref.items()} + ref[AtomicDataDict.FORCE_KEY][-1] = float("nan") + ref[AtomicDataDict.FORCE_KEY][0] = float("nan") + ref[AtomicDataDict.TOTAL_ENERGY_KEY][1] = float("nan") + + wo_nan_pred = {k: torch.clone(v) for k, v in _pred.items()} + wo_nan_ref = {k: torch.clone(v) for k, v in _ref.items()} + wo_nan_ref[AtomicDataDict.FORCE_KEY] = wo_nan_ref[AtomicDataDict.FORCE_KEY][1:-1] + wo_nan_ref[AtomicDataDict.TOTAL_ENERGY_KEY] = wo_nan_ref[ + AtomicDataDict.TOTAL_ENERGY_KEY + ][:1] + wo_nan_ref[AtomicDataDict.ATOM_TYPE_KEY] = wo_nan_ref[AtomicDataDict.ATOM_TYPE_KEY][ + 1:-1 + ] + wo_nan_ref[AtomicDataDict.BATCH_KEY] = torch.tensor([0, 0, 0], dtype=torch.int) + wo_nan_pred[AtomicDataDict.FORCE_KEY] = wo_nan_pred[AtomicDataDict.FORCE_KEY][1:-1] + wo_nan_pred[AtomicDataDict.ATOM_TYPE_KEY] = wo_nan_pred[ + AtomicDataDict.ATOM_TYPE_KEY + ][1:-1] + wo_nan_pred[AtomicDataDict.TOTAL_ENERGY_KEY] = wo_nan_pred[ + AtomicDataDict.TOTAL_ENERGY_KEY + ][:1] + + yield pred, ref, wo_nan_pred, wo_nan_ref diff --git a/tests/trainer/test_metrics.py b/tests/unit/trainer/test_metrics.py similarity index 77% rename from tests/trainer/test_metrics.py rename to tests/unit/trainer/test_metrics.py index d0486c1f..17983e70 100644 --- a/tests/trainer/test_metrics.py +++ b/tests/unit/trainer/test_metrics.py @@ -1,4 +1,4 @@ -import inspect +# flake8: noqa import pytest import torch from nequip.data import AtomicDataDict @@ -41,7 +41,7 @@ def test_run(self, metrics, data): class TestWeight: @pytest.mark.parametrize("per_comp", [True, False]) - def test_per_specie(self, data, per_comp): + def test_per_species(self, data, per_comp): pred, ref = data @@ -65,7 +65,7 @@ def test_per_specie(self, data, per_comp): w_contb = w_loss(pred, ref) contb = loss(pred, ref) - # first half data are specie 1 + # first half data are species 1 loss_ref_0 = torch.square(pred["forces"][5:] - ref["forces"][5:]) loss_ref_1 = torch.square(pred["forces"][:5] - ref["forces"][:5]) if per_comp: @@ -81,17 +81,22 @@ def test_per_specie(self, data, per_comp): assert k in ["forces"] # mae should be the same cause # of type 0 == # of type 1 + dim = {"dim": 3, "report_per_component": per_comp} + hash_str_ref = Metrics.hash_component((AtomicDataDict.FORCE_KEY, "mae", dim)) + dim["PerSpecies"] = True + hash_str = Metrics.hash_component((AtomicDataDict.FORCE_KEY, "mae", dim)) assert torch.allclose( - w_contb[("forces", "mae")].mean(dim=0), contb[("forces", "mae")] + w_contb[("forces", hash_str)].mean(dim=0), contb[("forces", hash_str_ref)] ) - assert torch.allclose(w_contb[("forces", "rmse")][0], loss_ref_0) - assert torch.allclose(w_contb[("forces", "rmse")][1], loss_ref_1) + hash_str = Metrics.hash_component((AtomicDataDict.FORCE_KEY, "rmse", dim)) + assert torch.allclose(w_contb[("forces", hash_str)][0], loss_ref_0) + assert torch.allclose(w_contb[("forces", hash_str)][1], loss_ref_1) @pytest.fixture(scope="class", params=metrics_tests) def metrics(request): """""" - coeffs = request.param + coeffs = request.param # noqa instance = Metrics(components=request.param) yield instance diff --git a/tests/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py similarity index 88% rename from tests/trainer/test_trainer.py rename to tests/unit/trainer/test_trainer.py index 992a97ac..c8169fda 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -10,16 +10,23 @@ import torch from torch.nn import Linear +from nequip.model import model_from_config from nequip.data import AtomicDataDict from nequip.train.trainer import Trainer from nequip.utils.savenload import load_file from nequip.nn import GraphModuleMixin + +def dummy_builder(): + return DummyNet(3) + + # set up two config to test DEBUG = False NATOMS = 3 NFRAMES = 10 minimal_config = dict( + run_name="test", n_train=4, n_val=4, exclude_keys=["sth"], @@ -28,29 +35,26 @@ learning_rate=1e-2, optimizer="Adam", seed=0, - restart=False, + append=False, T_0=50, T_mult=2, loss_coeffs={"forces": 2}, early_stopping_patiences={"loss": 50}, early_stopping_lower_bounds={"LR": 1e-10}, + model_builders=[dummy_builder], ) -configs_to_test = [dict(), minimal_config] -loop_config = pytest.mark.parametrize("trainer", configs_to_test, indirect=True) -one_config_test = pytest.mark.parametrize("trainer", [minimal_config], indirect=True) @pytest.fixture(scope="class") -def trainer(request): +def trainer(): """ Generate a class instance with minimal configurations """ - params = request.param - - model = DummyNet(3) + minimal_config["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] + model = model_from_config(minimal_config) with tempfile.TemporaryDirectory(prefix="output") as path: - params["root"] = path - c = Trainer(model=model, **params) + minimal_config["root"] = path + c = Trainer(model=model, **minimal_config) yield c @@ -59,37 +63,27 @@ class TestTrainerSetUp: test initialization """ - @one_config_test def test_init(self, trainer): assert isinstance(trainer, Trainer) class TestDuplicateError: def test_duplicate_id_2(self, temp_data): + """ + check whether the Output class can automatically + insert timestr when a workdir has pre-existed + """ minimal_config["root"] = temp_data model = DummyNet(3) - c1 = Trainer(model=model, **minimal_config) - logfile1 = c1.logfile + Trainer(model=model, **minimal_config) - c2 = Trainer(model=model, **minimal_config) - logfile2 = c2.logfile - - assert c1.root == c2.root - assert c1.workdir != c2.workdir - assert c1.logfile.endswith("log") - assert c2.logfile.endswith("log") - - -class TestInit: - @one_config_test - def test_init_model(self, trainer): - trainer.init_model() + with pytest.raises(RuntimeError): + Trainer(model=model, **minimal_config) class TestSaveLoad: - @loop_config @pytest.mark.parametrize("state_dict", [True, False]) @pytest.mark.parametrize("training_progress", [True, False]) def test_as_dict(self, trainer, state_dict, training_progress): @@ -103,7 +97,6 @@ def test_as_dict(self, trainer, state_dict, training_progress): assert training_progress == ("progress" in dictionary) assert len(dictionary["optimizer_kwargs"]) > 1 - @loop_config @pytest.mark.parametrize("format, suffix", [("torch", "pth"), ("yaml", "yaml")]) def test_save(self, trainer, format, suffix): @@ -113,8 +106,7 @@ def test_save(self, trainer, format, suffix): assert isfile(file_name), "fail to save to file" assert suffix in file_name - @loop_config - @pytest.mark.parametrize("append", [True, False]) + @pytest.mark.parametrize("append", [True]) # , False]) def test_from_dict(self, trainer, append): # torch.save(trainer.model, trainer.best_model_path) @@ -132,11 +124,9 @@ def test_from_dict(self, trainer, append): ]: v1 = getattr(trainer, key, None) v2 = getattr(trainer1, key, None) - print(key, v1, v2) assert append == (v1 == v2) - @loop_config - @pytest.mark.parametrize("append", [True, False]) + @pytest.mark.parametrize("append", [True]) # , False]) def test_from_file(self, trainer, append): format = "torch" @@ -160,7 +150,6 @@ def test_from_file(self, trainer, append): ]: v1 = getattr(trainer, key, None) v2 = getattr(trainer1, key, None) - print(key, v1, v2) assert append == (v1 == v2) for iparam, group1 in enumerate(trainer.optim.param_groups): @@ -174,7 +163,6 @@ def test_from_file(self, trainer, append): class TestData: - @one_config_test @pytest.mark.parametrize("mode", ["random", "sequential"]) def test_split(self, trainer, nequip_dataset, mode): @@ -185,7 +173,6 @@ def test_split(self, trainer, nequip_dataset, mode): class TestTrain: - @one_config_test def test_train(self, trainer, nequip_dataset): v0 = get_param(trainer.model) @@ -196,7 +183,6 @@ def test_train(self, trainer, nequip_dataset): assert not np.allclose(v0, v1), "fail to train parameters" assert isfile(trainer.last_model_path), "fail to save best model" - @one_config_test def test_load_w_revision(self, trainer): with tempfile.TemporaryDirectory() as folder: @@ -216,12 +202,10 @@ def test_load_w_revision(self, trainer): assert trainer1.iepoch == trainer.iepoch assert trainer1.max_epochs == minimal_config["max_epochs"] * 2 - @one_config_test def test_restart_training(self, trainer, nequip_dataset): - - model = trainer.model - device = trainer.device - optimizer = trainer.optim + _ = trainer.model + _ = trainer.device + _ = trainer.optim trainer.set_dataset(nequip_dataset) trainer.train() @@ -318,7 +302,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: class DummyScale(torch.nn.Module): - """ mimic the rescale model""" + """mimic the rescale model""" def __init__(self, key, scale, shift) -> None: super().__init__() @@ -360,6 +344,7 @@ def scale_train(nequip_dataset): batch_size=2, loss_coeffs=AtomicDataDict.FORCE_KEY, root=path, + run_name="test_scale", ) trainer.set_dataset(nequip_dataset) trainer.train() diff --git a/tests/utils/test_batch_ops.py b/tests/unit/utils/test_batch_ops.py similarity index 97% rename from tests/utils/test_batch_ops.py rename to tests/unit/utils/test_batch_ops.py index 5111e855..8788f86e 100644 --- a/tests/utils/test_batch_ops.py +++ b/tests/unit/utils/test_batch_ops.py @@ -12,7 +12,7 @@ def test_bincount(n_class, n_batch, n_max_nodes): n_nodes = torch.randint(1, n_max_nodes + 1, size=(n_batch,)) total_n_nodes = n_nodes.sum() input = torch.randint(0, n_class, size=(total_n_nodes,)) - batch = torch.LongTensor(sum(([i] * n for i, n in enumerate(n_nodes)), start=[])) + batch = torch.LongTensor(sum(([i] * n for i, n in enumerate(n_nodes)), [])) truth = [] for b in range(n_batch): diff --git a/tests/utils/test_config.py b/tests/unit/utils/test_config.py similarity index 98% rename from tests/utils/test_config.py rename to tests/unit/utils/test_config.py index f2992afd..0cd3151e 100644 --- a/tests/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -1,13 +1,8 @@ """ Config tests """ - - -import numpy as np import pytest -import torch - from os import remove from nequip.utils import Config diff --git a/tests/utils/test_instantiate.py b/tests/unit/utils/test_instantiate.py similarity index 98% rename from tests/utils/test_instantiate.py rename to tests/unit/utils/test_instantiate.py index 1762d51e..1009bd5f 100644 --- a/tests/utils/test_instantiate.py +++ b/tests/unit/utils/test_instantiate.py @@ -155,8 +155,8 @@ def __init__(self, cls_b, cls_b_kwargs): class C: - def __init__(self, cls_c, cls_c_kwargs): - self.c_obj = c_cls(**c_cls_kwargs) + def __init__(self, cls_c, cls_c_kwargs): # noqa + self.c_obj = c_cls(**c_cls_kwargs) # noqa def test_deep_nests(): diff --git a/tests/utils/test_output.py b/tests/unit/utils/test_output.py similarity index 82% rename from tests/utils/test_output.py rename to tests/unit/utils/test_output.py index 3035b3d0..cdc7b4ac 100644 --- a/tests/utils/test_output.py +++ b/tests/unit/utils/test_output.py @@ -1,14 +1,9 @@ """ Config tests """ - - -import numpy as np import pytest import tempfile -import torch -from os import remove from os.path import isdir from nequip.utils.output import Output @@ -20,7 +15,7 @@ class TestInit: def test_empty_init(self, root): - output = Output(root=root) + output = Output(root=root, run_name="test") print(output.root) print(output.workdir) assert isdir(output.root) @@ -36,9 +31,8 @@ def test_empty_init(self, root): class TestReload: - @pytest.mark.parametrize("restart", [True, False]) @pytest.mark.parametrize("append", [True, False]) - def test_restart(self, restart, append): + def test_restart(self, append): pass diff --git a/tests/utils/test_tests.py b/tests/unit/utils/test_tests.py similarity index 94% rename from tests/utils/test_tests.py rename to tests/unit/utils/test_tests.py index 36f79650..d176a0e1 100644 --- a/tests/utils/test_tests.py +++ b/tests/unit/utils/test_tests.py @@ -4,13 +4,11 @@ from e3nn import o3 -from nequip.data import AtomicDataDict +from nequip.data import AtomicDataDict, register_fields, deregister_fields from nequip.nn import GraphModuleMixin from nequip.utils.test import ( assert_AtomicData_equivariant, assert_permutation_equivariant, - register_fields, - deregister_fields, ) @@ -153,8 +151,6 @@ def test_permute_register(): with pytest.raises(AssertionError): # Fails because thinks "my_edge" is invariant assert_permutation_equivariant(mod, data_in=dict(inp)) - assert_permutation_equivariant( - mod, data_in=dict(inp), extra_edge_permute_fields=["my_edge"] - ) - register_fields(edge_permute_fields=["my_edge"]) + + register_fields(edge_fields=["my_edge"]) assert_permutation_equivariant(mod, data_in=dict(inp)) diff --git a/tests/unit/utils/test_weight_init.py b/tests/unit/utils/test_weight_init.py new file mode 100644 index 00000000..e76344ba --- /dev/null +++ b/tests/unit/utils/test_weight_init.py @@ -0,0 +1,12 @@ +import pytest + +import torch + +from nequip.model._weight_init import unit_uniform_init_ + + +@pytest.mark.parametrize("init_func_", [unit_uniform_init_]) +def test_2mom(init_func_): + t = torch.empty(1000, 100) + init_func_(t) + assert (t.square().mean() - 1.0).abs() <= 0.1 diff --git a/tests/utils/test_weight_init.py b/tests/utils/test_weight_init.py deleted file mode 100644 index b5e2a013..00000000 --- a/tests/utils/test_weight_init.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest - -import torch - -from nequip.utils.initialization import unit_uniform_init_, unit_orthogonal_init_ - - -@pytest.mark.parametrize("init_func_", [unit_uniform_init_, unit_orthogonal_init_]) -def test_2mom(init_func_): - t = torch.empty(1000, 100) - init_func_(t) - assert (t.square().mean() - 1.0).abs() <= 0.1