Skip to content

Commit

Permalink
Stop using rayon par_bridge() in various places
Browse files Browse the repository at this point in the history
Testing potential performance improvements by creating parallel iterators directly.
  • Loading branch information
LDeakin committed Oct 27, 2023
1 parent f68d6f1 commit cd89e73
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 52 deletions.
28 changes: 21 additions & 7 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub use self::{
};

use parking_lot::Mutex;
use rayon::prelude::{ParallelBridge, ParallelIterator};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use safe_transmute::TriviallyTransmutable;
use serde::Serialize;

Expand Down Expand Up @@ -707,9 +707,16 @@ impl<TStorage: ?Sized + ReadableStorageTraits> Array<TStorage> {
let mut output: Vec<u8> = vec![0; size_output];
if self.parallel_chunks {
let output = UnsafeCellSlice::new(output.as_mut_slice());
chunks
.iter_indices()
.par_bridge()
(0..chunks.shape().iter().product())
.into_par_iter()
.map(|chunk_index| {
std::iter::zip(unravel_index(chunk_index, chunks.shape()), chunks.start())
.map(|(chunk_indices, chunks_start)| chunk_indices + chunks_start)
.collect::<Vec<_>>()
})
// chunks
// .iter_indices()
// .par_bridge()
.map(|chunk_indices| decode_chunk(chunk_indices, unsafe { output.get() }))
.collect::<Result<Vec<_>, ArrayError>>()?;
} else {
Expand Down Expand Up @@ -1146,9 +1153,16 @@ impl<TStorage: ?Sized + ReadableStorageTraits + WritableStorageTraits> Array<TSt
};

if self.parallel_chunks {
chunks
.iter_indices()
.par_bridge()
(0..chunks.shape().iter().product())
.into_par_iter()
.map(|chunk_index| {
std::iter::zip(unravel_index(chunk_index, chunks.shape()), chunks.start())
.map(|(chunk_indices, chunks_start)| chunk_indices + chunks_start)
.collect::<Vec<_>>()
})
// chunks
// .iter_indices()
// .par_bridge()
.map(store_chunk)
.collect::<Result<Vec<_>, _>>()?;
} else {
Expand Down
122 changes: 78 additions & 44 deletions src/array/codec/array_to_bytes/sharding/sharding_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayToBytesCodecTraits,
BytesPartialDecoderTraits, Codec, CodecChain, CodecError, CodecPlugin, CodecTraits,
},
ArrayRepresentation, BytesRepresentation, UnsafeCellSlice,
unravel_index, ArrayRepresentation, BytesRepresentation, UnsafeCellSlice,
},
array_subset::ArraySubset,
metadata::Metadata,
Expand Down Expand Up @@ -194,13 +194,27 @@ impl ArrayCodecTraits for ShardingCodec {
.map_err(|e| CodecError::Other(e.to_string()))?;

// Iterate over chunk indices
let shard_inner_chunks = unsafe {
ArraySubset::new_with_shape(shard_representation.shape().to_vec())
.iter_chunks_unchecked(&self.chunk_shape)
}
.enumerate()
.par_bridge()
.map(|(chunk_index, (_chunk_indices, chunk_subset))| {
let shard_inner_chunks =
// unsafe {
// ArraySubset::new_with_shape(shard_representation.shape().to_vec())
// .iter_chunks_unchecked(&self.chunk_shape)
// }
// .enumerate()
// .par_bridge()
// .map(|(chunk_index, (_chunk_indices, chunk_subset))| {
(0..chunks_per_shard.iter().product::<u64>().try_into().unwrap())
.into_par_iter()
.map(|chunk_index| {
let chunk_indices = unravel_index(chunk_index as u64, &chunks_per_shard);
let chunk_start = std::iter::zip(&chunk_indices, &self.chunk_shape)
.map(|(i, c)| i * c)
.collect();
let shape = self.chunk_shape.clone();
let chunk_subset =
unsafe { ArraySubset::new_with_start_shape_unchecked(chunk_start, shape) };
(chunk_index, chunk_subset)
})
.map(|(chunk_index, chunk_subset)| {
let bytes = unsafe {
chunk_subset.extract_bytes_unchecked(
&decoded_value,
Expand Down Expand Up @@ -405,50 +419,70 @@ impl ShardingCodec {
.repeat(shard_representation.num_elements_usize());
let shard_slice = UnsafeCellSlice::new(shard.as_mut_slice());

// Decode chunks
let chunk_repr = unsafe {
let chunk_representation = unsafe {
ArrayRepresentation::new_unchecked(
self.chunk_shape.clone(),
shard_representation.data_type().clone(),
shard_representation.fill_value().clone(),
)
};
unsafe {
ArraySubset::new_with_shape(shard_representation.shape().to_vec())
.iter_chunks_unchecked(&self.chunk_shape)
}
.enumerate()
.par_bridge()
.map(|(chunk_index, (_chunk_indices, chunk_subset))| {
let shard_slice = unsafe { shard_slice.get() };

// Read the offset/size
let offset = shard_index[chunk_index * 2];
let size = shard_index[chunk_index * 2 + 1];
if offset != u64::MAX || size != u64::MAX {
let offset: usize = offset.try_into().unwrap(); // safe
let size: usize = size.try_into().unwrap(); // safe
let encoded_chunk_slice = encoded_shard[offset..offset + size].to_vec();
// NOTE: Intentionally using single threaded decode, since parallelisation is in the loop
let decoded_chunk = self.inner_codecs.decode(encoded_chunk_slice, &chunk_repr)?;
let chunks_per_shard =
calculate_chunks_per_shard(shard_representation.shape(), chunk_representation.shape())
.map_err(|e| CodecError::Other(e.to_string()))?;

// Copy to subset of shard
let mut data_idx = 0;
let element_size = chunk_repr.element_size() as u64;
for (index, num_elements) in unsafe {
chunk_subset
.iter_contiguous_linearised_indices_unchecked(shard_representation.shape())
} {
let shard_offset = usize::try_from(index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
shard_slice[shard_offset..shard_offset + length]
.copy_from_slice(&decoded_chunk[data_idx..data_idx + length]);
data_idx += length;
// Decode chunks
(0..chunks_per_shard.iter().product::<u64>().try_into().unwrap())
.into_par_iter()
.map(|chunk_index| {
let chunk_indices = unravel_index(chunk_index as u64, &chunks_per_shard);
let chunk_start = std::iter::zip(&chunk_indices, &self.chunk_shape)
.map(|(i, c)| i * c)
.collect();
let shape = self.chunk_shape.clone();
let chunk_subset =
unsafe { ArraySubset::new_with_start_shape_unchecked(chunk_start, shape) };
(chunk_index, chunk_subset)
})
.map(|(chunk_index, chunk_subset)| {
// unsafe {
// ArraySubset::new_with_shape(shard_representation.shape().to_vec())
// .iter_chunks_unchecked(&self.chunk_shape)
// }
// .enumerate()
// .par_bridge()
// .map(|(chunk_index, (_chunk_indices, chunk_subset))| {
let shard_slice = unsafe { shard_slice.get() };

// Read the offset/size
let offset = shard_index[chunk_index * 2];
let size = shard_index[chunk_index * 2 + 1];
if offset != u64::MAX || size != u64::MAX {
let offset: usize = offset.try_into().unwrap(); // safe
let size: usize = size.try_into().unwrap(); // safe
let encoded_chunk_slice = encoded_shard[offset..offset + size].to_vec();
// NOTE: Intentionally using single threaded decode, since parallelisation is in the loop
let decoded_chunk = self
.inner_codecs
.decode(encoded_chunk_slice, &chunk_representation)?;

// Copy to subset of shard
let mut data_idx = 0;
let element_size = chunk_representation.element_size() as u64;
for (index, num_elements) in unsafe {
chunk_subset.iter_contiguous_linearised_indices_unchecked(
shard_representation.shape(),
)
} {
let shard_offset = usize::try_from(index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
shard_slice[shard_offset..shard_offset + length]
.copy_from_slice(&decoded_chunk[data_idx..data_idx + length]);
data_idx += length;
}
}
}
Ok::<_, CodecError>(())
})
.collect::<Result<Vec<_>, CodecError>>()?;
Ok::<_, CodecError>(())
})
.collect::<Result<Vec<_>, CodecError>>()?;

Ok(shard)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl ArrayPartialDecoderTraits for ShardingPartialDecoder<'_> {
let mut out_array_subset = vec![0; array_subset_size];
let out_array_subset_slice = UnsafeCellSlice::new(out_array_subset.as_mut_slice());

// Decode those chunks if required and put in chunk cache
// Decode those chunks if required
unsafe { array_subset.iter_chunks_unchecked(chunk_representation.shape()) }
.par_bridge()
.map(|(chunk_indices, chunk_subset)| {
Expand Down

0 comments on commit cd89e73

Please sign in to comment.