Skip to content

Commit

Permalink
Add support for custom models and species.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Aug 9, 2024
1 parent 0ffc0c4 commit 8ce4be2
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 16 deletions.
14 changes: 13 additions & 1 deletion emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
class ANI2xEMLE(_EMLE):
def __init__(
self,
emle_model=None,
emle_species=None,
alpha_mode="species",
model_index=None,
ani2x_model=None,
Expand All @@ -60,6 +62,14 @@ def __init__(
Parameters
----------
emle_model: str
Path to a custom EMLE model parameter file. If None, then the
default model for the specified 'alpha_mode' will be used.
emle_species: List[int]
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand All @@ -69,7 +79,7 @@ def __init__(
for each reference environment
model_index: int
The index of the model to use. If None, then the full 8 model
The index of the ANI2x model to use. If None, then the full 8 model
ensemble will be used.
ani2x_model: torchani.models.ANI2x, NNPOPS.OptimizedTorchANI
Expand Down Expand Up @@ -120,6 +130,8 @@ def __init__(

# Call the base class constructor.
super().__init__(
model=emle_model,
species=emle_species,
alpha_mode=alpha_mode,
device=device,
dtype=dtype,
Expand Down
98 changes: 83 additions & 15 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,45 @@ class EMLE(_torch.nn.Module):
embedding.
"""

# Class attributes.

# Store the expected path to the resources directory.
_resource_dir = _os.path.join(
_os.path.dirname(_os.path.abspath(__file__)), "..", "resources"
)

# Create the name of the default model file for each alpha mode.
_default_models = {
"species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"),
"reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"),
}

# Store the list of supported species.
_species = [1, 6, 7, 8, 16]

def __init__(
self, alpha_mode="species", device=None, dtype=None, create_aev_calculator=True
self,
model=None,
species=None,
alpha_mode="species",
device=None,
dtype=None,
create_aev_calculator=True,
):
"""
Constructor.
Parameters
----------
model: str
Path to a custom EMLE model parameter file. If None, then the
default model for the specified 'alpha_mode' will be used.
species: List[int]
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -99,22 +129,48 @@ def __init__(
# Fetch or update the resources.
_fetch_resources()

# Store the expected path to the resources directory.
resource_dir = _os.path.join(
_os.path.dirname(_os.path.abspath(__file__)), "..", "resources"
)

if not isinstance(alpha_mode, str):
raise TypeError("'alpha_mode' must be of type 'str'")
# Convert to lower case and strip whitespace.
alpha_mode = alpha_mode.lower().replace(" ", "")
if alpha_mode not in ["species", "reference"]:
raise ValueError("'alpha_mode' must be 'species' or 'reference'")
self._alpha_mode = alpha_mode

# Choose the model based on the alpha_mode.
if alpha_mode == "species":
model = _os.path.join(resource_dir, "emle_qm7_aev.mat")
if model is not None:
if not isinstance(model, str):
msg = "'model' must be of type 'str'"
_logger.error(msg)
raise TypeError(msg)

# Convert to an absolute path.
abs_model = _os.path.abspath(model)

if not _os.path.isfile(abs_model):
msg = f"Unable to locate EMLE embedding model file: '{model}'"
_logger.error(msg)
raise IOError(msg)
self._model = abs_model

# Validate the species for the custom model.
if species is not None:
if not isinstance(species, list):
raise TypeError("'species' must be of type 'list'")
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError(
"All elements of 'species' must be greater than zero"
)
else:
# Use the default species.
species = self._species
else:
model = _os.path.join(resource_dir, "emle_qm7_aev_alphagpr.mat")
# Choose the model based on the alpha_mode.
model = self._default_models[alpha_mode]

# Use the default species.
species = self._species

if device is not None:
if not isinstance(device, _torch.device):
Expand Down Expand Up @@ -144,9 +200,6 @@ def __init__(
except:
raise IOError(f"Unable to load model parameters from: '{model}'")

# Set the supported species.
species = [1, 6, 7, 8, 16]

# Create a map between species and their indices.
species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64)
for i, s in enumerate(species):
Expand All @@ -161,9 +214,24 @@ def __init__(
a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device)
a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device)
if self._alpha_mode == "species":
k = _torch.tensor(params["k_Z"], dtype=dtype, device=device)
try:
k = _torch.tensor(params["k_Z"], dtype=dtype, device=device)
except:
msg = (
"Missing 'k_Z' key in model. This is required when "
"using 'species' alpha mode."
)
raise ValueError(msg)
else:
k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device)
try:
k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device)
except:
msg = (
"Missing 'sqrtk_ref' key in model. This is required when "
"using 'reference' alpha mode."
)
raise ValueError(msg)

q_total = _torch.tensor(
params.get("total_charge", 0), dtype=dtype, device=device
)
Expand Down
13 changes: 13 additions & 0 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
class MACEEMLE(_EMLE):
def __init__(
self,
emle_model=None,
emle_species=None,
alpha_mode="species",
mace_model=None,
atomic_numbers=None,
Expand All @@ -65,6 +67,15 @@ def __init__(
Parameters
----------
emle_model: str
Path to a custom EMLE model parameter file. If None, then the
default model for the specified 'alpha_mode' will be used.
emle_species: List[int]
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -111,6 +122,8 @@ def __init__(

# Call the base class constructor.
super().__init__(
model=emle_model,
species=emle_species,
alpha_mode=alpha_mode,
device=device,
dtype=dtype,
Expand Down

0 comments on commit 8ce4be2

Please sign in to comment.