diff --git a/emle/models/_ani.py b/emle/models/_ani.py index ac0f5e2..47ba6b8 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -47,6 +47,8 @@ class ANI2xEMLE(_EMLE): def __init__( self, + emle_model=None, + emle_species=None, alpha_mode="species", model_index=None, ani2x_model=None, @@ -60,6 +62,14 @@ def __init__( Parameters ---------- + emle_model: str + Path to a custom EMLE model parameter file. If None, then the + default model for the specified 'alpha_mode' will be used. + + emle_species: List[int] + List of species (atomic numbers) supported by the EMLE model. If + None, then the default species list will be used. + alpha_mode: str How atomic polarizabilities are calculated. "species": @@ -69,7 +79,7 @@ def __init__( for each reference environment model_index: int - The index of the model to use. If None, then the full 8 model + The index of the ANI2x model to use. If None, then the full 8 model ensemble will be used. ani2x_model: torchani.models.ANI2x, NNPOPS.OptimizedTorchANI @@ -120,6 +130,8 @@ def __init__( # Call the base class constructor. super().__init__( + model=emle_model, + species=emle_species, alpha_mode=alpha_mode, device=device, dtype=dtype, diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 47d9a7b..3dfc5b7 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -61,8 +61,30 @@ class EMLE(_torch.nn.Module): embedding. """ + # Class attributes. + + # Store the expected path to the resources directory. + _resource_dir = _os.path.join( + _os.path.dirname(_os.path.abspath(__file__)), "..", "resources" + ) + + # Create the name of the default model file for each alpha mode. + _default_models = { + "species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"), + "reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"), + } + + # Store the list of supported species. + _species = [1, 6, 7, 8, 16] + def __init__( - self, alpha_mode="species", device=None, dtype=None, create_aev_calculator=True + self, + model=None, + species=None, + alpha_mode="species", + device=None, + dtype=None, + create_aev_calculator=True, ): """ Constructor. @@ -70,6 +92,14 @@ def __init__( Parameters ---------- + model: str + Path to a custom EMLE model parameter file. If None, then the + default model for the specified 'alpha_mode' will be used. + + species: List[int] + List of species (atomic numbers) supported by the EMLE model. If + None, then the default species list will be used. + alpha_mode: str How atomic polarizabilities are calculated. "species": @@ -99,22 +129,48 @@ def __init__( # Fetch or update the resources. _fetch_resources() - # Store the expected path to the resources directory. - resource_dir = _os.path.join( - _os.path.dirname(_os.path.abspath(__file__)), "..", "resources" - ) - if not isinstance(alpha_mode, str): raise TypeError("'alpha_mode' must be of type 'str'") + # Convert to lower case and strip whitespace. + alpha_mode = alpha_mode.lower().replace(" ", "") if alpha_mode not in ["species", "reference"]: raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode - # Choose the model based on the alpha_mode. - if alpha_mode == "species": - model = _os.path.join(resource_dir, "emle_qm7_aev.mat") + if model is not None: + if not isinstance(model, str): + msg = "'model' must be of type 'str'" + _logger.error(msg) + raise TypeError(msg) + + # Convert to an absolute path. + abs_model = _os.path.abspath(model) + + if not _os.path.isfile(abs_model): + msg = f"Unable to locate EMLE embedding model file: '{model}'" + _logger.error(msg) + raise IOError(msg) + self._model = abs_model + + # Validate the species for the custom model. + if species is not None: + if not isinstance(species, list): + raise TypeError("'species' must be of type 'list'") + if not all(isinstance(s, int) for s in species): + raise TypeError("All elements of 'species' must be of type 'int'") + if not all(s > 0 for s in species): + raise ValueError( + "All elements of 'species' must be greater than zero" + ) + else: + # Use the default species. + species = self._species else: - model = _os.path.join(resource_dir, "emle_qm7_aev_alphagpr.mat") + # Choose the model based on the alpha_mode. + model = self._default_models[alpha_mode] + + # Use the default species. + species = self._species if device is not None: if not isinstance(device, _torch.device): @@ -144,9 +200,6 @@ def __init__( except: raise IOError(f"Unable to load model parameters from: '{model}'") - # Set the supported species. - species = [1, 6, 7, 8, 16] - # Create a map between species and their indices. species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) for i, s in enumerate(species): @@ -161,9 +214,24 @@ def __init__( a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) if self._alpha_mode == "species": - k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + try: + k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'k_Z' key in model. This is required when " + "using 'species' alpha mode." + ) + raise ValueError(msg) else: - k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + try: + k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'sqrtk_ref' key in model. This is required when " + "using 'reference' alpha mode." + ) + raise ValueError(msg) + q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device ) diff --git a/emle/models/_mace.py b/emle/models/_mace.py index 44ebe2f..f70c2e7 100644 --- a/emle/models/_mace.py +++ b/emle/models/_mace.py @@ -54,6 +54,8 @@ class MACEEMLE(_EMLE): def __init__( self, + emle_model=None, + emle_species=None, alpha_mode="species", mace_model=None, atomic_numbers=None, @@ -65,6 +67,15 @@ def __init__( Parameters ---------- + + emle_model: str + Path to a custom EMLE model parameter file. If None, then the + default model for the specified 'alpha_mode' will be used. + + emle_species: List[int] + List of species (atomic numbers) supported by the EMLE model. If + None, then the default species list will be used. + alpha_mode: str How atomic polarizabilities are calculated. "species": @@ -111,6 +122,8 @@ def __init__( # Call the base class constructor. super().__init__( + model=emle_model, + species=emle_species, alpha_mode=alpha_mode, device=device, dtype=dtype,