Skip to content

Commit

Permalink
Add support for custom models and species.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Aug 9, 2024
1 parent 0ffc0c4 commit a8d8a75
Showing 1 changed file with 83 additions and 15 deletions.
98 changes: 83 additions & 15 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,45 @@ class EMLE(_torch.nn.Module):
embedding.
"""

# Class attributes.

# Store the expected path to the resources directory.
_resource_dir = _os.path.join(
_os.path.dirname(_os.path.abspath(__file__)), "..", "resources"
)

# Create the name of the default model file for each alpha mode.
_default_models = {
"species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"),
"reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"),
}

# Store the list of supported species.
_species = [1, 6, 7, 8, 16]

def __init__(
self, alpha_mode="species", device=None, dtype=None, create_aev_calculator=True
self,
model=None,
species=None,
alpha_mode="species",
device=None,
dtype=None,
create_aev_calculator=True,
):
"""
Constructor.
Parameters
----------
model: str
Path to a custom EMLE model parameter file. If None, then the
default model for the specified 'alpha_mode' will be used.
species: List[int]
List of atomic numbers for the species in the model. If None, then
the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -99,22 +129,48 @@ def __init__(
# Fetch or update the resources.
_fetch_resources()

# Store the expected path to the resources directory.
resource_dir = _os.path.join(
_os.path.dirname(_os.path.abspath(__file__)), "..", "resources"
)

if not isinstance(alpha_mode, str):
raise TypeError("'alpha_mode' must be of type 'str'")
# Convert to lower case and strip whitespace.
alpha_mode = alpha_mode.lower().replace(" ", "")
if alpha_mode not in ["species", "reference"]:
raise ValueError("'alpha_mode' must be 'species' or 'reference'")
self._alpha_mode = alpha_mode

# Choose the model based on the alpha_mode.
if alpha_mode == "species":
model = _os.path.join(resource_dir, "emle_qm7_aev.mat")
if model is not None:
if not isinstance(model, str):
msg = "'model' must be of type 'str'"
_logger.error(msg)
raise TypeError(msg)

# Convert to an absolute path.
abs_model = _os.path.abspath(model)

if not _os.path.isfile(abs_model):
msg = f"Unable to locate EMLE embedding model file: '{model}'"
_logger.error(msg)
raise IOError(msg)
self._model = abs_model

# Validate the species for the custom model.
if species is not None:
if not isinstance(species, list):
raise TypeError("'species' must be of type 'list'")
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError(
"All elements of 'species' must be greater than zero"
)
else:
# Use the default species.
species = self._species
else:
model = _os.path.join(resource_dir, "emle_qm7_aev_alphagpr.mat")
# Choose the model based on the alpha_mode.
model = self._default_models[alpha_mode]

# Use the default species.
species = self._species

if device is not None:
if not isinstance(device, _torch.device):
Expand Down Expand Up @@ -144,9 +200,6 @@ def __init__(
except:
raise IOError(f"Unable to load model parameters from: '{model}'")

# Set the supported species.
species = [1, 6, 7, 8, 16]

# Create a map between species and their indices.
species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64)
for i, s in enumerate(species):
Expand All @@ -161,9 +214,24 @@ def __init__(
a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device)
a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device)
if self._alpha_mode == "species":
k = _torch.tensor(params["k_Z"], dtype=dtype, device=device)
try:
k = _torch.tensor(params["k_Z"], dtype=dtype, device=device)
except:
msg = (
"Missing 'k_Z' key in model. This is required when "
"using 'species' alpha mode."
)
raise ValueError(msg)
else:
k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device)
try:
k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device)
except:
msg = (
"Missing 'sqrtk_ref' key in model. This is required when "
"using 'reference' alpha mode."
)
raise ValueError(msg)

q_total = _torch.tensor(
params.get("total_charge", 0), dtype=dtype, device=device
)
Expand Down

0 comments on commit a8d8a75

Please sign in to comment.