Skip to content

Commit

Permalink
Merge pull request #140 from jiangyi15/histogram_fill
Browse files Browse the repository at this point in the history
Histogram fill
  • Loading branch information
jiangyi15 authored Feb 18, 2024
2 parents ac14967 + 56ceea2 commit 8225a1b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 55 deletions.
37 changes: 19 additions & 18 deletions tf_pwa/amp/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class InterpLinearNpy(InterpolationParticle):
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".npy")
>>> m = np.linspace(0.2, 0.9)
>>> m = np.linspace(0.2, 0.9, 51)
>>> mi = m[::5]
>>> np.save(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_npy", {"file": a})
Expand All @@ -113,11 +113,16 @@ class InterpLinearNpy(InterpolationParticle):
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.load(self.input_file)
self.data = self.get_data(**kwargs)
points = self.data[:, 0]
kwargs["points"] = points
super().__init__(*args, **kwargs)
kwargs["points"] = points.tolist()
super(InterpLinearNpy, self).__init__(*args, **kwargs)
self.delta = np.concatenate([[1e20], points[1:] - points[:-1], [1e20]])
self.x_left = np.concatenate([[points[-1] - 1], points])

def get_data(self, **kwargs):
self.input_file = kwargs.get("file")
return np.load(self.input_file)

def init_params(self):
pass
Expand All @@ -129,20 +134,19 @@ def get_point_values(self):

def interp(self, m):
x, p_r, p_i = self.get_point_values()
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x)
bin_idx = (bin_idx) % (len(self.bound) + 1)
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=self.points)
ret_r_r = tf.gather(p_r[1:], bin_idx)
ret_i_r = tf.gather(p_i[1:], bin_idx)
ret_r_l = tf.gather(p_r[:-1], bin_idx)
ret_i_l = tf.gather(p_i[:-1], bin_idx)
delta = np.concatenate([[1e20], x[1:] - x[:-1], [1e20]])
x_left = np.concatenate([[x[0] - 1], x])
delta = tf.gather(delta, bin_idx)
x_left = tf.gather(x_left, bin_idx)
delta = tf.gather(self.delta, bin_idx)
x_left = tf.gather(self.x_left, bin_idx)
step = (m - x_left) / delta
a = step * (ret_r_r - ret_r_l) + ret_r_l
b = step * (ret_i_r - ret_i_l) + ret_i_l
return tf.complex(a, b)
ret = tf.complex(a, b)
cut = (bin_idx <= 0) | (bin_idx >= self.delta.shape[-1] - 1)
return tf.where(cut, tf.zeros_like(ret), ret)


@register_particle("linear_txt")
Expand All @@ -158,20 +162,17 @@ class InterpLinearTxt(InterpLinearNpy):
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".txt")
>>> m = np.linspace(0.2, 0.9)
>>> m = np.linspace(0.2, 0.9, 51)
>>> mi = m[::5]
>>> np.savetxt(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_txt", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")
"""

def __init__(self, *args, **kwargs):
def get_data(self, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.loadtxt(self.input_file)
points = self.data[:, 0]
kwargs["points"] = points
super(InterpLinearNpy, self).__init__(*args, **kwargs)
return np.loadtxt(self.input_file)


@register_particle("interp")
Expand Down
6 changes: 5 additions & 1 deletion tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,13 @@ def set_params(self, params, neglect_params=None):
if neglect_params is None:
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
# warnings.warn("Neglect {} when setting params.".format(neglect_params))
for v in params:
if v in self._neglect_when_set_params:
warnings.warn(
"Neglect {} when setting params.".format(
neglect_params
)
)
del ret[v]
amplitude.set_params(ret)
return True
Expand Down
27 changes: 1 addition & 26 deletions tf_pwa/config_loader/multi_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,32 +227,7 @@ def get_params(self, trainable_only=False):

def set_params(self, params, neglect_params=None):
_amps = self.get_amplitudes()
if isinstance(params, str):
if params == "":
return False
try:
with open(params) as f:
params = yaml.safe_load(f)
except Exception as e:
print(e)
return False
if hasattr(params, "params"):
params = params.params
if isinstance(params, dict):
if "value" in params:
params = params["value"]
ret = params.copy()
if neglect_params is None:
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
warnings.warn(
"Neglect {} when setting params.".format(neglect_params)
)
for v in params:
if v in self._neglect_when_set_params:
del ret[v]
self.vm.set_all(ret)
return True
self.configs[0].set_params(params, neglect_params=neglect_params)

@contextlib.contextmanager
def params_trans(self):
Expand Down
39 changes: 29 additions & 10 deletions tf_pwa/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,17 @@ def draw(self, ax=plt, **kwargs):
ret = []
for i in draw_type.split("+"):
ret.append(self.draw(ax=ax, type=i, **kwargs))
elif draw_type == "hist":
ret = self.draw_hist(ax=ax, **kwargs)
elif draw_type == "bar":
ret = self.draw_bar(ax=ax, **kwargs)
elif draw_type == "kde":
ret = self.draw_kde(ax=ax, **kwargs)
elif draw_type == "error":
ret = self.draw_error(ax=ax, **kwargs)
elif draw_type == "line":
ret = self.draw_line(ax=ax, **kwargs)
elif draw_type in [
"hist",
"bar",
"kde",
"error",
"line",
"fill",
"stepfill",
]:
draw_fun = getattr(self, "draw_" + draw_type)
ret = draw_fun(ax=ax, **kwargs)
else:
raise NotImplementedError()
return ret
Expand Down Expand Up @@ -130,6 +131,24 @@ def draw_kde(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
else:
return ax.plot(x, kde(x), color=color, **kwargs)

def draw_fill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
color = kwargs.pop("color", self._cached_color)
m = self.bin_center
bw = self.bin_width * bin_scale
kde = weighted_kde(m, self.count, bw, kind)
x = np.linspace(
self.binning[0], self.binning[-1], self.count.shape[0] * 10
)
return ax.fill_between(
x, kde(x), np.zeros_like(x), color=color, **kwargs
)

def draw_stepfill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
color = kwargs.pop("color", self._cached_color)
x = np.repeat(self.binning, 2)
y = np.concatenate([[0], np.repeat(self.count, 2), [0]])
return ax.fill_between(x, y, np.zeros_like(x), color=color, **kwargs)

def draw_pull(self, ax=plt, **kwargs):
with np.errstate(divide="ignore", invalid="ignore"):
y_error = np.where(self.error == 0, 0, self.count / self.error)
Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_hist1d():
hist.draw(type="line+bar")
hist.draw_kde(ax, kind="gauss", color="blue")
hist.draw_kde(ax, kind="cauchy", color="red")
(hist * 0.5).draw(ax, type="fill", facecolor="none", hatch="///")
(hist * 0.4).draw_stepfill(ax, facecolor="none", hatch="\\")
(0.1 * hist + hist * 0.1).draw_bar(ax)
hist.draw_error(ax)
hist2 = Hist1D.histogram(
Expand Down

0 comments on commit 8225a1b

Please sign in to comment.