From 44f44374baf00d7af1e228ece517f938f7493630 Mon Sep 17 00:00:00 2001 From: Ajit Johnson Nirmal Date: Tue, 19 Nov 2024 21:53:24 -0500 Subject: [PATCH] adding new function napariGater --- docs/Functions/pl/napariGater.md | 5 + docs/Tools Shortcut.md | 1 + pyproject.toml | 2 +- scimap/plotting/__init__.py | 3 +- scimap/plotting/gate_finder.py | 9 + scimap/plotting/napariGater.py | 435 ++++++++++++++++++ scimap/preprocessing/rescale.py | 19 +- .../tests/scimapExampleData/manual_gates.csv | 0 .../scimapExampleData/phenotype_workflow.csv | 0 9 files changed, 468 insertions(+), 6 deletions(-) create mode 100644 docs/Functions/pl/napariGater.md create mode 100644 scimap/plotting/napariGater.py mode change 100644 => 100755 scimap/tests/scimapExampleData/manual_gates.csv mode change 100644 => 100755 scimap/tests/scimapExampleData/phenotype_workflow.csv diff --git a/docs/Functions/pl/napariGater.md b/docs/Functions/pl/napariGater.md new file mode 100644 index 00000000..d308397f --- /dev/null +++ b/docs/Functions/pl/napariGater.md @@ -0,0 +1,5 @@ +--- +hide: + - toc # Hide table of contents +--- +::: scimap.plotting.napariGater \ No newline at end of file diff --git a/docs/Tools Shortcut.md b/docs/Tools Shortcut.md index c4f016ce..f776b69b 100644 --- a/docs/Tools Shortcut.md +++ b/docs/Tools Shortcut.md @@ -51,6 +51,7 @@ import scimap as sm | [`sm.pl.image_viewer`](Functions/pl/image_viewer.md) | Integrates with `napari` to offer an interactive platform for enhanced image viewing and annotation with data overlays. | | [`sm.pl.addROI_image`](Functions/pl/addROI_image.md) | Facilitates the addition of Regions of Interest (ROIs) through `napari`, enriching spatial analyses with precise locational data. | | [`sm.pl.gate_finder`](Functions/pl/gate_finder.md) | Aids in the manual gating process by overlaying marker positivity on images, simplifying the identification and analysis of cellular subsets. | +| [`sm.pl.napariGater`](Functions/pl/napariGater.md) | Modified version of gate_finder and soon to replace it. | | [`sm.pl.heatmap`](Functions/pl/heatmap.md) | Creates heatmaps to visually explore marker expression or feature distributions across different groups. | | [`sm.pl.markerCorrelation`](Functions/pl/markerCorrelation.md) | Computes and visualizes the correlation among selected markers. | | [`sm.pl.groupCorrelation`](Functions/pl/groupCorrelation.md) | Calculates and displays the correlation between the abundances of groups across user defined conditions. | diff --git a/pyproject.toml b/pyproject.toml index 570dae2f..e17ee0dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "SCIMAP" -version = "2.1.3" +version = "2.2.0" description = "Spatial Single-Cell Analysis Toolkit" license = "MIT" diff --git a/scimap/plotting/__init__.py b/scimap/plotting/__init__.py index b1aa35a3..e05b77eb 100644 --- a/scimap/plotting/__init__.py +++ b/scimap/plotting/__init__.py @@ -16,4 +16,5 @@ from .heatmap import heatmap from .markerCorrelation import markerCorrelation from .groupCorrelation import groupCorrelation -from .spatialInteractionNetwork import spatialInteractionNetwork \ No newline at end of file +from .spatialInteractionNetwork import spatialInteractionNetwork +from .napariGater import napariGater \ No newline at end of file diff --git a/scimap/plotting/gate_finder.py b/scimap/plotting/gate_finder.py index 6665437c..9f45e56f 100644 --- a/scimap/plotting/gate_finder.py +++ b/scimap/plotting/gate_finder.py @@ -13,6 +13,15 @@ ## Function """ +import warnings + +warnings.warn( + "gate_finder() is deprecated and will be removed in a future version. " + "Please use sm.pl.napariGater() instead.", + FutureWarning, + stacklevel=2, +) + try: import napari except: diff --git a/scimap/plotting/napariGater.py b/scimap/plotting/napariGater.py new file mode 100644 index 00000000..9abd0357 --- /dev/null +++ b/scimap/plotting/napariGater.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Created on Tue May 12 23:47:52 2020 +# @author: Ajit Johnson Nirmal +""" +!!! abstract "Short Description" + `sm.pl.gate_finder`: This function leverages Napari to display OME-TIFF images, + overlaying points that assist in manually determining gating thresholds for specific markers. + By visualizing marker expression spatially, users can more accurately define gates. + Subsequently, the identified gating parameters can be applied to the dataset using `sm.pp.rescale`, + enabling precise control over data segmentation and analysis based on marker expression levels. + +## Function +""" + +try: + import napari +except: + pass + +import pandas as pd +import tifffile as tiff +import numpy as np + +import dask.array as da +from dask.cache import Cache +import zarr +import os +from tqdm.auto import tqdm + +# cache = Cache(2e9) # Leverage two gigabytes of memory +# cache.register() + + +def get_marker_data(marker, adata, layer, log, verbose=False): + # """Helper function to get consistent data for a marker""" + if layer == 'raw': + data = pd.DataFrame( + adata.raw.X, index=adata.obs.index, columns=adata.var.index + )[[marker]] + if verbose: + print( + f"Raw data range - min: {float(data.min().iloc[0])}, max: {float(data.max().iloc[0])}" + ) + else: + data = pd.DataFrame( + adata.layers[layer], index=adata.obs.index, columns=adata.var.index + )[[marker]] + if verbose: + print( + f"Layer data range - min: {float(data.min().iloc[0])}, max: {float(data.max().iloc[0])}" + ) + + if log: + data = np.log1p(data) + if verbose: + print( + f"After log transform - min: {float(data.min().iloc[0])}, max: {float(data.max().iloc[0])}" + ) + + return data + + +def initialize_gates(adata, imageid): + # """Initialize gates DataFrame if it doesn't exist""" + from sklearn.mixture import GaussianMixture + from sklearn.preprocessing import StandardScaler + + # Create gates DataFrame if it doesn't exist + if 'gates' not in adata.uns: + print("Initializing gates with GMM...") + adata.uns['gates'] = pd.DataFrame( + index=adata.var.index, columns=adata.obs[imageid].unique(), dtype=float + ) + adata.uns['gates'].iloc[:, :] = np.nan + + # Run GMM for each marker + markers = list(adata.var.index) # Convert to list for tqdm + for marker in tqdm(markers, desc="Computing gates", ncols=80): + # Get log-transformed data + data = get_marker_data(marker, adata, 'raw', log=True, verbose=False) + + # Preprocess for GMM + values = data.values.flatten() + values = values[~np.isnan(values)] + + # Cut outliers + p01, p99 = np.percentile(values, [0.1, 99.9]) + values = values[(values >= p01) & (values <= p99)] + + # Scale data + scaler = StandardScaler() + values_scaled = scaler.fit_transform(values.reshape(-1, 1)) + + # Fit GMM + gmm = GaussianMixture(n_components=3, random_state=42) + gmm.fit(values_scaled) + + # Sort components by their means + means = scaler.inverse_transform(gmm.means_) + sorted_idx = np.argsort(means.flatten()) + sorted_means = means[sorted_idx] + + # Calculate gate as midpoint between middle and high components + gate_value = np.mean([sorted_means[1], sorted_means[2]]) + + # Ensure gate value is within data range + min_val = float(data.min().iloc[0]) + max_val = float(data.max().iloc[0]) + gate_value = np.clip(gate_value, min_val, max_val) + + # Store gate value for all images + adata.uns['gates'].loc[marker, :] = gate_value + + return adata + + +def calculate_auto_contrast(img, percentile_low=1, percentile_high=99, padding=0.1): + # """Calculate contrast limits using histogram analysis with padding""" + # If image is dask or zarr array, compute on smallest pyramid if available + if isinstance(img, (da.Array, zarr.Array)): + # Get smallest pyramid level if available + if hasattr(img, 'shape') and len(img.shape) > 2: + img = img[-1] # Use smallest pyramid level + # Compute statistics on a subset of data + sample = img[::10, ::10] # Sample every 10th pixel + if hasattr(sample, 'compute'): + sample = sample.compute() + else: + sample = img + + # Calculate percentiles for contrast + low = np.percentile(sample, percentile_low) + high = np.percentile(sample, percentile_high) + + # Add padding + range_val = high - low + low = max(0, low - (range_val * padding)) # Ensure we don't go below 0 + high = high + (range_val * padding) + + return low, high + + +def initialize_contrast_settings( + adata, img, channel_names, imageid='imageid', subset=None +): + """Initialize contrast settings if they don't exist""" + if 'image_contrast_settings' not in adata.uns: + print("Initializing contrast settings...") + adata.uns['image_contrast_settings'] = {} + + # Get current image ID + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + + # Initialize settings for current image if not already present + if current_image not in adata.uns['image_contrast_settings']: + contrast_settings = {} + for i, channel in enumerate(channel_names): + if isinstance(img, (da.Array, zarr.Array)): + channel_img = img[i] + else: + channel_img = img[i] if len(img.shape) > 2 else img + + low, high = calculate_auto_contrast(channel_img) + contrast_settings[channel] = { + 'low': float(low), + 'high': float(high), + } + + adata.uns['image_contrast_settings'][current_image] = contrast_settings + + return adata + + +def napariGater( + image_path, + adata, + layer='raw', + log=True, + x_coordinate='X_centroid', + y_coordinate='Y_centroid', + imageid='imageid', + subset=None, + flip_y=True, + channel_names='default', + point_size=10, +): + """ + !!! abstract "Short Description" + `sm.pl.napariGater`: This function provides an interactive interface using Napari to visualize and set + gating thresholds for markers in imaging data. It features automatic gate initialization using Gaussian + Mixture Models (GMM), real-time visualization of gated cells, and the ability to save gate settings. + The function supports both single and multi-channel images, with automatic contrast adjustment and + customizable display options. + + Parameters: + image_path (str): + Path to the high-resolution image file (supports formats like TIFF, OME.TIFF, ZARR). + + adata (anndata.AnnData): + The annotated data matrix. + + layer (str, optional): + Specifies the layer in `adata` containing expression data. Defaults to 'raw'. + + log (bool, optional): + Applies log transformation to expression data if True. Defaults to True. + + x_coordinate, y_coordinate (str, optional): + Columns in `adata.obs` specifying cell coordinates. Defaults are 'X_centroid' and 'Y_centroid'. + + imageid (str, optional): + Column in `adata.obs` identifying images for datasets with multiple images. Defaults to 'imageid'. + + subset (str, optional): + Specific image identifier for targeted analysis. Defaults to None. + + flip_y (bool, optional): + Inverts the Y-axis to match image coordinates if True. Defaults to True. + + channel_names (list or str, optional): + Names of the channels in the image. Defaults to 'default', using `adata.uns['all_markers']`. + + point_size (int, optional): + Size of points in the visualization. Defaults to 10. + + Returns: + None: + Updates `adata.uns['gates']` with the gating thresholds. + + Example: + ```python + # Basic usage with default mcmicro parameters + sm.pl.napariGater( + image_path='path/to/image.ome.tif', + adata=adata + ) + + # Custom settings with specific channels and coordinate columns + sm.pl.napariGater( + image_path='path/to/image.ome.tif', + adata=adata, + x_coordinate='X_position', + y_coordinate='Y_position', + channel_names=['DAPI', 'CD45', 'CD3', 'CD8'], + point_size=15 + ) + + # Working with specific image from a multi-image dataset + sm.pl.napariGater( + image_path='path/to/image.ome.tif', + adata=adata, + subset='sample1', + imageid='imageid' + ) + ``` + """ + import napari + from magicgui import magicgui + import time + + start_time = time.time() + + # Initialize gates with GMM if needed + adata = initialize_gates(adata, imageid) + + print(f"Opening napari viewer...") + + # Recover the channel names from adata + if channel_names == 'default': + channel_names = adata.uns['all_markers'] + else: + channel_names = channel_names + + # Load the image + if isinstance(image_path, str): + if image_path.endswith(('.tiff', '.tif')): + img = tiff.imread(image_path) + multiscale = False + elif image_path.endswith(('.zarr', '.zr')): + img = zarr.open(image_path, mode='r') + multiscale = True + else: + img = image_path + multiscale = False + + # Initialize contrast settings if needed + adata = initialize_contrast_settings(adata, img, channel_names) + + # Create the viewer and add the image + viewer = napari.Viewer() + + # Define a list of colormaps to cycle through + colormaps = ['magenta', 'cyan', 'yellow', 'red', 'green', 'blue'] + + # Add each channel as a separate layer with saved contrast settings + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + for i, channel_name in enumerate(channel_names): + contrast_limits = ( + adata.uns['image_contrast_settings'][current_image][channel_name]['low'], + adata.uns['image_contrast_settings'][current_image][channel_name]['high'], + ) + + viewer.add_image( + img[i], + name=channel_name, + visible=False, + colormap=colormaps[i % len(colormaps)], + blending='additive', + contrast_limits=contrast_limits, + ) + + # Create points layer + points_layer = viewer.add_points( + np.zeros((0, 2)), + size=point_size, + face_color='white', + name='gated_points', + visible=True, + ) + + # Create initial marker data before creating GUI + initial_marker = list(adata.var.index)[0] + initial_data = get_marker_data(initial_marker, adata, 'raw', log, verbose=False) + + # Calculate initial min/max from expression values + marker_data = pd.DataFrame(adata.raw.X, columns=adata.var.index)[initial_marker] + if log: + marker_data = np.log1p(marker_data) + min_val = float(marker_data.min()) + max_val = float(marker_data.max()) + + # Get initial gate value + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + initial_gate = adata.uns['gates'].loc[initial_marker, current_image] + if pd.isna(initial_gate) or initial_gate < min_val or initial_gate > max_val: + initial_gate = min_val + + @magicgui( + auto_call=True, + marker={'choices': list(adata.var.index), 'value': initial_marker}, + gate={ + 'widget_type': 'FloatSpinBox', + 'min': min_val, + 'max': max_val, + 'value': initial_gate, + 'step': 0.1, + }, + confirm_gate={'widget_type': 'PushButton', 'text': 'Confirm Gate'}, + finish={'widget_type': 'PushButton', 'text': 'Finish Gating'}, + ) + def gate_controls( + marker: str, + gate: float = initial_gate, + confirm_gate=False, + finish=False, + ): + # Get data using helper function + data = get_marker_data(marker, adata, layer, log) + + # Apply gate + mask = data.values >= gate + cells = data.index[mask.flatten()] + + # Update points + coordinates = adata[cells] + if flip_y: + coordinates = pd.DataFrame( + {'y': coordinates.obs[y_coordinate], 'x': coordinates.obs[x_coordinate]} + ) + else: + coordinates = pd.DataFrame( + {'x': coordinates.obs[x_coordinate], 'y': coordinates.obs[y_coordinate]} + ) + points_layer.data = coordinates.values + + # Add a separate handler for marker changes + @gate_controls.marker.changed.connect + def _on_marker_change(marker: str): + # Calculate min/max from expression values + marker_data = pd.DataFrame(adata.raw.X, columns=adata.var.index)[marker] + if log: + marker_data = np.log1p(marker_data) + min_val = float(marker_data.min()) + max_val = float(marker_data.max()) + + # Get existing gate value + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + existing_gate = adata.uns['gates'].loc[marker, current_image] + if pd.isna(existing_gate) or existing_gate < min_val or existing_gate > max_val: + value = min_val + else: + value = existing_gate + + # Update the spinbox properties + gate_controls.gate.min = min_val + gate_controls.gate.max = max_val + gate_controls.gate.value = value + + # Update layer visibility + for layer in viewer.layers: + if isinstance(layer, napari.layers.Image): + if layer.name == marker: + layer.visible = True + else: + layer.visible = False + + # Force viewer update + viewer.reset_view() + + @gate_controls.confirm_gate.clicked.connect + def _on_confirm(): + marker = gate_controls.marker.value + gate = gate_controls.gate.value + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + adata.uns['gates'].loc[marker, current_image] = float(gate) + # print(f"Gate for {marker} in image {current_image} set to {gate}") + + # Add handler for finish button + @gate_controls.finish.clicked.connect + def _on_finish(): + viewer.close() + + # Initialize with empty points + points_layer.data = np.zeros((0, 2)) + + # Add the GUI to the viewer + viewer.window.add_dock_widget(gate_controls) + + # Start the viewer + napari.run() + + print(f"Napari viewer initialized in {time.time() - start_time:.2f} seconds") + + # return adata diff --git a/scimap/preprocessing/rescale.py b/scimap/preprocessing/rescale.py index f968dc82..bbe683b7 100644 --- a/scimap/preprocessing/rescale.py +++ b/scimap/preprocessing/rescale.py @@ -180,11 +180,22 @@ def clipping (x): # Find GMM based gates def gmm_gating (marker, data): - if verbose: - print('Finding the optimal gate by GMM for ' + str(marker)) + """Internal function to identify gates using GMM""" + # Prepare data for GMM data_gm = data[marker].values.reshape(-1, 1) - gmm = GaussianMixture(n_components=2, random_state=random_state).fit(data_gm) - gate = np.mean(gmm.means_) + data_gm = data_gm[~np.isnan(data_gm)] + + # Fit GMM with 3 components + gmm = GaussianMixture(n_components=3, random_state=random_state).fit(data_gm) + + # Sort components by their means + means = gmm.means_.flatten() + sorted_idx = np.argsort(means) + sorted_means = means[sorted_idx] + + # Calculate gate as midpoint between middle and high components + gate = np.mean([sorted_means[1], sorted_means[2]]) + return gate # Running gmm_gating on the dataset diff --git a/scimap/tests/scimapExampleData/manual_gates.csv b/scimap/tests/scimapExampleData/manual_gates.csv old mode 100644 new mode 100755 diff --git a/scimap/tests/scimapExampleData/phenotype_workflow.csv b/scimap/tests/scimapExampleData/phenotype_workflow.csv old mode 100644 new mode 100755