Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OER: improve decoding field presence tracking #375

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion macros/macros_impl/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn derive_struct_impl(
.constraints
.const_expr(crate_root)
.unwrap_or_else(|| quote!(#crate_root::types::Constraints::default()));
let constraint_name = format_ident!("DELEGATE_DECODE_CONSTRAINT");
let constraint_name = format_ident!("delegate_constraint");
let constraint_def = if generics.params.is_empty() {
quote! {
let #constraint_name: #crate_root::types::Constraints = const {<#ty as #crate_root::AsnType>::CONSTRAINTS.intersect(#constraints)}.intersect(constraints);
Expand Down
121 changes: 74 additions & 47 deletions src/oer/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
// the encoding itself without knowledge of the type being encoded ITU-T X.696 (6.2).

use alloc::{
collections::VecDeque,
string::{String, ToString},
vec::Vec,
};
Expand Down Expand Up @@ -73,9 +72,9 @@ impl DecoderOptions {
pub struct Decoder<'input, const RFC: usize = 0, const EFC: usize = 0> {
input: InputSlice<'input>,
options: DecoderOptions,
fields: VecDeque<(Field, bool)>,
fields: ([Option<Field>; RFC], usize),
extension_fields: Option<Fields<EFC>>,
extensions_present: Option<Option<VecDeque<(Field, bool)>>>,
extensions_present: Option<Option<([Option<Field>; EFC], usize)>>,
}

impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {
Expand All @@ -85,7 +84,7 @@ impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {
Self {
input: input.into(),
options,
fields: <_>::default(),
fields: ([None; RFC], 0),
extension_fields: <_>::default(),
extensions_present: <_>::default(),
}
Expand Down Expand Up @@ -388,30 +387,39 @@ impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {

#[track_caller]
fn require_field(&mut self, tag: Tag) -> Result<bool, DecodeError> {
if self
.fields
.front()
.is_some_and(|field| field.0.tag_tree.smallest_tag() == tag)
{
Ok(self.fields.pop_front().unwrap().1)
} else {
Err(DecodeError::missing_tag_class_or_value_in_sequence_or_set(
let (fields, index) = &mut self.fields;
let Some(field) = fields.get(*index) else {
return Err(DecodeError::missing_tag_class_or_value_in_sequence_or_set(
tag.class,
tag.value,
self.codec(),
))
));
};

*index += 1;
match field {
Some(field) if field.tag_tree.smallest_tag() == tag => Ok(true),
None => Ok(false),
_ => Err(DecodeError::missing_tag_class_or_value_in_sequence_or_set(
tag.class,
tag.value,
self.codec(),
)),
}
}

fn extension_is_present(&mut self) -> Result<Option<(Field, bool)>, DecodeError> {
fn extension_is_present(&mut self) -> Result<Option<&Field>, DecodeError> {
let codec = self.codec();
Ok(self
.extensions_present
.as_mut()
.ok_or_else(|| DecodeError::type_not_extensible(codec))?
.as_mut()
.ok_or_else(|| DecodeError::type_not_extensible(codec))?
.pop_front())
let Some(Some((fields, index))) = self.extensions_present.as_mut() else {
return Err(DecodeError::type_not_extensible(codec));
};

let field = fields
.get(*index)
.ok_or_else(|| DecodeError::type_not_extensible(codec))?;

*index += 1;
Ok(field.as_ref())
}

fn parse_extension_header(&mut self) -> Result<bool, DecodeError> {
Expand All @@ -427,9 +435,11 @@ impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {
"Extension length should be at least 1 byte".to_string(),
));
}
let extension_fields = self
.extension_fields
.ok_or_else(|| DecodeError::type_not_extensible(self.codec()))?;
// Must be at least 8 bits at this point or error is already raised
let bitfield = self.extract_data_by_length(extensions_length)?.to_bitvec();
// let mut missing_bits: bitvec::vec::BitVec<u8, bitvec::order::Msb0>;
let bitfield = self.extract_data_by_length(extensions_length)?;
// Initial octet
let (unused_bits, bitfield) = bitfield.split_at(8);
let unused_bits: usize = unused_bits.load();
Expand All @@ -439,26 +449,23 @@ impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {
"Invalid extension bitfield initial octet".to_string(),
));
}
let (bitfield, _) = bitfield.split_at(bitfield.len() - unused_bits);
let needed = bitfield.len() - unused_bits;
let bitfield = &bitfield[..needed];

let extensions_present: VecDeque<_> = self
.extension_fields
.as_ref()
.unwrap()
.iter()
.zip(bitfield.iter().map(|b| *b))
.collect();

for (field, is_present) in &extensions_present {
if field.is_not_optional_or_default() && !*is_present {
let mut fields: [Option<Field>; EFC] = [None; EFC];
for (i, field) in extension_fields.iter().enumerate() {
if field.is_not_optional_or_default() && !bitfield[i] {
return Err(DecodeError::required_extension_not_present(
field.tag,
self.codec(),
));
} else if bitfield[i] {
fields[i] = Some(field);
} else {
fields[i] = None;
}
}
self.extensions_present = Some(Some(extensions_present));

self.extensions_present = Some(Some((fields, 0)));
Ok(true)
}

Expand All @@ -481,7 +488,7 @@ impl<'input, const RFC: usize, const EFC: usize> Decoder<'input, RFC, EFC> {
} else {
bitmap.len()
};
self.drop_preamble_bits((8 - preamble_length % 8) % 8)?;
self.drop_preamble_bits((8 - (preamble_length & 7)) & 7)?;

debug_assert_eq!(self.input.len() % 8, 0);
Ok((bitmap, extensible_present))
Expand Down Expand Up @@ -608,10 +615,18 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp
// ### PREAMBLE ###
let (bitmap, extensible_present) = self.parse_preamble::<RC, EC, D>()?;
// ### ENDS
let fields = D::FIELDS
let mut fields = ([None; RC], 0);
for (i, (field, is_present)) in D::FIELDS
.optional_and_default_fields()
.zip(bitmap.into_iter().map(|b| *b))
.collect();
.enumerate()
{
if is_present {
fields.0[i] = Some(field);
} else {
fields.0[i] = None;
}
}

let value = {
let mut sequence_decoder = Decoder::new(self.input.0, self.options);
Expand Down Expand Up @@ -841,26 +856,38 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp
{
let (bitmap, extensible_present) = self.parse_preamble::<RC, EC, SET>()?;

let field_map = SET::FIELDS
let mut field_map: ([Option<Field>; RC], usize) = ([None; RC], 0);
for (i, (field, is_present)) in SET::FIELDS
.canonised()
.optional_and_default_fields()
.zip(bitmap.into_iter().map(|b| *b))
.collect::<alloc::collections::BTreeMap<_, _>>();
.enumerate()
{
if is_present {
field_map.0[i] = Some(field);
} else {
field_map.0[i] = None;
}
}

let fields = {
let extended_fields_len = SET::EXTENDED_FIELDS.map_or(0, |fields| fields.len());
let mut fields = Vec::with_capacity(SET::FIELDS.len() + extended_fields_len);
let mut set_decoder = Decoder::new(self.input.0, self.options);
set_decoder.extension_fields = SET::EXTENDED_FIELDS;
set_decoder.extensions_present = extensible_present.then_some(None);
set_decoder.fields = field_map.clone().into_iter().collect();
set_decoder.fields = field_map;

let mut opt_index = 0;
for field in SET::FIELDS.canonised().iter() {
match field_map.get(&field).copied() {
Some(true) | None => {
if field.is_optional_or_default() {
// Safe unwrap, we just created the field_map
if field_map.0.get(opt_index).unwrap().is_some() {
fields.push(decode_fn(&mut set_decoder, field.index, field.tag)?);
}
Some(false) => {}
opt_index += 1;
} else {
fields.push(decode_fn(&mut set_decoder, field.index, field.tag)?);
}
}
for (indice, field) in SET::EXTENDED_FIELDS
Expand Down Expand Up @@ -967,7 +994,7 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp
return Ok(None);
}

let extension_is_present = self.extension_is_present()?.is_some_and(|(_, b)| b);
let extension_is_present = self.extension_is_present()?.is_some();

if !extension_is_present {
return Ok(None);
Expand All @@ -992,7 +1019,7 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp
return Ok(None);
}

let extension_is_present = self.extension_is_present()?.is_some_and(|(_, b)| b);
let extension_is_present = self.extension_is_present()?.is_some();

if !extension_is_present {
return Ok(None);
Expand Down
Loading