From 76ffc3f85d926aabcfbeab95ca6e78499ee28a1b Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Fri, 13 Sep 2024 16:25:12 -0400 Subject: [PATCH] fix(plots): synchronize vector field color map with rainbowplots Signed-off-by: Cameron Smith --- src/pyrovelocity/plots/_vector_fields.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/pyrovelocity/plots/_vector_fields.py b/src/pyrovelocity/plots/_vector_fields.py index 0de0b7637..1373d5677 100644 --- a/src/pyrovelocity/plots/_vector_fields.py +++ b/src/pyrovelocity/plots/_vector_fields.py @@ -8,7 +8,7 @@ import seaborn as sns from anndata import AnnData from beartype import beartype -from beartype.typing import Dict, List, Optional, Tuple +from beartype.typing import Any, Dict, List, Optional, Tuple from matplotlib.axes import Axes from matplotlib.colors import Normalize from matplotlib.figure import FigureBase @@ -23,7 +23,7 @@ get_posterior_sample_angle_uncertainty, ) from pyrovelocity.styles import configure_matplotlib_style -from pyrovelocity.utils import quartile_coefficient_of_dispersion +from pyrovelocity.utils import quartile_coefficient_of_dispersion, setup_colors __all__ = [ "plot_vector_field_summary", @@ -82,7 +82,7 @@ def plot_vector_field_summary( vector_field_basis: str, plot_name: Optional[PathLike | str] = None, cell_state: str = "cell_type", - state_color_dict: Optional[Dict[str, str]] = None, + state_color_dict: Optional[Dict[str, Any]] = None, fig: Optional[FigureBase] = None, gs: Optional[SubplotSpec] = None, default_fontsize: int = 7 if matplotlib.rcParams["text.usetex"] else 6, @@ -114,6 +114,9 @@ def plot_vector_field_summary( } ) + if state_color_dict is None: + state_color_dict = setup_colors(adata, cell_state) + sns.scatterplot( x="X1", y="X2",