Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding attrs at the SpatialData object level #711

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ and this project adheres to [Semantic Versioning][].

## [0.2.3] - 2024-09-25

### Major

- Added attributes at the SpatialData object level (`.attrs`)

### Minor

- Added `clip: bool = False` parameter to `polygon_query()` #670
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
Submodule notebooks updated 162 files
10 changes: 9 additions & 1 deletion src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
from warnings import warn

import numpy as np
from anndata import AnnData
from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -80,6 +81,7 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -108,6 +110,8 @@ def concatenate(
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
attrs_merge
How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html)
kwargs
See :func:`anndata.concat` for more details.

Expand Down Expand Up @@ -188,12 +192,16 @@ def concatenate(
else:
merged_tables[k] = v

attrs_merge = resolve_merge_strategy(attrs_merge)
attrs = attrs_merge([sdata.attrs for sdata in sdatas])

sdata = SpatialData(
images=merged_images,
labels=merged_labels,
points=merged_points,
shapes=merged_shapes,
tables=merged_tables,
attrs=attrs,
)
if obs_names_make_unique:
for table in sdata.tables.values():
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def rasterize(
new_labels[new_name] = rasterized
else:
raise RuntimeError(f"Unsupported model {model} detected as return type of rasterize().")
return SpatialData(images=new_images, labels=new_labels, tables=data.tables)
return SpatialData(images=new_images, labels=new_labels, tables=data.tables, attrs=data.attrs)

parsed_data = _parse_element(element=data, sdata=sdata, element_var_name="data", sdata_var_name="sdata")
model = get_model(parsed_data)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _(
new_elements[element_type][k] = transform(
v, transformation, to_coordinate_system=to_coordinate_system, maintain_positioning=maintain_positioning
)
return SpatialData(**new_elements)
return SpatialData(**new_elements, attrs=data.attrs)


@transform.register(DataArray)
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@bounding_box_query.register(DataArray)
Expand Down Expand Up @@ -885,7 +885,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@polygon_query.register(DataArray)
Expand Down
48 changes: 42 additions & 6 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
import warnings
from collections.abc import Generator
from collections.abc import Generator, Mapping
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
points: dict[str, DaskDataFrame] | None = None,
shapes: dict[str, GeoDataFrame] | None = None,
tables: dict[str, AnnData] | Tables | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> None:
self._path: Path | None = None

Expand All @@ -131,6 +132,7 @@ def __init__(
self._points: Points = Points(shared_keys=self._shared_keys)
self._shapes: Shapes = Shapes(shared_keys=self._shared_keys)
self._tables: Tables = Tables(shared_keys=self._shared_keys)
self._attrs: dict[Any, Any] = dict(attrs) if attrs else {}

# Workaround to allow for backward compatibility
if isinstance(tables, AnnData):
Expand Down Expand Up @@ -712,7 +714,7 @@ def filter_by_coordinate_system(
set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system
)

return SpatialData(**elements, tables=tables)
return SpatialData(**elements, tables=tables, attrs=self.attrs)

# TODO: move to relational query with refactor
def _filter_tables(
Expand Down Expand Up @@ -954,7 +956,7 @@ def transform_to_coordinate_system(
if element_type not in elements:
elements[element_type] = {}
elements[element_type][element_name] = transformed
return SpatialData(**elements, tables=sdata.tables)
return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs)

def elements_are_self_contained(self) -> dict[str, bool]:
"""
Expand Down Expand Up @@ -1179,7 +1181,8 @@ def write(
self._validate_can_safely_write_to_path(file_path, overwrite=overwrite)

store = parse_url(file_path, mode="w").store
_ = zarr.group(store=store, overwrite=overwrite)
zarr_group = zarr.group(store=store, overwrite=overwrite)
self.write_attrs(zarr_group=zarr_group)
store.close()

for element_type, element_name, element in self.gen_elements():
Expand Down Expand Up @@ -1583,7 +1586,28 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s
element_type, element_name = element_path.split("/")
return element_type, element_name

def write_metadata(self, element_name: str | None = None, consolidate_metadata: bool | None = None) -> None:
def write_attrs(self, overwrite: bool = True, zarr_group: zarr.Group | None = None) -> None:
store = None

if zarr_group is None:
assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs."
store = parse_url(self.path, mode="w").store
zarr_group = zarr.group(store=store, overwrite=overwrite)

try:
zarr_group.attrs.put(self.attrs)
except TypeError as e:
raise TypeError("Invalid attribute in SpatialData.attrs") from e

if store is not None:
store.close()

def write_metadata(
self,
element_name: str | None = None,
consolidate_metadata: bool | None = None,
write_attrs: bool = True,
) -> None:
"""
Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data.

Expand Down Expand Up @@ -1618,6 +1642,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.

if write_attrs:
self.write_attrs()

if consolidate_metadata is None and self.has_consolidated_metadata():
consolidate_metadata = True
if consolidate_metadata:
Expand Down Expand Up @@ -2173,7 +2200,7 @@ def subset(
include_orphan_tables,
elements_dict=elements_dict,
)
return SpatialData(**elements_dict, tables=tables)
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)

def __getitem__(self, item: str) -> SpatialElement:
"""
Expand Down Expand Up @@ -2255,6 +2282,15 @@ def __delitem__(self, key: str) -> None:
element_type, _, _ = self._find_element(key)
getattr(self, element_type).__delitem__(key)

@property
def attrs(self) -> dict[Any, Any]:
"""Dictionary of global attributes on this SpatialData object."""
return self._attrs

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
self._attrs = dict(value)


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down
1 change: 1 addition & 0 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non
points=points,
shapes=shapes,
tables=tables,
attrs=f.attrs.asdict(),
)
sdata.path = Path(store)
return sdata
Loading