Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improves dataloader performance #687

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 142 additions & 24 deletions src/spatialdata/_core/query/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

from typing import Any

import numba as nb
import numpy as np
from anndata import AnnData
from datatree import DataTree
from xarray import DataArray

from spatialdata._core._elements import Tables
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata._utils import Number, _parse_list_into_array
from spatialdata.transformations._utils import compute_coordinates
from spatialdata.transformations.transformations import (
BaseTransformation,
Sequence,
Translation,
)


def get_bounding_box_corners(
Expand Down Expand Up @@ -36,37 +45,146 @@ def get_bounding_box_corners(
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

if len(min_coordinate) not in (2, 3):
if min_coordinate.ndim == 1:
min_coordinate = min_coordinate[np.newaxis, :]
max_coordinate = max_coordinate[np.newaxis, :]

if min_coordinate.shape[1] not in (2, 3):
raise ValueError("bounding box must be 2D or 3D")

if len(min_coordinate) == 2:
num_boxes = min_coordinate.shape[0]
num_dims = min_coordinate.shape[1]

if num_dims == 2:
# 2D bounding box
assert len(axes) == 2
return DataArray(
corners = np.array(
[
[min_coordinate[0], min_coordinate[1]],
[min_coordinate[0], max_coordinate[1]],
[max_coordinate[0], max_coordinate[1]],
[max_coordinate[0], min_coordinate[1]],
],
coords={"corner": range(4), "axis": list(axes)},
[min_coordinate[:, 0], min_coordinate[:, 1]],
[min_coordinate[:, 0], max_coordinate[:, 1]],
[max_coordinate[:, 0], max_coordinate[:, 1]],
[max_coordinate[:, 0], min_coordinate[:, 1]],
]
)

# 3D bounding cube
assert len(axes) == 3
return DataArray(
[
[min_coordinate[0], min_coordinate[1], min_coordinate[2]],
[min_coordinate[0], min_coordinate[1], max_coordinate[2]],
[min_coordinate[0], max_coordinate[1], max_coordinate[2]],
[min_coordinate[0], max_coordinate[1], min_coordinate[2]],
[max_coordinate[0], min_coordinate[1], min_coordinate[2]],
[max_coordinate[0], min_coordinate[1], max_coordinate[2]],
[max_coordinate[0], max_coordinate[1], max_coordinate[2]],
[max_coordinate[0], max_coordinate[1], min_coordinate[2]],
],
coords={"corner": range(8), "axis": list(axes)},
corners = np.transpose(corners, (2, 0, 1))
else:
# 3D bounding cube
assert len(axes) == 3
corners = np.array(
[
[min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]],
[min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]],
[min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]],
[min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]],
[max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]],
[max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]],
[max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]],
[max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]],
]
)
corners = np.transpose(corners, (2, 0, 1))
output = DataArray(
corners,
coords={
"box": range(num_boxes),
"corner": range(corners.shape[1]),
"axis": list(axes),
},
)
if num_boxes > 1:
return output
return output.squeeze().drop_vars("box")


@nb.njit(parallel=False, nopython=True)
def _create_slices_and_translation(
min_values: nb.types.Array[nb.float64, nb.float64],
max_values: nb.types.Array[nb.float64, nb.float64],
) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]:
n_boxes, n_dims = min_values.shape
slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max])
translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims)

for i in range(n_boxes):
for j in range(n_dims):
slices[i, j, 0] = min_values[i, j]
slices[i, j, 1] = max_values[i, j]
translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0))

return slices, translation_vectors


def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None:
d = {}
for k, data_tree in query_result.items():
v = data_tree.values()
assert len(v) == 1
xdata = v.__iter__().__next__()
if 0 in xdata.shape:
if k == "scale0":
return None
else:
d[k] = xdata

# Remove scales after finding a missing scale
scales_to_keep = []
for i, scale_name in enumerate(d.keys()):
if scale_name == f"scale{i}":
scales_to_keep.append(scale_name)
else:
break

# Case in which scale0 is not present but other scales are
if len(scales_to_keep) == 0:
return None

d = {k: d[k] for k in scales_to_keep}
result = DataTree.from_dict(d)

# Rechunk the data to avoid irregular chunks
for scale in result:
result[scale]["image"] = result[scale]["image"].chunk("auto")

return result


def _process_query_result(
result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...]
) -> DataArray | DataTree | None:
from spatialdata.transformations import get_transformation, set_transformation

if isinstance(result, DataArray):
if 0 in result.shape:
return None
# rechunk the data to avoid irregular chunks
result = result.chunk("auto")
elif isinstance(result, DataTree):
result = _process_data_tree_query_result(result)
if result is None:
return None

result = compute_coordinates(result)

if not np.allclose(np.array(translation_vector), 0):
translation_transform = Translation(translation=translation_vector, axes=axes)

transformations = get_transformation(result, get_all=True)
assert isinstance(transformations, dict)

new_transformations = {}
for coordinate_system, initial_transform in transformations.items():
new_transformation: BaseTransformation = Sequence(
[translation_transform, initial_transform],
)
new_transformations[coordinate_system] = new_transformation
set_transformation(result, new_transformations, set_all=True)

# let's make a copy of the transformations so that we don't modify the original object
t = get_transformation(result, get_all=True)
assert isinstance(t, dict)
set_transformation(result, t.copy(), set_all=True)

return result


def _get_filtered_or_unfiltered_tables(
Expand Down
Loading
Loading