diff --git a/src/descriptor/bare.rs b/src/descriptor/bare.rs index 176137ad8..63cb8168a 100644 --- a/src/descriptor/bare.rs +++ b/src/descriptor/bare.rs @@ -12,7 +12,6 @@ use core::fmt; use bitcoin::script::{self, PushBytes}; use bitcoin::{Address, Network, ScriptBuf, Weight}; -use super::checksum::verify_checksum; use crate::descriptor::{write_descriptor, DefiniteDescriptorKey}; use crate::expression::{self, FromTree}; use crate::miniscript::context::{ScriptContext, ScriptContextError}; @@ -186,8 +185,7 @@ impl FromTree for Bare { impl core::str::FromStr for Bare { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; Self::from_tree(&top) } } @@ -387,8 +385,7 @@ impl FromTree for Pkh { impl core::str::FromStr for Pkh { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; Self::from_tree(&top) } } diff --git a/src/descriptor/checksum.rs b/src/descriptor/checksum.rs index 6a79194c2..6ae296433 100644 --- a/src/descriptor/checksum.rs +++ b/src/descriptor/checksum.rs @@ -14,46 +14,117 @@ use core::iter::FromIterator; use bech32::primitives::checksum::PackedFe32; use bech32::{Checksum, Fe32}; -pub use crate::expression::VALID_CHARS; use crate::prelude::*; -use crate::Error; const CHECKSUM_LENGTH: usize = 8; const CODE_LENGTH: usize = 32767; -/// Compute the checksum of a descriptor. +/// Map of valid characters in descriptor strings. /// -/// Note that this function does not check if the descriptor string is -/// syntactically correct or not. This only computes the checksum. -pub fn desc_checksum(desc: &str) -> Result { - let mut eng = Engine::new(); - eng.input(desc)?; - Ok(eng.checksum()) +/// The map starts at 32 (space) and runs up to 126 (tilde). +#[rustfmt::skip] +const CHAR_MAP: [u8; 95] = [ + 94, 59, 92, 91, 28, 29, 50, 15, 10, 11, 17, 51, 14, 52, 53, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 27, 54, 55, 56, 57, 58, + 26, 82, 83, 84, 85, 86, 87, 88, 89, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 12, 93, 13, 60, 61, + 90, 18, 19, 20, 21, 22, 23, 24, 25, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 30, 62, 31, 63, +]; + +/// Error validating descriptor checksum. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Error { + /// Character outside of descriptor charset. + InvalidCharacter { + /// The character in question. + ch: char, + /// Its position in the string. + pos: usize, + }, + /// Checksum had the incorrect length. + InvalidChecksumLength { + /// The length of the checksum in the string. + actual: usize, + /// The length of a valid descriptor checksum. + expected: usize, + }, + /// Checksum was invalid. + InvalidChecksum { + /// The checksum in the string. + actual: [char; CHECKSUM_LENGTH], + /// The checksum that should have been there, assuming the string is valid. + expected: [char; CHECKSUM_LENGTH], + }, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::InvalidCharacter { ch, pos } => { + write!(f, "invalid character '{}' (position {})", ch, pos) + } + Error::InvalidChecksumLength { actual, expected } => { + write!(f, "invalid checksum (length {}, expected {})", actual, expected) + } + Error::InvalidChecksum { actual, expected } => { + f.write_str("invalid checksum ")?; + for ch in actual { + ch.fmt(f)?; + } + f.write_str("; expected ")?; + for ch in expected { + ch.fmt(f)?; + } + Ok(()) + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn cause(&self) -> Option<&dyn std::error::Error> { None } } /// Helper function for `FromStr` for various descriptor types. /// /// Checks and verifies the checksum if it is present and returns the descriptor /// string without the checksum. -pub(super) fn verify_checksum(s: &str) -> Result<&str, Error> { - for ch in s.as_bytes() { - if *ch < 20 || *ch > 127 { - return Err(Error::Unprintable(*ch)); +pub fn verify_checksum(s: &str) -> Result<&str, Error> { + let mut last_hash_pos = s.len(); + for (pos, ch) in s.char_indices() { + if !(32..127).contains(&u32::from(ch)) { + return Err(Error::InvalidCharacter { ch, pos }); + } else if ch == '#' { + last_hash_pos = pos; } } + // After this point we know we have ASCII and can stop using character methods. + + if last_hash_pos < s.len() { + let checksum_str = &s[last_hash_pos + 1..]; + if checksum_str.len() != CHECKSUM_LENGTH { + return Err(Error::InvalidChecksumLength { + actual: checksum_str.len(), + expected: CHECKSUM_LENGTH, + }); + } + + let mut eng = Engine::new(); + eng.input_unchecked(s[..last_hash_pos].as_bytes()); - let mut parts = s.splitn(2, '#'); - let desc_str = parts.next().unwrap(); - if let Some(checksum_str) = parts.next() { - let expected_sum = desc_checksum(desc_str)?; - if checksum_str != expected_sum { - return Err(Error::BadDescriptor(format!( - "Invalid checksum '{}', expected '{}'", - checksum_str, expected_sum - ))); + let expected = eng.checksum_chars(); + let mut actual = ['_'; CHECKSUM_LENGTH]; + for (act, ch) in actual.iter_mut().zip(checksum_str.chars()) { + *act = ch; + } + + if expected != actual { + return Err(Error::InvalidChecksum { actual, expected }); } } - Ok(desc_str) + Ok(&s[..last_hash_pos]) } /// An engine to compute a checksum from a string. @@ -78,16 +149,18 @@ impl Engine { /// If this function returns an error, the `Engine` will be left in an indeterminate /// state! It is safe to continue feeding it data but the result will not be meaningful. pub fn input(&mut self, s: &str) -> Result<(), Error> { - for ch in s.chars() { - let pos = VALID_CHARS - .get(ch as usize) - .ok_or_else(|| { - Error::BadDescriptor(format!("Invalid character in checksum: '{}'", ch)) - })? - .ok_or_else(|| { - Error::BadDescriptor(format!("Invalid character in checksum: '{}'", ch)) - })? as u64; + for (pos, ch) in s.char_indices() { + if !(32..127).contains(&u32::from(ch)) { + return Err(Error::InvalidCharacter { ch, pos }); + } + } + self.input_unchecked(s.as_bytes()); + Ok(()) + } + fn input_unchecked(&mut self, s: &[u8]) { + for ch in s { + let pos = u64::from(CHAR_MAP[usize::from(*ch) - 32]); let fe = Fe32::try_from(pos & 31).expect("pos is valid because of the mask"); self.inner.input_fe(fe); @@ -100,7 +173,6 @@ impl Engine { self.clscount = 0; } } - Ok(()) } /// Obtains the checksum characters of all the data thus-far fed to the @@ -192,7 +264,9 @@ mod test { macro_rules! check_expected { ($desc: expr, $checksum: expr) => { - assert_eq!(desc_checksum($desc).unwrap(), $checksum); + let mut eng = Engine::new(); + eng.input_unchecked($desc.as_bytes()); + assert_eq!(eng.checksum(), $checksum); }; } @@ -229,8 +303,8 @@ mod test { let invalid_desc = format!("wpkh(tprv8ZgxMBicQKsPdpkqS7Eair4YxjcuuvDPNYmKX3sCniCf16tHEVrjjiSXEkFRnUH77yXc6ZcwHHcL{}fjdi5qUvw3VDfgYiH5mNsj5izuiu2N/1/2/*)", sparkle_heart); assert_eq!( - desc_checksum(&invalid_desc).err().unwrap().to_string(), - format!("Invalid descriptor: Invalid character in checksum: '{}'", sparkle_heart) + verify_checksum(&invalid_desc).err().unwrap().to_string(), + format!("invalid character '{}' (position 85)", sparkle_heart) ); } diff --git a/src/descriptor/mod.rs b/src/descriptor/mod.rs index b172aa69e..6c973cc5d 100644 --- a/src/descriptor/mod.rs +++ b/src/descriptor/mod.rs @@ -21,7 +21,6 @@ use bitcoin::{ }; use sync::Arc; -use self::checksum::verify_checksum; use crate::miniscript::decode::Terminal; use crate::miniscript::{satisfy, Legacy, Miniscript, Segwitv0}; use crate::plan::{AssetProvider, Plan}; @@ -988,8 +987,7 @@ impl FromStr for Descriptor { let desc = if s.starts_with("tr(") { Ok(Descriptor::Tr(Tr::from_str(s)?)) } else { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; expression::FromTree::from_tree(&top) }?; @@ -1053,8 +1051,7 @@ mod tests { use bitcoin::sighash::EcdsaSighashType; use bitcoin::{bip32, PublicKey, Sequence}; - use super::checksum::desc_checksum; - use super::*; + use super::{checksum, *}; use crate::hex_script; #[cfg(feature = "compiler")] use crate::policy; @@ -1066,10 +1063,10 @@ mod tests { let desc = Descriptor::::from_str(s).unwrap(); let output = desc.to_string(); let normalize_aliases = s.replace("c:pk_k(", "pk(").replace("c:pk_h(", "pkh("); - assert_eq!( - format!("{}#{}", &normalize_aliases, desc_checksum(&normalize_aliases).unwrap()), - output - ); + + let mut checksum_eng = checksum::Engine::new(); + checksum_eng.input(&normalize_aliases).unwrap(); + assert_eq!(format!("{}#{}", &normalize_aliases, checksum_eng.checksum()), output); } #[test] @@ -1841,7 +1838,7 @@ mod tests { ($secp: ident,$($desc: expr),*) => { $( match Descriptor::parse_descriptor($secp, $desc) { - Err(Error::BadDescriptor(_)) => {}, + Err(Error::ParseTree(crate::ParseTreeError::Checksum(_))) => {}, Err(e) => panic!("Expected bad checksum for {}, got '{}'", $desc, e), _ => panic!("Invalid checksum treated as valid: {}", $desc), }; diff --git a/src/descriptor/segwitv0.rs b/src/descriptor/segwitv0.rs index 2f2532b85..c8552eb4b 100644 --- a/src/descriptor/segwitv0.rs +++ b/src/descriptor/segwitv0.rs @@ -10,7 +10,6 @@ use core::fmt; use bitcoin::{Address, Network, ScriptBuf, Weight}; -use super::checksum::verify_checksum; use super::SortedMultiVec; use crate::descriptor::{write_descriptor, DefiniteDescriptorKey}; use crate::expression::{self, FromTree}; @@ -288,8 +287,7 @@ impl fmt::Display for Wsh { impl core::str::FromStr for Wsh { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; Wsh::::from_tree(&top) } } @@ -505,8 +503,7 @@ impl crate::expression::FromTree for Wpkh { impl core::str::FromStr for Wpkh { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; Self::from_tree(&top) } } diff --git a/src/descriptor/sh.rs b/src/descriptor/sh.rs index cf05c1b71..f9f21544d 100644 --- a/src/descriptor/sh.rs +++ b/src/descriptor/sh.rs @@ -13,7 +13,6 @@ use core::fmt; use bitcoin::script::PushBytes; use bitcoin::{script, Address, Network, ScriptBuf, Weight}; -use super::checksum::verify_checksum; use super::{SortedMultiVec, Wpkh, Wsh}; use crate::descriptor::{write_descriptor, DefiniteDescriptorKey}; use crate::expression::{self, FromTree}; @@ -109,8 +108,7 @@ impl crate::expression::FromTree for Sh { impl core::str::FromStr for Sh { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; - let top = expression::Tree::from_str(desc_str)?; + let top = expression::Tree::from_str(s)?; Self::from_tree(&top) } } diff --git a/src/descriptor/tr.rs b/src/descriptor/tr.rs index 30d6c5c74..77cceb7e0 100644 --- a/src/descriptor/tr.rs +++ b/src/descriptor/tr.rs @@ -557,7 +557,9 @@ impl crate::expression::FromTree for Tr { impl core::str::FromStr for Tr { type Err = Error; fn from_str(s: &str) -> Result { - let desc_str = verify_checksum(s)?; + let desc_str = verify_checksum(s) + .map_err(From::from) + .map_err(Error::ParseTree)?; let top = parse_tr_tree(desc_str)?; Self::from_tree(&top) } @@ -587,8 +589,6 @@ impl fmt::Display for Tr { // Helper function to parse string into miniscript tree form fn parse_tr_tree(s: &str) -> Result { - expression::check_valid_chars(s)?; - if s.len() > 3 && &s[..3] == "tr(" && s.as_bytes()[s.len() - 1] == b')' { let rest = &s[3..s.len() - 1]; if !rest.contains(',') { diff --git a/src/expression/error.rs b/src/expression/error.rs index f82358647..d705e2df5 100644 --- a/src/expression/error.rs +++ b/src/expression/error.rs @@ -4,9 +4,116 @@ use core::fmt; +use crate::descriptor::checksum; use crate::prelude::*; use crate::ThresholdError; +/// An error parsing an expression tree. +#[derive(Debug, PartialEq, Eq)] +pub enum ParseTreeError { + /// Error validating the checksum or character set. + Checksum(checksum::Error), + /// Expression tree had depth exceeding our hard cap. + MaxRecursionDepthExceeded { + /// The depth of the tree that was attempted to be parsed. + actual: usize, + /// The maximum depth. + maximum: u32, + }, + /// After a close-paren, the only valid next characters are close-parens and commas. Got + /// something else. + ExpectedParenOrComma { + /// What we got instead. + ch: char, + /// Its byte-index into the string. + pos: usize, + }, + /// An open-parenthesis had no corresponding close-parenthesis. + UnmatchedOpenParen { + /// The character in question ('(' or '{') + ch: char, + /// Its byte-index into the string. + pos: usize, + }, + /// A close-parenthesis had no corresponding open-parenthesis. + UnmatchedCloseParen { + /// The character in question (')' or '}') + ch: char, + /// Its byte-index into the string. + pos: usize, + }, + /// A `(` was matched with a `}` or vice-versa. + MismatchedParens { + /// The opening parenthesis ('(' or '{') + open_ch: char, + /// The position of the opening parethesis. + open_pos: usize, + /// The closing parenthesis (')' or '}') + close_ch: char, + /// The position of the closing parethesis. + close_pos: usize, + }, + /// Data occurred after the final ). + TrailingCharacter { + /// The first trailing character. + ch: char, + /// Its byte-index into the string. + pos: usize, + }, +} + +impl From for ParseTreeError { + fn from(e: checksum::Error) -> Self { Self::Checksum(e) } +} + +impl fmt::Display for ParseTreeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ParseTreeError::Checksum(ref e) => e.fmt(f), + ParseTreeError::MaxRecursionDepthExceeded { actual, maximum } => { + write!(f, "maximum recursion depth exceeded (max {}, got {})", maximum, actual) + } + ParseTreeError::ExpectedParenOrComma { ch, pos } => { + write!( + f, + "invalid character `{}` (position {}); expected comma or close-paren", + ch, pos + ) + } + ParseTreeError::UnmatchedOpenParen { ch, pos } => { + write!(f, "`{}` (position {}) not closed", ch, pos) + } + ParseTreeError::UnmatchedCloseParen { ch, pos } => { + write!(f, "`{}` (position {}) not opened", ch, pos) + } + ParseTreeError::MismatchedParens { open_ch, open_pos, close_ch, close_pos } => { + write!( + f, + "`{}` (position {}) closed by `{}` (position {})", + open_ch, open_pos, close_ch, close_pos + ) + } + ParseTreeError::TrailingCharacter { ch, pos } => { + write!(f, "trailing data `{}...` (position {})", ch, pos) + } + } + } +} +#[cfg(feature = "std")] +impl std::error::Error for ParseTreeError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ParseTreeError::Checksum(ref e) => Some(e), + ParseTreeError::MaxRecursionDepthExceeded { .. } + | ParseTreeError::ExpectedParenOrComma { .. } + | ParseTreeError::UnmatchedOpenParen { .. } + | ParseTreeError::UnmatchedCloseParen { .. } + | ParseTreeError::MismatchedParens { .. } + | ParseTreeError::TrailingCharacter { .. } => None, + } + } +} + /// Error parsing a threshold expression. #[derive(Clone, Debug, PartialEq, Eq)] pub enum ParseThresholdError { diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 3224a4f45..844b3c869 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -8,33 +8,14 @@ mod error; use core::fmt; use core::str::FromStr; -pub use self::error::ParseThresholdError; +pub use self::error::{ParseThresholdError, ParseTreeError}; +use crate::descriptor::checksum::verify_checksum; use crate::prelude::*; use crate::{errstr, Error, Threshold, MAX_RECURSION_DEPTH}; /// Allowed characters are descriptor strings. pub const INPUT_CHARSET: &str = "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVWXYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#\"\\ "; -/// Map of valid characters in descriptor strings. -#[rustfmt::skip] -pub const VALID_CHARS: [Option; 128] = [ - None, None, None, None, None, None, None, None, None, None, None, None, None, - None, None, None, None, None, None, None, None, None, None, None, None, None, - None, None, None, None, None, None, Some(94), Some(59), Some(92), Some(91), - Some(28), Some(29), Some(50), Some(15), Some(10), Some(11), Some(17), Some(51), - Some(14), Some(52), Some(53), Some(16), Some(0), Some(1), Some(2), Some(3), - Some(4), Some(5), Some(6), Some(7), Some(8), Some(9), Some(27), Some(54), - Some(55), Some(56), Some(57), Some(58), Some(26), Some(82), Some(83), - Some(84), Some(85), Some(86), Some(87), Some(88), Some(89), Some(32), Some(33), - Some(34), Some(35), Some(36), Some(37), Some(38), Some(39), Some(40), Some(41), - Some(42), Some(43), Some(44), Some(45), Some(46), Some(47), Some(48), Some(49), - Some(12), Some(93), Some(13), Some(60), Some(61), Some(90), Some(18), Some(19), - Some(20), Some(21), Some(22), Some(23), Some(24), Some(25), Some(64), Some(65), - Some(66), Some(67), Some(68), Some(69), Some(70), Some(71), Some(72), Some(73), - Some(74), Some(75), Some(76), Some(77), Some(78), Some(79), Some(80), Some(81), - Some(30), Some(62), Some(31), Some(63), None, -]; - #[derive(Debug)] /// A token of the form `x(...)` or `x` pub struct Tree<'a> { @@ -43,6 +24,20 @@ pub struct Tree<'a> { /// The comma-separated contents of the `(...)`, if any pub args: Vec>, } + +impl PartialEq for Tree<'_> { + fn eq(&self, other: &Self) -> bool { + let mut stack = vec![(self, other)]; + while let Some((me, you)) = stack.pop() { + if me.name != you.name || me.args.len() != you.args.len() { + return false; + } + stack.extend(me.args.iter().zip(you.args.iter())); + } + true + } +} +impl Eq for Tree<'_> {} // or_b(pk(A),pk(B)) // // A = musig(musig(B,C),D,E) @@ -131,13 +126,105 @@ impl<'a> Tree<'a> { Self::from_slice_delim(sl, 0u32, '(') } + /// Check that a string is a well-formed expression string, with optional + /// checksum. + /// + /// Returns the string with the checksum removed. + fn parse_pre_check(s: &str, open: u8, close: u8) -> Result<&str, ParseTreeError> { + // Do ASCII check first; after this we can use .bytes().enumerate() rather + // than .char_indices(), which is *significantly* faster. + let s = verify_checksum(s)?; + + let mut max_depth = 0; + let mut open_paren_stack = Vec::with_capacity(128); + for (pos, ch) in s.bytes().enumerate() { + if ch == open { + open_paren_stack.push((ch, pos)); + if max_depth < open_paren_stack.len() { + max_depth = open_paren_stack.len(); + } + } else if ch == close { + if let Some((open_ch, open_pos)) = open_paren_stack.pop() { + if (open_ch == b'(' && ch == b'}') || (open_ch == b'{' && ch == b')') { + return Err(ParseTreeError::MismatchedParens { + open_ch: open_ch.into(), + open_pos, + close_ch: ch.into(), + close_pos: pos, + }); + } + + if let Some(&(paren_ch, paren_pos)) = open_paren_stack.last() { + // not last paren; this should not be the end of the string, + // and the next character should be a , ) or }. + if pos == s.len() - 1 { + return Err(ParseTreeError::UnmatchedOpenParen { + ch: paren_ch.into(), + pos: paren_pos, + }); + } else { + let next_byte = s.as_bytes()[pos + 1]; + if next_byte != b')' && next_byte != b'}' && next_byte != b',' { + return Err(ParseTreeError::ExpectedParenOrComma { + ch: next_byte.into(), + pos: pos + 1, + }); + // + } + } + } else { + // last paren; this SHOULD be the end of the string + if pos < s.len() - 1 { + return Err(ParseTreeError::TrailingCharacter { + ch: s.as_bytes()[pos + 1].into(), + pos: pos + 1, + }); + } + } + } else { + // In practice, this is only hit if there are no open parens at all. + // If there are open parens, like in "())", then on the first ), we + // would have returned TrailingCharacter in the previous clause. + // + // From a user point of view, UnmatchedCloseParen would probably be + // a clearer error to get, but it complicates the parser to do this, + // and "TralingCharacter" is technically correct, so we leave it for + // now. + return Err(ParseTreeError::UnmatchedCloseParen { ch: ch.into(), pos }); + } + } else if ch == b',' && open_paren_stack.is_empty() { + // We consider commas outside of the tree to be "trailing characters" + return Err(ParseTreeError::TrailingCharacter { ch: ch.into(), pos }); + } + } + // Catch "early end of string" + if let Some((ch, pos)) = open_paren_stack.pop() { + return Err(ParseTreeError::UnmatchedOpenParen { ch: ch.into(), pos }); + } + + // FIXME should be able to remove this once we eliminate all recursion + // in the library. + if u32::try_from(max_depth).unwrap_or(u32::MAX) > MAX_RECURSION_DEPTH { + return Err(ParseTreeError::MaxRecursionDepthExceeded { + actual: max_depth, + maximum: MAX_RECURSION_DEPTH, + }); + } + + Ok(s) + } + pub(crate) fn from_slice_delim( mut sl: &'a str, depth: u32, delim: char, ) -> Result<(Tree<'a>, &'a str), Error> { - if depth >= MAX_RECURSION_DEPTH { - return Err(Error::MaxRecursiveDepthExceeded); + if depth == 0 { + if delim == '{' { + sl = Self::parse_pre_check(sl, b'{', b'}').map_err(Error::ParseTree)?; + } else { + sl = Self::parse_pre_check(sl, b'(', b')').map_err(Error::ParseTree)?; + } } match next_expr(sl, delim) { @@ -157,7 +244,7 @@ impl<'a> Tree<'a> { ret.args.push(arg); if new_sl.is_empty() { - return Err(Error::ExpectedChar(closing_delim(delim))); + unreachable!() } sl = &new_sl[1..]; @@ -167,7 +254,7 @@ impl<'a> Tree<'a> { if last_byte == closing_delim(delim) as u8 { break; } else { - return Err(Error::ExpectedChar(closing_delim(delim))); + unreachable!() } } } @@ -180,13 +267,11 @@ impl<'a> Tree<'a> { /// Parses a tree from a string #[allow(clippy::should_implement_trait)] // Cannot use std::str::FromStr because of lifetimes. pub fn from_str(s: &'a str) -> Result, Error> { - check_valid_chars(s)?; - let (top, rem) = Tree::from_slice(s)?; if rem.is_empty() { Ok(top) } else { - Err(errstr(rem)) + unreachable!() } } @@ -220,23 +305,6 @@ impl<'a> Tree<'a> { } } -/// Filter out non-ASCII because we byte-index strings all over the -/// place and Rust gets very upset when you splinch a string. -pub fn check_valid_chars(s: &str) -> Result<(), Error> { - for ch in s.bytes() { - if !ch.is_ascii() { - return Err(Error::Unprintable(ch)); - } - // Index bounds: We know that ch is ASCII, so it is <= 127. - if VALID_CHARS[ch as usize].is_none() { - return Err(Error::Unexpected( - "Only characters in INPUT_CHARSET are allowed".to_string(), - )); - } - } - Ok(()) -} - /// Parse a string as a u32, for timelocks or thresholds pub fn parse_num(s: &str) -> Result { if s.len() > 1 { @@ -293,7 +361,12 @@ where #[cfg(test)] mod tests { - use super::parse_num; + use super::*; + + /// Test functions to manually build trees + fn leaf(name: &str) -> Tree { Tree { name, args: vec![] } } + + fn paren_node<'a>(name: &'a str, args: Vec>) -> Tree<'a> { Tree { name, args } } #[test] fn test_parse_num() { @@ -306,11 +379,113 @@ mod tests { } #[test] - fn test_valid_char_map() { - let mut valid_chars = [None; 128]; - for (i, ch) in super::INPUT_CHARSET.chars().enumerate() { - valid_chars[ch as usize] = Some(i as u8); - } - assert_eq!(valid_chars, super::VALID_CHARS); + fn parse_tree_basic() { + assert_eq!(Tree::from_str("thresh").unwrap(), leaf("thresh")); + + assert!(matches!( + Tree::from_str("thresh,").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: ',', pos: 6 }), + )); + + assert!(matches!( + Tree::from_str("thresh,thresh").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: ',', pos: 6 }), + )); + + assert!(matches!( + Tree::from_str("thresh()thresh()").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: 't', pos: 8 }), + )); + + assert_eq!(Tree::from_str("thresh()").unwrap(), paren_node("thresh", vec![leaf("")])); + + assert!(matches!( + Tree::from_str("thresh(a()b)"), + Err(Error::ParseTree(ParseTreeError::ExpectedParenOrComma { ch: 'b', pos: 10 })), + )); + + assert!(matches!( + Tree::from_str("thresh()xyz"), + Err(Error::ParseTree(ParseTreeError::TrailingCharacter { ch: 'x', pos: 8 })), + )); + } + + #[test] + fn parse_tree_parens() { + assert!(matches!( + Tree::from_str("a(").unwrap_err(), + Error::ParseTree(ParseTreeError::UnmatchedOpenParen { ch: '(', pos: 1 }), + )); + + assert!(matches!( + Tree::from_str(")").unwrap_err(), + Error::ParseTree(ParseTreeError::UnmatchedCloseParen { ch: ')', pos: 0 }), + )); + + assert!(matches!( + Tree::from_str("x(y))").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: ')', pos: 4 }), + )); + + /* Will be enabled in a later PR which unifies TR and non-TR parsing. + assert!(matches!( + Tree::from_str("a{").unwrap_err(), + Error::ParseTree(ParseTreeError::UnmatchedOpenParen { ch: '{', pos: 1 }), + )); + + assert!(matches!( + Tree::from_str("}").unwrap_err(), + Error::ParseTree(ParseTreeError::UnmatchedCloseParen { ch: '}', pos: 0 }), + )); + */ + + assert!(matches!( + Tree::from_str("x(y)}").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: '}', pos: 4 }), + )); + + /* Will be enabled in a later PR which unifies TR and non-TR parsing. + assert!(matches!( + Tree::from_str("x{y)").unwrap_err(), + Error::ParseTree(ParseTreeError::MismatchedParens { + open_ch: '{', + open_pos: 1, + close_ch: ')', + close_pos: 3, + }), + )); + */ + } + + #[test] + fn parse_tree_taproot() { + // This test will change in a later PR which unifies TR and non-TR parsing. + assert!(matches!( + Tree::from_str("a{b(c),d}").unwrap_err(), + Error::ParseTree(ParseTreeError::TrailingCharacter { ch: ',', pos: 6 }), + )); + } + + #[test] + fn parse_tree_desc() { + let keys = [ + "02c2fd50ceae468857bb7eb32ae9cd4083e6c7e42fbbec179d81134b3e3830586c", + "0257f4a2816338436cccabc43aa724cf6e69e43e84c3c8a305212761389dd73a8a", + ]; + let desc = format!("wsh(t:or_c(pk({}),v:pkh({})))", keys[0], keys[1]); + + assert_eq!( + Tree::from_str(&desc).unwrap(), + paren_node( + "wsh", + vec![paren_node( + "t:or_c", + vec![ + paren_node("pk", vec![leaf(keys[0])]), + paren_node("v:pkh", vec![leaf(keys[1])]), + ] + )] + ), + ); } } diff --git a/src/lib.rs b/src/lib.rs index d566fa7c9..f15d99e46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ use bitcoin::{script, Opcode}; pub use crate::blanket_traits::FromStrKey; pub use crate::descriptor::{DefiniteDescriptorKey, Descriptor, DescriptorPublicKey}; -pub use crate::expression::ParseThresholdError; +pub use crate::expression::{ParseThresholdError, ParseTreeError}; pub use crate::interpreter::Interpreter; pub use crate::miniscript::analyzable::{AnalysisError, ExtParams}; pub use crate::miniscript::context::{BareCtx, Legacy, ScriptContext, Segwitv0, SigType, Tap}; @@ -492,6 +492,8 @@ pub enum Error { Threshold(ThresholdError), /// Invalid threshold. ParseThreshold(ParseThresholdError), + /// Invalid expression tree. + ParseTree(ParseTreeError), } // https://github.com/sipa/miniscript/pull/5 for discussion on this number @@ -553,6 +555,7 @@ impl fmt::Display for Error { Error::RelativeLockTime(ref e) => e.fmt(f), Error::Threshold(ref e) => e.fmt(f), Error::ParseThreshold(ref e) => e.fmt(f), + Error::ParseTree(ref e) => e.fmt(f), } } } @@ -603,6 +606,7 @@ impl error::Error for Error { RelativeLockTime(e) => Some(e), Threshold(e) => Some(e), ParseThreshold(e) => Some(e), + ParseTree(e) => Some(e), } } } diff --git a/src/miniscript/mod.rs b/src/miniscript/mod.rs index 1b2bddc7d..03c242ec8 100644 --- a/src/miniscript/mod.rs +++ b/src/miniscript/mod.rs @@ -1315,7 +1315,7 @@ mod tests { assert!(Segwitv0Script::from_str_insane("🌏") .unwrap_err() .to_string() - .contains("unprintable character")); + .contains("invalid character")); } #[test] diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 578d821b4..deebe289c 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -835,8 +835,6 @@ impl fmt::Display for Policy { impl str::FromStr for Policy { type Err = Error; fn from_str(s: &str) -> Result, Error> { - expression::check_valid_chars(s)?; - let tree = expression::Tree::from_str(s)?; let policy: Policy = FromTree::from_tree(&tree)?; policy.check_timelocks().map_err(Error::ConcretePolicy)?; diff --git a/src/policy/semantic.rs b/src/policy/semantic.rs index 62487a54a..2eca31350 100644 --- a/src/policy/semantic.rs +++ b/src/policy/semantic.rs @@ -314,8 +314,6 @@ impl fmt::Display for Policy { impl str::FromStr for Policy { type Err = Error; fn from_str(s: &str) -> Result, Error> { - expression::check_valid_chars(s)?; - let tree = expression::Tree::from_str(s)?; expression::FromTree::from_tree(&tree) }