Skip to content

Commit

Permalink
feat: stepfill for Hist1D
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Feb 17, 2024
1 parent 685ffa6 commit 56ceea2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
29 changes: 17 additions & 12 deletions tf_pwa/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,17 @@ def draw(self, ax=plt, **kwargs):
ret = []
for i in draw_type.split("+"):
ret.append(self.draw(ax=ax, type=i, **kwargs))
elif draw_type == "hist":
ret = self.draw_hist(ax=ax, **kwargs)
elif draw_type == "bar":
ret = self.draw_bar(ax=ax, **kwargs)
elif draw_type == "kde":
ret = self.draw_kde(ax=ax, **kwargs)
elif draw_type == "error":
ret = self.draw_error(ax=ax, **kwargs)
elif draw_type == "line":
ret = self.draw_line(ax=ax, **kwargs)
elif draw_type == "fill":
ret = self.draw_fill(ax=ax, **kwargs)
elif draw_type in [
"hist",
"bar",
"kde",
"error",
"line",
"fill",
"stepfill",
]:
draw_fun = getattr(self, "draw_" + draw_type)
ret = draw_fun(ax=ax, **kwargs)
else:
raise NotImplementedError()
return ret
Expand Down Expand Up @@ -144,6 +143,12 @@ def draw_fill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
x, kde(x), np.zeros_like(x), color=color, **kwargs
)

def draw_stepfill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
color = kwargs.pop("color", self._cached_color)
x = np.repeat(self.binning, 2)
y = np.concatenate([[0], np.repeat(self.count, 2), [0]])
return ax.fill_between(x, y, np.zeros_like(x), color=color, **kwargs)

def draw_pull(self, ax=plt, **kwargs):
with np.errstate(divide="ignore", invalid="ignore"):
y_error = np.where(self.error == 0, 0, self.count / self.error)
Expand Down
1 change: 1 addition & 0 deletions tf_pwa/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_hist1d():
hist.draw_kde(ax, kind="gauss", color="blue")
hist.draw_kde(ax, kind="cauchy", color="red")
(hist * 0.5).draw(ax, type="fill", facecolor="none", hatch="///")
(hist * 0.4).draw_stepfill(ax, facecolor="none", hatch="\\")
(0.1 * hist + hist * 0.1).draw_bar(ax)
hist.draw_error(ax)
hist2 = Hist1D.histogram(
Expand Down

0 comments on commit 56ceea2

Please sign in to comment.