Skip to content

Commit

Permalink
Refactor ArraySubset iterators (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
LDeakin authored Feb 12, 2024
1 parent 8d5a27c commit ca01e66
Show file tree
Hide file tree
Showing 22 changed files with 1,128 additions and 616 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `codec::{Encode,Decode,PartialDecode,PartialDecoder}Options`
- Added new `Array::opt` methods which can use new encode/decode options
- **Breaking** Existing `Array` `_opt` use new encode/decode options insted of `parallel: bool`
- Implement `DoubleEndedIterator` for `{Indices,LinearisedIndices,ContiguousIndices,ContiguousLinearisedIndicesIterator}Iterator`
- Add `ParIndicesIterator` and `ParChunksIterator`

### Changed
- Dependency bumps
Expand All @@ -43,6 +45,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking**: `_opt` variants use new `codec::{Encode,Decode,PartialDecode,PartialDecoder}Options` instead of `parallel: bool`
- variants without prefix/suffix are no longer serial variants but parallel
- TODO: Remove these?
- **Major breaking**: refactor array subset iterators:
- `ArraySubset::iter_` methods no longer have an `iter_` prefix and return structures implementing `IntoIterator` including
- `Indices`, `LinearisedIndices`, `ContiguousIndices`, `ContiguousLinearisedIndices`, `Chunks`
- `Indices` and `Chunks` implement `IntoParIter`
- **Breaking**: array subset iterators are moved into public `array_subset::iterators` and no longer in the `array_subset` namespace

### Removed
- **Breaking**: remove `InvalidArraySubsetError` and `ArrayExtractElementsError`
Expand Down
2 changes: 1 addition & 1 deletion benches/array_subset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn array_subset_indices_iterator(c: &mut Criterion) {
group.throughput(Throughput::Elements(array_subset.num_elements()));
group.bench_function(BenchmarkId::new("size", array_subset_size), |b| {
b.iter(|| {
array_subset.iter_indices().for_each(|indices| {
array_subset.indices().into_iter().for_each(|indices| {
black_box(indices.first().unwrap());
})
});
Expand Down
52 changes: 30 additions & 22 deletions src/array/array_async_readable.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::Arc;

use futures::{stream::FuturesUnordered, StreamExt};
use itertools::Itertools;

use crate::{
array_subset::ArraySubset,
Expand Down Expand Up @@ -301,7 +300,7 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Array<TStorage> {
));
}

let array_subset = self.chunks_subset(chunks)?;
let array_subset = Arc::new(self.chunks_subset(chunks)?);

// Retrieve chunk bytes
let num_chunks = chunks.num_elements();
Expand All @@ -321,16 +320,20 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Array<TStorage> {
std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::<u8>(), size_output)
};
let output_slice = UnsafeCellSlice::new(output_slice);
let indices = chunks.iter_indices().collect_vec();
let mut futures = indices
.iter()
let mut futures = chunks
.indices()
.into_iter()
.map(|chunk_indices| {
self._async_decode_chunk_into_array_subset(
chunk_indices,
&array_subset,
unsafe { output_slice.get() },
options,
)
let array_subset = array_subset.clone();
async move {
self._async_decode_chunk_into_array_subset(
&chunk_indices,
&array_subset,
unsafe { output_slice.get() },
options,
)
.await
}
})
.collect::<FuturesUnordered<_>>();
while let Some(item) = futures.next().await {
Expand Down Expand Up @@ -437,12 +440,14 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Array<TStorage> {
let chunk_subset_in_array_subset =
unsafe { overlap.relative_to_unchecked(array_subset.start()) };
let mut decoded_offset = 0;
for (array_subset_element_index, num_elements) in unsafe {
let contiguous_indices = unsafe {
chunk_subset_in_array_subset
.iter_contiguous_linearised_indices_unchecked(array_subset.shape())
} {
.contiguous_linearised_indices_unchecked(array_subset.shape())
};
let length =
usize::try_from(contiguous_indices.contiguous_elements() * element_size).unwrap();
for (array_subset_element_index, _num_elements) in &contiguous_indices {
let output_offset = usize::try_from(array_subset_element_index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
debug_assert!((output_offset + length) <= output.len());
debug_assert!((decoded_offset + length) <= decoded_bytes.len());
output[output_offset..output_offset + length]
Expand Down Expand Up @@ -535,7 +540,8 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Array<TStorage> {
};
let output_slice = UnsafeCellSlice::new(output_slice);
let mut futures = chunks
.iter_indices()
.indices()
.into_iter()
.map(|chunk_indices| {
async move {
// Get the subset of the array corresponding to the chunk
Expand Down Expand Up @@ -589,16 +595,18 @@ impl<TStorage: ?Sized + AsyncReadableStorageTraits + 'static> Array<TStorage> {
let chunk_subset_in_array_subset =
unsafe { overlap.relative_to_unchecked(array_subset.start()) };
let mut decoded_offset = 0;
for (array_subset_element_index, num_elements) in unsafe {
let contiguous_indices = unsafe {
chunk_subset_in_array_subset
.iter_contiguous_linearised_indices_unchecked(
array_subset.shape(),
)
} {
.contiguous_linearised_indices_unchecked(array_subset.shape())
};
let length = usize::try_from(
contiguous_indices.contiguous_elements() * element_size,
)
.unwrap();
for (array_subset_element_index, _num_elements) in &contiguous_indices {
let output_offset =
usize::try_from(array_subset_element_index * element_size)
.unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
debug_assert!((output_offset + length) <= size_output);
debug_assert!((decoded_offset + length) <= decoded_bytes.len());
unsafe {
Expand Down
13 changes: 8 additions & 5 deletions src/array/array_async_readable_writable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ impl<TStorage: ?Sized + AsyncReadableWritableStorageTraits + 'static> Array<TSto
}
} else {
let chunks_to_update = chunks
.iter_indices()
.indices()
.into_iter()
.map(|chunk_indices| {
let chunk_subset_in_array = unsafe {
self.chunk_grid()
Expand Down Expand Up @@ -310,11 +311,13 @@ impl<TStorage: ?Sized + AsyncReadableWritableStorageTraits + 'static> Array<TSto
// Update the intersecting subset of the chunk
let element_size = self.data_type().size() as u64;
let mut offset = 0;
for (chunk_element_index, num_elements) in unsafe {
chunk_subset.iter_contiguous_linearised_indices_unchecked(&chunk_shape)
} {
let contiguous_indices =
unsafe { chunk_subset.contiguous_linearised_indices_unchecked(&chunk_shape) };
let length =
usize::try_from(contiguous_indices.contiguous_elements() * element_size)
.unwrap();
for (chunk_element_index, _num_elements) in &contiguous_indices {
let chunk_offset = usize::try_from(chunk_element_index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
debug_assert!(chunk_offset + length <= chunk_bytes.len());
debug_assert!(offset + length <= chunk_subset_bytes.len());
chunk_bytes[chunk_offset..chunk_offset + length]
Expand Down
33 changes: 20 additions & 13 deletions src/array/array_async_writable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ impl<TStorage: ?Sized + AsyncWritableStorageTraits + 'static> Array<TStorage> {

let element_size = self.data_type().size();

let chunks_to_update = chunks.iter_indices().collect::<Vec<_>>();
let mut futures = chunks_to_update
.iter()
let mut futures = chunks
.indices()
.into_iter()
.map(|chunk_indices| {
let chunk_subset_in_array = unsafe {
self.chunk_grid()
.subset_unchecked(chunk_indices, self.shape())
.subset_unchecked(&chunk_indices, self.shape())
.unwrap()
};
let overlap = unsafe { array_subset.overlap_unchecked(&chunk_subset_in_array) };
Expand All @@ -226,7 +226,10 @@ impl<TStorage: ?Sized + AsyncWritableStorageTraits + 'static> Array<TStorage> {
chunk_subset_in_array_subset.num_elements()
);

self.async_store_chunk_opt(chunk_indices, chunk_bytes, options)
async move {
self.async_store_chunk_opt(&chunk_indices, chunk_bytes, options)
.await
}
})
.collect::<FuturesUnordered<_>>();
while let Some(item) = futures.next().await {
Expand Down Expand Up @@ -335,16 +338,20 @@ impl<TStorage: ?Sized + AsyncWritableStorageTraits + 'static> Array<TStorage> {
let storage_transformer = self
.storage_transformers()
.create_async_writable_transformer(storage_handle);
let chunks = chunks.iter_indices().collect::<Vec<_>>();
let mut futures = chunks
.iter()
.indices()
.into_iter()
.map(|chunk_indices| {
crate::storage::async_erase_chunk(
&*storage_transformer,
self.path(),
chunk_indices,
self.chunk_key_encoding(),
)
let storage_transformer = storage_transformer.clone();
async move {
crate::storage::async_erase_chunk(
&*storage_transformer,
self.path(),
&chunk_indices,
self.chunk_key_encoding(),
)
.await
}
})
.collect::<FuturesUnordered<_>>();
while let Some(item) = futures.next().await {
Expand Down
41 changes: 13 additions & 28 deletions src/array/array_sync_readable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayToBytesCodecTraits, DecodeOptions,
PartialDecoderOptions, StoragePartialDecoder,
},
transmute_from_bytes_vec, unravel_index,
transmute_from_bytes_vec,
unsafe_cell_slice::UnsafeCellSlice,
validate_element_size, Array, ArrayCreateError, ArrayError, ArrayMetadata,
};
Expand Down Expand Up @@ -307,19 +307,9 @@ impl<TStorage: ?Sized + ReadableStorageTraits + 'static> Array<TStorage> {
};
if options.is_parallel() {
let output = UnsafeCellSlice::new(output_slice);
(0..chunks.shape().iter().product())
chunks
.indices()
.into_par_iter()
.map(|chunk_index| {
std::iter::zip(
unravel_index(chunk_index, chunks.shape()),
chunks.start(),
)
.map(|(chunk_indices, chunks_start)| chunk_indices + chunks_start)
.collect::<Vec<_>>()
})
// chunks
// .iter_indices()
// .par_bridge()
.try_for_each(|chunk_indices| {
self._decode_chunk_into_array_subset(
&chunk_indices,
Expand All @@ -329,7 +319,7 @@ impl<TStorage: ?Sized + ReadableStorageTraits + 'static> Array<TStorage> {
)
})?;
} else {
for chunk_indices in chunks.iter_indices() {
for chunk_indices in &chunks.indices() {
self._decode_chunk_into_array_subset(
&chunk_indices,
&array_subset,
Expand Down Expand Up @@ -449,12 +439,14 @@ impl<TStorage: ?Sized + ReadableStorageTraits + 'static> Array<TStorage> {
let chunk_subset_in_array_subset =
unsafe { overlap.relative_to_unchecked(array_subset.start()) };
let mut decoded_offset = 0;
for (array_subset_element_index, num_elements) in unsafe {
let contiguous_indices = unsafe {
chunk_subset_in_array_subset
.iter_contiguous_linearised_indices_unchecked(array_subset.shape())
} {
.contiguous_linearised_indices_unchecked(array_subset.shape())
};
let length =
usize::try_from(contiguous_indices.contiguous_elements() * element_size).unwrap();
for (array_subset_element_index, _num_elements) in &contiguous_indices {
let output_offset = usize::try_from(array_subset_element_index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
debug_assert!((output_offset + length) <= output.len());
debug_assert!((decoded_offset + length) <= decoded_bytes.len());
output[output_offset..output_offset + length]
Expand Down Expand Up @@ -547,16 +539,9 @@ impl<TStorage: ?Sized + ReadableStorageTraits + 'static> Array<TStorage> {
// FIXME: Constrain concurrency here based on parallelism internally vs externally

let output = UnsafeCellSlice::new(output_slice);
(0..chunks.shape().iter().product())
chunks
.indices()
.into_par_iter()
.map(|chunk_index| {
std::iter::zip(
unravel_index(chunk_index, chunks.shape()),
chunks.start(),
)
.map(|(chunk_indices, chunks_start)| chunk_indices + chunks_start)
.collect::<Vec<_>>()
})
.try_for_each(|chunk_indices| {
self._decode_chunk_into_array_subset(
&chunk_indices,
Expand All @@ -566,7 +551,7 @@ impl<TStorage: ?Sized + ReadableStorageTraits + 'static> Array<TStorage> {
)
})?;
} else {
for chunk_indices in chunks.iter_indices() {
for chunk_indices in &chunks.indices() {
self._decode_chunk_into_array_subset(
&chunk_indices,
array_subset,
Expand Down
35 changes: 13 additions & 22 deletions src/array/array_sync_readable_writable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{

use super::{
codec::{DecodeOptions, EncodeOptions},
unravel_index, Array, ArrayError,
Array, ArrayError,
};

impl<TStorage: ?Sized + ReadableWritableStorageTraits + 'static> Array<TStorage> {
Expand Down Expand Up @@ -122,23 +122,13 @@ impl<TStorage: ?Sized + ReadableWritableStorageTraits + 'static> Array<TStorage>

Ok(())
};
if encode_options.is_parallel() {
(0..chunks.shape().iter().product())
.into_par_iter()
.map(|chunk_index| {
std::iter::zip(unravel_index(chunk_index, chunks.shape()), chunks.start())
.map(|(chunk_indices, chunks_start)| chunk_indices + chunks_start)
.collect::<Vec<_>>()
})
// chunks
// .iter_indices()
// .par_bridge()
.try_for_each(store_chunk)?;
} else {
for chunk_indices in chunks.iter_indices() {
store_chunk(chunk_indices)?;
}
}
let indices = chunks.indices();
rayon_iter_concurrent_limit::iter_concurrent_limit!(
encode_options.concurrent_limit().get(),
indices.into_par_iter(),
try_for_each,
store_chunk
)?;
}
Ok(())
}
Expand Down Expand Up @@ -301,11 +291,12 @@ impl<TStorage: ?Sized + ReadableWritableStorageTraits + 'static> Array<TStorage>
// Update the intersecting subset of the chunk
let element_size = self.data_type().size() as u64;
let mut offset = 0;
for (chunk_element_index, num_elements) in
unsafe { chunk_subset.iter_contiguous_linearised_indices_unchecked(&chunk_shape) }
{
let contiguous_iterator =
unsafe { chunk_subset.contiguous_linearised_indices_unchecked(&chunk_shape) };
let length =
usize::try_from(contiguous_iterator.contiguous_elements() * element_size).unwrap();
for (chunk_element_index, _num_elements) in &contiguous_iterator {
let chunk_offset = usize::try_from(chunk_element_index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
debug_assert!(chunk_offset + length <= chunk_bytes.len());
debug_assert!(offset + length <= chunk_subset_bytes.len());
chunk_bytes[chunk_offset..chunk_offset + length]
Expand Down
Loading

0 comments on commit ca01e66

Please sign in to comment.