Skip to content

Commit

Permalink
[tet.py, tetku.py] Add support for trim style hamiltonian.
Browse files Browse the repository at this point in the history
This kind of hamiltonian is useful especially for quantum chemistry
models, which is set by attribute["quantum_chemistry_term"].
  • Loading branch information
hzhangxyz committed Jul 30, 2024
1 parent e16bc66 commit e576340
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 3 deletions.
168 changes: 165 additions & 3 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch
from ..utility import allreduce_buffer, allreduce_number, show, showln
from .state import Configuration, index_tensor_element
from .state import Configuration, index_tensor_element, torch_grad


class Observer():
Expand Down Expand Up @@ -284,6 +284,167 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
} for _ in range(batch_size)]
whole_result = [{name: 0.0 for name in self._observer} for _ in range(batch_size)]

quantum_chemistry_term = [0 for _ in range(batch_size)]
if "quantum_chemistry_term" in self.owner.attribute:
with torch_grad(False):
# Precalculate hopping
# batch_size * site * site
# we only calculate site a and site b exchange here
# since for quantum chemistry model hopping is just exchange if one has electron and the other is empty
# configurations_cpu : batch * L1 * L2 * 1
# configurations_precalc : (site * site-1 / 2) * batch * L1 * L2 * 1
configurations_precalc = configurations.unsqueeze(0).repeat(
[self.owner.site_number * (self.owner.site_number - 1) // 2 + 1, 1, 1, 1, 1])
for al1, al2 in self.owner.sites():
for bl1, bl2 in self.owner.sites():
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
if ai < bi:
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
a = configurations_precalc[i, :, al1, al2, :].clone()
b = configurations_precalc[i, :, bl1, bl2, :]
configurations_precalc[i, :, al1, al2, :] = b
configurations_precalc[i, :, bl1, bl2, :] = a
amplitudes_precalc = self.owner(configurations_precalc.reshape([-1, self.owner.L1, self.owner.L2, 1]),
enable_grad=False).reshape([
self.owner.site_number * (self.owner.site_number - 1) // 2 + 1,
batch_size
])
amplitudes_precalc_cpu = amplitudes_precalc.cpu()
gradients_precalc_norm = torch.zeros_like(amplitudes_precalc_cpu)
gradients_precalc_conj = torch.zeros_like(amplitudes_precalc_cpu)

for batch_index in range(batch_size):
configuration_cpu = configurations_cpu[batch_index]
amplitude = amplitudes_cpu[batch_index]
weight = weights_cpu[batch_index]
multiplicity = multiplicities_cpu[batch_index]
reweight = (amplitude.abs()**2 / weight).item() # <psi|s|psi> / p(s)
energy = 0
for positions, observer, exists in self.owner.attribute["quantum_chemistry_term"][0]:
if not all(configuration_cpu[exist] == 1 for exist in exists):
continue
body = 2
element_pool = index_tensor_element(observer)
positions_configuration = tuple(configuration_cpu[l1l2o].item() for l1l2o in positions)
if positions_configuration not in element_pool:
continue
sub_energy = 0
for positions_configuration_s, item in element_pool[positions_configuration].items():
configuration_cpu_s = configuration_cpu.clone()
for l1l2o, value in zip(positions, positions_configuration_s):
configuration_cpu_s[l1l2o] = value
# we only have hopping item, since position configuration found in element pool, it must be a swap item.
((al1, al2, _), (bl1, bl2, _)) = positions
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
ai, bi = sorted([ai, bi])
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
value = item * self._fermi_sign(
configuration_cpu, configuration_cpu_s,
positions) * amplitudes_precalc_cpu[i, batch_index].conj() / amplitude.conj()
sub_energy = sub_energy + value
if self._enable_gradient:
gradients_precalc_conj[i, batch_index] += multiplicity * reweight * value / 2
energy = energy + sub_energy
if self._enable_gradient:
gradients_precalc_norm[0, batch_index] += multiplicity * reweight * sub_energy / 2
for positions_1, observer_1, positions_2, observer_2, exists in self.owner.attribute[
"quantum_chemistry_term"][1]:
if not all(configuration_cpu[exist] == 1 for exist in exists):
continue
body_1 = body_2 = 2
element_pool_1 = index_tensor_element(observer_1)
element_pool_2 = index_tensor_element(observer_2)
positions_configuration_1 = tuple(configuration_cpu[l1l2o].item() for l1l2o in positions_1)
positions_configuration_2 = tuple(configuration_cpu[l1l2o].item() for l1l2o in positions_2)
if positions_configuration_1 not in element_pool_1:
continue
if positions_configuration_2 not in element_pool_2:
continue
sub_energy_1 = 0
for positions_configuration_s_1, item_1 in element_pool_1[positions_configuration_1].items():
configuration_cpu_s_1 = configuration_cpu.clone()
for l1l2o, value in zip(positions_1, positions_configuration_s_1):
configuration_cpu_s_1[l1l2o] = value
# we only have hopping item, since position configuration found in element pool, it must be a swap item.
((al1, al2, _), (bl1, bl2, _)) = positions_1
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
ai, bi = sorted([ai, bi])
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
value = item_1 * self._fermi_sign(
configuration_cpu, configuration_cpu_s_1,
positions_1) * amplitudes_precalc_cpu[i, batch_index].conj() / amplitude.conj()
sub_energy_1 = sub_energy_1 + value
sub_energy_2 = 0
for positions_configuration_s_2, item_2 in element_pool_2[positions_configuration_2].items():
configuration_cpu_s_2 = configuration_cpu.clone()
for l1l2o, value in zip(positions_2, positions_configuration_s_2):
configuration_cpu_s_2[l1l2o] = value
# we only have hopping item, since position configuration found in element pool, it must be a swap item.
((al1, al2, _), (bl1, bl2, _)) = positions_2
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
ai, bi = sorted([ai, bi])
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
value = item_2 * self._fermi_sign(
configuration_cpu, configuration_cpu_s_2,
positions_2) * amplitudes_precalc_cpu[i, batch_index].conj() / amplitude.conj()
sub_energy_2 = sub_energy_2 + value
energy = energy + sub_energy_1 * sub_energy_2.conj()
if self._enable_gradient:
for positions_configuration_s_1, item_1 in element_pool_1[positions_configuration_1].items(
):
configuration_cpu_s_1 = configuration_cpu.clone()
for l1l2o, value in zip(positions_1, positions_configuration_s_1):
configuration_cpu_s_1[l1l2o] = value
# we only have hopping item, since position configuration found in element pool, it must be a swap item.
((al1, al2, _), (bl1, bl2, _)) = positions_1
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
ai, bi = sorted([ai, bi])
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
value = item_1 * self._fermi_sign(
configuration_cpu, configuration_cpu_s_1,
positions_1) * amplitudes_precalc_cpu[i, batch_index].conj() / amplitude.conj()
gradients_precalc_conj[
i, batch_index] += multiplicity * reweight * value * sub_energy_2.conj() / 2

for positions_configuration_s_2, item_2 in element_pool_2[positions_configuration_2].items(
):
configuration_cpu_s_2 = configuration_cpu.clone()
for l1l2o, value in zip(positions_2, positions_configuration_s_2):
configuration_cpu_s_2[l1l2o] = value
# we only have hopping item, since position configuration found in element pool, it must be a swap item.
((al1, al2, _), (bl1, bl2, _)) = positions_2
ai = al1 * self.owner.L2 + al2
bi = bl1 * self.owner.L2 + bl2
ai, bi = sorted([ai, bi])
i = (2 * self.owner.site_number - ai - 1) * ai // 2 + (bi - ai - 1) + 1
value = item_2 * self._fermi_sign(
configuration_cpu, configuration_cpu_s_2,
positions_2) * amplitudes_precalc_cpu[i, batch_index].conj() / amplitude.conj()
gradients_precalc_norm[
i, batch_index] += multiplicity * reweight * sub_energy_1 * value.conj() / 2

whole_result[batch_index]["energy"] += complex(energy)
quantum_chemistry_term[batch_index] = complex(energy)
if self._enable_gradient:
edelta = 0
for i in range(self.owner.site_number * (self.owner.site_number - 1) // 2 + 1):
for j in range(batch_size):
if gradients_precalc_norm[i, j] != 0 or gradients_precalc_conj[i, j] != 0:
configuration = configurations_precalc[i, j]
amplitude = self.owner(configuration.unsqueeze(0), enable_grad=True)
grad = self.owner.holes(amplitude)
edelta = edelta + gradients_precalc_norm[i, j] * grad
edelta = edelta + gradients_precalc_conj[i, j] * grad.conj()
if self._EDelta is None:
self._EDelta = edelta
else:
self._EDelta += edelta

for batch_index in range(batch_size):
configuration_cpu = configurations_cpu[batch_index]
amplitude = amplitudes_cpu[batch_index]
Expand Down Expand Up @@ -340,7 +501,8 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
self._total_imaginary_energy_reweight += multiplicity * whole_result[batch_index][
name].imag * reweight
if name == "energy" and self._enable_gradient:
Es = whole_result[batch_index][name]
Es = whole_result[batch_index][name] - quantum_chemistry_term[batch_index]
# The EDelta contributed by quantum chemistry term has been collected before.
if self.owner.Tensor.is_real:
Es = Es.real

Expand Down Expand Up @@ -571,7 +733,7 @@ def DT(v):
def A(v):
return DT(D(v))

b = DT(Energy)
b = ((self._EDelta / self._total_weight) - energy * (self._Delta / self._total_weight))
b_square = torch.dot(torch.conj(b), b).real

x = torch.zeros_like(b)
Expand Down
Loading

0 comments on commit e576340

Please sign in to comment.