Skip to content

Commit

Permalink
Add v2, v3 specific dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 24, 2024
1 parent 6db8225 commit 1087178
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
4 changes: 3 additions & 1 deletion src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
return False
# use array_equal to obtain equal_nan=True functionality
data, other = np.broadcast_arrays(self._data, other)
result = np.array_equal(self._data, other, equal_nan=equal_nan)
result = np.array_equal(
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
)
return result

def fill(self, value: Any) -> None:
Expand Down
49 changes: 39 additions & 10 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
Expand All @@ -19,19 +19,34 @@
)


def dtypes() -> st.SearchStrategy[np.dtype]:
def v3_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
# | npst.byte_string_dtypes(endianness="=")
# | npst.unicode_string_dtypes()
# | npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


def v2_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
| npst.byte_string_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
| npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
)


# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
# 1. must not be the empty string ("")
# 2. must not include the character "/"
Expand All @@ -46,18 +61,29 @@ def dtypes() -> st.SearchStrategy[np.dtype]:
array_names = node_names
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
dtype=dtypes(),
shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
compressors = st.sampled_from([None, "default"])
zarr_formats = st.sampled_from([2, 3])
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
array_shapes = npst.array_shapes(max_dims=4)


@st.composite # type: ignore[misc]
def numpy_arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
) -> Any:
"""
Generate numpy arrays that can be saved in the provided Zarr format.
"""
zarr_format = draw(zarr_formats)
return draw(npst.arrays(dtype=v3_dtypes() if zarr_format == 3 else v2_dtypes(), shape=shapes))


@st.composite # type: ignore[misc]
def np_array_and_chunks(
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = np_arrays
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
) -> tuple[np.ndarray, tuple[int]]: # type: ignore[type-arg]
"""A hypothesis strategy to generate small sized random arrays.
Expand All @@ -76,20 +102,23 @@ def np_array_and_chunks(
def arrays(
draw: st.DrawFn,
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
compressors: st.SearchStrategy = compressors,
stores: st.SearchStrategy[StoreLike] = stores,
arrays: st.SearchStrategy[np.ndarray] = np_arrays,
paths: st.SearchStrategy[None | str] = paths,
array_names: st.SearchStrategy = array_names,
arrays: st.SearchStrategy | None = None,
attrs: st.SearchStrategy = attrs,
zarr_formats: st.SearchStrategy = zarr_formats,
) -> Array:
store = draw(stores)
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
path = draw(paths)
name = draw(array_names)
attributes = draw(attrs)
zarr_format = draw(zarr_formats)
if arrays is None:
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
# test that None works too.
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
# compressor = draw(compressors)
Expand Down
11 changes: 6 additions & 5 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import hypothesis.extra.numpy as npst # noqa
import hypothesis.strategies as st # noqa
from hypothesis import given, settings # noqa
from zarr.testing.strategies import arrays, np_arrays, basic_indices # noqa
from zarr.testing.strategies import arrays, numpy_arrays, basic_indices, zarr_formats # noqa


@given(st.data())
def test_roundtrip(data: st.DataObject) -> None:
nparray = data.draw(np_arrays)
zarray = data.draw(arrays(arrays=st.just(nparray)))
@given(data=st.data(), zarr_format=zarr_formats)
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
assert_array_equal(nparray, zarray[:])


Expand All @@ -31,6 +31,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@settings(report_multiple_bugs=False)
@given(data=st.data())
def test_vindex(data: st.DataObject) -> None:
zarray = data.draw(arrays())
Expand Down
15 changes: 15 additions & 0 deletions zarr/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
VERSION_TUPLE = tuple[int | str, ...]
else:
VERSION_TUPLE = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "2.18.2"
__version_tuple__ = version_tuple = (2, 18, 2)

0 comments on commit 1087178

Please sign in to comment.