From b7e7c862c7808994fefccb108192591ee660ea38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20BRIOL?= Date: Sun, 29 Oct 2023 07:55:34 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Validate=20partitions.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zcollection/collection/__init__.py | 55 +++++++++++++++++++ .../collection/tests/test_collection.py | 32 +++++++++++ zcollection/storage.py | 28 +++++++++- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/zcollection/collection/__init__.py b/zcollection/collection/__init__.py index e3fc5f3..12f3327 100644 --- a/zcollection/collection/__init__.py +++ b/zcollection/collection/__init__.py @@ -105,6 +105,23 @@ def _infer_callable( return tuple(func_result) +def _check_partition( + fs: fsspec.AbstractFileSystem, + partition: str, +) -> tuple[str, bool]: + """Check if a given partition is a valid Zarr group. + + Args: + fs: The file system to use. + partition: The partition to check. + + Returns: + A tuple containing the partition and a boolean indicating whether it is + a valid Zarr group. + """ + return partition, storage.check_zarr_group(partition, fs) + + class Collection(ReadOnlyCollection): """This class manages a collection of files in Zarr format stored in a set of subdirectories. These subdirectories split the data, by cycles or dates @@ -722,3 +739,41 @@ def worker_task(args: Sequence[tuple[str, str]]) -> None: mode=mode, filesystem=filesystem, synchronizer=synchronizer) + + def validate_partitions(self, + filters: PartitionFilter | None = None, + fix: bool = False) -> list[str]: + """Validates partitions in the collection by checking if they exist and + are readable. If `fix` is True, invalid partitions will be removed from + the collection. + + Args: + filters: The predicate used to filter the partitions to + validate. By default, all partitions are validated. + fix: Whether to fix invalid partitions by removing them from + the collection. + + Returns: + A list of invalid partitions. + """ + partitions = tuple(self.partitions(filters=filters)) + if not partitions: + return [] + client: dask.distributed.Client = dask_utils.get_client() + futures: list[dask.distributed.Future] = [ + client.submit(_check_partition, self.fs, partition) + for partition in partitions + ] + invalid_partitions: list[str] = [] + for item in dask.distributed.as_completed(futures): + partition, valid = item.result() # type: ignore + if not valid: + warnings.warn(f'Invalid partition: {partition}', + category=RuntimeWarning) + invalid_partitions.append(partition) + + if fix and invalid_partitions: + for item in invalid_partitions: + _LOGGER.info('Removing invalid partition: %s', item) + self.fs.rm(item, recursive=True) + return invalid_partitions diff --git a/zcollection/collection/tests/test_collection.py b/zcollection/collection/tests/test_collection.py index cf5a6b1..d1c6401 100644 --- a/zcollection/collection/tests/test_collection.py +++ b/zcollection/collection/tests/test_collection.py @@ -1008,3 +1008,35 @@ def new_shape( with pytest.raises(RuntimeError, match='Try to re-load'): _ = zds['time'].values + + +def test_invalid_partitions( + dask_client, # pylint: disable=redefined-outer-name,unused-argument + tmpdir) -> None: + fs = fsspec.filesystem('file') + datasets = list(create_test_dataset()) + zds = datasets.pop(0) + zds.concat(datasets, 'num_lines') + base_dir = str(tmpdir / 'test') + zcollection = collection.Collection('time', + zds.metadata(), + partitioning.Date(('time', ), 'D'), + base_dir, + filesystem=fs) + zcollection.insert(zds) + partitions = tuple(zcollection.partitions()) + choices = numpy.random.choice(len(partitions), size=2, replace=False) + for idx in choices: + var2 = fs.sep.join((partitions[idx], 'var2', '0.0')) + with fs.open(var2, 'wb') as file: + file.write(b'invalid') + with pytest.raises(ValueError): + _ = zcollection.load(delayed=False) + with pytest.warns(RuntimeWarning, match='Invalid partition'): + invalid_partitions = zcollection.validate_partitions() + assert len(invalid_partitions) == 2 + assert sorted(invalid_partitions) == sorted(partitions[ix] + for ix in choices) + with pytest.warns(RuntimeWarning, match='Invalid partition'): + zcollection.validate_partitions(fix=True) + assert zcollection.load() is not None diff --git a/zcollection/storage.py b/zcollection/storage.py index 8f73e2e..d93fac0 100644 --- a/zcollection/storage.py +++ b/zcollection/storage.py @@ -26,7 +26,7 @@ from . import dataset, meta, sync from .fs_utils import join_path -from .type_hints import ArrayLike +from .type_hints import ArrayLike, NDArray #: Name of the attribute storing the names of the dimensions of an array. DIMENSIONS = '_ARRAY_DIMENSIONS' @@ -439,3 +439,29 @@ def add_zarr_array( filters=variable.filters) write_zattrs(dirname, variable, fs) zarr.consolidate_metadata(fs.get_mapper(dirname)) # type: ignore[arg-type] + + +def check_zarr_group( + dirname: str, + fs: fsspec.AbstractFileSystem, +) -> bool: + """Check if a directory contains a valid Zarr group. + + Args: + dirname The name of the directory containing the Zarr group to check. + fs: The file system to use. + + Returns: + True if the directory contains a valid Zarr group, False otherwise. + """ + try: + store: zarr.Group = zarr.open_consolidated( # type: ignore + fs.get_mapper(dirname), + mode='r', + ) + for _, array in store.arrays(): + data: NDArray = array[...] # type: ignore + del data + except (ValueError, TypeError, RuntimeError): + return False + return True