Skip to content

Commit

Permalink
✨ Validate partitions.
Browse files Browse the repository at this point in the history
  • Loading branch information
fbriol committed Oct 29, 2023
1 parent 51f4f18 commit b7e7c86
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
55 changes: 55 additions & 0 deletions zcollection/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ def _infer_callable(
return tuple(func_result)


def _check_partition(
fs: fsspec.AbstractFileSystem,
partition: str,
) -> tuple[str, bool]:
"""Check if a given partition is a valid Zarr group.
Args:
fs: The file system to use.
partition: The partition to check.
Returns:
A tuple containing the partition and a boolean indicating whether it is
a valid Zarr group.
"""
return partition, storage.check_zarr_group(partition, fs)


class Collection(ReadOnlyCollection):
"""This class manages a collection of files in Zarr format stored in a set
of subdirectories. These subdirectories split the data, by cycles or dates
Expand Down Expand Up @@ -722,3 +739,41 @@ def worker_task(args: Sequence[tuple[str, str]]) -> None:
mode=mode,
filesystem=filesystem,
synchronizer=synchronizer)

def validate_partitions(self,
filters: PartitionFilter | None = None,
fix: bool = False) -> list[str]:
"""Validates partitions in the collection by checking if they exist and
are readable. If `fix` is True, invalid partitions will be removed from
the collection.
Args:
filters: The predicate used to filter the partitions to
validate. By default, all partitions are validated.
fix: Whether to fix invalid partitions by removing them from
the collection.
Returns:
A list of invalid partitions.
"""
partitions = tuple(self.partitions(filters=filters))
if not partitions:
return []
client: dask.distributed.Client = dask_utils.get_client()
futures: list[dask.distributed.Future] = [
client.submit(_check_partition, self.fs, partition)
for partition in partitions
]
invalid_partitions: list[str] = []
for item in dask.distributed.as_completed(futures):
partition, valid = item.result() # type: ignore
if not valid:
warnings.warn(f'Invalid partition: {partition}',
category=RuntimeWarning)
invalid_partitions.append(partition)

if fix and invalid_partitions:
for item in invalid_partitions:
_LOGGER.info('Removing invalid partition: %s', item)
self.fs.rm(item, recursive=True)
return invalid_partitions
32 changes: 32 additions & 0 deletions zcollection/collection/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,3 +1008,35 @@ def new_shape(

with pytest.raises(RuntimeError, match='Try to re-load'):
_ = zds['time'].values


def test_invalid_partitions(
dask_client, # pylint: disable=redefined-outer-name,unused-argument
tmpdir) -> None:
fs = fsspec.filesystem('file')
datasets = list(create_test_dataset())
zds = datasets.pop(0)
zds.concat(datasets, 'num_lines')
base_dir = str(tmpdir / 'test')
zcollection = collection.Collection('time',
zds.metadata(),
partitioning.Date(('time', ), 'D'),
base_dir,
filesystem=fs)
zcollection.insert(zds)
partitions = tuple(zcollection.partitions())
choices = numpy.random.choice(len(partitions), size=2, replace=False)
for idx in choices:
var2 = fs.sep.join((partitions[idx], 'var2', '0.0'))
with fs.open(var2, 'wb') as file:
file.write(b'invalid')
with pytest.raises(ValueError):
_ = zcollection.load(delayed=False)
with pytest.warns(RuntimeWarning, match='Invalid partition'):
invalid_partitions = zcollection.validate_partitions()
assert len(invalid_partitions) == 2
assert sorted(invalid_partitions) == sorted(partitions[ix]
for ix in choices)
with pytest.warns(RuntimeWarning, match='Invalid partition'):
zcollection.validate_partitions(fix=True)
assert zcollection.load() is not None
28 changes: 27 additions & 1 deletion zcollection/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from . import dataset, meta, sync
from .fs_utils import join_path
from .type_hints import ArrayLike
from .type_hints import ArrayLike, NDArray

#: Name of the attribute storing the names of the dimensions of an array.
DIMENSIONS = '_ARRAY_DIMENSIONS'
Expand Down Expand Up @@ -439,3 +439,29 @@ def add_zarr_array(
filters=variable.filters)
write_zattrs(dirname, variable, fs)
zarr.consolidate_metadata(fs.get_mapper(dirname)) # type: ignore[arg-type]


def check_zarr_group(
dirname: str,
fs: fsspec.AbstractFileSystem,
) -> bool:
"""Check if a directory contains a valid Zarr group.
Args:
dirname The name of the directory containing the Zarr group to check.
fs: The file system to use.
Returns:
True if the directory contains a valid Zarr group, False otherwise.
"""
try:
store: zarr.Group = zarr.open_consolidated( # type: ignore
fs.get_mapper(dirname),
mode='r',
)
for _, array in store.arrays():
data: NDArray = array[...] # type: ignore
del data
except (ValueError, TypeError, RuntimeError):
return False
return True

0 comments on commit b7e7c86

Please sign in to comment.