From d44a27c4adf148cae1f53143256785d908c27d46 Mon Sep 17 00:00:00 2001 From: Bradley Lowekamp Date: Tue, 25 Jun 2024 07:39:05 -0400 Subject: [PATCH] Use local theads for dask array compute Add a compute_args configuration parameter to the HedwigImages class to allow configuration of the dask schedualer used for dask array computation and ZARR operations. This addresses a deadlock issue when the Dask array is using the global schedualer that Prefect uses. This avoids workers recusivedly using the globlal schedualer while blocking a worker. --- pytools/HedwigZarrImage.py | 16 +++++++++++++--- pytools/HedwigZarrImages.py | 36 +++++++++++++++++++++++++++++------- pytools/utils/histogram.py | 19 ++++++++++++++++--- 3 files changed, 58 insertions(+), 13 deletions(-) diff --git a/pytools/HedwigZarrImage.py b/pytools/HedwigZarrImage.py index d48bc2a..a0796c3 100644 --- a/pytools/HedwigZarrImage.py +++ b/pytools/HedwigZarrImage.py @@ -16,6 +16,7 @@ import SimpleITK as sitk import zarr from typing import Tuple, Dict, List, Optional, Iterable +from types import MappingProxyType from pytools.utils import OMEInfo import logging import math @@ -33,10 +34,18 @@ class HedwigZarrImage: Represents a OME-NGFF Zarr pyramidal image. The members provide information useful for the Hedwig imaging pipelines. """ - def __init__(self, zarr_grp: zarr.Group, _ome_info: OMEInfo, _ome_idx: Optional[int] = None): + def __init__( + self, + zarr_grp: zarr.Group, + _ome_info: OMEInfo, + _ome_idx: Optional[int] = None, + *, + compute_args: Optional[Dict[str, str]] = MappingProxyType({"scheduler": "threads"}), + ): self.zarr_group = zarr_grp self.ome_info = _ome_info self.ome_idx = _ome_idx + self.compute_args = compute_args if compute_args is not None else {} assert "multiscales" in self.zarr_group.attrs @@ -211,7 +220,7 @@ def extract_2d( else: d_arr = dask.array.squeeze(d_arr, axis=(0, 1, 2)) - img = sitk.GetImageFromArray(d_arr.compute(), isVector=is_vector) + img = sitk.GetImageFromArray(d_arr.compute(**self.compute_args), isVector=is_vector) img.SetSpacing((spacing_tczyx[4], spacing_tczyx[3])) logger.debug(img) @@ -457,7 +466,8 @@ def _image_statistics(self, quantiles=None, channel=None, *, zero_black_quantile logger.debug(f"dask.config.global_config: {dask.config.global_config}") logger.info(f'Building histogram for "{self.path}"...') - h, bins = histo.compute_histogram(histogram_bin_edges=None, density=False) + + h, bins = histo.compute_histogram(histogram_bin_edges=None, density=False, compute_args=self.compute_args) mids = 0.5 * (bins[1:] + bins[:-1]) diff --git a/pytools/HedwigZarrImages.py b/pytools/HedwigZarrImages.py index 79ce0b4..8c8a225 100644 --- a/pytools/HedwigZarrImages.py +++ b/pytools/HedwigZarrImages.py @@ -13,7 +13,8 @@ from pathlib import Path import zarr -from typing import Optional, Iterable, Tuple, AnyStr, Union +from typing import Optional, Iterable, Tuple, AnyStr, Union, Dict +from types import MappingProxyType from pytools.utils import OMEInfo from pytools.HedwigZarrImage import HedwigZarrImage import logging @@ -27,9 +28,24 @@ class HedwigZarrImages: Represents the set of images in a OME-NGFF ZARR structure. """ - def __init__(self, zarr_path: Path, read_only=True): + def __init__( + self, + zarr_path: Path, + read_only=True, + *, + compute_args: Optional[Dict[str, str]] = MappingProxyType({"scheduler": "threads", "num_workers": 4}), + ): """ Initialized by the path to a root of an OME zarr structure. + + :param zarr_path: Path to the root of the ZARR structure. + :param read_only: If True, the ZARR structure is read only. + :param compute_args: A dictionary of arguments to be passed to dask.compute. + - The default uses a local threadpool scheduler with 4 threads. This provides reasonable performance and does + not oversubscribe the CPU when multiple operations are being performed concurrently. + - A 'synchronous' scheduler can be used for debugging or when no parallelism is required. + - If `None` then the global dask scheduler or Dask distributed scheduler will be used. + """ # check zarr is valid assert zarr_path.exists() @@ -37,6 +53,7 @@ def __init__(self, zarr_path: Path, read_only=True): self.zarr_store = zarr.DirectoryStore(zarr_path) self.zarr_root = zarr.Group(store=self.zarr_store, read_only=read_only) self._ome_info = None + self._compute_args = compute_args @property def ome_xml_path(self) -> Optional[Path]: @@ -80,11 +97,11 @@ def group(self, name: str) -> HedwigZarrImage: """ if self.ome_xml_path is None: - return HedwigZarrImage(self.zarr_root[name]) + return HedwigZarrImage(self.zarr_root[name], compute_args=self._compute_args) ome_index_to_zarr_group = self.zarr_root["OME"].attrs["series"] k_idx = ome_index_to_zarr_group.index(name) - return HedwigZarrImage(self.zarr_root[name], self.ome_info, k_idx) + return HedwigZarrImage(self.zarr_root[name], self.ome_info, k_idx, compute_args=self._compute_args) def __getitem__(self, item: Union[str, int]) -> HedwigZarrImage: """ @@ -92,16 +109,21 @@ def __getitem__(self, item: Union[str, int]) -> HedwigZarrImage: """ if "OME" not in self.zarr_root.group_keys(): - return HedwigZarrImage(self.zarr_root[item], self.ome_info, 404) + return HedwigZarrImage(self.zarr_root[item], self.ome_info, 404, compute_args=self._compute_args) elif isinstance(item, int): - return HedwigZarrImage(self.zarr_root[item], self.ome_info, item) + return HedwigZarrImage(self.zarr_root[item], self.ome_info, item, compute_args=self._compute_args) elif isinstance(item, str): ome_index_to_zarr_group = self.zarr_root["OME"].attrs["series"] for ome_idx, k in enumerate(self.get_series_keys()): if k == item: - return HedwigZarrImage(self.zarr_root[ome_index_to_zarr_group[ome_idx]], self.ome_info, ome_idx) + return HedwigZarrImage( + self.zarr_root[ome_index_to_zarr_group[ome_idx]], + self.ome_info, + ome_idx, + compute_args=self._compute_args, + ) raise KeyError(f"Series name {item} not found: {list(self.get_series_keys())}! ") def series(self) -> Iterable[Tuple[str, HedwigZarrImage]]: diff --git a/pytools/utils/histogram.py b/pytools/utils/histogram.py index 437d925..5f23a4d 100644 --- a/pytools/utils/histogram.py +++ b/pytools/utils/histogram.py @@ -115,7 +115,20 @@ def compute_min_max(self): def dtype(self): return self._arr.dtype - def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np.array, np.array]: + def compute_histogram( + self, histogram_bin_edges=None, density=False, *, compute_args=None + ) -> Tuple[np.array, np.array]: + """ + Compute the histogram of the array. + + :param histogram_bin_edges: The edges of the bins. If None, the edges are computed from the array, with integers + of 16 bits or less, the edges are computed from the dtype for exact bins. + :param density: If True, the histogram is normalized to form a probability density, otherwise the count of + samples in each bin. + :param compute_args: Additional arguments to pass to the dask compute method. + """ + if compute_args is None: + compute_args = {} if histogram_bin_edges is None: if np.issubdtype(self.dtype, np.integer) and np.iinfo(self.dtype).bits <= 16: histogram_bin_edges = self.compute_histogram_bin_edges( @@ -127,7 +140,7 @@ def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np new_chunk = (None,) * (self._arr.ndim - 1) + (-1,) arr = self._arr.rechunk(new_chunk).ravel() return ( - dask.array.bincount(arr, minlength=len(histogram_bin_edges) - 1).compute(), + dask.array.bincount(arr, minlength=len(histogram_bin_edges) - 1).compute(**compute_args), histogram_bin_edges, ) @@ -135,7 +148,7 @@ def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np histogram_bin_edges = self.compute_histogram_bin_edges() h, bins = dask.array.histogram(self._arr, bins=histogram_bin_edges, density=density) - return h.compute(), bins + return h.compute(**compute_args), bins class ZARRHistogramHelper(DaskHistogramHelper):