From 5f46c8c76272df9c23e9479d1a30afcbe2613a56 Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Thu, 29 Feb 2024 11:05:47 +0000 Subject: [PATCH] Update `plot_tsi` * Update default values of `x_offset` and `y_offset` and update type hints. * Rename argument `fname` to `save`. * Add arguments `figsize` and `dpi`. * Update docstrings. * Update figure setup. * Fix x-axis labels caused by latest changes to `get_tsi`. * Use `save_fig` function. * Return matplotlib `Figure` and `Axis` instances. --- .../estimators/terminal_states/_gpcca.py | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index 113f1d239..1e81303cc 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -588,39 +588,35 @@ def get_tsi_score(self) -> float: return self.tsi["Identified terminal states"].sum() / optimal_score + @d.dedent def plot_tsi( self, tsi_df: pd.DataFrame, - fname: Optional[Path] = None, - x_offset: Optional[Tuple[float, float]] = None, - y_offset: Optional[Tuple[float, float]] = None, + x_offset: Tuple[float, float] = (0.2, 0.2), + y_offset: Tuple[float, float] = (0.1, 0.1), + figsize: Tuple[float, float] = (6, 4), + dpi: Optional[int] = None, + save: Optional[Union[str, Path]] = None, **kwargs: Any, - ): + ) -> Tuple[plt.Figure, Axes]: """Plot terminal state identificiation (TSI). Parameters ---------- tsi_df - Pre-computed TSI DataFrame. - fname - File name under which the plot is saved. The plot is not saved if the argument is not specified. + Pre-computed TSI DataFrame with :meth:`get_tsi_score`. x_offset - Offset of x-axis. Defaults to `[0.2, 0.2]` if not specified. + Offset of x-axis. y_offset - Offset of y-axis. Defaults to `[0.1, 0.1]` if not specified. + Offset of y-axis. + %(plotting)s kwargs Keyword arguments for :meth:`~seaborn.lineplot`. Returns ------- - Returns TSI as a Pandas DataFrame and adds the class attribute :attr:`tsi`. + Plot TSI of the kernel and an optimal identification strategy. """ - if x_offset is None: - x_offset = [0.2, 0.2] - - if y_offset is None: - y_offset = [0.1, 0.1] - optimal_identification = tsi_df[["Number of macrostates", "Optimal identification"]] optimal_identification = optimal_identification.rename( columns={"Optimal identification": "Identified terminal states"} @@ -634,7 +630,7 @@ def plot_tsi( df = pd.concat([df, optimal_identification]) - fig, ax = plt.subplots(figsize=(6, 4)) + fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True) sns.lineplot( data=df, x="Number of macrostates", @@ -647,7 +643,8 @@ def plot_tsi( ) ax.set_xticks(df["Number of macrostates"].unique().astype(int)) - for label_id, label in enumerate(ax.xaxis.get_ticklabels()): + # Plot is generated from large to small values on the x-axis + for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]): if ((label_id + 1) % 5 != 0) and label_id != 0: label.set_visible(False) ax.set_yticks(df["Identified terminal states"].unique()) @@ -666,17 +663,11 @@ def plot_tsi( handles = handles[: (n_methods + 1)] labels = labels[: (n_methods + 1)] fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1)) - plt.tight_layout() - plt.show() - - if fname is not None: - format = fname.suffix[1:] - fig.savefig( - fname=fname, - format=format, - transparent=True, - bbox_inches="tight", - ) + + if save is not None: + save_fig(fig=fig, path=save) + + return fig, ax @d.dedent def fit(