Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for large data #88

Merged
merged 17 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _get_cg_matrix(self, ls): # CG factor inside H
lambda_b - lambda_c,
)
)
return tf.convert_to_tensor(ret)
return ret

def get_helicity_amp(self, data, data_p, **kwargs):
m_dep = self.get_ls_amp(data, data_p, **kwargs)
Expand Down Expand Up @@ -1622,6 +1622,8 @@ def add_used_chains(self, used_chains):
self.chains_idx.append(i)

def set_used_chains(self, used_chains):
if isinstance(used_chains, str):
used_chains = [used_chains]
self.chains_idx = list(used_chains)
if len(self.chains_idx) != len(self.chains):
self.not_full = True
Expand Down Expand Up @@ -1704,10 +1706,18 @@ def value_and_grad(f, var):

class AmplitudeModel(object):
def __init__(
self, decay_group, name="", polar=None, vm=None, use_tf_function=False
self,
decay_group,
name="",
polar=None,
vm=None,
use_tf_function=False,
no_id_cached=False,
jit_compile=False,
):
self.decay_group = decay_group
self._name = name
self.no_id_cached = no_id_cached
with variable_scope(vm) as vm:
if polar is not None:
vm.polar = polar
Expand All @@ -1720,7 +1730,9 @@ def __init__(
if use_tf_function:
from tf_pwa.experimental.wrap_function import WrapFun

self.cached_fun = WrapFun(self.decay_group.sum_amp)
self.cached_fun = WrapFun(
self.decay_group.sum_amp, jit_compile=jit_compile
)
else:
self.cached_fun = self.decay_group.sum_amp

Expand Down Expand Up @@ -1783,7 +1795,7 @@ def trainable_variables(self):
def __call__(self, data, cached=False):
if isinstance(data, LazyCall):
data = data.eval()
if id(data) in self.f_data:
if id(data) in self.f_data or self.no_id_cached:
if not self.decay_group.not_full:
return self.cached_fun(data)
else:
Expand Down
37 changes: 31 additions & 6 deletions tf_pwa/cal_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .angle import SU2M, EulerAngle, LorentzVector, Vector3, _epsilon
from .config import get_config
from .data import (
HeavyCall,
LazyCall,
data_index,
data_merge,
Expand Down Expand Up @@ -261,8 +262,8 @@ def cal_single_boost(data, decay_chain: DecayChain) -> dict:
def cal_helicity_angle(
data: dict,
decay_chain: DecayChain,
base_z=np.array([[0.0, 0.0, 1.0]]),
base_x=np.array([[1.0, 0.0, 0.0]]),
base_z=np.array([0.0, 0.0, 1.0]),
base_x=np.array([1.0, 0.0, 0.0]),
) -> dict:
"""
Calculate helicity angle for A -> B + C: :math:`\\theta_{B}^{A}, \\phi_{B}^{A}` from momentum.
Expand All @@ -276,7 +277,6 @@ def cal_helicity_angle(

# print(decay_chain, part_data)
part_data = cal_chain_boost(data, decay_chain)
# print(decay_chain , part_data)
# calculate angle and base x,z axis from mother particle rest frame momentum and base axis
set_x = {decay_chain.top: base_x}
set_z = {decay_chain.top: base_z}
Expand Down Expand Up @@ -405,6 +405,7 @@ def cal_angle_from_particle(
r_boost=True,
final_rest=True,
align_ref=None, # "center_mass",
only_left_angle=False,
):
"""
Calculate helicity angle for particle momentum, add aligned angle.
Expand All @@ -422,7 +423,7 @@ def cal_angle_from_particle(
# get base z axis
p4 = data[decay_group.top]["p"]
p3 = LorentzVector.vect(p4)
base_z = np.array([[0.0, 0.0, 1.0]]) + tf.zeros_like(p3)
base_z = np.array([0.0, 0.0, 1.0]) + tf.zeros_like(p3)
if random_z:
p3_norm = Vector3.norm(p3)
mask = tf.expand_dims(p3_norm < 1e-5, -1)
Expand Down Expand Up @@ -474,6 +475,10 @@ def cal_angle_from_particle(
# ang = AlignmentAngle.angle_px_px(z1, x1, z2, x2)
part_data[i]["aligned_angle"] = ang
ret = data_strip(decay_data, ["r_matrix", "b_matrix", "x", "z"])
if only_left_angle:
for i in ret:
for j in ret[i]:
del ret[i][j][j.outs[1]]["ang"]
return ret


Expand Down Expand Up @@ -629,6 +634,7 @@ def cal_angle_from_momentum_base(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -646,6 +652,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
ret = []
for i in split_generator(p, batch):
Expand All @@ -658,6 +665,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
)
return data_merge(*ret)
Expand Down Expand Up @@ -707,11 +715,20 @@ def cal_angle_from_momentum_id_swap(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
ret = []
id_particles = decs.identical_particles
data = cal_angle_from_momentum_base(
p, decs, using_topology, center_mass, r_boost, random_z, batch
p,
decs,
using_topology,
center_mass,
r_boost,
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if id_particles is None or len(id_particles) == 0:
return data
Expand All @@ -727,6 +744,7 @@ def cal_angle_from_momentum_id_swap(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -740,6 +758,7 @@ def cal_angle_from_momentum(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -750,13 +769,15 @@ def cal_angle_from_momentum(
"""
if isinstance(p, LazyCall):
return LazyCall(
cal_angle_from_momentum,
HeavyCall(cal_angle_from_momentum),
p,
decs=decs,
using_topology=using_topology,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
batch=batch,
)
ret = []
Expand All @@ -771,6 +792,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if cp_particles is None or len(cp_particles) == 0:
return data
Expand All @@ -785,6 +807,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -797,6 +820,7 @@ def cal_angle_from_momentum_single(
r_boost=True,
random_z=True,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand Down Expand Up @@ -824,6 +848,7 @@ def cal_angle_from_momentum_single(
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
data = {"particle": data_p, "decay": data_d}
add_relative_momentum(data)
Expand Down
21 changes: 16 additions & 5 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,23 @@ def get_decay(self, full=True):

@functools.lru_cache()
def get_amplitude(self, vm=None, name=""):
use_tf_function = self.config.get("data", {}).get(
"use_tf_function", False
)
amp_config = self.config.get("data", {})
use_tf_function = amp_config.get("use_tf_function", False)
no_id_cached = amp_config.get("no_id_cached", False)
jit_compile = amp_config.get("jit_compile", False)
decay_group = self.full_decay
self.check_valid_jp(decay_group)
if vm is None:
vm = self.vm
if vm in self.amps:
return self.amps[vm]
amp = AmplitudeModel(
decay_group, vm=vm, name=name, use_tf_function=use_tf_function
decay_group,
vm=vm,
name=name,
use_tf_function=use_tf_function,
no_id_cached=no_id_cached,
jit_compile=jit_compile,
)
self.add_constraints(amp)
self.amps[vm] = amp
Expand Down Expand Up @@ -561,6 +567,7 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
bg = [None] * self._Ngroup
model = self._get_model(vm=vm, name=name)
fcns = []

# print(self.config["data"].get("using_mix_likelihood", False))
if self.config["data"].get("using_mix_likelihood", False):
print(" Using Mix Likelihood")
Expand All @@ -575,7 +582,9 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""):
if all_data is None:
self.cached_fcn[vm] = fcn
return fcn
for md, dt, mc, sb, ij in zip(model, data, phsp, bg, inmc):
for idx, (md, dt, mc, sb, ij) in enumerate(
zip(model, data, phsp, bg, inmc)
):
if self.config["data"].get("model", "auto") == "cfit":
fcns.append(
FCN(
Expand Down Expand Up @@ -644,6 +653,7 @@ def fit(
maxiter=None,
jac=True,
print_init_nll=True,
callback=None,
):
if data is None and phsp is None:
data, phsp, bg, inmc = self.get_all_data()
Expand Down Expand Up @@ -677,6 +687,7 @@ def fit(
improve=False,
maxiter=maxiter,
jac=jac,
callback=callback,
)
if self.fit_params.hess_inv is not None:
self.inv_he = self.fit_params.hess_inv
Expand Down
46 changes: 31 additions & 15 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def get_data(self, idx) -> dict:
weight_sign = self.get_weight_sign(idx)
charge = self.dic.get(idx + "_charge", None)
ret = self.load_data(files, weights, weight_sign, charge)
return self.process_scale(idx, ret)
ret = self.process_scale(idx, ret)
return ret

def process_scale(self, idx, data):
if idx in self.scale_list and self.dic.get("weight_scale", False):
Expand All @@ -136,6 +137,12 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
name = idx
cached_file = self.dic.get("cached_lazy_call", None)
data.set_cached_file(cached_file, name)

def get_n_data(self):
data = self.get_data("data")
weight = data.get("weight", np.ones((data_shape(data),)))
Expand All @@ -156,13 +163,15 @@ def cal_angle(self, p4, charge=None):
r_boost = self.dic.get("r_boost", True)
random_z = self.dic.get("random_z", True)
align_ref = self.dic.get("align_ref", None)
only_left_angle = self.dic.get("only_left_angle", False)
data = cal_angle_from_momentum(
p4,
self.decay_struct,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if charge is not None:
data["charge_conjugation"] = charge
Expand All @@ -185,18 +194,17 @@ def load_data(
p4 = self.load_p4(files)
charges = None if charges is None else charges[: data_shape(p4)]
data = self.cal_angle(p4, charges)
if weights is not None:
if isinstance(weights, float):
data["weight"] = np.array(
[weights * weights_sign] * data_shape(data)
)
elif isinstance(weights, str): # weight files
weight = self.load_weight_file(weights)
data["weight"] = weight[: data_shape(data)] * weights_sign
else:
raise TypeError(
"weight format error: {}".format(type(weights))
)
if weights is None:
data["weight"] = np.array([1.0 * weights_sign] * data_shape(data))
elif isinstance(weights, float):
data["weight"] = np.array(
[weights * weights_sign] * data_shape(data)
)
elif isinstance(weights, str): # weight files
weight = self.load_weight_file(weights)
data["weight"] = weight[: data_shape(data)] * weights_sign
else:
raise TypeError("weight format error: {}".format(type(weights)))

if charge is None:
data["charge_conjugation"] = tf.ones((data_shape(data),))
Expand Down Expand Up @@ -322,8 +330,11 @@ def savetxt(self, file_name, data):
else:
raise ValueError("not support data")
p4 = data_to_numpy(p4)
p4 = np.stack(p4).transpose((1, 0, 2)).reshape((-1, 4))
np.savetxt(file_name, p4)
p4 = np.stack(p4).transpose((1, 0, 2))
if file_name.endswith("npy"):
np.save(file_name, p4)
else:
np.savetxt(file_name, p4.reshape((-1, 4)))


@register_data_mode("multi")
Expand All @@ -342,6 +353,10 @@ def process_scale(self, idx, data):
)
return data

def set_lazy_call(self, data, idx):
for i, data_i in enumerate(data):
super().set_lazy_call(data_i, "s{}{}".format(i, idx))

def get_n_data(self):
data = self.get_data("data")
weight = [
Expand Down Expand Up @@ -405,6 +420,7 @@ def get_data(self, idx) -> list:
data_shape(k)
)
ret = self.process_scale(idx, ret)
self.set_lazy_call(ret, idx)
return ret

def get_phsp_noeff(self):
Expand Down
Loading