Skip to content

Commit

Permalink
Fix/anndata serialization (#1122)
Browse files Browse the repository at this point in the history
* Fix AnnData I/O for Lineage, bump version

* Remove variable

* Add test

* Add `zarr` to test reqs
  • Loading branch information
michalk8 authored Sep 20, 2023
1 parent bd35386 commit c3ced63
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks
Submodule notebooks updated 1 files
+1 −1 README.rst
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ maintainers = [


dependencies = [
"anndata>=0.8",
"anndata>=0.9",
"docrep>=0.3.0",
"joblib>=0.13.1",
"matplotlib>=3.5.0,<3.7.2",
Expand Down Expand Up @@ -74,6 +74,7 @@ test = [
"pytest-mock>=3.5.0",
"pytest-cov>=4",
"coverage[toml]>=7",
"zarr",
"igraph",
"leidenalg",
"Pillow",
Expand Down
14 changes: 6 additions & 8 deletions src/cellrank/_utils/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,17 +1195,15 @@ def _mutual_info(reference, query):
return weights


_SPEC = IOSpec("array", "0.2.0")


@_REGISTRY.register_write(H5Group, Lineage, _SPEC)
@_REGISTRY.register_write(H5Group, LineageView, _SPEC)
@_REGISTRY.register_write(ZarrGroup, Lineage, _SPEC)
@_REGISTRY.register_write(ZarrGroup, LineageView, _SPEC)
@_REGISTRY.register_write(H5Group, Lineage, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, LineageView, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, Lineage, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(ZarrGroup, LineageView, IOSpec("array", "0.2.0"))
def _write_lineage(
f: Any,
k: str,
elem: Union[Lineage, LineageView],
_writer: Any,
dataset_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> None:
write_basic(f, k, elem=elem.X, dataset_kwargs=dataset_kwargs)
write_basic(f, k, elem=elem.X, _writer=_writer, dataset_kwargs=dataset_kwargs)
22 changes: 21 additions & 1 deletion tests/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from matplotlib import colors

from anndata import AnnData, read_h5ad, read_zarr

from cellrank._utils import Lineage
from cellrank._utils._colors import _compute_mean_color, _create_categorical_colors
from cellrank._utils._lineage import _HT_CELLS, LineageView, PrimingDegree
Expand Down Expand Up @@ -973,7 +975,7 @@ def test_double_view_owner(self, lineage: Lineage):
assert x.owner is lineage


class TestPickling:
class TestIO:
def test_pickle_normal(self, lineage: Lineage):
handle = io.BytesIO()

Expand Down Expand Up @@ -1015,6 +1017,24 @@ def test_pickle_transposed(self, lineage: Lineage):
assert res._n_lineages == lineage._n_lineages
assert res._is_transposed == lineage._is_transposed

def test_anndata_write(self, lineage: Lineage, tmp_path):
rng = np.random.default_rng(0)
adata = AnnData(rng.normal(size=(lineage.shape[0], 13)))
adata.obsm["lin"] = lineage

assert isinstance(adata.obsm["lin"], Lineage)

adata.write_h5ad(tmp_path / "tmp.h5ad")
adata.write_zarr(tmp_path / "tmp.zarr")

adata_h5ad = read_h5ad(tmp_path / "tmp.h5ad")
adata_zarr = read_zarr(tmp_path / "tmp.zarr")

assert isinstance(adata_h5ad.obsm["lin"], np.ndarray)
assert isinstance(adata_zarr.obsm["lin"], np.ndarray)
np.testing.assert_array_equal(adata_h5ad.obsm["lin"], adata.obsm["lin"].X)
np.testing.assert_array_equal(adata_zarr.obsm["lin"], adata.obsm["lin"].X)


class TestPriming:
def test_invalid_method(self, lineage: Lineage):
Expand Down

0 comments on commit c3ced63

Please sign in to comment.