Skip to content

Commit

Permalink
feature: Added custom chunk postprocessing function argument (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Karol-G committed May 21, 2024
1 parent d24e2bd commit 19abf27
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion patchly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.16"
__version__ = "0.0.17"

from patchly.sampler import GridSampler, SamplingMode
from patchly.aggregator import Aggregator
Expand Down
21 changes: 16 additions & 5 deletions patchly/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class PatchStatus(Enum):

class Aggregator:
def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.ArrayLike]] = None, output: Optional[npt.ArrayLike] = None,
weights: Union[str, Callable] = 'avg', softmax_dim: Optional[int] = None, has_batch_dim: bool = False, spatial_first: bool = True, device: str = 'cpu'):
weights: Union[str, Callable] = 'avg', softmax_dim: Optional[int] = None, has_batch_dim: bool = False, spatial_first: bool = True,
chunk_postprocess_func: Optional[Callable] = None, device: str = 'cpu'):
"""
Initializes the Aggregator object for aggregating patches into a larger output image.
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.
self.array_type = self.set_array_type(output)
self.output_h = self.set_output(output, output_size)
self.weight_patch_s, self.weight_map_s = self.set_weights(weights)
self.chunk_postprocess_func = chunk_postprocess_func
self.check_sanity()
self.aggregator = self.set_aggregator(self.sampler, self.output_h, self.softmax_dim)

Expand Down Expand Up @@ -140,6 +142,10 @@ def check_sanity(self) -> None:
raise RuntimeError("The spatial size of the given output {} is unequal to the given spatial size {}.".format(self.output_h.shape[-len(self.image_size_s):], self.image_size_s))
if self.has_batch_dim and self.spatial_first:
raise RuntimeError("The arguments has_batch_dim and spatial_first cannot both be true at the same time.")
if self.chunk_postprocess_func is not None and self.chunk_size_s is None:
raise RuntimeError("The argument chunk_postprocess_func can only be used with active chunking.")
if self.chunk_postprocess_func is not None and not isinstance(self.chunk_postprocess_func, Callable):
raise RuntimeError("The argument chunk_postprocess_func must either be 'None' or a callable.")
if self.mode.name.startswith('PAD_') and self.chunk_size_s is not None:
raise RuntimeError("The given sampling mode ({}) is not compatible with chunk sampling.".format(self.mode))

Expand All @@ -160,7 +166,7 @@ def set_aggregator(self, sampler: GridSampler, output_h: npt.ArrayLike, softmax_
elif self.mode.name.startswith('SAMPLE_') and self.chunk_size_s is not None:
aggregator = _ChunkAggregator(sampler=sampler, image_size_s=self.image_size_s, patch_size_s=self.patch_size_s, step_size_s=self.step_size_s,
chunk_size_s=self.chunk_size_s, output_h=output_h, spatial_first=self.spatial_first, softmax_dim=softmax_dim,
has_batch_dim=self.has_batch_dim, weight_patch_s=self.weight_patch_s, device=self.device, array_type=self.array_type)
has_batch_dim=self.has_batch_dim, weight_patch_s=self.weight_patch_s, chunk_postprocess_func=self.chunk_postprocess_func, device=self.device, array_type=self.array_type)
elif self.mode.name.startswith('PAD_') and self.chunk_size_s is None:
raise NotImplementedError("The given sampling mode ({}) is not supported.".format(self.mode))
elif self.mode.name.startswith('PAD_') and self.chunk_size_s is not None:
Expand Down Expand Up @@ -347,7 +353,7 @@ def verify_array_types(self, array_type: type) -> None:
class _ChunkAggregator(_Aggregator):
def __init__(self, sampler: GridSampler, image_size_s: Union[Tuple, npt.ArrayLike], patch_size_s: Union[Tuple, npt.ArrayLike], step_size_s: Union[Tuple, npt.ArrayLike], chunk_size_s: Union[Tuple, npt.ArrayLike],
output_h: Optional[npt.ArrayLike] = None, spatial_first: bool = True,
softmax_dim: Optional[int] = None, has_batch_dim: bool = False, weight_patch_s: npt.ArrayLike = None, device: str = 'cpu', array_type = None):
softmax_dim: Optional[int] = None, has_batch_dim: bool = False, weight_patch_s: npt.ArrayLike = None, chunk_postprocess_func: Optional[Callable] = None, device: str = 'cpu', array_type = None):
"""
Initializes the _ChunkAggregator object, a subclass of _Aggregator, for managing patch aggregation in chunks.
Expand All @@ -372,6 +378,7 @@ def __init__(self, sampler: GridSampler, image_size_s: Union[Tuple, npt.ArrayLik
self.chunk_size_s = chunk_size_s
self.chunk_dtype = self.set_chunk_dtype()
self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict = self.sampler.chunk_sampler, self.sampler.chunk_patch_dict, self.sampler.patch_chunk_dict
self.chunk_postprocess_func = chunk_postprocess_func
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)

def set_chunk_dtype(self) -> np.dtype:
Expand Down Expand Up @@ -468,8 +475,12 @@ def process_chunk(self, chunk_id: int) -> None:
chunk_h = chunk_h / weight_map_h.astype(chunk_h.dtype)
chunk_h = chunk_h.nan_to_num()
else:
# Argmax the softmax chunk
chunk_h = chunk_h.argmax(axis=self.softmax_dim).astype("uint16")
if self.chunk_postprocess_func is None:
# Argmax the softmax chunk
chunk_h = chunk_h.argmax(axis=self.softmax_dim).astype("uint16")
else:
# Apply custom chunk postprocessing function
chunk_h = self.chunk_postprocess_func(chunk_h)
chunk_bbox_h = self.chunk_sampler.__getitem__(chunk_id)
chunk_bbox_h = utils.bbox_s_to_bbox_h(chunk_bbox_h, self.output_h, self.spatial_first)
self.output_h[slicer(self.output_h, chunk_bbox_h)] = chunk_h.astype(self.output_h.dtype)
Expand Down

0 comments on commit 19abf27

Please sign in to comment.