diff --git a/emle/emle.py b/emle/emle.py index f7428b0..b7b2f8e 100644 --- a/emle/emle.py +++ b/emle/emle.py @@ -49,9 +49,11 @@ BOHR_TO_ANGSTROM = ase.units.Bohr EV_TO_HARTREE = 1.0 / ase.units.Hartree KCAL_MOL_TO_HARTREE = 1.0 / ase.units.Hartree * ase.units.kcal / ase.units.mol + +# Settings for the default model. For system specific models, these will be +# overwritten by values in the model file. SPECIES = (1, 6, 7, 8, 16) SIGMA = 1e-3 - SPHERICAL_EXPANSION_HYPERS_COMMON = { "gaussian_sigma_constant": 0.5, "gaussian_sigma_type": "Constant", @@ -61,8 +63,6 @@ "global_species": SPECIES, } -ELEMENT_DICT = {1: "H", 6: "C", 7: "N", 8: "O", 16: "S"} - class GPRCalculator: """Predicts an atomic property for a molecule with Gaussian Process Regression (GPR).""" @@ -844,6 +844,11 @@ def __init__( except: self._hypers[key] = self._params[key] + # Work out the supported elements. + self._supported_elements = [] + for id in self._hypers["global_species"]: + self._supported_elements.append(ase.atoms.Atom(id).symbol) + self._get_soap = SOAPCalculatorSpinv(self._hypers) self._q_core = torch.tensor( self._params["q_core"], dtype=torch.float32, device=self._device @@ -1001,11 +1006,11 @@ def run(self, path=None): for id in atomic_numbers: try: species_id.append(self._hypers["global_species"].index(id)) - elements.append(ELEMENT_DICT[id]) + elements.append(ase.atom.Atom(id).symbol) except: raise ValueError( f"Unsupported element index '{id}'. " - f"We currently support {', '.join(ELEMENT_DICT.values())}." + f"The current model supports {', '.join(self._supported_elements)}" ) self._species_id = np.array(species_id)