Skip to content

Commit

Permalink
added zarr capability
Browse files Browse the repository at this point in the history
  • Loading branch information
ajitjohnson committed Nov 20, 2024
1 parent 98d7ceb commit 4d277b7
Showing 1 changed file with 131 additions and 61 deletions.
192 changes: 131 additions & 61 deletions scimap/plotting/napariGater.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_marker_data(marker, adata, layer, log, verbose=False):


def initialize_gates(adata, imageid):
# """Initialize gates DataFrame if it doesn't exist"""
"""Initialize gates DataFrame if it doesn't exist"""
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler

Expand All @@ -74,43 +74,45 @@ def initialize_gates(adata, imageid):
)
adata.uns['gates'].iloc[:, :] = np.nan

# Run GMM for each marker
markers = list(adata.var.index) # Convert to list for tqdm
for marker in tqdm(markers, desc="Computing gates", ncols=80):
# Get log-transformed data
data = get_marker_data(marker, adata, 'raw', log=True, verbose=False)
# Convert to list and use tqdm properly
markers = list(adata.var.index)
with tqdm(total=len(markers), desc="Computing gates", leave=False) as pbar:
for marker in markers:
# Get log-transformed data
data = get_marker_data(marker, adata, 'raw', log=True, verbose=False)

# Preprocess for GMM
values = data.values.flatten()
values = values[~np.isnan(values)]
# Preprocess for GMM
values = data.values.flatten()
values = values[~np.isnan(values)]

# Cut outliers
p01, p99 = np.percentile(values, [0.1, 99.9])
values = values[(values >= p01) & (values <= p99)]
# Cut outliers
p01, p99 = np.percentile(values, [0.1, 99.9])
values = values[(values >= p01) & (values <= p99)]

# Scale data
scaler = StandardScaler()
values_scaled = scaler.fit_transform(values.reshape(-1, 1))
# Scale data
scaler = StandardScaler()
values_scaled = scaler.fit_transform(values.reshape(-1, 1))

# Fit GMM
gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(values_scaled)
# Fit GMM
gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(values_scaled)

# Sort components by their means
means = scaler.inverse_transform(gmm.means_)
sorted_idx = np.argsort(means.flatten())
sorted_means = means[sorted_idx]
# Sort components by their means
means = scaler.inverse_transform(gmm.means_)
sorted_idx = np.argsort(means.flatten())
sorted_means = means[sorted_idx]

# Calculate gate as midpoint between middle and high components
gate_value = np.mean([sorted_means[1], sorted_means[2]])
# Calculate gate as midpoint between middle and high components
gate_value = np.mean([sorted_means[1], sorted_means[2]])

# Ensure gate value is within data range
min_val = float(data.min().iloc[0])
max_val = float(data.max().iloc[0])
gate_value = np.clip(gate_value, min_val, max_val)
# Ensure gate value is within data range
min_val = float(data.min().iloc[0])
max_val = float(data.max().iloc[0])
gate_value = np.clip(gate_value, min_val, max_val)

# Store gate value for all images
adata.uns['gates'].loc[marker, :] = gate_value
# Store gate value for all images
adata.uns['gates'].loc[marker, :] = gate_value
pbar.update(1)

return adata

Expand Down Expand Up @@ -144,28 +146,34 @@ def calculate_auto_contrast(img, percentile_low=1, percentile_high=99, padding=0
def initialize_contrast_settings(
adata, img, channel_names, imageid='imageid', subset=None
):
# """Initialize contrast settings if they don't exist"""
"""Initialize contrast settings if they don't exist"""
if 'image_contrast_settings' not in adata.uns:
print("Initializing contrast settings...")
adata.uns['image_contrast_settings'] = {}

# Get current image ID
current_image = adata.obs[imageid].iloc[0] if subset is None else subset

# Initialize settings for current image if not already present
if current_image not in adata.uns['image_contrast_settings']:
tiff_file = img._store._source
contrast_settings = {}
for i, channel in enumerate(channel_names):
if isinstance(img, (da.Array, zarr.Array)):
channel_img = img[i]
else:
channel_img = img[i] if len(img.shape) > 2 else img

low, high = calculate_auto_contrast(channel_img)
contrast_settings[channel] = {
'low': float(low),
'high': float(high),
}

# Use tqdm for contrast calculation progress
for i, channel in enumerate(
tqdm(channel_names, desc="Calculating contrast", leave=False)
):
try:
channel_data = tiff_file.series[0].pages[i].asarray()
low, high = calculate_auto_contrast(channel_data)
contrast_settings[channel] = {
'low': float(low),
'high': float(high),
}
except Exception as e:
# Set default contrast values if calculation fails
contrast_settings[channel] = {
'low': 0.0,
'high': 1.0,
}

adata.uns['image_contrast_settings'][current_image] = contrast_settings

Expand Down Expand Up @@ -268,17 +276,42 @@ def napariGater(
# Load the image
if isinstance(image_path, str):
if image_path.endswith(('.tiff', '.tif')):
img = tiff.imread(image_path)
multiscale = False
image = tiff.TiffFile(image_path, is_ome=False)
store = image.aszarr()
img = zarr.open(store, mode='r')
# Store the TiffFile object for later use
img._store._source = image

# Get shape from the TiffFile object
shape = image.series[0].shape
ndim = len(shape)
is_multichannel = ndim > 2
num_channels = shape[0] if is_multichannel else 1

print(f"Image shape: {shape}")
print(f"Number of channels: {num_channels}")

elif image_path.endswith(('.zarr', '.zr')):
img = zarr.open(image_path, mode='r')
multiscale = True
shape = img.shape
ndim = len(shape)
is_multichannel = ndim > 2
num_channels = shape[0] if is_multichannel else 1
else:
img = image_path
multiscale = False
shape = img.shape
ndim = len(shape)
is_multichannel = ndim > 2
num_channels = shape[0] if is_multichannel else 1

# Initialize contrast settings if needed
adata = initialize_contrast_settings(adata, img, channel_names)
adata = initialize_contrast_settings(
adata,
img,
channel_names,
imageid=imageid,
subset=subset,
)

# Create the viewer and add the image
viewer = napari.Viewer()
Expand All @@ -288,20 +321,57 @@ def napariGater(

# Add each channel as a separate layer with saved contrast settings
current_image = adata.obs[imageid].iloc[0] if subset is None else subset
for i, channel_name in enumerate(channel_names):
contrast_limits = (
adata.uns['image_contrast_settings'][current_image][channel_name]['low'],
adata.uns['image_contrast_settings'][current_image][channel_name]['high'],
)

viewer.add_image(
img[i],
name=channel_name,
visible=False,
colormap=colormaps[i % len(colormaps)],
blending='additive',
contrast_limits=contrast_limits,
# Get the TiffFile object
tiff_file = img._store._source

# Suppress progress bar output
with tqdm(total=len(channel_names), desc="Loading channels", leave=False) as pbar:
for i, channel_name in enumerate(channel_names):
try:
contrast_limits = (
adata.uns['image_contrast_settings'][current_image][channel_name][
'low'
],
adata.uns['image_contrast_settings'][current_image][channel_name][
'high'
],
)

try:
# Try direct page access first
channel_data = tiff_file.series[0].pages[i].asarray()
except:
# Fallback to zarr array if direct access fails
channel_data = img[i]
if isinstance(channel_data, zarr.core.Array):
channel_data = channel_data[:]

viewer.add_image(
channel_data,
name=channel_name,
visible=False,
colormap=colormaps[i % len(colormaps)],
blending='additive',
contrast_limits=contrast_limits,
)
pbar.update(1)
except Exception as e:
print(f"Failed to load channel {channel_name}: {type(e).__name__}")
pbar.update(1)
continue

# Verify loaded channels
loaded_channels = [
layer.name for layer in viewer.layers if isinstance(layer, napari.layers.Image)
]
if len(loaded_channels) != len(channel_names):
print(
f"\nWarning: Only loaded {len(loaded_channels)}/{len(channel_names)} channels"
)
missing = set(channel_names) - set(loaded_channels)
if missing:
print(f"Missing channels: {', '.join(missing)}")

# Create points layer
points_layer = viewer.add_points(
Expand Down

0 comments on commit 4d277b7

Please sign in to comment.