From cc8ba5138180d0fef0ba6596e0e43cdbe15cbc0b Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Thu, 3 Aug 2023 14:07:51 +0800 Subject: [PATCH] feat: only_left_angle to reduce memory cost --- tf_pwa/cal_angle.py | 30 +++++++++++- tf_pwa/config_loader/data.py | 7 ++- tf_pwa/data.py | 78 ++++++++++++++++++++++---------- tf_pwa/tests/config_lazycall.yml | 1 + 4 files changed, 89 insertions(+), 27 deletions(-) diff --git a/tf_pwa/cal_angle.py b/tf_pwa/cal_angle.py index 0e74b1bb..b7113912 100644 --- a/tf_pwa/cal_angle.py +++ b/tf_pwa/cal_angle.py @@ -56,6 +56,7 @@ from .angle import SU2M, EulerAngle, LorentzVector, Vector3, _epsilon from .config import get_config from .data import ( + HeavyCall, LazyCall, data_index, data_merge, @@ -404,6 +405,7 @@ def cal_angle_from_particle( r_boost=True, final_rest=True, align_ref=None, # "center_mass", + only_left_angle=False, ): """ Calculate helicity angle for particle momentum, add aligned angle. @@ -473,6 +475,10 @@ def cal_angle_from_particle( # ang = AlignmentAngle.angle_px_px(z1, x1, z2, x2) part_data[i]["aligned_angle"] = ang ret = data_strip(decay_data, ["r_matrix", "b_matrix", "x", "z"]) + if only_left_angle: + for i in ret: + for j in ret[i]: + del ret[i][j][j.outs[1]]["ang"] return ret @@ -628,6 +634,7 @@ def cal_angle_from_momentum_base( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -645,6 +652,7 @@ def cal_angle_from_momentum_base( r_boost, random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) ret = [] for i in split_generator(p, batch): @@ -657,6 +665,7 @@ def cal_angle_from_momentum_base( r_boost, random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) ) return data_merge(*ret) @@ -706,11 +715,20 @@ def cal_angle_from_momentum_id_swap( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: ret = [] id_particles = decs.identical_particles data = cal_angle_from_momentum_base( - p, decs, using_topology, center_mass, r_boost, random_z, batch + p, + decs, + using_topology, + center_mass, + r_boost, + random_z, + batch, + align_ref=align_ref, + only_left_angle=only_left_angle, ) if id_particles is None or len(id_particles) == 0: return data @@ -726,6 +744,7 @@ def cal_angle_from_momentum_id_swap( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) return data @@ -739,6 +758,7 @@ def cal_angle_from_momentum( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -749,13 +769,15 @@ def cal_angle_from_momentum( """ if isinstance(p, LazyCall): return LazyCall( - cal_angle_from_momentum, + HeavyCall(cal_angle_from_momentum), p, decs=decs, using_topology=using_topology, center_mass=center_mass, r_boost=r_boost, random_z=random_z, + align_ref=align_ref, + only_left_angle=only_left_angle, batch=batch, ) ret = [] @@ -770,6 +792,7 @@ def cal_angle_from_momentum( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) if cp_particles is None or len(cp_particles) == 0: return data @@ -784,6 +807,7 @@ def cal_angle_from_momentum( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) return data @@ -796,6 +820,7 @@ def cal_angle_from_momentum_single( r_boost=True, random_z=True, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -823,6 +848,7 @@ def cal_angle_from_momentum_single( r_boost=r_boost, random_z=random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) data = {"particle": data_p, "decay": data_d} add_relative_momentum(data) diff --git a/tf_pwa/config_loader/data.py b/tf_pwa/config_loader/data.py index 7c07bead..462f6b91 100644 --- a/tf_pwa/config_loader/data.py +++ b/tf_pwa/config_loader/data.py @@ -139,8 +139,9 @@ def process_scale(self, idx, data): def set_lazy_call(self, data, idx): if isinstance(data, LazyCall): - data.name = idx - data.cached_file = self.dic.get("cached_lazy_call", None) + name = idx + cached_file = self.dic.get("cached_lazy_call", None) + data.set_cached_file(cached_file, name) def get_n_data(self): data = self.get_data("data") @@ -162,6 +163,7 @@ def cal_angle(self, p4, charge=None): r_boost = self.dic.get("r_boost", True) random_z = self.dic.get("random_z", True) align_ref = self.dic.get("align_ref", None) + only_left_angle = self.dic.get("only_left_angle", False) data = cal_angle_from_momentum( p4, self.decay_struct, @@ -169,6 +171,7 @@ def cal_angle(self, p4, charge=None): r_boost=r_boost, random_z=random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) if charge is not None: data["charge_conjugation"] = charge diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 63b70c11..745f670e 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -69,6 +69,14 @@ from collections import Iterable +class HeavyCall: + def __init__(self, f): + self.f = f + + def __call__(self, *args, **kwargs): + return self.f(*args, **kwargs) + + class LazyCall: def __init__(self, f, x, *args, **kwargs): self.f = f @@ -76,55 +84,80 @@ def __init__(self, f, x, *args, **kwargs): self.args = args self.kwargs = kwargs self.extra = {} + self.batch_size = None self.cached_batch = {} self.cached_file = None self.name = "" - self.version = 0 - - def batch(self, batch, axis): - for i, j in zip( - data_split(self.x, batch, axis=axis), - data_split(self.extra, batch, axis=axis), - ): - ret = LazyCall(self.f, i, *self.args, **self.kwargs) - for k, v in j.items(): - ret[k] = v - yield ret + + def batch(self, batch, axis=0): + return self.as_dataset(batch) + + def __iter__(self): + assert self.batch_size is not None, "" + if isinstance(self.f, HeavyCall): + for i, j in zip( + self.cached_batch[self.batch_size], + split_generator(self.extra, self.batch_size), + ): + yield {**i, **j} + elif isinstance(self.x, LazyCall): + for i, j in zip( + self.x, split_generator(self.extra, self.batch_size) + ): + yield {**i, **j} + else: + for i, j in zip( + split_generator(self.x, self.batch_size), + split_generator(self.extra, self.batch_size), + ): + yield {**i, **j} def as_dataset(self, batch=65000): + self.batch_size = batch + if isinstance(self.x, LazyCall): + self.x.as_dataset(batch) + + if not isinstance(self.f, HeavyCall): + return self + if batch in self.cached_batch: - return self.cached_batch[batch] + return self def f(x): - x_a = x["x"] - extra = x["extra"] - ret = self.f(x_a, *self.args, **self.kwargs) - return {**ret, **extra} + ret = self.f(x, *self.args, **self.kwargs) + return ret if isinstance(self.x, LazyCall): real_x = self.x.eval() else: real_x = self.x - data = tf.data.Dataset.from_tensor_slices( - {"x": real_x, "extra": self.extra} - ) + data = tf.data.Dataset.from_tensor_slices(real_x) # data = data.batch(batch).cache().map(f) if self.cached_file is not None: from tf_pwa.utils import create_dir - cached_file = self.cached_file + self.name + str(self.version) + cached_file = self.cached_file + self.name cached_file += "_" + str(batch) create_dir(cached_file) data = data.batch(batch).map(f) - data = data.cache(cached_file) + if self.cached_file == "": + data = data.cache() + else: + data = data.cache(cached_file) else: data = data.batch(batch).cache().map(f) data = data.prefetch(tf.data.AUTOTUNE) self.cached_batch[batch] = data - return data + return self + + def set_cached_file(self, cached_file, name): + if isinstance(self.x, LazyCall): + self.x.set_cached_file(cached_file, name) + self.cached_file = cached_file + self.name = name def merge(self, *other, axis=0): all_x = [self.x] @@ -166,7 +199,6 @@ def copy(self): ret.extra = self.extra.copy() ret.cached_file = self.cached_file ret.name = self.name - ret.version += self.version + 1 return ret def eval(self): diff --git a/tf_pwa/tests/config_lazycall.yml b/tf_pwa/tests/config_lazycall.yml index 735a889d..490e43e8 100644 --- a/tf_pwa/tests/config_lazycall.yml +++ b/tf_pwa/tests/config_lazycall.yml @@ -10,6 +10,7 @@ data: use_tf_function: True no_id_cached: True jit_compile: True + only_left_angle: True cached_lazy_call: toy_data/cached/ decay: