Skip to content

Commit

Permalink
Make k_Z a torch parameter and fix module methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 22, 2024
1 parent 1bc96bf commit 85040b2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
22 changes: 21 additions & 1 deletion emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
raise TypeError("'device' must be of type 'torch.device'")
else:
device = _torch.get_default_device()
self._device = device

if dtype is not None:
if not isinstance(dtype, _torch.dtype):
Expand Down Expand Up @@ -229,8 +230,15 @@ def __init__(
except:
pass

# Add a hook to the ANI2x model to capture the AEV features.
self._add_hook()

def _add_hook(self):
"""
Add a hook to the ANI2x model to capture the AEV features.
"""
# Assign a tensor attribute that can be used for assigning the AEVs.
self._ani2x.aev_computer._aev = _torch.empty(0, device=device)
self._ani2x.aev_computer._aev = _torch.empty(0, device=self._device)

# Hook the forward pass of the ANI2x model to get the AEV features.
# Note that this currently requires a patched versions of TorchANI and NNPOps.
Expand Down Expand Up @@ -261,6 +269,13 @@ def to(self, *args, **kwargs):
"""
self._emle = self._emle.to(*args, **kwargs)
self._ani2x = self._ani2x.to(*args, **kwargs)

# Check for a device type in args and update the device attribute.
for arg in args:
if isinstance(arg, _torch.device):
self._device = arg
break

return self

def cpu(self, **kwargs):
Expand All @@ -269,6 +284,7 @@ def cpu(self, **kwargs):
"""
self._emle = self._emle.cpu(**kwargs)
self._ani2x = self._ani2x.cpu(**kwargs)
self._device = _torch.device("cpu")
return self

def cuda(self, **kwargs):
Expand All @@ -277,6 +293,7 @@ def cuda(self, **kwargs):
"""
self._emle = self._emle.cuda(**kwargs)
self._ani2x = self._ani2x.cuda(**kwargs)
self._device = _torch.device("cuda")
return self

def double(self):
Expand Down Expand Up @@ -306,6 +323,9 @@ def float(self):
except:
pass

# Re-append the hook.
self._add_hook()

return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
Expand Down
15 changes: 2 additions & 13 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,6 @@ def __init__(
dtype=dtype,
)

def _to_dict(self):
"""
Return the configuration of the module as a dictionary.
"""
return {
"model": self._model,
"method": self._method,
"species": self._species_map.tolist(),
"alpha_mode": self._alpha_mode,
}

def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion on the model.
Expand Down Expand Up @@ -393,7 +382,7 @@ def cuda(self, **kwargs):
self._emle_base = self._emle_base.cuda(**kwargs)

# Update the device attribute.
self._device = self._species_map.device
self._device = self._q_total.device

return self

Expand All @@ -408,7 +397,7 @@ def cpu(self, **kwargs):
self._emle_base = self._emle_base.cpu()

# Update the device attribute.
self._device = self._species_map.device
self._device = self._q_total.device

return self

Expand Down
13 changes: 5 additions & 8 deletions emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.a_Thole = _torch.nn.Parameter(params["a_Thole"])
self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"])
self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"])
k_Z = _torch.nn.Parameter(params["k_Z"])
self.k_Z = _torch.nn.Parameter(params["k_Z"])

if self._alpha_mode == "reference":
try:
Expand Down Expand Up @@ -270,7 +270,6 @@ def __init__(
self.register_buffer("_c_s", c_s)
self.register_buffer("_c_chi", c_chi)
self.register_buffer("_c_sqrtk", c_sqrtk)
self.register_buffer("_k_Z", k_Z)

# Initalise an empty AEV tensor to use to store the AEVs in parent models.
# If AEVs are computed externally, then this tensor will be set by the
Expand All @@ -292,7 +291,6 @@ def to(self, *args, **kwargs):
self._c_s = self._c_s.to(*args, **kwargs)
self._c_chi = self._c_chi.to(*args, **kwargs)
self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs)
self._k_Z = self._k_Z.to(*args, **kwargs)

# Check for a device type in args and update the device attribute.
for arg in args:
Expand Down Expand Up @@ -320,7 +318,7 @@ def cuda(self, **kwargs):
self._c_s = self._c_s.cuda(**kwargs)
self._c_chi = self._c_chi.cuda(**kwargs)
self._c_sqrtk = self._c_sqrtk.cuda(**kwargs)
self._k_Z = self._k_Z.cuda(**kwargs)
self.k_Z = self.k_Z.cuda(**kwargs)

# Update the device attribute.
self._device = self._species_map.device
Expand All @@ -345,7 +343,6 @@ def cpu(self, **kwargs):
self._c_s = self._c_s.cpu(**kwargs)
self._c_chi = self._c_chi.cpu(**kwargs)
self._c_sqrtk = self._c_sqrtk.cpu(**kwargs)
self._k_Z = self._k_Z.cpu(**kwargs)

# Update the device attribute.
self._device = self._species_map.device
Expand All @@ -367,7 +364,7 @@ def double(self):
self._c_s = self._c_s.double()
self._c_chi = self._c_chi.double()
self._c_sqrtk = self._c_sqrtk.double()
self._k_Z = self._k_Z.double()
self.k_Z = _torch.nn.Parameter(self.k_Z.double())
return self

def float(self):
Expand All @@ -385,7 +382,7 @@ def float(self):
self._c_s = self._c_s.float()
self._c_chi = self._c_chi.float()
self._c_sqrtk = self._c_sqrtk.float()
self._k_Z = self._k_Z.float()
self.k_Z = _torch.nn.Parameter(self.k_Z.float())
return self

def forward(self, atomic_numbers, xyz_qm, q_total):
Expand Down Expand Up @@ -445,7 +442,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total):
q = self._get_q(r_data, s, chi, q_total, mask)
q_val = q - q_core

k = self._k_Z[species_id]
k = self.k_Z[species_id]

if self._alpha_mode == "reference":
k_scale = (
Expand Down

0 comments on commit 85040b2

Please sign in to comment.