From a8d8a75203e4754a889d5707e22c7847d5a5dd64 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Fri, 9 Aug 2024 14:24:15 +0100 Subject: [PATCH] Add support for custom models and species. --- emle/models/_emle.py | 98 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 15 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 47d9a7b..bba4671 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -61,8 +61,30 @@ class EMLE(_torch.nn.Module): embedding. """ + # Class attributes. + + # Store the expected path to the resources directory. + _resource_dir = _os.path.join( + _os.path.dirname(_os.path.abspath(__file__)), "..", "resources" + ) + + # Create the name of the default model file for each alpha mode. + _default_models = { + "species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"), + "reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"), + } + + # Store the list of supported species. + _species = [1, 6, 7, 8, 16] + def __init__( - self, alpha_mode="species", device=None, dtype=None, create_aev_calculator=True + self, + model=None, + species=None, + alpha_mode="species", + device=None, + dtype=None, + create_aev_calculator=True, ): """ Constructor. @@ -70,6 +92,14 @@ def __init__( Parameters ---------- + model: str + Path to a custom EMLE model parameter file. If None, then the + default model for the specified 'alpha_mode' will be used. + + species: List[int] + List of atomic numbers for the species in the model. If None, then + the default species list will be used. + alpha_mode: str How atomic polarizabilities are calculated. "species": @@ -99,22 +129,48 @@ def __init__( # Fetch or update the resources. _fetch_resources() - # Store the expected path to the resources directory. - resource_dir = _os.path.join( - _os.path.dirname(_os.path.abspath(__file__)), "..", "resources" - ) - if not isinstance(alpha_mode, str): raise TypeError("'alpha_mode' must be of type 'str'") + # Convert to lower case and strip whitespace. + alpha_mode = alpha_mode.lower().replace(" ", "") if alpha_mode not in ["species", "reference"]: raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode - # Choose the model based on the alpha_mode. - if alpha_mode == "species": - model = _os.path.join(resource_dir, "emle_qm7_aev.mat") + if model is not None: + if not isinstance(model, str): + msg = "'model' must be of type 'str'" + _logger.error(msg) + raise TypeError(msg) + + # Convert to an absolute path. + abs_model = _os.path.abspath(model) + + if not _os.path.isfile(abs_model): + msg = f"Unable to locate EMLE embedding model file: '{model}'" + _logger.error(msg) + raise IOError(msg) + self._model = abs_model + + # Validate the species for the custom model. + if species is not None: + if not isinstance(species, list): + raise TypeError("'species' must be of type 'list'") + if not all(isinstance(s, int) for s in species): + raise TypeError("All elements of 'species' must be of type 'int'") + if not all(s > 0 for s in species): + raise ValueError( + "All elements of 'species' must be greater than zero" + ) + else: + # Use the default species. + species = self._species else: - model = _os.path.join(resource_dir, "emle_qm7_aev_alphagpr.mat") + # Choose the model based on the alpha_mode. + model = self._default_models[alpha_mode] + + # Use the default species. + species = self._species if device is not None: if not isinstance(device, _torch.device): @@ -144,9 +200,6 @@ def __init__( except: raise IOError(f"Unable to load model parameters from: '{model}'") - # Set the supported species. - species = [1, 6, 7, 8, 16] - # Create a map between species and their indices. species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) for i, s in enumerate(species): @@ -161,9 +214,24 @@ def __init__( a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) if self._alpha_mode == "species": - k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + try: + k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'k_Z' key in model. This is required when " + "using 'species' alpha mode." + ) + raise ValueError(msg) else: - k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + try: + k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'sqrtk_ref' key in model. This is required when " + "using 'reference' alpha mode." + ) + raise ValueError(msg) + q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device )