Skip to content

Commit

Permalink
Merge pull request #32 from kzinovjev/aev-tweaks
Browse files Browse the repository at this point in the history
Add aev_mean model parameter
  • Loading branch information
lohedges authored Oct 29, 2024
2 parents e0f5985 + ea45505 commit 53049aa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(
q_core,
aev_computer=self._aev_computer,
aev_mask=aev_mask,
aev_mean=params.get("aev_mean"),
alpha_mode=self._alpha_mode,
species=params.get("species", self._species),
device=device,
Expand Down
12 changes: 12 additions & 0 deletions emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
q_core,
aev_computer=None,
aev_mask=None,
aev_mean=None,
species=None,
alpha_mode="species",
device=None,
Expand Down Expand Up @@ -102,6 +103,9 @@ def __init__(
aev_mask: torch.Tensor
Mask for features coming from aev_computer.
aev_mean: torch.Tensor
Mean values to be subtracted from features
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model.
Expand Down Expand Up @@ -203,6 +207,10 @@ def __init__(
else:
dtype = _torch.get_default_dtype()

self._aev_mean = None
if aev_mean is not None:
self._aev_mean = _torch.tensor(aev_mean, dtype=dtype, device=device)

# Store model parameters as tensors.
self.a_QEq = _torch.nn.Parameter(params["a_QEq"])
self.a_Thole = _torch.nn.Parameter(params["a_Thole"])
Expand Down Expand Up @@ -425,6 +433,10 @@ def forward(self, atomic_numbers, xyz_qm, q_total):
# The AEVs have been pre-computed by a parent model.
else:
aev = self._aev[:, :, self._aev_mask]

if self._aev_mean is not None:
aev = aev - self._aev_mean[None, None, :]

aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True)

# Compute the MBIS valence shell widths.
Expand Down

0 comments on commit 53049aa

Please sign in to comment.