Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant attributes and kwargs now that the species are encoded in future models #31

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ try:
except:
port = None
model = os.getenv("EMLE_MODEL")
try:
species = [int(x) for x in os.getenv("EMLE_SPECIES").split(",")]
except:
species = None
method = os.getenv("EMLE_METHOD")
alpha_mode = os.getenv("EMLE_ALPHA_MODE")
atomic_numbers = os.getenv("EMLE_ATOMIC_NUMBERS")
Expand Down Expand Up @@ -134,7 +130,6 @@ env = {
"host": host,
"port": port,
"model": model,
"species": species,
"method": method,
"alpha_mode": alpha_mode,
"atomic_numbers": atomic_numbers,
Expand Down Expand Up @@ -191,13 +186,6 @@ parser.add_argument("--port", type=str, help="the port number", required=False)
parser.add_argument(
"--model", type=str, help="path to an EMLE model file", required=False
)
parser.add_argument(
"--species",
type=str,
nargs="*",
help="the species supported by the model",
required=False,
)
parser.add_argument(
"--method",
type=str,
Expand Down
3 changes: 2 additions & 1 deletion emle/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _fetch_resources():

import os as _os
import pygit2 as _pygit2
import sys

# Create the name for the expected resources directory.
resource_dir = _os.path.join(
Expand All @@ -40,7 +41,7 @@ def _fetch_resources():
# Check if the resources directory exists.
if not _os.path.exists(resource_dir):
# If it doesn't, clone the resources repository.
print("Downloading EMLE resources...")
print("Downloading EMLE resources...", file=sys.stderr)
_pygit2.clone_repository(
"https://github.com/chemle/emle-models.git", resource_dir
)
Expand Down
40 changes: 4 additions & 36 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class EMLECalculator:
def __init__(
self,
model=None,
species=None,
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
Expand Down Expand Up @@ -127,10 +126,6 @@ def __init__(
Path to the EMLE embedding model parameter file. If None, then a
default model will be used.

species: List[int]
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.

method: str
The desired embedding method. Options are:
"electrostatic":
Expand Down Expand Up @@ -418,7 +413,6 @@ def __init__(
self._emle = _EMLE(
model=model,
method=method,
species=species,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
mm_charges=self._mm_charges,
Expand Down Expand Up @@ -795,7 +789,6 @@ def __init__(
# Create an MM EMLE model for interpolation.
self._emle_mm = _EMLE(
model=model,
species=species,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
method="mm",
Expand Down Expand Up @@ -941,7 +934,6 @@ def __init__(

# Get the settings from the internal EMLE model.
self._model = self._emle._model
self._species = self._emle._species
self._method = self._emle._method
self._alpha_mode = self._emle._alpha_mode
self._atomic_numbers = self._emle._atomic_numbers
Expand All @@ -952,7 +944,6 @@ def __init__(
# Store the settings as a dictionary.
self._settings = {
"model": None if model is None else self._model,
"species": None if species is None else self._species,
"method": self._method,
"alpha_mode": self._alpha_mode,
"atomic_numbers": None if atomic_numbers is None else atomic_numbers,
Expand Down Expand Up @@ -1052,21 +1043,10 @@ def run(self, path=None):
xyz_mm = _np.append(xyz_mm, xyz_mm_pad, axis=0)
charges_mm = _np.append(charges_mm, charges_mm_pad)

# Convert the QM atomic numbers to elements and species IDs.
species_id = []
# Convert the QM atomic numbers to elements.
elements = []
for id in atomic_numbers:
try:
species_id.append(self._species.index(id))
elements.append(_ase.Atom(id).symbol)
except:
msg = (
f"Unsupported element index '{id}'. "
f"The current model supports {', '.join(self._supported_elements)}"
)
_logger.error(msg)
raise ValueError(msg)
self._species_id = _torch.tensor(_np.array(species_id), device=self._device)
elements.append(_ase.Atom(id).symbol)

# First try to use the specified backend to compute in vacuo
# energies and (optionally) gradients.
Expand Down Expand Up @@ -1437,21 +1417,10 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
xyz_mm = _np.append(xyz_mm, xyz_mm_pad, axis=0)
charges_mm = _np.append(charges_mm, charges_mm_pad)

# Convert the QM atomic numbers to elements and species IDs.
species_id = []
# Convert the QM atomic numbers to elements.
elements = []
for id in atomic_numbers:
try:
species_id.append(self._species.index(id))
elements.append(_ase.Atom(id).symbol)
except:
msg = (
f"Unsupported element index '{id}'. "
f"The current model supports {', '.join(self._supported_elements)}"
)
_logger.error(msg)
raise ValueError(msg)
self._species_id = _torch.tensor(_np.array(species_id), device=self._device)
elements.append(_ase.Atom(id).symbol)

# First try to use the specified backend to compute in vacuo
# energies and (optionally) gradients.
Expand Down Expand Up @@ -1774,7 +1743,6 @@ def _sire_callback_optimised(
# Create the model.
ani2x_emle = _ANI2xEMLE(
emle_model=self._model,
emle_species=self._species,
alpha_mode=self._alpha_mode,
mm_charges=self._mm_charges,
model_index=self._ani2x_model_index,
Expand Down
6 changes: 0 additions & 6 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
self,
emle_model=None,
emle_method="electrostatic",
emle_species=None,
alpha_mode="species",
mm_charges=None,
model_index=None,
Expand Down Expand Up @@ -90,10 +89,6 @@ def __init__(
MM charges are used for the core charge and valence charges
are set to zero.

emle_species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.

alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -173,7 +168,6 @@ def __init__(
self._emle = _EMLE(
model=emle_model,
method=emle_method,
species=emle_species,
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
Expand Down
25 changes: 0 additions & 25 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
self,
model=None,
method="electrostatic",
species=None,
alpha_mode="species",
atomic_numbers=None,
mm_charges=None,
Expand Down Expand Up @@ -116,10 +115,6 @@ def __init__(
should also specify the MM charges for atoms in the QM
region.

species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.

alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -220,26 +215,6 @@ def __init__(
if not _os.path.isfile(abs_model):
raise IOError(f"Unable to locate EMLE embedding model file: '{model}'")
self._model = abs_model

# Validate the species for the custom model.
if species is not None:
if isinstance(species, (_np.ndarray, _torch.Tensor)):
species = species.tolist()
if not isinstance(species, (tuple, list)):
raise TypeError(
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError(
"All elements of 'species' must be greater than zero"
)
# Use the custom species.
self._species = species
else:
# Use the default species.
species = self._species
else:
# Set to None as this will be used in any calculator configuration.
self._model = None
Expand Down
3 changes: 2 additions & 1 deletion emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def __init__(
)
if not isinstance(aev_computer, tuple(allowed_types)):
raise TypeError(
"'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
"'aev_computer' must be of type 'torchani.AEVComputer' or "
"'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
)
self._aev_computer = aev_computer
else:
Expand Down
6 changes: 0 additions & 6 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
self,
emle_model=None,
emle_method="electrostatic",
emle_species=None,
alpha_mode="species",
mm_charges=None,
mace_model=None,
Expand Down Expand Up @@ -95,10 +94,6 @@ def __init__(
MM charges are used for the core charge and valence charges
are set to zero.

emle_species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.

alpha_mode: str
How atomic polarizabilities are calculated.
"species":
Expand Down Expand Up @@ -175,7 +170,6 @@ def __init__(
self._emle = _EMLE(
model=emle_model,
method=emle_method,
species=emle_species,
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
Expand Down