Skip to content

Commit

Permalink
(fix): use dask array for missing element in concatenation
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Nov 27, 2024
1 parent 41369da commit 3009058
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,21 +939,31 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0):
return reindexers


def missing_element(
n: int,
els: Iterable[
SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray
],
axis: Literal[0, 1] = 0,
) -> np.ndarray:
"""Generates value to use when there is a missing element."""
should_return_dask = any(isinstance(el, DaskArray) for el in els)
shape = (0, n) if axis else (n, 0)
if should_return_dask:
import dask.array as da

return da.zeros(shape)
return np.zeros(shape, dtype=bool)


def outer_concat_aligned_mapping(
mappings, *, reindexers=None, index=None, axis=0, fill_value=None
):
result = {}
ns = [m.parent.shape[axis] for m in mappings]

def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray:
"""Generates value to use when there is a missing element."""
if axis == 0:
return np.zeros((n, 0), dtype=bool)
else:
return np.zeros((0, n), dtype=bool)

for k in union_keys(mappings):
els = [m.get(k, MissingVal) for m in mappings]
els = [m[k] if k in m else MissingVal for m in mappings]
if reindexers is None:
cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis)
else:
Expand All @@ -963,7 +973,7 @@ def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray:
# We should probably just handle missing elements for all types
result[k] = concat_arrays(
[
el if not_missing(el) else missing_element(n, axis=axis)
el if not_missing(el) else missing_element(n, axis=axis, els=els)
for el, n in zip(els, ns)
],
cur_reindexers,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,16 @@ def test_concat_different_types_dask(merge_strategy, array_type):
assert_equal(result2, target2)


# Tests how dask plays with other types on concatenation.
def test_impute_dask():
import dask.array as da

from anndata._core.merge import missing_element

els = [da.ones((5, 5))]
assert isinstance(missing_element(5, els, axis=0), DaskArray)


def test_outer_concat_with_missing_value_for_df():
# https://github.com/scverse/anndata/issues/901
# TODO: Extend this test to cover all cases of missing values
Expand Down

0 comments on commit 3009058

Please sign in to comment.