Skip to content

Commit

Permalink
some dirty opt.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Aug 18, 2024
1 parent 8109dc5 commit 13e1317
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def __init__(
if enable_natural_gradient:
self.enable_natural_gradient()

def _positions_tensor(self, positions):
L1, L2, orbit, dim = self.owner.op_pool.shape
positions = torch.tensor(positions, device=torch.device("cpu"))
indices = (positions[:, 0] * L2 + positions[:, 1]) * orbit + positions[:, 2]
return indices

def add_observer(self, name, observers):
"""
Add an observer set into this observer object, cannot add observer once observer started.
Expand All @@ -192,10 +198,11 @@ def add_observer(self, name, observers):
for positions, observer in observers.items():
if not isinstance(observer, self.owner.Tensor):
raise TypeError("Wrong observer type")
L1, L2, orbit, dim = self.owner.op_pool.shape
self._observer[name] = {
positions: (
observer, # The tensor
torch.tensor(positions, device=torch.device("cpu")), # The position tensor
observer,
self._positions_tensor(positions),
self._prepare_element(observer, positions),
) for positions, observer in observers.items()
}
Expand All @@ -214,11 +221,7 @@ def _prepare_element(self, observer, positions):

def _fermi_sign(self, x, y, positions):
if self.owner.op_pool is None:
return (
False,
torch.empty([0, 3], device=torch.device("cpu")),
torch.empty([0, 3], device=torch.device("cpu")),
)
return (False, torch.empty([0, 3], device=torch.device("cpu")))

L1, L2, orbit, dim = self.owner.op_pool.shape
x = torch.tensor(x, device=torch.device("cpu"))
Expand All @@ -238,11 +241,10 @@ def _fermi_sign(self, x, y, positions):
count = 0
count += torch.sum(op_x.unsqueeze(0) * op_x.unsqueeze(1) * fix_update)
count += torch.sum(op_y.unsqueeze(0) * op_y.unsqueeze(1) * fix_update)
return (
count % 2 != 0,
mask_x.reshape([L1, L2, orbit]).to(dtype=torch.int64).nonzero(),
mask_y.reshape([L1, L2, orbit]).to(dtype=torch.int64).nonzero(),
)
for index, o_x, o_y in zip(indices, op_x, op_y):
count += (o_x ^ o_y) and mask_y[index]
mask = mask_x ^ mask_y
return (count % 2 != 0, mask.to(dtype=torch.int64).nonzero())

def _fermi_sign_old(self, config, config_s, positions):
if self.owner.op_pool is None:
Expand Down Expand Up @@ -498,31 +500,20 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
for batch_index in range(batch_size):
configuration_cpu = configurations_cpu[batch_index]
amplitude = amplitudes_cpu[batch_index]
parity = torch.gather(self.owner.op_pool, 3, configuration_cpu.unsqueeze(-1))
parity = torch.gather(self.owner.op_pool, 3, configuration_cpu.unsqueeze(-1)).reshape([-1])
for name, observers in self._observer.items():
for positions, [observer, positions_tensor, element_pool] in observers.items():
body = len(positions)
positions_configuration = tuple(configuration_cpu[
positions_tensor[:, 0],
positions_tensor[:, 1],
positions_tensor[:, 2],
].tolist())
positions_configuration = tuple(configuration_cpu.view([-1])[positions_tensor].tolist())
if positions_configuration not in element_pool:
continue
for positions_configuration_s, [
item, tensor_positions_configuration_s, [base_sign, fermi_sign, fermi_sign_s]
] in element_pool[positions_configuration].items():
for positions_configuration_s, [item, tensor_positions_configuration_s,
[base_sign,
fermi_sign]] in element_pool[positions_configuration].items():
configuration_cpu_s = configuration_cpu.clone()
configuration_cpu_s[
positions_tensor[:, 0],
positions_tensor[:, 1],
positions_tensor[:, 2],
] = tensor_positions_configuration_s
parity_s = torch.gather(self.owner.op_pool, 3, configuration_cpu_s.unsqueeze(-1))
configuration_cpu_s.view([-1])[positions_tensor] = tensor_positions_configuration_s
# self.owner(configuration_s) to be multiplied
total_parity = ((parity[fermi_sign[:, 0], fermi_sign[:, 1], fermi_sign[:, 2]].sum() % 2 != 0) ^
(parity_s[fermi_sign_s[:, 0], fermi_sign_s[:, 1], fermi_sign_s[:, 2]].sum() % 2
!= 0) ^ base_sign)
total_parity = ((parity[fermi_sign].sum() % 2 != 0) ^ base_sign)
value = item * (-1 if total_parity else +1) / amplitude.conj()
if torch.equal(configuration_cpu_s, configuration_cpu):
result[batch_index][name][positions] += amplitude.conj().item() * complex(value)
Expand Down

0 comments on commit 13e1317

Please sign in to comment.