Skip to content

Commit

Permalink
Use local theads for dask array compute
Browse files Browse the repository at this point in the history
Add a compute_args configuration parameter to the HedwigImages class
to allow configuration of the dask schedualer used for dask array
computation and ZARR operations.

This addresses a deadlock issue when the Dask array is using the
global schedualer that Prefect uses. This avoids workers recusivedly
using the globlal schedualer while blocking a worker.
  • Loading branch information
blowekamp committed Jun 25, 2024
1 parent 89ce9c0 commit d44a27c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 13 deletions.
16 changes: 13 additions & 3 deletions pytools/HedwigZarrImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import SimpleITK as sitk
import zarr
from typing import Tuple, Dict, List, Optional, Iterable
from types import MappingProxyType
from pytools.utils import OMEInfo
import logging
import math
Expand All @@ -33,10 +34,18 @@ class HedwigZarrImage:
Represents a OME-NGFF Zarr pyramidal image. The members provide information useful for the Hedwig imaging pipelines.
"""

def __init__(self, zarr_grp: zarr.Group, _ome_info: OMEInfo, _ome_idx: Optional[int] = None):
def __init__(
self,
zarr_grp: zarr.Group,
_ome_info: OMEInfo,
_ome_idx: Optional[int] = None,
*,
compute_args: Optional[Dict[str, str]] = MappingProxyType({"scheduler": "threads"}),
):
self.zarr_group = zarr_grp
self.ome_info = _ome_info
self.ome_idx = _ome_idx
self.compute_args = compute_args if compute_args is not None else {}

assert "multiscales" in self.zarr_group.attrs

Expand Down Expand Up @@ -211,7 +220,7 @@ def extract_2d(
else:
d_arr = dask.array.squeeze(d_arr, axis=(0, 1, 2))

img = sitk.GetImageFromArray(d_arr.compute(), isVector=is_vector)
img = sitk.GetImageFromArray(d_arr.compute(**self.compute_args), isVector=is_vector)
img.SetSpacing((spacing_tczyx[4], spacing_tczyx[3]))

logger.debug(img)
Expand Down Expand Up @@ -457,7 +466,8 @@ def _image_statistics(self, quantiles=None, channel=None, *, zero_black_quantile
logger.debug(f"dask.config.global_config: {dask.config.global_config}")

logger.info(f'Building histogram for "{self.path}"...')
h, bins = histo.compute_histogram(histogram_bin_edges=None, density=False)

h, bins = histo.compute_histogram(histogram_bin_edges=None, density=False, compute_args=self.compute_args)

mids = 0.5 * (bins[1:] + bins[:-1])

Expand Down
36 changes: 29 additions & 7 deletions pytools/HedwigZarrImages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from pathlib import Path
import zarr
from typing import Optional, Iterable, Tuple, AnyStr, Union
from typing import Optional, Iterable, Tuple, AnyStr, Union, Dict
from types import MappingProxyType
from pytools.utils import OMEInfo
from pytools.HedwigZarrImage import HedwigZarrImage
import logging
Expand All @@ -27,16 +28,32 @@ class HedwigZarrImages:
Represents the set of images in a OME-NGFF ZARR structure.
"""

def __init__(self, zarr_path: Path, read_only=True):
def __init__(
self,
zarr_path: Path,
read_only=True,
*,
compute_args: Optional[Dict[str, str]] = MappingProxyType({"scheduler": "threads", "num_workers": 4}),
):
"""
Initialized by the path to a root of an OME zarr structure.
:param zarr_path: Path to the root of the ZARR structure.
:param read_only: If True, the ZARR structure is read only.
:param compute_args: A dictionary of arguments to be passed to dask.compute.
- The default uses a local threadpool scheduler with 4 threads. This provides reasonable performance and does
not oversubscribe the CPU when multiple operations are being performed concurrently.
- A 'synchronous' scheduler can be used for debugging or when no parallelism is required.
- If `None` then the global dask scheduler or Dask distributed scheduler will be used.
"""
# check zarr is valid
assert zarr_path.exists()
assert zarr_path.is_dir()
self.zarr_store = zarr.DirectoryStore(zarr_path)
self.zarr_root = zarr.Group(store=self.zarr_store, read_only=read_only)
self._ome_info = None
self._compute_args = compute_args

@property
def ome_xml_path(self) -> Optional[Path]:
Expand Down Expand Up @@ -80,28 +97,33 @@ def group(self, name: str) -> HedwigZarrImage:
"""

if self.ome_xml_path is None:
return HedwigZarrImage(self.zarr_root[name])
return HedwigZarrImage(self.zarr_root[name], compute_args=self._compute_args)

ome_index_to_zarr_group = self.zarr_root["OME"].attrs["series"]
k_idx = ome_index_to_zarr_group.index(name)
return HedwigZarrImage(self.zarr_root[name], self.ome_info, k_idx)
return HedwigZarrImage(self.zarr_root[name], self.ome_info, k_idx, compute_args=self._compute_args)

def __getitem__(self, item: Union[str, int]) -> HedwigZarrImage:
"""
Returns a HedwigZarrImage from the given the OME series name or a ZARR index.
"""

if "OME" not in self.zarr_root.group_keys():
return HedwigZarrImage(self.zarr_root[item], self.ome_info, 404)
return HedwigZarrImage(self.zarr_root[item], self.ome_info, 404, compute_args=self._compute_args)

elif isinstance(item, int):
return HedwigZarrImage(self.zarr_root[item], self.ome_info, item)
return HedwigZarrImage(self.zarr_root[item], self.ome_info, item, compute_args=self._compute_args)

elif isinstance(item, str):
ome_index_to_zarr_group = self.zarr_root["OME"].attrs["series"]
for ome_idx, k in enumerate(self.get_series_keys()):
if k == item:
return HedwigZarrImage(self.zarr_root[ome_index_to_zarr_group[ome_idx]], self.ome_info, ome_idx)
return HedwigZarrImage(
self.zarr_root[ome_index_to_zarr_group[ome_idx]],
self.ome_info,
ome_idx,
compute_args=self._compute_args,
)
raise KeyError(f"Series name {item} not found: {list(self.get_series_keys())}! ")

def series(self) -> Iterable[Tuple[str, HedwigZarrImage]]:
Expand Down
19 changes: 16 additions & 3 deletions pytools/utils/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,20 @@ def compute_min_max(self):
def dtype(self):
return self._arr.dtype

def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np.array, np.array]:
def compute_histogram(
self, histogram_bin_edges=None, density=False, *, compute_args=None
) -> Tuple[np.array, np.array]:
"""
Compute the histogram of the array.
:param histogram_bin_edges: The edges of the bins. If None, the edges are computed from the array, with integers
of 16 bits or less, the edges are computed from the dtype for exact bins.
:param density: If True, the histogram is normalized to form a probability density, otherwise the count of
samples in each bin.
:param compute_args: Additional arguments to pass to the dask compute method.
"""
if compute_args is None:
compute_args = {}
if histogram_bin_edges is None:
if np.issubdtype(self.dtype, np.integer) and np.iinfo(self.dtype).bits <= 16:
histogram_bin_edges = self.compute_histogram_bin_edges(
Expand All @@ -127,15 +140,15 @@ def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np
new_chunk = (None,) * (self._arr.ndim - 1) + (-1,)
arr = self._arr.rechunk(new_chunk).ravel()
return (
dask.array.bincount(arr, minlength=len(histogram_bin_edges) - 1).compute(),
dask.array.bincount(arr, minlength=len(histogram_bin_edges) - 1).compute(**compute_args),
histogram_bin_edges,
)

else:
histogram_bin_edges = self.compute_histogram_bin_edges()

h, bins = dask.array.histogram(self._arr, bins=histogram_bin_edges, density=density)
return h.compute(), bins
return h.compute(**compute_args), bins


class ZARRHistogramHelper(DaskHistogramHelper):
Expand Down

0 comments on commit d44a27c

Please sign in to comment.