From 197b6d46f3d8a0b052e58b29a0786edbe3a2dbd0 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Tue, 22 Oct 2024 20:25:17 +0200 Subject: [PATCH 1/2] Provide "species" from model params to EMLEBase --- emle/models/_emle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 344faf1..7a23c0d 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -348,7 +348,7 @@ def __init__( aev_computer=self._aev_computer, aev_mask=aev_mask, alpha_mode=self._alpha_mode, - species=self._species, + species=params.get("species", self._species), device=device, dtype=dtype, ) From e2f2dbdafff130d6b74c66a53e63ff9a0116a274 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Tue, 22 Oct 2024 20:27:24 +0200 Subject: [PATCH 2/2] Remove check for k_Z in model params (now is always present) --- emle/models/_emle.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 7a23c0d..1e9254e 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -310,11 +310,7 @@ def __init__( "ref_values_chi": _torch.tensor( params["chi_ref"], dtype=dtype, device=device ), - "k_Z": ( - _torch.tensor(params["k_Z"], dtype=dtype, device=device) - if "k_Z" in params - else None - ), + "k_Z": _torch.tensor(params["k_Z"], dtype=dtype, device=device), "sqrtk_ref": ( _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) if "sqrtk_ref" in params