diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 4680148e..2233b2e9 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -617,7 +617,14 @@ def _( def _polygon_query( - sdata: SpatialData, polygon: Polygon, target_coordinate_system: str, filter_table: bool, shapes: bool, points: bool + sdata: SpatialData, + polygon: Polygon, + target_coordinate_system: str, + filter_table: bool, + shapes: bool, + points: bool, + images: bool, + labels: bool, ) -> SpatialData: from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.query.relational_query import _filter_table_by_elements @@ -669,11 +676,32 @@ def _polygon_query( set_transformation(ddf, transformation, target_coordinate_system) new_points[points_name] = ddf - if filter_table: + new_images = {} + if images: + for images_name, im in sdata.images.items(): + min_x, min_y, max_x, max_y = polygon.bounds + cropped = bounding_box_query( + im, + min_coordinate=[min_x, min_y], + max_coordinate=[max_x, max_y], + axes=("x", "y"), + target_coordinate_system=target_coordinate_system, + ) + new_images[images_name] = cropped + if labels: + for labels_name, l in sdata.labels.items(): + _ = labels_name + _ = l + raise NotImplementedError( + "labels=True is not implemented yet. If you encounter this error please open an " + "issue and we will prioritize the implementation." + ) + + if filter_table and sdata.table is not None: table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points}) else: table = sdata.table - return SpatialData(shapes=new_shapes, points=new_points, table=table) + return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table) # this function is currently excluded from the API documentation. TODO: add it after the refactoring @@ -684,6 +712,8 @@ def polygon_query( filter_table: bool = True, shapes: bool = True, points: bool = True, + images: bool = True, + labels: bool = True, ) -> SpatialData: """ Query a spatial data object by a polygon, filtering shapes and points. @@ -725,14 +755,21 @@ def polygon_query( filter_table=filter_table, shapes=shapes, points=points, + images=images, + labels=labels, ) # TODO: the performance for this case can be greatly improved by using the geopandas queries only once, and not # in a loop as done preliminarily here - if points: - raise NotImplementedError( - "points=True is not implemented when querying by multiple polygons. If you encounter this error, please" - " open an issue on GitHub and we will prioritize the implementation." + if points or images or labels: + logger.warning( + "Spatial querying of images, points and labels is not implemented when querying by multiple polygons " + 'simultaneously. You can silence this warning by setting "points=False, images=False, labels=False". If ' + "you need this implementation please open an issue on GitHub." ) + points = False + images = False + labels = False + sdatas = [] for polygon in tqdm(polygons): try: @@ -744,6 +781,8 @@ def polygon_query( filter_table=False, shapes=shapes, points=points, + images=images, + labels=labels, ) sdatas.append(queried_sdata) except ValueError as e: diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index ae9e2047..6db7e904 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -4,6 +4,7 @@ import pytest from anndata import AnnData from multiscale_spatial_image import MultiscaleSpatialImage +from shapely import Polygon from spatial_image import SpatialImage from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( @@ -379,7 +380,11 @@ def test_polygon_query_shapes(sdata_query_aggregation): circle_pol = circle.buffer(sdata["by_circles"].radius.iloc[0]) queried = polygon_query( - values_sdata, polygons=polygon, target_coordinate_system="global", shapes=True, points=False + values_sdata, + polygons=polygon, + target_coordinate_system="global", + shapes=True, + points=False, ) assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 @@ -432,11 +437,34 @@ def test_polygon_query_spatial_data(sdata_query_aggregation): assert len(queried.table) == 8 -@pytest.mark.skip -def test_polygon_query_image2d(): - # single image case - # multiscale case - pass +@pytest.mark.parametrize("n_channels", [1, 2, 3]) +def test_polygon_query_image2d(n_channels: int): + original_image = np.zeros((n_channels, 10, 10)) + # y: [5, 9], x: [0, 4] has value 1 + original_image[:, 5::, 0:5] = 1 + image_element = Image2DModel.parse(original_image) + image_element_multiscale = Image2DModel.parse(original_image, scale_factors=[2, 2]) + + polygon = Polygon([(3, 3), (3, 7), (5, 3)]) + for image in [image_element, image_element_multiscale]: + # bounding box: y: [5, 10[, x: [0, 5[ + image_result = polygon_query( + SpatialData(images={"my_image": image}), + polygons=polygon, + target_coordinate_system="global", + )["my_image"] + expected_image = original_image[:, 3:7, 3:5] # c dimension is preserved + if isinstance(image, SpatialImage): + assert isinstance(image, SpatialImage) + np.testing.assert_allclose(image_result, expected_image) + elif isinstance(image, MultiscaleSpatialImage): + assert isinstance(image_result, MultiscaleSpatialImage) + v = image_result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected_image) + else: + raise ValueError("Unexpected type") @pytest.mark.skip