Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fill_value handling for complex dtypes #2200

Merged
merged 12 commits into from
Sep 25, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@ fixture/
.DS_Store
tests/.hypothesis
.hypothesis/

zarr/version.py
7 changes: 0 additions & 7 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,6 @@ async def _create_v3(
shape = parse_shapelike(shape)
codecs = list(codecs) if codecs is not None else [BytesCodec()]

if fill_value is None:
normanrz marked this conversation as resolved.
Show resolved Hide resolved
if dtype == np.dtype("bool"):
fill_value = False
else:
fill_value = 0

if chunk_key_encoding is None:
chunk_key_encoding = ("default", "/")
assert chunk_key_encoding is not None
Expand All @@ -281,7 +275,6 @@ async def _create_v3(
)

array = cls(metadata=metadata, store_path=store_path)

await array._save_metadata(metadata)
return array

Expand Down
7 changes: 6 additions & 1 deletion src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,14 @@ def __repr__(self) -> str:

def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
"""Compare to `other` using np.array_equal."""
if other is None:
# Handle None fill_value for Zarr V2
return False
# use array_equal to obtain equal_nan=True functionality
data, other = np.broadcast_arrays(self._data, other)
result = np.array_equal(self._data, other, equal_nan=equal_nan)
result = np.array_equal(
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also needed for V2

)
return result

def fill(self, value: Any) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def parse_fill_value(
if fill_value is None:
return dtype.type(0)
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
if dtype in (np.complex64, np.complex128):
if dtype.type in (np.complex64, np.complex128):
dtype = cast(COMPLEX_DTYPE, dtype)
if len(fill_value) == 2:
# complex datatypes serialize to JSON arrays with two elements
Expand Down Expand Up @@ -391,7 +391,7 @@ def parse_fill_value(
pass
elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value):
pass
elif dtype.kind == "f":
elif dtype.kind in "cf":
# float comparison is not exact, especially when dtype <float64
# so we us np.isclose for this comparison.
# this also allows us to compare nan fill_values
Expand Down
112 changes: 62 additions & 50 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any
from typing import Any, Literal

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
Expand All @@ -19,6 +18,35 @@
max_leaves=3,
)


def v3_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
# | npst.byte_string_dtypes(endianness="=")
# | npst.unicode_string_dtypes()
# | npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


def v2_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
| npst.byte_string_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
| npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
# 1. must not be the empty string ("")
# 2. must not include the character "/"
Expand All @@ -33,21 +61,29 @@
array_names = node_names
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
# TODO: re-enable timedeltas once they are supported
dtype=npst.scalar_dtypes().filter(
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
),
shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
compressors = st.sampled_from([None, "default"])
format = st.sampled_from([2, 3])
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
array_shapes = npst.array_shapes(max_dims=4)


@st.composite # type: ignore[misc]
def numpy_arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
) -> Any:
"""
Generate numpy arrays that can be saved in the provided Zarr format.
"""
zarr_format = draw(zarr_formats)
return draw(npst.arrays(dtype=v3_dtypes() if zarr_format == 3 else v2_dtypes(), shape=shapes))


@st.composite # type: ignore[misc]
def np_array_and_chunks(
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = np_arrays
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
) -> tuple[np.ndarray, tuple[int]]: # type: ignore[type-arg]
"""A hypothesis strategy to generate small sized random arrays.

Expand All @@ -66,73 +102,49 @@ def np_array_and_chunks(
def arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
compressors: st.SearchStrategy = compressors,
stores: st.SearchStrategy[StoreLike] = stores,
arrays: st.SearchStrategy[np.ndarray] = np_arrays,
paths: st.SearchStrategy[None | str] = paths,
array_names: st.SearchStrategy = array_names,
arrays: st.SearchStrategy | None = None,
attrs: st.SearchStrategy = attrs,
format: st.SearchStrategy = format,
zarr_formats: st.SearchStrategy = zarr_formats,
) -> Array:
store = draw(stores)
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
path = draw(paths)
name = draw(array_names)
attributes = draw(attrs)
zarr_format = draw(format)
zarr_format = draw(zarr_formats)
if arrays is None:
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
# test that None works too.
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
# compressor = draw(compressors)

# TODO: clean this up
# if path is None and name is None:
# array_path = None
# array_name = None
# elif path is None and name is not None:
# array_path = f"{name}"
# array_name = f"/{name}"
# elif path is not None and name is None:
# array_path = path
# array_name = None
# elif path == "/":
# assert name is not None
# array_path = name
# array_name = "/" + name
# else:
# assert name is not None
# array_path = f"{path}/{name}"
# array_name = "/" + array_path

expected_attrs = {} if attributes is None else attributes

array_path = path + ("/" if not path.endswith("/") else "") + name
root = Group.from_store(store, zarr_format=zarr_format)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
m = re.search(r"\[(.+)\]", nparray.dtype.str)
if not m:
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")

fill_value_args = (
# e.g. ns, D
m.groups()[0],
)

a = root.create_array(
array_path,
shape=nparray.shape,
chunks=chunks,
dtype=nparray.dtype.str,
dtype=nparray.dtype,
attributes=attributes,
# compressor=compressor, # TODO: FIXME
fill_value=nparray.dtype.type(0, *fill_value_args),
# compressor=compressor, # FIXME
fill_value=fill_value,
)

assert isinstance(a, Array)
assert a.fill_value is not None
assert isinstance(root[array_path], Array)
assert nparray.shape == a.shape
assert chunks == a.chunks
assert array_path == a.path, (path, name, array_path, a.name, a.path)
# assert array_path == a.name, (path, name, array_path, a.name, a.path)
# assert a.basename is None # TODO
# assert a.store == normalize_store_arg(store)
assert a.basename == name, (a.basename, name)
assert dict(a.attrs) == expected_attrs

a[:] = nparray
Expand Down
10 changes: 5 additions & 5 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import hypothesis.strategies as st # noqa: E402
from hypothesis import given # noqa: E402

from zarr.testing.strategies import arrays, basic_indices, np_arrays # noqa: E402
from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats # noqa: E402


@given(st.data())
def test_roundtrip(data: st.DataObject) -> None:
nparray = data.draw(np_arrays)
zarray = data.draw(arrays(arrays=st.just(nparray)))
@given(data=st.data(), zarr_format=zarr_formats)
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
assert_array_equal(nparray, zarray[:])


Expand Down