Skip to content

Commit

Permalink
Improved plotting.
Browse files Browse the repository at this point in the history
  • Loading branch information
hpparvi committed Jun 17, 2024
1 parent 68fdde6 commit 1518561
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions spright/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __repr__(self):
s = f" T(m={p[0]:3.2f}, σ={p[1]:3.2f}, λ={p[2]:3.2f})"
return c + s

def _fit_kde(self) -> tuple[ndarray, ndarray]:
return az.kde(self.samples, adaptive=True)
def _fit_kde(self, bw_fct: float = 1) -> tuple[ndarray, ndarray]:
return az.kde(self.samples, adaptive=True, bw_fct=bw_fct)

def _fit_distribution(self, m1: float, m2: Optional[float]) -> tuple[Callable, ndarray, float, Optional[float]]:
if m2 is None:
Expand Down Expand Up @@ -103,10 +103,10 @@ def minfun(pv):
self.model, self.model_pars, self._m1, self._m2 = dmodel, res.x, res.x[1], res.x[4]
return dmodel, res.x, res.x[1], res.x[4]

def plot(self, plot_model: bool = True, plot_modes: bool = True, ax = None):
def plot(self, plot_model: bool = True, plot_modes: bool = True, ax = None, bw_fct: float = 1):
plot_model &= self.model is not None
ps = percentile(self.samples, [50, 16, 84, 2.5, 97.5])
x, y = self._fit_kde()
x, y = self._fit_kde(bw_fct=bw_fct)
il, iu = argmin(abs(x - ps[1])), argmin(abs(x - ps[2]))
ill, iuu = argmin(abs(x - ps[3])), argmin(abs(x - ps[4]))

Expand Down Expand Up @@ -134,4 +134,5 @@ def plot(self, plot_model: bool = True, plot_modes: bool = True, ax = None):

setp(ax, ylabel='Posterior probability', xlabel=xlabel, yticks=[], xlim=percentile(self.samples, [1, 99]))
if fig is not None:
fig.tight_layout()
fig.tight_layout()
return ax
4 changes: 2 additions & 2 deletions spright/relationmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def sample_separated(self, v: tuple[float, float], pvs: ndarray, relation: str,
return vs[m], samples[m]


def plot_map(self, ax=None, cm=None):
def plot_map(self, ax=None, cm=None, norm=None):
"""
Plots the data map.
Expand All @@ -267,7 +267,7 @@ def plot_map(self, ax=None, cm=None):
"""
if ax is None:
fig, ax = subplots()
ax.imshow(self._pmapf.T, origin='lower', aspect='auto', cmap=cm,
ax.imshow(self._pmapc.T, origin='lower', aspect='auto', cmap=cm, norm=norm, interpolation='bicubic',
extent=(self.x[0], self.x[-1], self.y[0], self.y[-1]))
setp(ax, xlabel=f"{self.xname} [{self.xunit}]", ylabel=f"{self.yname} [{self.yunit}]")

Expand Down

0 comments on commit 1518561

Please sign in to comment.