Skip to content

Commit

Permalink
Add atomic_numbers kwarg to allow use of NNPOps for AEVs.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 17, 2024
1 parent c4f40c1 commit 3f8a6b7
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 35 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ To stop the server:
emle-stop
```

## NNPOps

The ``EMLE`` model uses Atomic Environment Vectors (AEVs) for the calculation of
the electrostatic embedding energy. For performance, it is desirable to use
the optimised symmetry functions provided by the [NNPOps](https://github.com/openmm/NNPOps)
package. This requires a *static* compute graph, so needs to know the atomic
numbers for the atoms in the QM region in advance. These can be specified using
the ``EMLE_ATOMIC_NUMBERS`` environment variable, or the ``--atomic-numbers``
command-line argument when launching the server. This option should only be
used if the QM region is fixed, i.e. the atoms in the QM region do not change
each time the calculator is called.

## Backends

The embedding method relies on in vacuo energies and gradients, to which
Expand Down
41 changes: 41 additions & 0 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ except:
species = None
method = os.getenv("EMLE_METHOD")
alpha_mode = os.getenv("EMLE_ALPHA_MODE")
atomic_numbers = os.getenv("EMLE_ATOMIC_NUMBERS")
mm_charges = os.getenv("EMLE_MM_CHARGES")
try:
num_clients = int(os.getenv("EMLE_NUM_CLIENTS"))
Expand Down Expand Up @@ -136,6 +137,7 @@ env = {
"species": species,
"method": method,
"alpha_mode": alpha_mode,
"atomic_numbers": atomic_numbers,
"mm_charges": mm_charges,
"num_clients": num_clients,
"backend": backend,
Expand Down Expand Up @@ -210,6 +212,13 @@ parser.add_argument(
choices=["species", "reference"],
required=False,
)
parser.add_argument(
"--atomic-numbers",
type=str,
nargs="*",
help="the atomic numbers of the atoms in the qm region",
required=False,
)
parser.add_argument(
"--mm-charges",
type=str,
Expand Down Expand Up @@ -480,6 +489,38 @@ if set_lambda_interpolate is not None:

# Handle special case formatting for environment variables.

# Validate the atomic numbers.
if args["atomic_numbers"] is not None:
# Whether we are parsing a list of atomic numbers, rather than a file.
is_list = False

if isinstance(args["atomic_numbers"], str):
# If this isn't a path to a file, try splitting on commas.
if not os.path.isfile(args["atomic_numbers"]) or not os.path.isfile(
os.path.abspath(args["atomic_numbers"])
):
try:
args["atomic_numbers"] = args["atomic_numbers"].split(",")
is_list = True
except:
raise ValueError(
"Unable to parse EMLE_ATOMIC_NUMBERS environment variable as a comma-separated list of ints"
)

# A single entry list is assumed to be the path to a file.
elif isinstance(args["atomic_numbers"], list):
if len(args["atomic_numbers"]) == 1:
args["atomic_numbers"] = args["atomic_numbers"][0]
else:
is_list = True

# Try to parse lists of atomic numbers into a list of ints.
if is_list:
try:
args["atomic_numbers"] = [int(x) for x in args["atomic_numbers"]]
except:
raise TypeError("Unable to parse atomic numbers as a list of ints")

# Validate the MM charges.
if args["mm_charges"] is None:
if method == "mm":
Expand Down
14 changes: 14 additions & 0 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
species=None,
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
backend="torchani",
external_backend=None,
plugin_path=".",
Expand Down Expand Up @@ -158,6 +159,12 @@ def __init__(
scaling factors are obtained with GPR using the values learned
for each reference environment
atomic_numbers: List[int], Tuple[int], numpy.ndarray
Atomic numbers for the QM region. This allows use of optimised AEV
symmetry functions from the NNPOps package. Only use this option if
you are using a fixed QM region, i.e. the same QM region for each
call to the calculator.
external_backend: str
The name of an external backend to use to compute in vacuo energies.
This should be a callback function formatted as 'module.function'.
Expand Down Expand Up @@ -411,6 +418,8 @@ def __init__(
self._emle = _EMLE(
model=model,
method=method,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
mm_charges=self._mm_charges,
device=self._device,
)
Expand Down Expand Up @@ -786,6 +795,7 @@ def __init__(
self._emle_mm = _EMLE(
model=model,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
method="mm",
mm_charges=self._mm_charges,
device=self._device,
Expand Down Expand Up @@ -933,12 +943,16 @@ def __init__(
self._method = self._emle._method
self._alpha_mode = self._emle._alpha_mode

if isinstance(atomic_numbers, _np.ndarray):
atomic_numbers = atomic_numbers.tolist()

# Store the settings as a dictionary.
self._settings = {
"model": None if model is None else self._model,
"species": None if species is None else self._species,
"method": self._method,
"alpha_mode": self._alpha_mode,
"atomic_numbers": None if atomic_numbers is None else atomic_numbers,
"backend": self._backend,
"external_backend": None if external_backend is None else external_backend,
"mm_charges": None if mm_charges is None else self._mm_charges.tolist(),
Expand Down
34 changes: 22 additions & 12 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

__all__ = ["ANI2xEMLE"]

import numpy as _np
import torch as _torch
import torchani as _torchani

Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
MM charges are used for the core charge and valence charges
are set to zero.
emle_species: List[int]
emle_species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
Expand All @@ -101,9 +102,9 @@ def __init__(
scaling factors are obtained with GPR using the values learned
for each reference environment
mm_charges: numpy.ndarray
An array of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' emle_method is specified.
mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
model_index: int
The index of the ANI2x model to use. If None, then the full 8 model
Expand All @@ -115,10 +116,11 @@ def __init__(
the ANI2x model from which it derived was created using
periodic_table_index=True.
atomic_numbers: torch.Tensor (N_ATOMS,)
List of atomic numbers to use in the ANI2x model. If specified,
and NNPOps is available, then an optimised version of ANI2x will
be used.
atomic_numbers: List[float], Tuple[float], numpy.ndarray, torch.Tensor (N_ATOMS,)
Atomic numbers for the QM region. This allows use of optimised AEV
symmetry functions from the NNPOps package. Only use this option
if you are using a fixed QM region, i.e. the same QM region for each
evalulation of the module.
device: torch.device
The device on which to run the model.
Expand Down Expand Up @@ -150,6 +152,13 @@ def __init__(
dtype = _torch.get_default_dtype()

if atomic_numbers is not None:
if isinstance(atomic_numbers, _np.ndarray):
atomic_numbers = atomic_numbers.tolist()
if isinstance(atomic_numbers, (list, tuple)):
if not all(isinstance(i, int) for i in atomic_numbers):
raise ValueError("'atomic_numbers' must be a list of integers")
else:
atomic_numbers = _torch.tensor(atomic_numbers, dtype=_torch.int64)
if not isinstance(atomic_numbers, _torch.Tensor):
raise TypeError("'atomic_numbers' must be of type 'torch.Tensor'")
# Check that they are integers.
Expand All @@ -165,6 +174,7 @@ def __init__(
method=emle_method,
species=emle_species,
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -212,10 +222,10 @@ def __init__(
# Optimise the ANI2x model if atomic_numbers are specified.
if atomic_numbers is not None:
try:
species = atomic_numbers.reshape(1, *atomic_numbers.shape)
self._ani2x = _NNPOps.OptimizedTorchANI(self._ani2x, species).to(
device
)
atomic_numbers = atomic_numbers.reshape(1, *atomic_numbers.shape)
self._ani2x = _NNPOps.OptimizedTorchANI(
self._ani2x, atomic_numbers
).to(device)
except:
pass

Expand Down
61 changes: 56 additions & 5 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
method="electrostatic",
species=None,
alpha_mode="species",
atomic_numbers=None,
mm_charges=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(
should also specify the MM charges for atoms in the QM
region.
species: List[int]
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
Expand All @@ -129,8 +130,14 @@ def __init__(
scaling factors are obtained with GPR using the values learned
for each reference environment
mm_charges: numpy.ndarray
An array of MM charges for atoms in the QM region in units of mod
atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor
Atomic numbers for the QM region. This allows use of optimised AEV
symmetry functions from the NNPOps package. Only use this option
if you are using a fixed QM region, i.e. the same QM region for each
evalulation of the module.
mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
device: torch.device
Expand Down Expand Up @@ -174,9 +181,29 @@ def __init__(
raise ValueError("'alpha_mode' must be 'species' or 'reference'")
self._alpha_mode = alpha_mode

if atomic_numbers is not None:
if isinstance(atomic_numbers, (_np.ndarray, _torch.Tensor)):
atomic_numbers = atomic_numbers.tolist()
if not isinstance(atomic_numbers, (tuple, list)):
raise TypeError(
"'atomic_numbers' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(a, int) for a in atomic_numbers):
raise TypeError(
"All elements of 'atomic_numbers' must be of type 'int'"
)
if not all(a > 0 for a in atomic_numbers):
raise ValueError(
"All elements of 'atomic_numbers' must be greater than zero"
)

if method == "mm":
if mm_charges is None:
raise ValueError("MM charges must be provided for the 'mm' method")
if isinstance(mm_charges, (list, tuple)):
mm_charges = _np.array(mm_charges)
elif isinstance(mm_charges, _torch.Tensor):
mm_charges = mm_charges.cpu().numpy()
if not isinstance(mm_charges, _np.ndarray):
raise TypeError("'mm_charges' must be of type 'numpy.ndarray'")
if mm_charges.dtype != _np.float64:
Expand All @@ -201,8 +228,12 @@ def __init__(

# Validate the species for the custom model.
if species is not None:
if not isinstance(species, list):
raise TypeError("'species' must be of type 'list'")
if isinstance(species, (_np.ndarray, _torch.Tensor)):
species = species.tolist()
if not isinstance(species, (tuple, list)):
raise TypeError(
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
Expand Down Expand Up @@ -241,6 +272,26 @@ def __init__(
if create_aev_calculator:
ani2x = _torchani.models.ANI2x(periodic_table_index=True).to(device)
self._aev_computer = ani2x.aev_computer

# Optimise the AEV computer using NNPOps if available.
if atomic_numbers is not None:
if _has_nnpops:
try:
atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=device
)
atomic_numbers = atomic_numbers.reshape(
1, *atomic_numbers.shape
)
self._ani2x.aev_computer = (
_NNPOps.SymmetryFunctions.TorchANISymmetryFunctions(
self._aev_computer.species_converter,
self._aev_computer.aev_computer,
atomic_numbers,
)
)
except:
pass
else:
self._aev_computer = None

Expand Down
Loading

0 comments on commit 3f8a6b7

Please sign in to comment.