Skip to content

Commit

Permalink
feat: only_left_angle to reduce memory cost
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Aug 3, 2023
1 parent cf44fe4 commit cc8ba51
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 27 deletions.
30 changes: 28 additions & 2 deletions tf_pwa/cal_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .angle import SU2M, EulerAngle, LorentzVector, Vector3, _epsilon
from .config import get_config
from .data import (
HeavyCall,
LazyCall,
data_index,
data_merge,
Expand Down Expand Up @@ -404,6 +405,7 @@ def cal_angle_from_particle(
r_boost=True,
final_rest=True,
align_ref=None, # "center_mass",
only_left_angle=False,
):
"""
Calculate helicity angle for particle momentum, add aligned angle.
Expand Down Expand Up @@ -473,6 +475,10 @@ def cal_angle_from_particle(
# ang = AlignmentAngle.angle_px_px(z1, x1, z2, x2)
part_data[i]["aligned_angle"] = ang
ret = data_strip(decay_data, ["r_matrix", "b_matrix", "x", "z"])
if only_left_angle:
for i in ret:
for j in ret[i]:
del ret[i][j][j.outs[1]]["ang"]

Check warning on line 481 in tf_pwa/cal_angle.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/cal_angle.py#L481

Added line #L481 was not covered by tests
return ret


Expand Down Expand Up @@ -628,6 +634,7 @@ def cal_angle_from_momentum_base(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -645,6 +652,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
ret = []
for i in split_generator(p, batch):
Expand All @@ -657,6 +665,7 @@ def cal_angle_from_momentum_base(
r_boost,
random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
)
return data_merge(*ret)
Expand Down Expand Up @@ -706,11 +715,20 @@ def cal_angle_from_momentum_id_swap(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
ret = []
id_particles = decs.identical_particles
data = cal_angle_from_momentum_base(
p, decs, using_topology, center_mass, r_boost, random_z, batch
p,
decs,
using_topology,
center_mass,
r_boost,
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if id_particles is None or len(id_particles) == 0:
return data
Expand All @@ -726,6 +744,7 @@ def cal_angle_from_momentum_id_swap(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -739,6 +758,7 @@ def cal_angle_from_momentum(
random_z=False,
batch=65000,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand All @@ -749,13 +769,15 @@ def cal_angle_from_momentum(
"""
if isinstance(p, LazyCall):
return LazyCall(
cal_angle_from_momentum,
HeavyCall(cal_angle_from_momentum),
p,
decs=decs,
using_topology=using_topology,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
batch=batch,
)
ret = []
Expand All @@ -770,6 +792,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if cp_particles is None or len(cp_particles) == 0:
return data
Expand All @@ -784,6 +807,7 @@ def cal_angle_from_momentum(
random_z,
batch,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
return data

Expand All @@ -796,6 +820,7 @@ def cal_angle_from_momentum_single(
r_boost=True,
random_z=True,
align_ref=None,
only_left_angle=False,
) -> CalAngleData:
"""
Transform 4-momentum data in files for the amplitude model automatically via DecayGroup.
Expand Down Expand Up @@ -823,6 +848,7 @@ def cal_angle_from_momentum_single(
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
data = {"particle": data_p, "decay": data_d}
add_relative_momentum(data)
Expand Down
7 changes: 5 additions & 2 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def process_scale(self, idx, data):

def set_lazy_call(self, data, idx):
if isinstance(data, LazyCall):
data.name = idx
data.cached_file = self.dic.get("cached_lazy_call", None)
name = idx
cached_file = self.dic.get("cached_lazy_call", None)
data.set_cached_file(cached_file, name)

def get_n_data(self):
data = self.get_data("data")
Expand All @@ -162,13 +163,15 @@ def cal_angle(self, p4, charge=None):
r_boost = self.dic.get("r_boost", True)
random_z = self.dic.get("random_z", True)
align_ref = self.dic.get("align_ref", None)
only_left_angle = self.dic.get("only_left_angle", False)
data = cal_angle_from_momentum(
p4,
self.decay_struct,
center_mass=center_mass,
r_boost=r_boost,
random_z=random_z,
align_ref=align_ref,
only_left_angle=only_left_angle,
)
if charge is not None:
data["charge_conjugation"] = charge
Expand Down
78 changes: 55 additions & 23 deletions tf_pwa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,62 +69,95 @@
from collections import Iterable


class HeavyCall:
def __init__(self, f):
self.f = f

def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)


class LazyCall:
def __init__(self, f, x, *args, **kwargs):
self.f = f
self.x = x
self.args = args
self.kwargs = kwargs
self.extra = {}
self.batch_size = None
self.cached_batch = {}
self.cached_file = None
self.name = ""
self.version = 0

def batch(self, batch, axis):
for i, j in zip(
data_split(self.x, batch, axis=axis),
data_split(self.extra, batch, axis=axis),
):
ret = LazyCall(self.f, i, *self.args, **self.kwargs)
for k, v in j.items():
ret[k] = v
yield ret

def batch(self, batch, axis=0):
return self.as_dataset(batch)

def __iter__(self):
assert self.batch_size is not None, ""
if isinstance(self.f, HeavyCall):
for i, j in zip(
self.cached_batch[self.batch_size],
split_generator(self.extra, self.batch_size),
):
yield {**i, **j}
elif isinstance(self.x, LazyCall):
for i, j in zip(
self.x, split_generator(self.extra, self.batch_size)
):
yield {**i, **j}

Check warning on line 107 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L107

Added line #L107 was not covered by tests
else:
for i, j in zip(
split_generator(self.x, self.batch_size),
split_generator(self.extra, self.batch_size),
):
yield {**i, **j}

Check warning on line 113 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L113

Added line #L113 was not covered by tests

def as_dataset(self, batch=65000):
self.batch_size = batch
if isinstance(self.x, LazyCall):
self.x.as_dataset(batch)

if not isinstance(self.f, HeavyCall):
return self

if batch in self.cached_batch:
return self.cached_batch[batch]
return self

def f(x):
x_a = x["x"]
extra = x["extra"]
ret = self.f(x_a, *self.args, **self.kwargs)
return {**ret, **extra}
ret = self.f(x, *self.args, **self.kwargs)
return ret

Check warning on line 128 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L127-L128

Added lines #L127 - L128 were not covered by tests

if isinstance(self.x, LazyCall):
real_x = self.x.eval()
else:
real_x = self.x

Check warning on line 133 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L133

Added line #L133 was not covered by tests

data = tf.data.Dataset.from_tensor_slices(
{"x": real_x, "extra": self.extra}
)
data = tf.data.Dataset.from_tensor_slices(real_x)
# data = data.batch(batch).cache().map(f)
if self.cached_file is not None:
from tf_pwa.utils import create_dir

cached_file = self.cached_file + self.name + str(self.version)
cached_file = self.cached_file + self.name

cached_file += "_" + str(batch)
create_dir(cached_file)
data = data.batch(batch).map(f)
data = data.cache(cached_file)
if self.cached_file == "":
data = data.cache()

Check warning on line 146 in tf_pwa/data.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/data.py#L146

Added line #L146 was not covered by tests
else:
data = data.cache(cached_file)
else:
data = data.batch(batch).cache().map(f)
data = data.prefetch(tf.data.AUTOTUNE)

self.cached_batch[batch] = data
return data
return self

def set_cached_file(self, cached_file, name):
if isinstance(self.x, LazyCall):
self.x.set_cached_file(cached_file, name)
self.cached_file = cached_file
self.name = name

def merge(self, *other, axis=0):
all_x = [self.x]
Expand Down Expand Up @@ -166,7 +199,6 @@ def copy(self):
ret.extra = self.extra.copy()
ret.cached_file = self.cached_file
ret.name = self.name
ret.version += self.version + 1
return ret

def eval(self):
Expand Down
1 change: 1 addition & 0 deletions tf_pwa/tests/config_lazycall.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ data:
use_tf_function: True
no_id_cached: True
jit_compile: True
only_left_angle: True
cached_lazy_call: toy_data/cached/

decay:
Expand Down

0 comments on commit cc8ba51

Please sign in to comment.