From 3009058edbeb2e8704f9c029455276a8f6c77c31 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 27 Nov 2024 10:59:12 +0100 Subject: [PATCH] (fix): use dask array for missing element in concatenation --- src/anndata/_core/merge.py | 28 +++++++++++++++++++--------- tests/test_concatenate.py | 10 ++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 0dfa5dab2..dba1c41c7 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -939,21 +939,31 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0): return reindexers +def missing_element( + n: int, + els: Iterable[ + SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray + ], + axis: Literal[0, 1] = 0, +) -> np.ndarray: + """Generates value to use when there is a missing element.""" + should_return_dask = any(isinstance(el, DaskArray) for el in els) + shape = (0, n) if axis else (n, 0) + if should_return_dask: + import dask.array as da + + return da.zeros(shape) + return np.zeros(shape, dtype=bool) + + def outer_concat_aligned_mapping( mappings, *, reindexers=None, index=None, axis=0, fill_value=None ): result = {} ns = [m.parent.shape[axis] for m in mappings] - def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray: - """Generates value to use when there is a missing element.""" - if axis == 0: - return np.zeros((n, 0), dtype=bool) - else: - return np.zeros((0, n), dtype=bool) - for k in union_keys(mappings): - els = [m.get(k, MissingVal) for m in mappings] + els = [m[k] if k in m else MissingVal for m in mappings] if reindexers is None: cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis) else: @@ -963,7 +973,7 @@ def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray: # We should probably just handle missing elements for all types result[k] = concat_arrays( [ - el if not_missing(el) else missing_element(n, axis=axis) + el if not_missing(el) else missing_element(n, axis=axis, els=els) for el, n in zip(els, ns) ], cur_reindexers, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index d9f399dd6..a4be37bd5 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1533,6 +1533,16 @@ def test_concat_different_types_dask(merge_strategy, array_type): assert_equal(result2, target2) +# Tests how dask plays with other types on concatenation. +def test_impute_dask(): + import dask.array as da + + from anndata._core.merge import missing_element + + els = [da.ones((5, 5))] + assert isinstance(missing_element(5, els, axis=0), DaskArray) + + def test_outer_concat_with_missing_value_for_df(): # https://github.com/scverse/anndata/issues/901 # TODO: Extend this test to cover all cases of missing values