Skip to content

Commit

Permalink
一些比较脏的优化
Browse files Browse the repository at this point in the history
...

...
  • Loading branch information
hzhangxyz committed Aug 16, 2024
1 parent 8447116 commit 8109dc5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 37 deletions.
134 changes: 98 additions & 36 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,35 +192,59 @@ def add_observer(self, name, observers):
for positions, observer in observers.items():
if not isinstance(observer, self.owner.Tensor):
raise TypeError("Wrong observer type")
self._observer[name] = observers

def add_energy(self):
"""
Add energy as an observer.
"""
self.add_observer("energy", self.owner._hamiltonians)

def enable_gradient(self):
"""
Enable observing gradient.
"""
if self._start:
raise RuntimeError("Cannot enable gradient after sampling start")
if "energy" not in self._observer:
self.add_energy()
self._enable_gradient = True
self._observer[name] = {
positions: (
observer, # The tensor
torch.tensor(positions, device=torch.device("cpu")), # The position tensor
self._prepare_element(observer, positions),
) for positions, observer in observers.items()
}

def enable_natural_gradient(self):
"""
Enable observing natural gradient.
"""
if self._start:
raise RuntimeError("Cannot enable natural gradient after sampling start")
if not self._enable_gradient:
self.enable_gradient()
self._enable_natural = True
def _prepare_element(self, observer, positions):
element_pool = index_tensor_element(observer)
return {
x: {
y: (
item,
tensor_y,
self._fermi_sign(x, y, positions),
) for y, [item, tensor_y] in items.items()
} for x, items in element_pool.items()
}

def _fermi_sign(self, config, config_s, positions):
def _fermi_sign(self, x, y, positions):
if self.owner.op_pool is None:
return (
False,
torch.empty([0, 3], device=torch.device("cpu")),
torch.empty([0, 3], device=torch.device("cpu")),
)

L1, L2, orbit, dim = self.owner.op_pool.shape
x = torch.tensor(x, device=torch.device("cpu"))
y = torch.tensor(y, device=torch.device("cpu"))
positions = torch.tensor(positions, device=torch.device("cpu"))
indices = (positions[:, 0] * L2 + positions[:, 1]) * orbit + positions[:, 2]
fix_update = (indices.unsqueeze(0) > indices.unsqueeze(1)).triu()
op_x = self.owner.op_pool[positions[:, 0], positions[:, 1], positions[:, 2], x]
op_y = self.owner.op_pool[positions[:, 0], positions[:, 1], positions[:, 2], y]

mask_x = torch.zeros([L1 * L2 * orbit], device=torch.device("cpu"), dtype=torch.bool)
for index, op in zip(indices, op_x):
mask_x[:index] ^= op
mask_y = torch.zeros([L1 * L2 * orbit], device=torch.device("cpu"), dtype=torch.bool)
for index, op in zip(indices, op_y):
mask_y[:index] ^= op
count = 0
count += torch.sum(op_x.unsqueeze(0) * op_x.unsqueeze(1) * fix_update)
count += torch.sum(op_y.unsqueeze(0) * op_y.unsqueeze(1) * fix_update)
return (
count % 2 != 0,
mask_x.reshape([L1, L2, orbit]).to(dtype=torch.int64).nonzero(),
mask_y.reshape([L1, L2, orbit]).to(dtype=torch.int64).nonzero(),
)

def _fermi_sign_old(self, config, config_s, positions):
if self.owner.op_pool is None:
return +1

Expand Down Expand Up @@ -249,6 +273,32 @@ def _fermi_sign(self, config, config_s, positions):
else:
return -1

def add_energy(self):
"""
Add energy as an observer.
"""
self.add_observer("energy", self.owner._hamiltonians)

def enable_gradient(self):
"""
Enable observing gradient.
"""
if self._start:
raise RuntimeError("Cannot enable gradient after sampling start")
if "energy" not in self._observer:
self.add_energy()
self._enable_gradient = True

def enable_natural_gradient(self):
"""
Enable observing natural gradient.
"""
if self._start:
raise RuntimeError("Cannot enable natural gradient after sampling start")
if not self._enable_gradient:
self.enable_gradient()
self._enable_natural = True

def __call__(self, configurations, amplitudes, weights, multiplicities):
"""
Collect observer value from given configurations, the sampling should have distribution based on weights
Expand Down Expand Up @@ -448,20 +498,32 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
for batch_index in range(batch_size):
configuration_cpu = configurations_cpu[batch_index]
amplitude = amplitudes_cpu[batch_index]
parity = torch.gather(self.owner.op_pool, 3, configuration_cpu.unsqueeze(-1))
for name, observers in self._observer.items():
for positions, observer in observers.items():
for positions, [observer, positions_tensor, element_pool] in observers.items():
body = len(positions)
element_pool = index_tensor_element(observer)
positions_configuration = tuple(configuration_cpu[l1l2o].item() for l1l2o in positions)
positions_configuration = tuple(configuration_cpu[
positions_tensor[:, 0],
positions_tensor[:, 1],
positions_tensor[:, 2],
].tolist())
if positions_configuration not in element_pool:
continue
for positions_configuration_s, item in element_pool[positions_configuration].items():
for positions_configuration_s, [
item, tensor_positions_configuration_s, [base_sign, fermi_sign, fermi_sign_s]
] in element_pool[positions_configuration].items():
configuration_cpu_s = configuration_cpu.clone()
for l1l2o, value in zip(positions, positions_configuration_s):
configuration_cpu_s[l1l2o] = value
configuration_cpu_s[
positions_tensor[:, 0],
positions_tensor[:, 1],
positions_tensor[:, 2],
] = tensor_positions_configuration_s
parity_s = torch.gather(self.owner.op_pool, 3, configuration_cpu_s.unsqueeze(-1))
# self.owner(configuration_s) to be multiplied
value = item * self._fermi_sign(configuration_cpu, configuration_cpu_s,
positions) / amplitude.conj()
total_parity = ((parity[fermi_sign[:, 0], fermi_sign[:, 1], fermi_sign[:, 2]].sum() % 2 != 0) ^
(parity_s[fermi_sign_s[:, 0], fermi_sign_s[:, 1], fermi_sign_s[:, 2]].sum() % 2
!= 0) ^ base_sign)
value = item * (-1 if total_parity else +1) / amplitude.conj()
if torch.equal(configuration_cpu_s, configuration_cpu):
result[batch_index][name][positions] += amplitude.conj().item() * complex(value)
whole_result[batch_index][name] += amplitude.conj().item() * complex(value)
Expand All @@ -487,7 +549,7 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
self._total_weight_square += multiplicity * reweight * reweight
self._total_log_ws += multiplicity * amplitude.abs().log().item()

for name, observers in self._observer.items():
for name, _ in self._observer.items():
# for positions in observers:
# to_save = result[batch_index][name][positions].real
# self._result_reweight[name][positions] += multiplicity * to_save * reweight
Expand Down
5 changes: 4 additions & 1 deletion tetragono/tetragono/sampling_neural_state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def index_tensor_element(tensor, pool={}):
new_pack = result[new_x] = {}
for y, tensor_shrinked in pack.items():
new_y = tuple(tensor.edge_by_name(f"O{rank}").index_by_point(config) for rank, config in enumerate(y))
new_pack[new_y] = tensor_shrinked.transpose(standard_names).storage[0].item()
new_pack[new_y] = (
tensor_shrinked.transpose(standard_names).storage[0].item(),
torch.tensor(new_y, device=torch.device("cpu")),
)
pool[tensor_id] = result
return pool[tensor_id]

Expand Down

0 comments on commit 8109dc5

Please sign in to comment.