diff --git a/src/lib.rs b/src/lib.rs index 02f731f..7f96f85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,7 @@ extern crate serde_yaml; use std::borrow::ToOwned; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; +use std::fmt::Debug; use std::fs::File; use std::hash::Hash; use std::io::prelude::*; @@ -56,8 +57,8 @@ use serde::Serialize; use serde_yaml as yaml; /// The definition of all types that can be used in a `Chain`. -pub trait Chainable: Eq + Hash + Clone {} -impl Chainable for T where T: Eq + Hash + Clone {} +pub trait Chainable: Eq + Hash + Clone + Debug {} +impl Chainable for T where T: Eq + Hash + Clone + Debug {} type Token = Option; @@ -106,6 +107,36 @@ where } } + /// Returns a vector of current counts of token T + fn rank>(&self, tokens: S) -> Vec<(&Token, &usize)> { + let tokens = tokens.as_ref(); + if tokens.is_empty() { + return Vec::new(); + } + let mut toks = vec![]; + toks.extend(tokens.iter().map(|token| Some(token.clone()))); + if !self.map.contains_key(&toks) { + return Vec::new(); + } + println!("Tokens {:?}", toks); + println!("Map {:?}", self.map); + let result = self.map.get(&toks).unwrap(); + let sorted: Vec<_> = result + .iter() + .sorted_by(|&a, &b| Ord::cmp(a.1, b.1).reverse()) + .collect(); + sorted + } + + /// Get iterator over all tokens following a given set of tokens + /// (sorted by count) + pub fn iter_rank>(&self, tokens: S) -> RankIterator { + // TODO: + // The iterator is not stable. If elements have the same count, they + // have a different order when put out + RankIterator::new(self, tokens) + } + /// Determines whether or not the chain is empty. A chain is considered empty if nothing has /// been fed into it. pub fn is_empty(&self) -> bool { @@ -432,6 +463,44 @@ where } } +#[derive(Debug)] +/// Iterator over tokens sorted by rank given a token (sorted by highest probability) +pub struct RankIterator<'a, T> +where + T: Chainable + 'a, +{ + chain: Vec<(&'a Token, &'a usize)>, + count: usize, +} + +impl<'a, T> RankIterator<'a, T> +where + T: Chainable + 'a, +{ + /// Generate rank iterator + pub fn new>(chain: &'a Chain, tokens: S) -> Self { + let m = chain.rank(tokens); + RankIterator { chain: m, count: 0 } + } +} + +impl<'a, T> Iterator for RankIterator<'a, T> +where + T: Chainable + 'a, +{ + type Item = &'a Token; + + fn next(&mut self) -> Option { + if self.count >= self.chain.len() { + None + } else { + let r = Some(self.chain[self.count].0); + self.count += 1; + r + } + } +} + #[cfg(test)] mod test { use super::Chain; @@ -456,6 +525,71 @@ mod test { chain.feed(vec![3, 5, 10]).feed(vec![5, 12]); } + #[test] + fn rank() { + let mut chain = Chain::new(); + chain.feed(vec![3, 5, 10]).feed(vec![5, 12]); + let vec = chain.rank(vec![3]); + let mut iter = vec.iter(); + assert_eq!(iter.next(), Some(&(&Some(5), &1usize))); + assert_eq!(iter.next(), None); + + chain.feed(vec![3, 10, 3, 11, 3, 11, 3, 10, 3, 11]); + let vec = chain.rank(vec![3]); + let mut iter = vec.iter(); + assert_eq!(iter.next(), Some(&(&Some(11), &3usize))); + assert_eq!(iter.next(), Some(&(&Some(10), &2usize))); + assert_eq!(iter.next(), Some(&(&Some(5), &1usize))); + assert_eq!(iter.next(), None); + } + + #[test] + fn iter_rank() { + let mut chain = Chain::new(); + chain.feed(vec![3, 5, 10]).feed(vec![5, 3, 12, 3, 5]); + let mut iter = chain.iter_rank(vec![3]); + assert_eq!(iter.next(), Some(&Some(5))); + assert_eq!(iter.next(), Some(&Some(12))); + assert_eq!(iter.next(), None); + + let mut iter = chain.iter_rank(vec![3]).take(1); + assert_eq!(iter.next(), Some(&Some(5))); + assert_eq!(iter.next(), None); + + let mut iter = chain.iter_rank(vec![]).take(1); + assert_eq!(iter.next(), None); + } + + #[test] + fn iter_rank_higher_order() { + let mut chain = Chain::of_order(2); + chain.feed(vec![3, 5, 10]).feed(vec![5, 12]); + let mut iter = chain.iter_rank(vec![3, 5]); + assert_eq!(iter.next(), Some(&Some(10))); + assert_eq!(iter.next(), None); + + chain.feed(vec![3, 10, 3, 11, 3, 11, 3, 10, 3, 11]); + let mut iter = chain.iter_rank(vec![3, 10]); + assert_eq!(iter.next(), Some(&Some(3))); + assert_eq!(iter.next(), None); + } + + #[test] + fn rank_higher_order() { + let mut chain = Chain::of_order(2); + chain.feed(vec![3, 5, 10]).feed(vec![5, 12]); + let vec = chain.rank(vec![3, 5]); + let mut iter = vec.iter(); + assert_eq!(iter.next(), Some(&(&Some(10), &1usize))); + assert_eq!(iter.next(), None); + + chain.feed(vec![3, 10, 3, 11, 3, 11, 3, 10, 3, 11]); + let vec = chain.rank(vec![3, 10]); + let mut iter = vec.iter(); + assert_eq!(iter.next(), Some(&(&Some(3), &2usize))); + assert_eq!(iter.next(), None); + } + #[test] fn generate() { let mut chain = Chain::new();