Skip to content

Commit

Permalink
Update plot_tsi
Browse files Browse the repository at this point in the history
* Update default values of `x_offset` and `y_offset` and update type
hints.
* Rename argument `fname` to `save`.
* Add arguments `figsize` and `dpi`.
* Update docstrings.
* Update figure setup.
* Fix x-axis labels caused by latest changes to `get_tsi`.
* Use `save_fig` function.
* Return matplotlib `Figure` and `Axis` instances.
  • Loading branch information
WeilerP committed Feb 29, 2024
1 parent 56f03e6 commit 5f46c8c
Showing 1 changed file with 20 additions and 29 deletions.
49 changes: 20 additions & 29 deletions src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,39 +588,35 @@ def get_tsi_score(self) -> float:

return self.tsi["Identified terminal states"].sum() / optimal_score

@d.dedent
def plot_tsi(
self,
tsi_df: pd.DataFrame,
fname: Optional[Path] = None,
x_offset: Optional[Tuple[float, float]] = None,
y_offset: Optional[Tuple[float, float]] = None,
x_offset: Tuple[float, float] = (0.2, 0.2),
y_offset: Tuple[float, float] = (0.1, 0.1),
figsize: Tuple[float, float] = (6, 4),
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
**kwargs: Any,
):
) -> Tuple[plt.Figure, Axes]:
"""Plot terminal state identificiation (TSI).
Parameters
----------
tsi_df
Pre-computed TSI DataFrame.
fname
File name under which the plot is saved. The plot is not saved if the argument is not specified.
Pre-computed TSI DataFrame with :meth:`get_tsi_score`.
x_offset
Offset of x-axis. Defaults to `[0.2, 0.2]` if not specified.
Offset of x-axis.
y_offset
Offset of y-axis. Defaults to `[0.1, 0.1]` if not specified.
Offset of y-axis.
%(plotting)s
kwargs
Keyword arguments for :meth:`~seaborn.lineplot`.
Returns
-------
Returns TSI as a Pandas DataFrame and adds the class attribute :attr:`tsi`.
Plot TSI of the kernel and an optimal identification strategy.
"""
if x_offset is None:
x_offset = [0.2, 0.2]

if y_offset is None:
y_offset = [0.1, 0.1]

optimal_identification = tsi_df[["Number of macrostates", "Optimal identification"]]
optimal_identification = optimal_identification.rename(
columns={"Optimal identification": "Identified terminal states"}
Expand All @@ -634,7 +630,7 @@ def plot_tsi(

df = pd.concat([df, optimal_identification])

fig, ax = plt.subplots(figsize=(6, 4))
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
sns.lineplot(
data=df,
x="Number of macrostates",
Expand All @@ -647,7 +643,8 @@ def plot_tsi(
)

ax.set_xticks(df["Number of macrostates"].unique().astype(int))
for label_id, label in enumerate(ax.xaxis.get_ticklabels()):
# Plot is generated from large to small values on the x-axis
for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]):
if ((label_id + 1) % 5 != 0) and label_id != 0:
label.set_visible(False)
ax.set_yticks(df["Identified terminal states"].unique())
Expand All @@ -666,17 +663,11 @@ def plot_tsi(
handles = handles[: (n_methods + 1)]
labels = labels[: (n_methods + 1)]
fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1))
plt.tight_layout()
plt.show()

if fname is not None:
format = fname.suffix[1:]
fig.savefig(
fname=fname,
format=format,
transparent=True,
bbox_inches="tight",
)

if save is not None:
save_fig(fig=fig, path=save)

return fig, ax

@d.dedent
def fit(
Expand Down

0 comments on commit 5f46c8c

Please sign in to comment.