Skip to content

Commit

Permalink
Remove need for hard-coded ELEMENT_DICT.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 21, 2023
1 parent cdc4ec8 commit 42fcc57
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions emle/emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
BOHR_TO_ANGSTROM = ase.units.Bohr
EV_TO_HARTREE = 1.0 / ase.units.Hartree
KCAL_MOL_TO_HARTREE = 1.0 / ase.units.Hartree * ase.units.kcal / ase.units.mol

# Settings for the default model. For system specific models, these will be
# overwritten by values in the model file.
SPECIES = (1, 6, 7, 8, 16)
SIGMA = 1e-3

SPHERICAL_EXPANSION_HYPERS_COMMON = {
"gaussian_sigma_constant": 0.5,
"gaussian_sigma_type": "Constant",
Expand All @@ -61,8 +63,6 @@
"global_species": SPECIES,
}

ELEMENT_DICT = {1: "H", 6: "C", 7: "N", 8: "O", 16: "S"}


class GPRCalculator:
"""Predicts an atomic property for a molecule with Gaussian Process Regression (GPR)."""
Expand Down Expand Up @@ -844,6 +844,11 @@ def __init__(
except:
self._hypers[key] = self._params[key]

# Work out the supported elements.
self._supported_elements = []
for id in self._hypers["global_species"]:
self._supported_elements.append(ase.atoms.Atom(id).symbol)

self._get_soap = SOAPCalculatorSpinv(self._hypers)
self._q_core = torch.tensor(
self._params["q_core"], dtype=torch.float32, device=self._device
Expand Down Expand Up @@ -1001,11 +1006,11 @@ def run(self, path=None):
for id in atomic_numbers:
try:
species_id.append(self._hypers["global_species"].index(id))
elements.append(ELEMENT_DICT[id])
elements.append(ase.atom.Atom(id).symbol)
except:
raise ValueError(
f"Unsupported element index '{id}'. "
f"We currently support {', '.join(ELEMENT_DICT.values())}."
f"The current model supports {', '.join(self._supported_elements)}"
)
self._species_id = np.array(species_id)

Expand Down

0 comments on commit 42fcc57

Please sign in to comment.