Skip to content

Commit

Permalink
fix(plots): organize reporting function for use in summarization task
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Sep 5, 2024
1 parent 03eea70 commit c5a284d
Showing 1 changed file with 137 additions and 161 deletions.
298 changes: 137 additions & 161 deletions src/pyrovelocity/plots/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import dill
import matplotlib.pyplot as plt
from beartype.typing import Dict, List, Optional, Tuple
from anndata import AnnData
from beartype.typing import Any, Dict, List, Optional, Tuple
from matplotlib.axes import Axes
from matplotlib.figure import FigureBase
from matplotlib.gridspec import GridSpec
from numpy.typing import NDArray
from pandas import DataFrame

from pyrovelocity.analysis.analyze import top_mae_genes
Expand All @@ -22,27 +24,9 @@

configure_matplotlib_style()

adata = load_anndata_from_path("models/larry_model2/postprocessed.h5ad")
posterior_samples = CompressedPickle.load(
"models/larry_model2/pyrovelocity.pkl.zst"
)
volcano_data: DataFrame = posterior_samples["gene_ranking"]
vector_field_basis = "emb"
cell_state = "state_info"

putative_marker_genes = top_mae_genes(
volcano_data=volcano_data,
mae_top_percentile=3,
min_genes_per_bin=3,
)
rainbow_genes = putative_marker_genes[:6]
rainbow_genes = ["Cyp11a1", "Csf2rb", "Osbpl8", "Lgals1", "Cmtm7", "Runx1"]
putative_marker_genes = list(set(putative_marker_genes + rainbow_genes))


__all__ = [
"plot_main",
"plot_subfigures",
"plot_report",
"save_subfigures",
]


Expand All @@ -69,52 +53,71 @@ def create_main_figure(
return fig, axes, gs


def extract_subfigures(
buffer: Optional[bytes] = None,
axes_to_keep: List[str] = [],
main_fig: Optional[FigureBase] = None,
figure_file_path: Optional[str | Path] = None,
def plot_report(
adata: AnnData,
posterior_samples: Dict[str, NDArray[Any]],
volcano_data: DataFrame,
putative_marker_genes: List[str],
rainbow_genes: List[str],
figure_file_path: Path | str = "example_report_figure.dill.zst",
vector_field_basis: str = "emb",
cell_state: str = "state_info",
report_file_path: Path | str = "example_plot_report.pdf",
) -> FigureBase:
"""
Extract a subset of axes from the main figure using dill for serialization.
Plot a report figure with multiple subplots and serialize the Figure object.
Args:
main_fig: The main Figure object
axes_to_keep: List of axes keys to keep in the new figure
figure_file_path (Path | str, optional):
Figure object file. Defaults to "example_report_figure.dill.zst".
adata (AnnData, optional):
AnnData object. Defaults to adata.
posterior_samples (Dict[str, NDArray[Any]], optional):
Posterior samples dictionary. Defaults to posterior_samples.
volcano_data (DataFrame, optional):
Volcano data DataFrame. Defaults to volcano_data.
vector_field_basis (str, optional):
Vector field basis identifier. Defaults to "emb".
cell_state (str, optional):
Cell state identifier. Defaults to "state_info".
putative_marker_genes (List[str], optional):
List of putative marker genes. Defaults to putative_marker_genes.
rainbow_genes (List[str], optional):
List of genes to be included in report figure. Defaults to rainbow_genes.
report_file_path (Path | str, optional):
File to save report figure. Defaults to "example_plot_report.pdf".
Returns:
A new Figure object with only the specified axes
FigureBase: Figure object containing the report figure.
Examples:
>>> # xdoctest: +SKIP
>>> adata = load_anndata_from_path("models/larry_model2/postprocessed.h5ad")
>>> posterior_samples = CompressedPickle.load(
... "models/larry_model2/pyrovelocity.pkl.zst"
... )
>>> volcano_data: DataFrame = posterior_samples["gene_ranking"]
>>> vector_field_basis = "emb"
>>> cell_state = "state_info"
>>> putative_marker_genes = top_mae_genes(
... volcano_data=volcano_data,
... mae_top_percentile=3,
... min_genes_per_bin=3,
... )
>>> rainbow_genes = putative_marker_genes[:6]
>>> rainbow_genes = ["Cyp11a1", "Csf2rb", "Osbpl8", "Lgals1", "Cmtm7", "Runx1"]
>>> putative_marker_genes = list(set(putative_marker_genes + rainbow_genes))
>>> plot_report(
... adata=adata,
... posterior_samples=posterior_samples,
... volcano_data=volcano_data,
... putative_marker_genes=putative_marker_genes,
... rainbow_genes=rainbow_genes,
... vector_field_basis=vector_field_basis,
... cell_state=cell_state,
... report_file_path="example_plot_report.pdf",
... )
"""

if buffer:
subfig = dill.loads(buffer)
elif main_fig:
buffer = dill.dumps(main_fig)
subfig = dill.loads(buffer)
else:
raise ValueError("Either buffer or main_fig must be provided.")

for text in subfig.texts[:]:
subfig.texts.remove(text)

axes_to_remove = [
ax for ax in subfig.axes if ax.get_label() not in axes_to_keep
]

for ax in axes_to_remove:
subfig.delaxes(ax)

if figure_file_path:
with Path(figure_file_path).open("wb") as f:
dill.dump(subfig, f)

return subfig


def plot_main(
figure_file_path: Path | str = "main_figure.dill.zst",
) -> FigureBase:
"""Create an example plot with a custom layout and demonstrate subplot extraction."""
width = 8.5 - 1
height = (11 - 1) * 0.9
layout = {
Expand Down Expand Up @@ -144,13 +147,15 @@ def plot_main(
time_correlation_with="st",
show_marginal_histograms=False,
)

plot_parameter_posterior_distributions(
posterior_samples=posterior_samples,
adata=adata,
geneset=rainbow_genes,
fig=fig,
gs=gs[1, 1:],
)

rainbowplot(
volcano_data=volcano_data,
adata=adata,
Expand All @@ -172,131 +177,102 @@ def plot_main(
add_panel_label(fig, "c", 0.47, y_row2)
add_panel_label(fig, "d", x_col1, 0.57)

fig.savefig("example_plot_report.pdf", format="pdf")
fig.savefig(report_file_path, format="pdf")
CompressedPickle.save(figure_file_path, fig)

return fig


def plot_subfigures(figure_file_path: Path | str = "main_figure.dill.zst"):
def extract_subfigures(
buffer: Optional[bytes] = None,
axes_to_keep: List[str] = [],
main_fig: Optional[FigureBase] = None,
figure_file_path: Optional[str | Path] = None,
) -> FigureBase:
"""
Extract a subset of axes from the main figure using dill for serialization.
Args:
main_fig: The main Figure object
axes_to_keep: List of axes keys to keep in the new figure
Returns:
A new Figure object with only the specified axes
"""

if buffer:
subfig = dill.loads(buffer)
elif main_fig:
buffer = dill.dumps(main_fig)
subfig = dill.loads(buffer)
else:
raise ValueError("Either buffer or main_fig must be provided.")

for text in subfig.texts[:]:
subfig.texts.remove(text)

axes_to_remove = [
ax for ax in subfig.axes if ax.get_label() not in axes_to_keep
]

for ax in axes_to_remove:
subfig.delaxes(ax)

if figure_file_path:
with Path(figure_file_path).open("wb") as f:
dill.dump(subfig, f)

return subfig


def save_subfigures(
figure_file_path: Path | str = "main_figure.dill.zst",
vector_field_summary_file_path: Path
| str = "extracted_vector_field_summary.pdf",
gene_selection_file_path: Path | str = "extracted_gene_selection.pdf",
parameter_posteriors_file_path: Path
| str = "extracted_parameter_posteriors.pdf",
rainbow_file_path: Path | str = "extracted_rainbow.pdf",
):
fig = CompressedPickle.load(figure_file_path)
buffer = dill.dumps(fig)

subfig_parameter_posteriors = extract_subfigures(
subfig_vector_field_summary = extract_subfigures(
buffer=buffer,
axes_to_keep=["vector_field"],
)
subfig_parameter_posteriors.savefig(
"extracted_vector_field_summary.pdf", format="pdf"
subfig_vector_field_summary.savefig(
fname=vector_field_summary_file_path,
format="pdf",
)

subfig_parameter_posteriors = extract_subfigures(
subfig_gene_selection = extract_subfigures(
buffer=buffer,
axes_to_keep=["parameter_posteriors"],
axes_to_keep=["gene_selection"],
)
subfig_parameter_posteriors.savefig(
"extracted_parameter_posteriors.pdf", format="pdf"
subfig_gene_selection.savefig(
fname=gene_selection_file_path,
format="pdf",
)

subfig_gene_selection = extract_subfigures(
subfig_parameter_posteriors = extract_subfigures(
buffer=buffer,
axes_to_keep=["gene_selection"],
axes_to_keep=["parameter_posteriors"],
)
subfig_gene_selection.savefig("extracted_gene_selection.pdf", format="pdf")

subfig_rainbow = extract_subfigures(buffer=buffer, axes_to_keep=["rainbow"])
subfig_rainbow.savefig("extracted_rainbow.pdf", format="pdf")


def example_plot_manual():
"""
Create an example plot with a custom layout.
Each subplot in the gridspec grid may be labeled with a panel label
whose location is given in Figure-level coordinates.
"""
n_rows = 3
n_cols = 3
width = 8.5 - 1
height = (11 - 1) * 0.9
row_1_fraction = 0.1
row_2_fraction = 0.3
row_3_fraction = 0.6

col_1_fraction = 0.5
col_2_fraction = 0.25
col_3_fraction = 0.25

fig = plt.figure(figsize=(width, height))
gs = GridSpec(
figure=fig,
nrows=n_rows,
height_ratios=[
row_1_fraction,
row_2_fraction,
row_3_fraction,
],
ncols=n_cols,
width_ratios=[
col_1_fraction,
col_2_fraction,
col_3_fraction,
],
subfig_parameter_posteriors.savefig(
fname=parameter_posteriors_file_path,
format="pdf",
)

fig_1 = plt.figure(
figsize=(
(col_2_fraction + col_3_fraction) * width,
(row_2_fraction) * height,
)
subfig_rainbow = extract_subfigures(
buffer=buffer,
axes_to_keep=["rainbow"],
)
gs_1 = GridSpec(
figure=fig,
nrows=1,
height_ratios=[row_2_fraction],
ncols=2,
width_ratios=[
col_2_fraction,
col_3_fraction,
],
subfig_rainbow.savefig(
fname=rainbow_file_path,
format="pdf",
)

ax1 = fig.add_subplot(gs[0, :])
plot_wide_row(ax1)
fig2 = plt.figure()
ax = fig2.add_subplot(111)
plot_wide_row(ax)

ax2 = fig.add_subplot(gs[1, 0])
plot_small_cell(ax2)

ax3 = fig.add_subplot(gs[1, 1])
plot_narrow_column(ax3)
ax3_1 = fig_1.add_subplot(gs_1[0, 0])
plot_narrow_column(ax3_1)

ax4 = fig.add_subplot(gs[1, 2])
plot_narrow_column(ax4)
ax4_1 = fig_1.add_subplot(gs_1[0, 1])
plot_narrow_column(ax4_1)

ax5 = fig.add_subplot(gs[2, :])
plot_large_cell(ax5)

fig.tight_layout()
fig_1.tight_layout()

x_col1 = -0.015
y_row2 = 0.87
add_panel_label(fig, "a", x_col1, 1.00)
add_panel_label(fig, "b", x_col1, y_row2)
add_panel_label(fig, "c", 0.45, y_row2)
add_panel_label(fig, "d", 0.72, y_row2)
add_panel_label(fig, "e", x_col1, 0.57)

fig.savefig("example_plot_layout.pdf", format="pdf")
fig_1.savefig("example_plot_layout_1.pdf", format="pdf")


def add_panel_label(
fig: FigureBase,
Expand Down

0 comments on commit c5a284d

Please sign in to comment.