-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add experimental bitround codec
- Loading branch information
Showing
9 changed files
with
441 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
//! Array to array codecs. | ||
|
||
#[cfg(feature = "bitround")] | ||
pub mod bitround; | ||
#[cfg(feature = "transpose")] | ||
pub mod transpose; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
//! The bitround array to array codec. | ||
//! | ||
//! Rounds the mantissa of floating point data types to the specified number of bits. | ||
//! | ||
//! This codec requires the `bitround` feature, which is disabled by default. | ||
//! | ||
//! The current implementation does not write its metadata to the array metadata, so the array can be imported by tools which do not presently support this codec. | ||
//! This functionality will be changed when the bitround codec is in the zarr specification and supported by multiple implementations. | ||
//! | ||
//! See [`BitroundCodecConfigurationV1`] for example `JSON` metadata. | ||
//! | ||
|
||
mod bitround_codec; | ||
mod bitround_configuration; | ||
mod bitround_partial_decoder; | ||
|
||
pub use bitround_codec::BitroundCodec; | ||
pub use bitround_configuration::{BitroundCodecConfiguration, BitroundCodecConfigurationV1}; | ||
|
||
use crate::{ | ||
array::{ | ||
codec::{Codec, CodecError, CodecPlugin}, | ||
DataType, | ||
}, | ||
metadata::Metadata, | ||
plugin::PluginCreateError, | ||
}; | ||
|
||
const IDENTIFIER: &str = "bitround"; | ||
|
||
// Register the codec. | ||
inventory::submit! { | ||
CodecPlugin::new(IDENTIFIER, is_name_bitround, create_codec_bitround) | ||
} | ||
|
||
fn is_name_bitround(name: &str) -> bool { | ||
name.eq(IDENTIFIER) | ||
} | ||
|
||
fn create_codec_bitround(metadata: &Metadata) -> Result<Codec, PluginCreateError> { | ||
let configuration: BitroundCodecConfiguration = metadata.to_configuration()?; | ||
let codec = Box::new(BitroundCodec::new_with_configuration(&configuration)); | ||
Ok(Codec::ArrayToArray(codec)) | ||
} | ||
|
||
fn round_bits16(mut input: u16, keepbits: u32) -> u16 { | ||
let maxbits = 10; | ||
if keepbits >= maxbits { | ||
input | ||
} else { | ||
let maskbits = maxbits - keepbits; | ||
let all_set = u16::MAX; | ||
let mask = (all_set >> maskbits) << maskbits; | ||
let half_quantum1 = (1 << (maskbits - 1)) - 1; | ||
input += ((input >> maskbits) & 1) + half_quantum1; | ||
input &= mask; | ||
input | ||
} | ||
} | ||
|
||
fn round_bits32(mut input: u32, keepbits: u32) -> u32 { | ||
let maxbits = 23; | ||
if keepbits >= maxbits { | ||
input | ||
} else { | ||
let maskbits = maxbits - keepbits; | ||
let all_set = u32::MAX; | ||
let mask = (all_set >> maskbits) << maskbits; | ||
let half_quantum1 = (1 << (maskbits - 1)) - 1; | ||
input += ((input >> maskbits) & 1) + half_quantum1; | ||
input &= mask; | ||
input | ||
} | ||
} | ||
|
||
fn round_bits64(mut input: u64, keepbits: u32) -> u64 { | ||
let maxbits = 52; | ||
if keepbits >= maxbits { | ||
input | ||
} else { | ||
let maskbits = maxbits - keepbits; | ||
let all_set = u64::MAX; | ||
let mask = (all_set >> maskbits) << maskbits; | ||
let half_quantum1 = (1 << (maskbits - 1)) - 1; | ||
input += ((input >> maskbits) & 1) + half_quantum1; | ||
input &= mask; | ||
input | ||
} | ||
} | ||
|
||
fn round_bytes(bytes: &mut [u8], data_type: &DataType, keepbits: u32) -> Result<(), CodecError> { | ||
match data_type { | ||
DataType::Float16 | DataType::BFloat16 => { | ||
let round = |chunk: &mut [u8]| { | ||
let element = u16::from_ne_bytes(chunk.try_into().unwrap()); | ||
let element = u16::to_ne_bytes(round_bits16(element, keepbits)); | ||
chunk.copy_from_slice(&element); | ||
}; | ||
bytes.chunks_exact_mut(2).for_each(round); | ||
Ok(()) | ||
} | ||
DataType::Float32 | DataType::Complex64 => { | ||
let round = |chunk: &mut [u8]| { | ||
let element = u32::from_ne_bytes(chunk.try_into().unwrap()); | ||
let element = u32::to_ne_bytes(round_bits32(element, keepbits)); | ||
chunk.copy_from_slice(&element); | ||
}; | ||
bytes.chunks_exact_mut(4).for_each(round); | ||
Ok(()) | ||
} | ||
DataType::Float64 | DataType::Complex128 => { | ||
let round = |chunk: &mut [u8]| { | ||
let element = u64::from_ne_bytes(chunk.try_into().unwrap()); | ||
let element = u64::to_ne_bytes(round_bits64(element, keepbits)); | ||
chunk.copy_from_slice(&element); | ||
}; | ||
bytes.chunks_exact_mut(8).for_each(round); | ||
Ok(()) | ||
} | ||
_ => Err(CodecError::UnsupportedDataType( | ||
data_type.clone(), | ||
IDENTIFIER.to_string(), | ||
)), | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use array_representation::ArrayRepresentation; | ||
use itertools::Itertools; | ||
|
||
use crate::{ | ||
array::{ | ||
array_representation, | ||
codec::{ | ||
ArrayCodecTraits, ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, BytesCodec, | ||
}, | ||
DataType, | ||
}, | ||
array_subset::ArraySubset, | ||
}; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn codec_bitround_float() { | ||
// 1 sign bit, 8 exponent, 3 mantissa | ||
const JSON: &'static str = r#"{ "keepbits": 3 }"#; | ||
let array_representation = | ||
ArrayRepresentation::new(vec![4], DataType::Float32, 0.0f32.into()).unwrap(); | ||
let elements: Vec<f32> = vec![ | ||
// | | ||
0.0, | ||
// 1.23456789 -> 001111111001|11100000011001010010 | ||
// 1.25 -> 001111111010 | ||
1.23456789, | ||
// -8.3587192 -> 110000010000|01011011110101010000 | ||
// -8.0 -> 110000010000 | ||
-8.3587192834, | ||
// 98765.43210-> 010001111100|00001110011010110111 | ||
// 98304.0 -> 010001111100 | ||
98765.43210, | ||
]; | ||
let bytes = safe_transmute::transmute_to_bytes(&elements).to_vec(); | ||
|
||
let codec_configuration: BitroundCodecConfiguration = serde_json::from_str(JSON).unwrap(); | ||
let codec = BitroundCodec::new_with_configuration(&codec_configuration); | ||
|
||
let encoded = codec.encode(bytes.clone(), &array_representation).unwrap(); | ||
let decoded = codec | ||
.decode(encoded.clone(), &array_representation) | ||
.unwrap(); | ||
let decoded_elements = safe_transmute::transmute_many_permissive::<f32>(&decoded) | ||
.unwrap() | ||
.to_vec(); | ||
assert_eq!(decoded_elements, &[0.0f32, 1.25f32, -8.0f32, 98304.0f32]); | ||
} | ||
|
||
#[test] | ||
fn codec_bitround_partial_decode() { | ||
const JSON: &'static str = r#"{ "keepbits": 2 }"#; | ||
let codec_configuration: BitroundCodecConfiguration = serde_json::from_str(JSON).unwrap(); | ||
let codec = BitroundCodec::new_with_configuration(&codec_configuration); | ||
|
||
let elements: Vec<f32> = (0..32).map(|i| i as f32).collect(); | ||
let bytes = safe_transmute::transmute_to_bytes(&elements).to_vec(); | ||
let array_representation = ArrayRepresentation::new( | ||
vec![elements.len().try_into().unwrap()], | ||
DataType::Float32, | ||
0.0f32.into(), | ||
) | ||
.unwrap(); | ||
|
||
let encoded = codec.encode(bytes.clone(), &array_representation).unwrap(); | ||
let decoded_regions = [ | ||
ArraySubset::new_with_start_shape(vec![3], vec![2]).unwrap(), | ||
ArraySubset::new_with_start_shape(vec![17], vec![4]).unwrap(), | ||
]; | ||
let input_handle = Box::new(std::io::Cursor::new(encoded)); | ||
let bytes_codec = BytesCodec::default(); | ||
let input_handle = bytes_codec.partial_decoder(input_handle); | ||
let partial_decoder = codec.partial_decoder(input_handle); | ||
let decoded_partial_chunk = partial_decoder | ||
.partial_decode(&array_representation, &decoded_regions) | ||
.unwrap(); | ||
let decoded_partial_chunk = decoded_partial_chunk | ||
.iter() | ||
.map(|bytes| { | ||
safe_transmute::transmute_many_permissive::<f32>(&bytes) | ||
.unwrap() | ||
.to_vec() | ||
}) | ||
.collect_vec(); | ||
let answer: &[Vec<f32>] = &[vec![3.0, 4.0], vec![16.0, 16.0, 20.0, 20.0]]; | ||
assert_eq!(answer, decoded_partial_chunk); | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
src/array/codec/array_to_array/bitround/bitround_codec.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
use crate::{ | ||
array::{ | ||
codec::{ | ||
ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayToArrayCodecTraits, CodecError, | ||
CodecTraits, | ||
}, | ||
ArrayRepresentation, DataType, | ||
}, | ||
metadata::Metadata, | ||
}; | ||
|
||
use super::{bitround_partial_decoder, round_bytes, BitroundCodecConfiguration, IDENTIFIER}; | ||
|
||
/// A `bitround` codec implementation. | ||
#[derive(Clone, Debug, Default)] | ||
pub struct BitroundCodec { | ||
keepbits: u32, | ||
} | ||
|
||
impl BitroundCodec { | ||
/// Create a new bitround codec. | ||
/// | ||
/// `keepbits` is the number of bits to round to in the floating point mantissa. | ||
#[must_use] | ||
pub fn new(keepbits: u32) -> Self { | ||
Self { keepbits } | ||
} | ||
|
||
/// Create a new bitround codec from a configuration. | ||
#[must_use] | ||
pub fn new_with_configuration(configuration: &BitroundCodecConfiguration) -> Self { | ||
let BitroundCodecConfiguration::V1(configuration) = configuration; | ||
Self { | ||
keepbits: configuration.keepbits, | ||
} | ||
} | ||
} | ||
|
||
impl CodecTraits for BitroundCodec { | ||
fn create_metadata(&self) -> Option<Metadata> { | ||
// FIXME: Output the metadata when the bitround codec is in the zarr specification and supported by multiple implementations. | ||
// let configuration = BitroundCodecConfigurationV1 { | ||
// keepbits: self.keepbits, | ||
// }; | ||
// Some(Metadata::new_with_serializable_configuration(IDENTIFIER, &configuration).unwrap()) | ||
None | ||
} | ||
|
||
fn partial_decoder_should_cache_input(&self) -> bool { | ||
false | ||
} | ||
|
||
fn partial_decoder_decodes_all(&self) -> bool { | ||
false | ||
} | ||
} | ||
|
||
impl ArrayCodecTraits for BitroundCodec { | ||
fn encode( | ||
&self, | ||
mut decoded_value: Vec<u8>, | ||
decoded_representation: &ArrayRepresentation, | ||
) -> Result<Vec<u8>, CodecError> { | ||
round_bytes( | ||
&mut decoded_value, | ||
decoded_representation.data_type(), | ||
self.keepbits, | ||
)?; | ||
Ok(decoded_value) | ||
} | ||
|
||
fn decode( | ||
&self, | ||
encoded_value: Vec<u8>, | ||
_decoded_representation: &ArrayRepresentation, | ||
) -> Result<Vec<u8>, CodecError> { | ||
Ok(encoded_value) | ||
} | ||
} | ||
|
||
impl ArrayToArrayCodecTraits for BitroundCodec { | ||
fn partial_decoder<'a>( | ||
&'a self, | ||
input_handle: Box<dyn ArrayPartialDecoderTraits + 'a>, | ||
) -> Box<dyn ArrayPartialDecoderTraits + 'a> { | ||
Box::new(bitround_partial_decoder::BitroundPartialDecoder::new( | ||
input_handle, | ||
self.keepbits, | ||
)) | ||
} | ||
|
||
fn compute_encoded_size( | ||
&self, | ||
decoded_representation: &ArrayRepresentation, | ||
) -> Result<ArrayRepresentation, CodecError> { | ||
let data_type = decoded_representation.data_type(); | ||
match data_type { | ||
DataType::Float16 | DataType::BFloat16 | DataType::Float32 | DataType::Float64 => { | ||
Ok(decoded_representation.clone()) | ||
} | ||
_ => Err(CodecError::UnsupportedDataType( | ||
data_type.clone(), | ||
IDENTIFIER.to_string(), | ||
)), | ||
} | ||
} | ||
} |
Oops, something went wrong.