Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allowing to use zcollection without any dask cluster. #16

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 179 additions & 71 deletions zcollection/collection/__init__.py

Large diffs are not rendered by default.

85 changes: 61 additions & 24 deletions zcollection/collection/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def load(
filters: PartitionFilter = None,
indexer: Indexer | None = None,
selected_variables: Iterable[str] | None = None,
distributed: bool = True,
) -> dataset.Dataset | None:
"""Load the selected partitions.

Expand All @@ -564,6 +565,7 @@ def load(
indexer: The indexer to apply.
selected_variables: A list of variables to retain from the
collection. If None, all variables are kept.
distributed: Whether to use dask or not. Default To True.

Returns:
The dataset containing the selected partitions, or None if no
Expand All @@ -582,22 +584,42 @@ def load(
... filters=lambda keys: keys["year"] == 2019 and
... keys["month"] == 3 and keys["day"] % 2 == 0)
"""
client: dask.distributed.Client = dask_utils.get_client()
# Delayed has to be True of dask is disabled
if not distributed:
delayed = False

arrays: list[dataset.Dataset]
client: dask.distributed.Client

if indexer is None:
# No indexer, so the dataset is loaded directly for each
# selected partition.
selected_partitions = tuple(self.partitions(filters=filters))
if len(selected_partitions) == 0:
return None

# No indexer, so the dataset is loaded directly for each
# selected partition.
bag: dask.bag.core.Bag = dask.bag.core.from_sequence(
self.partitions(filters=filters),
npartitions=dask_utils.dask_workers(client, cores_only=True))
arrays = bag.map(storage.open_zarr_group,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables).compute()
partitions = self.partitions(filters=filters)

if distributed:
client = dask_utils.get_client()
bag: dask.bag.core.Bag = dask.bag.core.from_sequence(
partitions,
npartitions=dask_utils.dask_workers(client,
cores_only=True))
arrays = bag.map(
storage.open_zarr_group,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables).compute()
else:
arrays = [
storage.open_zarr_group(
dirname=partition,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables)
for partition in partitions
]
else:
# We're going to reuse the indexer variable, so ensure it is
# an iterable not a generator.
Expand All @@ -617,21 +639,36 @@ def load(
if len(args) == 0:
return None

bag = dask.bag.core.from_sequence(
args,
npartitions=dask_utils.dask_workers(client, cores_only=True))

# Finally, load the selected partitions and apply the indexer.
arrays = list(
itertools.chain.from_iterable(
bag.map(
_load_and_apply_indexer,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables,
).compute()))
if distributed:
client = dask_utils.get_client()
bag = dask.bag.core.from_sequence(
args,
npartitions=dask_utils.dask_workers(client,
cores_only=True))

arrays = list(
itertools.chain.from_iterable(
bag.map(
_load_and_apply_indexer,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables,
).compute()))
else:
arrays = list(
itertools.chain.from_iterable([
_load_and_apply_indexer(
args=a,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables)
for a in args
]))

array: dataset.Dataset = arrays.pop(0)
if arrays:
Expand Down
13 changes: 10 additions & 3 deletions zcollection/collection/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _insert(
fs: fsspec.AbstractFileSystem,
merge_callable: merging.MergeCallable | None,
partitioning_properties: PartitioningProperties,
distributed: bool = True,
**kwargs,
) -> None:
"""Insert or update a partition in the collection.
Expand All @@ -405,6 +406,7 @@ def _insert(
fs: The file system that the partition is stored on.
merge_callable: The merge callable.
partitioning_properties: The partitioning properties.
distributed: Whether to use dask or not. Default To True.
**kwargs: Additional keyword arguments to pass to the merge callable.
"""
partition: tuple[str, ...]
Expand All @@ -423,7 +425,8 @@ def _insert(
axis,
fs,
partitioning_properties.dim,
delayed=zds.delayed,
delayed=zds.delayed if distributed else False,
distributed=distributed,
merge_callable=merge_callable,
**kwargs)
return
Expand All @@ -434,7 +437,11 @@ def _insert(
zarr.storage.init_group(store=fs.get_mapper(dirname))

# The synchronization is done by the caller.
write_zarr_group(zds.isel(indexer), dirname, fs, sync.NoSync())
write_zarr_group(zds.isel(indexer),
dirname,
fs,
sync.NoSync(),
distributed=distributed)
except: # noqa: E722
# If the construction of the new dataset fails, the created
# partition is deleted, to guarantee the integrity of the
Expand All @@ -459,7 +466,7 @@ def _load_and_apply_indexer(
fs: The file system that the partition is stored on.
partition_handler: The partitioning handler.
partition_properties: The partitioning properties.
selected_variable: The selected variables to load.
selected_variables: The selected variables to load.

Returns:
The list of loaded datasets.
Expand Down
Loading
Loading