Skip to content

Commit

Permalink
refactor(plots): extract report subgrid plots
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Sep 4, 2024
1 parent f205304 commit 7b9f22a
Showing 1 changed file with 95 additions and 59 deletions.
154 changes: 95 additions & 59 deletions src/pyrovelocity/plots/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,40 @@
from matplotlib.axes import Axes
from matplotlib.figure import FigureBase
from matplotlib.gridspec import GridSpec
from pandas import DataFrame

from pyrovelocity.analysis.analyze import top_mae_genes
from pyrovelocity.io.compressedpickle import CompressedPickle
from pyrovelocity.plots._genes import plot_gene_ranking
from pyrovelocity.plots._parameters import (
plot_parameter_posterior_distributions,
)
from pyrovelocity.plots._rainbow import rainbowplot_module as rainbowplot
from pyrovelocity.plots._vector_fields import plot_vector_field_summary
from pyrovelocity.styles import configure_matplotlib_style
from pyrovelocity.styles.colors import LARRY_CELL_TYPE_COLORS
from pyrovelocity.utils import load_anndata_from_path

configure_matplotlib_style()

adata = load_anndata_from_path("models/larry_model2/postprocessed.h5ad")
posterior_samples = CompressedPickle.load(
"models/larry_model2/pyrovelocity.pkl.zst"
)
volcano_data: DataFrame = posterior_samples["gene_ranking"]
vector_field_basis = "emb"
cell_state = "state_info"

putative_marker_genes = top_mae_genes(
volcano_data=volcano_data,
mae_top_percentile=3,
min_genes_per_bin=3,
)
rainbow_genes = putative_marker_genes[:6]
rainbow_genes = ["Cyp11a1", "Csf2rb", "Osbpl8", "Lgals1", "Cmtm7", "Runx1"]
putative_marker_genes = list(set(putative_marker_genes + rainbow_genes))


__all__ = [
"plot_main",
"plot_subfigures",
Expand All @@ -34,16 +62,11 @@ def create_main_figure(
)

axes = {}
axes["ax1"] = fig.add_subplot(gs[0, :])
axes["ax2"] = fig.add_subplot(gs[1, 0])
axes["ax3"] = fig.add_subplot(gs[1, 1])
axes["ax4"] = fig.add_subplot(gs[1, 2])
axes["ax5"] = fig.add_subplot(gs[2, :])

for key, ax in axes.items():
ax.set_label(key)

return fig, axes
return fig, axes, gs


def extract_subfigures(
Expand Down Expand Up @@ -95,29 +118,61 @@ def plot_main(
width = 8.5 - 1
height = (11 - 1) * 0.9
layout = {
"height_ratios": [0.1, 0.3, 0.6],
"height_ratios": [0.12, 0.28, 0.6],
"width_ratios": [0.5, 0.25, 0.25],
}

fig, axes = create_main_figure(width, height, layout)
fig, axes, gs = create_main_figure(width, height, layout)
plot_vector_field_summary(
adata=adata,
posterior_samples=posterior_samples,
vector_field_basis="emb",
cell_state="state_info",
state_color_dict=LARRY_CELL_TYPE_COLORS,
fig=fig,
gs=gs[0, :],
default_fontsize=7,
)

plot_wide_row(axes["ax1"])
plot_small_cell(axes["ax2"])
plot_narrow_column(axes["ax3"])
plot_narrow_column(axes["ax4"])
plot_large_cell(axes["ax5"])
plot_gene_ranking(
posterior_samples=posterior_samples,
adata=adata,
fig=fig,
gs=gs[1, 0],
selected_genes=putative_marker_genes,
rainbow_genes=rainbow_genes,
time_correlation_with="st",
show_marginal_histograms=False,
)
plot_parameter_posterior_distributions(
posterior_samples=posterior_samples,
adata=adata,
geneset=rainbow_genes,
fig=fig,
gs=gs[1, 1:],
)
rainbowplot(
volcano_data=volcano_data,
adata=adata,
posterior_samples=posterior_samples,
genes=rainbow_genes,
data=["st", "ut"],
basis=vector_field_basis,
cell_state=cell_state,
fig=fig,
gs=gs[2, :],
)

fig.tight_layout()

x_col1 = -0.015
y_row2 = 0.87
x_col1 = -0.005
y_row2 = 0.84
add_panel_label(fig, "a", x_col1, 1.00)
add_panel_label(fig, "b", x_col1, y_row2)
add_panel_label(fig, "c", 0.45, y_row2)
add_panel_label(fig, "d", 0.72, y_row2)
add_panel_label(fig, "e", x_col1, 0.57)
add_panel_label(fig, "c", 0.47, y_row2)
add_panel_label(fig, "d", x_col1, 0.57)

fig.savefig("example_plot_layout.pdf", format="pdf")
fig.savefig("example_plot_report.pdf", format="pdf")
CompressedPickle.save(figure_file_path, fig)

return fig
Expand All @@ -127,20 +182,30 @@ def plot_subfigures(figure_file_path: Path | str = "main_figure.dill.zst"):
fig = CompressedPickle.load(figure_file_path)
buffer = dill.dumps(fig)

subfig1 = extract_subfigures(buffer=buffer, axes_to_keep=["ax1"])
subfig1.savefig("extracted_ax1.pdf", format="pdf")

subfig23 = extract_subfigures(buffer=buffer, axes_to_keep=["ax2", "ax3"])
subfig23.savefig("extracted_ax2_ax3.pdf", format="pdf")
subfig_parameter_posteriors = extract_subfigures(
buffer=buffer,
axes_to_keep=["vector_field"],
)
subfig_parameter_posteriors.savefig(
"extracted_vector_field_summary.pdf", format="pdf"
)

subfig34 = extract_subfigures(buffer=buffer, axes_to_keep=["ax3", "ax4"])
subfig34.savefig("extracted_ax3_ax4.pdf", format="pdf")
subfig_parameter_posteriors = extract_subfigures(
buffer=buffer,
axes_to_keep=["parameter_posteriors"],
)
subfig_parameter_posteriors.savefig(
"extracted_parameter_posteriors.pdf", format="pdf"
)

subfig15 = extract_subfigures(buffer=buffer, axes_to_keep=["ax1", "ax5"])
subfig15.savefig("extracted_ax1_ax5.pdf", format="pdf")
subfig_gene_selection = extract_subfigures(
buffer=buffer,
axes_to_keep=["gene_selection"],
)
subfig_gene_selection.savefig("extracted_gene_selection.pdf", format="pdf")

subfig5 = extract_subfigures(buffer=buffer, axes_to_keep=["ax5"])
subfig5.savefig("extracted_ax5.pdf", format="pdf")
subfig_rainbow = extract_subfigures(buffer=buffer, axes_to_keep=["rainbow"])
subfig_rainbow.savefig("extracted_rainbow.pdf", format="pdf")


def example_plot_manual():
Expand Down Expand Up @@ -233,35 +298,6 @@ def example_plot_manual():
fig_1.savefig("example_plot_layout_1.pdf", format="pdf")


def plot_wide_row(
ax: Axes,
title: str = "Wide Row",
):
ax.set_title(title)


def plot_small_cell(
ax: Axes,
title: str = "Small Cell",
):
ax.set_title(title)
ax.set_aspect("equal", adjustable="box")


def plot_narrow_column(
ax: Axes,
title: str = "Narrow Column",
):
ax.set_title(title)


def plot_large_cell(
ax: Axes,
title: str = "Large Cell",
):
ax.set_title(title)


def add_panel_label(
fig: FigureBase,
label: str,
Expand Down

0 comments on commit 7b9f22a

Please sign in to comment.