Skip to content

Commit

Permalink
Merge pull request #31 from chemle/feature_species
Browse files Browse the repository at this point in the history
Remove redundant attributes and kwargs now that the species are encoded in future models
  • Loading branch information
lohedges authored Oct 23, 2024
2 parents 7e54719 + d05a491 commit cd9633a
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 87 deletions.
12 changes: 0 additions & 12 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ try:
except:
port = None
model = os.getenv("EMLE_MODEL")
try:
species = [int(x) for x in os.getenv("EMLE_SPECIES").split(",")]
except:
species = None
method = os.getenv("EMLE_METHOD")
alpha_mode = os.getenv("EMLE_ALPHA_MODE")
atomic_numbers = os.getenv("EMLE_ATOMIC_NUMBERS")
Expand Down Expand Up @@ -134,7 +130,6 @@ env = {
"host": host,
"port": port,
"model": model,
"species": species,
"method": method,
"alpha_mode": alpha_mode,
"atomic_numbers": atomic_numbers,
Expand Down Expand Up @@ -191,13 +186,6 @@ parser.add_argument("--port", type=str, help="the port number", required=False)
parser.add_argument(
"--model", type=str, help="path to an EMLE model file", required=False
)
parser.add_argument(
"--species",
type=str,
nargs="*",
help="the species supported by the model",
required=False,
)
parser.add_argument(
"--method",
type=str,
Expand Down
3 changes: 2 additions & 1 deletion emle/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _fetch_resources():

import os as _os
import pygit2 as _pygit2
import sys

# Create the name for the expected resources directory.
resource_dir = _os.path.join(
Expand All @@ -40,7 +41,7 @@ def _fetch_resources():
# Check if the resources directory exists.
if not _os.path.exists(resource_dir):
# If it doesn't, clone the resources repository.
print("Downloading EMLE resources...")
print("Downloading EMLE resources...", file=sys.stderr)
_pygit2.clone_repository(
"https://github.com/chemle/emle-models.git", resource_dir
)
Expand Down
40 changes: 4 additions & 36 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class EMLECalculator:
def __init__(
self,
model=None,
species=None,
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
Expand Down Expand Up @@ -127,10 +126,6 @@ def __init__(
Path to the EMLE embedding model parameter file. If None, then a
default model will be used.
species: List[int]
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
method: str
The desired embedding method. Options are:
"electrostatic":
Expand Down Expand Up @@ -418,7 +413,6 @@ def __init__(
self._emle = _EMLE(
model=model,
method=method,
species=species,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
mm_charges=self._mm_charges,
Expand Down Expand Up @@ -795,7 +789,6 @@ def __init__(
# Create an MM EMLE model for interpolation.
self._emle_mm = _EMLE(
model=model,
species=species,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
method="mm",
Expand Down Expand Up @@ -941,7 +934,6 @@ def __init__(

# Get the settings from the internal EMLE model.
self._model = self._emle._model
self._species = self._emle._species
self._method = self._emle._method
self._alpha_mode = self._emle._alpha_mode
self._atomic_numbers = self._emle._atomic_numbers
Expand All @@ -952,7 +944,6 @@ def __init__(
# Store the settings as a dictionary.
self._settings = {
"model": None if model is None else self._model,
"species": None if species is None else self._species,
"method": self._method,
"alpha_mode": self._alpha_mode,
"atomic_numbers": None if atomic_numbers is None else atomic_numbers,
Expand Down Expand Up @@ -1052,21 +1043,10 @@ def run(self, path=None):
xyz_mm = _np.append(xyz_mm, xyz_mm_pad, axis=0)
charges_mm = _np.append(charges_mm, charges_mm_pad)

# Convert the QM atomic numbers to elements and species IDs.
species_id = []
# Convert the QM atomic numbers to elements.
elements = []
for id in atomic_numbers:
try:
species_id.append(self._species.index(id))
elements.append(_ase.Atom(id).symbol)
except:
msg = (
f"Unsupported element index '{id}'. "
f"The current model supports {', '.join(self._supported_elements)}"
)
_logger.error(msg)
raise ValueError(msg)
self._species_id = _torch.tensor(_np.array(species_id), device=self._device)
elements.append(_ase.Atom(id).symbol)

# First try to use the specified backend to compute in vacuo
# energies and (optionally) gradients.
Expand Down Expand Up @@ -1437,21 +1417,10 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
xyz_mm = _np.append(xyz_mm, xyz_mm_pad, axis=0)
charges_mm = _np.append(charges_mm, charges_mm_pad)

# Convert the QM atomic numbers to elements and species IDs.
species_id = []
# Convert the QM atomic numbers to elements.
elements = []
for id in atomic_numbers:
try:
species_id.append(self._species.index(id))
elements.append(_ase.Atom(id).symbol)
except:
msg = (
f"Unsupported element index '{id}'. "
f"The current model supports {', '.join(self._supported_elements)}"
)
_logger.error(msg)
raise ValueError(msg)
self._species_id = _torch.tensor(_np.array(species_id), device=self._device)
elements.append(_ase.Atom(id).symbol)

# First try to use the specified backend to compute in vacuo
# energies and (optionally) gradients.
Expand Down Expand Up @@ -1774,7 +1743,6 @@ def _sire_callback_optimised(
# Create the model.
ani2x_emle = _ANI2xEMLE(
emle_model=self._model,
emle_species=self._species,
alpha_mode=self._alpha_mode,
mm_charges=self._mm_charges,
model_index=self._ani2x_model_index,
Expand Down
6 changes: 0 additions & 6 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
self,
emle_model=None,
emle_method="electrostatic",
emle_species=None,
alpha_mode="species",
mm_charges=None,
model_index=None,
Expand Down Expand Up @@ -90,10 +89,6 @@ def __init__(
MM charges are used for the core charge and valence charges
are set to zero.
emle_species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -173,7 +168,6 @@ def __init__(
self._emle = _EMLE(
model=emle_model,
method=emle_method,
species=emle_species,
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
Expand Down
25 changes: 0 additions & 25 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
self,
model=None,
method="electrostatic",
species=None,
alpha_mode="species",
atomic_numbers=None,
mm_charges=None,
Expand Down Expand Up @@ -116,10 +115,6 @@ def __init__(
should also specify the MM charges for atoms in the QM
region.
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -220,26 +215,6 @@ def __init__(
if not _os.path.isfile(abs_model):
raise IOError(f"Unable to locate EMLE embedding model file: '{model}'")
self._model = abs_model

# Validate the species for the custom model.
if species is not None:
if isinstance(species, (_np.ndarray, _torch.Tensor)):
species = species.tolist()
if not isinstance(species, (tuple, list)):
raise TypeError(
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError(
"All elements of 'species' must be greater than zero"
)
# Use the custom species.
self._species = species
else:
# Use the default species.
species = self._species
else:
# Set to None as this will be used in any calculator configuration.
self._model = None
Expand Down
3 changes: 2 additions & 1 deletion emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def __init__(
)
if not isinstance(aev_computer, tuple(allowed_types)):
raise TypeError(
"'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
"'aev_computer' must be of type 'torchani.AEVComputer' or "
"'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
)
self._aev_computer = aev_computer
else:
Expand Down
6 changes: 0 additions & 6 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
self,
emle_model=None,
emle_method="electrostatic",
emle_species=None,
alpha_mode="species",
mm_charges=None,
mace_model=None,
Expand Down Expand Up @@ -95,10 +94,6 @@ def __init__(
MM charges are used for the core charge and valence charges
are set to zero.
emle_species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -175,7 +170,6 @@ def __init__(
self._emle = _EMLE(
model=emle_model,
method=emle_method,
species=emle_species,
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
Expand Down

0 comments on commit cd9633a

Please sign in to comment.