diff --git a/src/lava/build.rs b/src/lava/build.rs index 8bc8328..414f977 100644 --- a/src/lava/build.rs +++ b/src/lava/build.rs @@ -21,19 +21,13 @@ use std::fs::File; use std::io::{BufWriter, Seek, SeekFrom, Write}; use zstd::stream::encode_all; -use crate::vamana::{build_index_par, IndexParams, VamanaIndex}; -use crate::vamana::{EuclideanF32, InMemoryAccessMethodF32}; -use ndarray::Array2; - use rayon::prelude::*; fn get_tokenizer(tokenizer_file: Option) -> Result<(Tokenizer, Vec), LavaError> { // if the tokenizer file is provided, check if the file exists. If it does not exist, raise an Error let tokenizer = if let Some(tokenizer_file) = tokenizer_file { if !std::path::Path::new(&tokenizer_file).exists() { - return Err(LavaError::Parse( - "Tokenizer file does not exist".to_string(), - )); + return Err(LavaError::Parse("Tokenizer file does not exist".to_string())); } println!("Tokenizer file: {}", tokenizer_file); Tokenizer::from_file(tokenizer_file).unwrap() @@ -42,8 +36,7 @@ fn get_tokenizer(tokenizer_file: Option) -> Result<(Tokenizer, Vec), }; let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); - let compressed_tokenizer = - encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); + let compressed_tokenizer = encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); Ok((tokenizer, compressed_tokenizer)) } @@ -59,21 +52,15 @@ pub async fn build_lava_uuid( let array: &arrow_array::GenericByteArray> = array .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects string array as first argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; let uid: &arrow_array::PrimitiveArray = uid .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects uint64 array as second argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; if array.len() != uid.len() { - return Err(LavaError::Parse( - "The length of the array and the uid array must be the same".to_string(), - )); + return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); } let mut texts = Vec::with_capacity(array.len()); @@ -121,21 +108,15 @@ pub async fn build_lava_bm25( let array: &arrow_array::GenericByteArray> = array .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects string array as first argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; let uid = uid .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects uint64 array as second argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; if array.len() != uid.len() { - return Err(LavaError::Parse( - "The length of the array and the uid array must be the same".to_string(), - )); + return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); } let (tokenizer, compressed_tokenizer) = get_tokenizer(tokenizer_file)?; @@ -235,8 +216,7 @@ pub async fn build_lava_bm25( let compressed_plist_offsets_offset = file.seek(SeekFrom::Current(0))?; let serialized = bincode::serialize(&plist_offsets).unwrap(); - let compressed_plist_offsets = - encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); + let compressed_plist_offsets = encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); file.write_all(&compressed_plist_offsets)?; file.write_all(&(compressed_term_dict_offset as u64).to_le_bytes())?; @@ -249,179 +229,159 @@ pub async fn build_lava_bm25( } #[tokio::main] -pub async fn build_lava_kmer( +pub async fn build_lava_substring_char( output_file_name: String, array: ArrayData, uid: ArrayData, - tokenizer_file: Option, -) -> Result<(), LavaError> { + char_skip_factor: Option, +) -> Result, LavaError> { let array = make_array(array); // let uid = make_array(ArrayData::from_pyarrow(uid)?); let uid = make_array(uid); - // if the tokenizer file is provided, check if the file exists. If it does not exist, raise an Error - let (tokenizer, compressed_tokenizer) = get_tokenizer(tokenizer_file)?; - // let vocab_size: usize = tokenizer.get_vocab_size(false); + let char_skip_factor = char_skip_factor.unwrap_or(1); let array: &arrow_array::GenericByteArray> = array .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects string array as first argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; let uid = uid .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects uint64 array as second argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; if array.len() != uid.len() { - return Err(LavaError::Parse( - "The length of the array and the uid array must be the same".to_string(), - )); + return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); } - let mut texts = Vec::with_capacity(array.len()); + let mut texts: Vec<(u64, &str)> = Vec::with_capacity(array.len()); for i in 0..array.len() { let text = array.value(i); - texts.push(text); + texts.push((uid.value(i), text)); } - let encodings = texts + // parallelize the string operations + let named_encodings = texts .into_maybe_par_iter() - .map(|text| { - let encoding = tokenizer.encode(text, false).unwrap(); - encoding.get_ids().to_vec() + .map(|(uid, text)| { + let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); + let result: Vec = if char_skip_factor == 1 { + text.chars().filter(|id| !SKIP.chars().contains(id)).map(|c| c as u8).collect() + } else { + text.chars() + .filter(|id| !SKIP.chars().contains(id)) + .enumerate() + .filter(|&(index, _)| index % char_skip_factor as usize == 1) + .map(|(_, c)| c as u8) + .collect() + }; + (vec![uid; result.len()], result) }) - .collect::>>(); + .collect::, Vec)>>(); - // get all trigrams. + let uids: Vec = named_encodings.iter().map(|(uid, _)| uid).flatten().cloned().collect::>(); + let encodings: Vec = named_encodings.into_iter().map(|(_, text)| text).flatten().collect::>(); - let mut trigrams_inverted_index: BTreeMap<(u32, u32, u32), BTreeSet> = BTreeMap::new(); - - for (i, encoding) in encodings.iter().enumerate() { - let this_uid = uid.value(i) as usize; - for j in 0..(encoding.len() as i64 - 2) { - let j = j as usize; - // let trigram = (encoding[j], encoding[j + 1], encoding[j + 2]); - let bigram = (u32::MAX, encoding[j], encoding[j + 1]); - let unigram = (u32::MAX, u32::MAX, encoding[j]); - // trigrams_inverted_index - // .entry(trigram) - // .or_insert_with(BTreeSet::new) - // .insert(this_uid as u64); - trigrams_inverted_index - .entry(bigram) - .or_insert_with(BTreeSet::new) - .insert(this_uid as u64); - trigrams_inverted_index - .entry(unigram) - .or_insert_with(BTreeSet::new) - .insert(this_uid as u64); - } + let mut suffices: Vec> = vec![]; - if encoding.len() >= 2 { - let bigram = ( - u32::MAX, - encoding[encoding.len() - 2], - encoding[encoding.len() - 1], - ); - trigrams_inverted_index - .entry(bigram) - .or_insert_with(BTreeSet::new) - .insert(this_uid as u64); - let unigram = (u32::MAX, u32::MAX, encoding[encoding.len() - 2]); - trigrams_inverted_index - .entry(unigram) - .or_insert_with(BTreeSet::new) - .insert(this_uid as u64); - } + for i in 10..encodings.len() { + suffices.push(encodings[i - 10..i].to_vec()); + } - if encoding.len() >= 1 { - let unigram = (u32::MAX, u32::MAX, encoding[encoding.len() - 1]); - trigrams_inverted_index - .entry(unigram) - .or_insert_with(BTreeSet::new) - .insert(this_uid as u64); - } + for i in encodings.len()..encodings.len() + 10 { + let mut suffix = encodings[i - 10..encodings.len()].to_vec(); + suffix.append(&mut vec![0; i - encodings.len()]); + suffices.push(suffix); } - // figure out the average length of the inverted index posting lists - // filter the trigrams to include only things where the value length is smaller than 10 + let mut sa: Vec = (0..suffices.len()).collect(); - let trigrams_inverted_index = trigrams_inverted_index - .into_iter() - .filter(|(_, value)| value.len() < (uid.value(encodings.len() - 1) / 10 * 3) as usize) - .collect::>(); + sa.par_sort_by(|&a, &b| suffices[a].cmp(&suffices[b])); - let mut avg_len: f32 = 0.0; - let mut all_lens: Vec = vec![]; - for (_, value) in trigrams_inverted_index.iter() { - avg_len += value.len() as f32; - all_lens.push(value.len()); - } - //write out all_lens as a numpy file - let mut file = File::create("all_lens.npy")?; - for val in all_lens.iter() { - file.write_all(&val.to_le_bytes())?; + let mut idx: Vec = Vec::with_capacity(encodings.len()); + let mut bwt: Vec = Vec::with_capacity(encodings.len()); + for i in 0..sa.len() { + if sa[i] == 0 { + bwt.push(encodings[encodings.len() - 1]); + idx.push(uids[uids.len() - 1]); + } else { + bwt.push(encodings[(sa[i] - 1) as usize]); + idx.push(uids[(sa[i] - 1) as usize]); + } } - avg_len /= trigrams_inverted_index.len() as f32; - println!("Average length: {}", avg_len); - let mut file = File::create(output_file_name)?; - file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; - file.write_all(&compressed_tokenizer)?; - // Handle the compressed data (for example, saving to a file or sending over a network) - println!("Number of trigrams: {}", trigrams_inverted_index.len()); + let mut fm_chunk_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - let mut plist_offsets: Vec = vec![file.seek(SeekFrom::Current(0))?]; - let mut plist_elems: Vec = vec![0]; - let mut plist_chunk = PListChunk::new()?; - let mut counter: u64 = 0; + let mut current_chunk: Vec = vec![]; + let mut current_chunk_counts: HashMap = HashMap::new(); + let mut next_chunk_counts: HashMap = HashMap::new(); - let mut term_dictionary: Vec<(u32, u32, u32)> = Vec::new(); + for i in 0..bwt.len() { + let current_tok = bwt[i]; + next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); + current_chunk.push(current_tok); - for (key, value) in trigrams_inverted_index.iter() { - if value.len() < 5 { - continue; + if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { + let serialized_counts = bincode::serialize(¤t_chunk_counts)?; + let compressed_counts = encode_all(&serialized_counts[..], 10).expect("Compression failed"); + println!("chunk size: {}", compressed_counts.len()); + file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; + file.write_all(&compressed_counts)?; + let serialized_chunk = bincode::serialize(¤t_chunk)?; + let compressed_chunk = encode_all(&serialized_chunk[..], 10).expect("Compression failed"); + file.write_all(&compressed_chunk)?; + fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + current_chunk_counts = next_chunk_counts.clone(); + current_chunk = vec![]; } - counter += 1; + } + // print out total file size so far + println!("total file size: {}", file.seek(SeekFrom::Current(0))?); - term_dictionary.push(*key); + let mut cumulative_counts: Vec = vec![0]; + for i in 0..current_chunk_counts.len() { + cumulative_counts.push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u8)).unwrap_or(&0)); + } - let written = plist_chunk.add_plist(&value.iter().map(|x| *x as u64).collect_vec())?; - if written > 1024 * 1024 || counter == trigrams_inverted_index.len() as u64 { - let bytes = plist_chunk.finalize_compression()?; - file.write_all(&bytes)?; - plist_offsets.push(plist_offsets[plist_offsets.len() - 1] + bytes.len() as u64); - plist_elems.push(counter); - plist_chunk = PListChunk::new()?; - } + let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { + let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; + let serialized_slice = bincode::serialize(slice)?; + let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); + file.write_all(&compressed_slice)?; + posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); } - let serialized_term_dictionary = bincode::serialize(&term_dictionary).unwrap(); - let compressed_term_dictionary = encode_all(&serialized_term_dictionary[..], 0) - .expect("Compression of term dictionary failed"); + let cache_start = file.seek(SeekFrom::Current(0))? as usize; - plist_offsets.append(&mut plist_elems); - let compressed_term_dict_offset = file.seek(SeekFrom::Current(0))?; - file.write_all(&compressed_term_dictionary)?; + let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; + let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_fm_chunk_offsets)?; - let compressed_plist_offsets_offset = file.seek(SeekFrom::Current(0))?; - let serialized = bincode::serialize(&plist_offsets).unwrap(); - let compressed_plist_offsets = - encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); - file.write_all(&compressed_plist_offsets)?; + let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; + let compressed_posting_list_offsets = + encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_posting_list_offsets)?; - file.write_all(&(compressed_term_dict_offset as u64).to_le_bytes())?; - file.write_all(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(encodings.len() as u64).to_le_bytes())?; + let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_total_counts = bincode::serialize(&cumulative_counts)?; + let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + file.write_all(&compressed_total_counts)?; - Ok(()) + file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(total_counts_offset as u64).to_le_bytes())?; + file.write_all(&(bwt.len() as u64).to_le_bytes())?; + + let cache_end = file.seek(SeekFrom::Current(0))? as usize; + + Ok(vec![(cache_start, cache_end)]) } #[tokio::main] @@ -440,9 +400,7 @@ pub async fn build_lava_substring( let tokenizer = if let Some(tokenizer_file) = tokenizer_file { if !std::path::Path::new(&tokenizer_file).exists() { - return Err(LavaError::Parse( - "Tokenizer file does not exist".to_string(), - )); + return Err(LavaError::Parse("Tokenizer file does not exist".to_string())); } println!("Tokenizer file: {}", tokenizer_file); Tokenizer::from_file(tokenizer_file).unwrap() @@ -451,27 +409,20 @@ pub async fn build_lava_substring( }; let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); - let compressed_tokenizer = - encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); + let compressed_tokenizer = encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); let array: &arrow_array::GenericByteArray> = array .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects string array as first argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; let uid = uid .as_any() .downcast_ref::() - .ok_or(LavaError::Parse( - "Expects uint64 array as second argument".to_string(), - ))?; + .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; if array.len() != uid.len() { - return Err(LavaError::Parse( - "The length of the array and the uid array must be the same".to_string(), - )); + return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); } let mut texts: Vec<(u64, &str)> = Vec::with_capacity(array.len()); @@ -483,27 +434,9 @@ pub async fn build_lava_substring( let mut skip_tokens: HashSet = HashSet::new(); for char in SKIP.chars() { let char_str = char.to_string(); - skip_tokens.extend( - tokenizer - .encode(char_str.clone(), false) - .unwrap() - .get_ids() - .to_vec(), - ); - skip_tokens.extend( - tokenizer - .encode(format!(" {}", char_str), false) - .unwrap() - .get_ids() - .to_vec(), - ); - skip_tokens.extend( - tokenizer - .encode(format!("{} ", char_str), false) - .unwrap() - .get_ids() - .to_vec(), - ); + skip_tokens.extend(tokenizer.encode(char_str.clone(), false).unwrap().get_ids().to_vec()); + skip_tokens.extend(tokenizer.encode(format!(" {}", char_str), false).unwrap().get_ids().to_vec()); + skip_tokens.extend(tokenizer.encode(format!("{} ", char_str), false).unwrap().get_ids().to_vec()); } let named_encodings = texts @@ -513,27 +446,13 @@ pub async fn build_lava_substring( let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); let encoding = tokenizer.encode(lower, false).unwrap(); - let result: Vec = encoding - .get_ids() - .iter() - .filter(|id| !skip_tokens.contains(id)) - .cloned() - .collect(); + let result: Vec = encoding.get_ids().iter().filter(|id| !skip_tokens.contains(id)).cloned().collect(); (vec![uid; result.len()], result) }) .collect::, Vec)>>(); - let uids: Vec = named_encodings - .iter() - .map(|(uid, _)| uid) - .flatten() - .cloned() - .collect::>(); - let encodings: Vec = named_encodings - .into_iter() - .map(|(_, text)| text) - .flatten() - .collect::>(); + let uids: Vec = named_encodings.iter().map(|(uid, _)| uid).flatten().cloned().collect::>(); + let encodings: Vec = named_encodings.into_iter().map(|(_, text)| text).flatten().collect::>(); let mut suffices: Vec> = vec![]; @@ -566,20 +485,10 @@ pub async fn build_lava_substring( suffices.push(suffix); } - // for i in 11..encodings.len() { - // let mut suffix = encodings[i - 10..i].to_vec(); - // suffix.push(encodings[i - 11]); - // suffices.push(suffix); - // } - let mut sa: Vec = (0..suffices.len()).collect(); - // let start = std::time::Instant::now(); - sa.par_sort_by(|&a, &b| suffices[a].cmp(&suffices[b])); - // let duration = start.elapsed(); - let mut idx: Vec = Vec::with_capacity(encodings.len()); let mut bwt: Vec = Vec::with_capacity(encodings.len()); for i in 0..sa.len() { @@ -592,19 +501,6 @@ pub async fn build_lava_substring( } } - // write out the bwt to a numpy array - - // let file = File::create("output.bin")?; - // let mut writer = BufWriter::new(file); - - // // Write each u32 to the file as bytes - // for number in bwt.iter() { - // writer.write_u32::(*number)?; - // } - - // Flush the buffer to ensure all data is written to the file - // writer.flush()?; - let mut file = File::create(output_file_name)?; file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; file.write_all(&compressed_tokenizer)?; @@ -617,22 +513,17 @@ pub async fn build_lava_substring( for i in 0..bwt.len() { let current_tok = bwt[i]; - next_chunk_counts - .entry(current_tok) - .and_modify(|count| *count += 1) - .or_insert(1); + next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); current_chunk.push(current_tok); if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { let serialized_counts = bincode::serialize(¤t_chunk_counts)?; - let compressed_counts = - encode_all(&serialized_counts[..], 10).expect("Compression failed"); + let compressed_counts = encode_all(&serialized_counts[..], 10).expect("Compression failed"); println!("chunk size: {}", compressed_counts.len()); file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; file.write_all(&compressed_counts)?; let serialized_chunk = bincode::serialize(¤t_chunk)?; - let compressed_chunk = - encode_all(&serialized_chunk[..], 10).expect("Compression failed"); + let compressed_chunk = encode_all(&serialized_chunk[..], 10).expect("Compression failed"); file.write_all(&compressed_chunk)?; fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); current_chunk_counts = next_chunk_counts.clone(); @@ -644,8 +535,7 @@ pub async fn build_lava_substring( let mut cumulative_counts: Vec = vec![0]; for i in 0..tokenizer.get_vocab_size(false) { - cumulative_counts - .push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u32)).unwrap_or(&0)); + cumulative_counts.push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u32)).unwrap_or(&0)); } let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; @@ -662,8 +552,7 @@ pub async fn build_lava_substring( let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; - let compressed_fm_chunk_offsets = - encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); file.write_all(&compressed_fm_chunk_offsets)?; let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; @@ -674,8 +563,7 @@ pub async fn build_lava_substring( let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; let serialized_total_counts = bincode::serialize(&cumulative_counts)?; - let compressed_total_counts: Vec = - encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); file.write_all(&compressed_total_counts)?; file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; diff --git a/src/lava/fm_chunk.rs b/src/lava/fm_chunk.rs index 680ca4d..bc0d1a8 100644 --- a/src/lava/fm_chunk.rs +++ b/src/lava/fm_chunk.rs @@ -1,34 +1,37 @@ -use std::collections::HashMap; -use zstd::stream::read::Decoder; +use super::error::LavaError; use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::io::Read; use zstd::stream::encode_all; -use super::error::LavaError; -pub(crate) struct FMChunk { - pub counts_so_far : HashMap, - pub bwt_chunk : Vec, +use zstd::stream::read::Decoder; + +pub(crate) struct FMChunk +where + T: Serialize + for<'de> Deserialize<'de> + Clone + Eq + std::hash::Hash, +{ + pub counts_so_far: HashMap, + pub bwt_chunk: Vec, } -impl FMChunk { - pub fn new( - chunk : Bytes - ) -> Result { - let compressed_counts_size = u64::from_le_bytes(chunk[0 .. 8].try_into().unwrap()); - let compressed_counts = &chunk[8 .. (compressed_counts_size + 8) as usize]; +impl FMChunk +where + T: Serialize + for<'de> Deserialize<'de> + Clone + Eq + std::hash::Hash, +{ + pub fn new(chunk: Bytes) -> Result { + let compressed_counts_size = u64::from_le_bytes(chunk[0..8].try_into().unwrap()); + let compressed_counts = &chunk[8..(compressed_counts_size + 8) as usize]; let mut decompressor = Decoder::new(compressed_counts)?; let mut serialized_counts: Vec = Vec::with_capacity(compressed_counts_size as usize); decompressor.read_to_end(&mut serialized_counts)?; - let counts: HashMap = bincode::deserialize(&serialized_counts)?; - let compressed_fm_chunk = &chunk[(compressed_counts_size + 8) as usize ..]; + let counts: HashMap = bincode::deserialize(&serialized_counts)?; + let compressed_fm_chunk = &chunk[(compressed_counts_size + 8) as usize..]; let mut decompressor = Decoder::new(compressed_fm_chunk)?; let mut serialized_fm_chunk: Vec = Vec::with_capacity(compressed_fm_chunk.len() as usize); decompressor.read_to_end(&mut serialized_fm_chunk)?; - let fm_chunk: Vec = bincode::deserialize(&serialized_fm_chunk)?; + let fm_chunk: Vec = bincode::deserialize(&serialized_fm_chunk)?; - Ok(Self { - counts_so_far : counts, - bwt_chunk : fm_chunk, - }) + Ok(Self { counts_so_far: counts, bwt_chunk: fm_chunk }) } #[allow(dead_code)] @@ -44,15 +47,13 @@ impl FMChunk { Ok(result) } - pub fn search(& self, token: u32, pos: usize) -> Result { + pub fn search(&self, token: T, pos: usize) -> Result { let mut result = *self.counts_so_far.get(&token).unwrap_or(&0); - for j in 0 .. pos { + for j in 0..pos { if self.bwt_chunk[j] == token { result += 1; } } Ok(result) } - - -} \ No newline at end of file +} diff --git a/src/lava/logcloud.rs b/src/lava/logcloud.rs new file mode 100644 index 0000000..9a953f4 --- /dev/null +++ b/src/lava/logcloud.rs @@ -0,0 +1,449 @@ +// use log::{info, warn}; +// use parquet::{column::reader, format::DictionaryPageHeader}; +// use rand::Rng; +// use tokio::{task::JoinSet, time::sleep}; + +// use crate::{ +// formats::readers::{ +// get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ClonableAsyncReader, ReaderType, +// }, +// lava::error::LavaError, +// lava::logcloud_plist::{PListChunk, PlistSize}, +// }; +// use serde::de::DeserializeOwned; +// use std::{ +// collections::{HashMap, HashSet}, +// io::Read, +// time::{Duration, Instant}, +// }; +// use zstd::stream::{encode_all, read::Decoder}; + +// async fn read_and_decompress(reader: &mut AsyncReader, start: u64, size: u64) -> Result +// where +// T: DeserializeOwned, +// { +// let compressed = reader.read_range(start, start + size).await?; +// let mut decompressor = Decoder::new(&compressed[..]).unwrap(); +// let mut decompressed = Vec::new(); +// std::io::copy(&mut decompressor, &mut decompressed)?; +// let result: T = bincode::deserialize(&decompressed)?; +// Ok(result) +// } + +// pub struct LogCloud { +// pub kauai: AsyncReader, +// pub oahu: AsyncReader, +// pub hawaii: AsyncReader, +// } + +// // std::pair> search_kauai(VirtualFileRegion * vfr, std::string query, int k) { + +// async fn query_kauai( +// reader: &mut AsyncReader, +// file_size: usize, +// query: &str, +// k: u32, +// ) -> Result<(u32, Vec), LavaError> { +// let byte_offsets = reader.read_usize_from_end(6).await?; + +// let dictionary: String = read_and_decompress(reader, 0, byte_offsets[0]).await?; + +// // Read and decompress templates +// let template: String = read_and_decompress(reader, byte_offsets[0], byte_offsets[1] - byte_offsets[0]).await?; + +// // Read template posting lists +// let template_plist: Vec> = +// read_and_decompress(reader, byte_offsets[1], byte_offsets[2] - byte_offsets[1]).await?; + +// // Read and decompress outlier strings +// let outlier: String = read_and_decompress(reader, byte_offsets[2], byte_offsets[3] - byte_offsets[2]).await?; + +// // Read outlier posting lists +// let outlier_plist: Vec> = +// read_and_decompress(reader, byte_offsets[3], byte_offsets[4] - byte_offsets[3]).await?; + +// // Read and decompress outlier type strings +// let outlier_type: String = read_and_decompress(reader, byte_offsets[4], byte_offsets[5] - byte_offsets[4]).await?; + +// // Read outlier type posting lists +// let outlier_type_pl_size = file_size as u64 - byte_offsets[5] - 6 * std::mem::size_of::() as u64; +// let outlier_type_plist: Vec> = +// read_and_decompress(reader, byte_offsets[5], outlier_type_pl_size).await?; + +// for (_, line) in dictionary.lines().enumerate() { +// if line.contains(query) { +// println!("query matched dictionary item, brute force {}", query); +// return Ok((0, Vec::new())); +// } +// } + +// let mut matched_row_groups = Vec::new(); + +// let search_text = |query: &str, +// source_str: &str, +// plists: &[Vec], +// matched_row_groups: &mut Vec, +// write: bool| { +// if write { +// println!("{}", source_str); +// } +// for (line_no, line) in source_str.lines().enumerate() { +// if let Some(_) = line.find(query) { +// println!("{} {}", line, line_no); +// let posting_list = &plists[line_no]; +// for &row_group in posting_list { +// print!("{} ", row_group); +// matched_row_groups.push(row_group); +// } +// println!(); +// } +// } +// }; + +// search_text(query, &template, &template_plist, &mut matched_row_groups, false); + +// // Print matched row groups +// for &row_group in &matched_row_groups { +// print!("{} ", row_group); +// } +// println!(); + +// search_text(query, &outlier, &outlier_plist, &mut matched_row_groups, false); + +// if matched_row_groups.len() >= k.try_into().unwrap() { +// println!("inexact query for top K satisfied by template and outlier {}", query); +// return Ok((1, matched_row_groups)); +// } + +// // Search in outlier types +// search_text(query, &outlier_type, &outlier_type_plist, &mut matched_row_groups, false); + +// if matched_row_groups.len() >= k.try_into().unwrap() { +// println!("inexact query for top K satisfied by template, outlier and outlier types {}", query); +// Ok((1, matched_row_groups)) +// } else { +// println!("inexact query for top K not satisfied by template, outlier and outlier types {}", query); +// Ok((2, matched_row_groups)) +// } +// } + +// async fn search_oahu( +// reader: &mut AsyncReader, +// file_size: usize, +// query_type: i32, +// chunks: Vec, +// query_str: &str, +// ) -> Result, LavaError> { +// // Read the metadata page length +// let metadata_page_length = reader.read_usize_from_end(1).await?[0]; + +// // Read the metadata page +// let metadata_page = reader +// .read_range( +// file_size as u64 - metadata_page_length as u64 - std::mem::size_of::() as u64, +// file_size as u64 - std::mem::size_of::() as u64, +// ) +// .await?; + +// let mut decompressor = Decoder::new(&metadata_page[..]).unwrap(); +// let mut decompressed_metadata_page: Vec = Vec::with_capacity(metadata_page.len() as usize); +// decompressor.read_to_end(&mut decompressed_metadata_page).unwrap(); + +// // Read metadata +// let num_types = u64::from_le_bytes(decompressed_metadata_page[0..8].try_into().unwrap()) as usize; +// let num_blocks = u64::from_le_bytes(decompressed_metadata_page[8..16].try_into().unwrap()) as usize; + +// let mut offset = 16; +// let type_order: Vec = (0..num_types) +// .map(|i| { +// let start = offset + i * 8; +// i32::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()) +// }) +// .collect(); + +// offset += num_types * 8; +// let type_offsets: Vec = (0..=num_types) +// .map(|i| { +// let start = offset + i * 8; +// u64::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()) as usize +// }) +// .collect(); + +// offset += (num_types + 1) * 8; +// let block_offsets: Vec = (0..=num_blocks) +// .map(|i| { +// let start = offset + i * 8; +// u64::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()) as usize +// }) +// .collect(); + +// // Find query_type in type_order +// let type_index = type_order +// .iter() +// .position(|&x| x == query_type) +// .ok_or_else(|| LavaError::Parse("Query type not found".to_string()))?; + +// let type_offset = type_offsets[type_index]; +// let num_chunks = type_offsets[type_index + 1] - type_offset; + +// // Process blocks using JoinSet +// let mut set = JoinSet::new(); + +// for chunk in chunks.into_iter().take(num_chunks) { +// let block_offset = block_offsets[type_offset + chunk] as u64; +// let next_block_offset = block_offsets[type_offset + chunk + 1] as u64; +// let block_size = next_block_offset - block_offset; + +// let mut reader_clone = reader.clone(); // Assuming AsyncReader implements Clone +// let query_str_clone = query_str.to_string(); + +// set.spawn(async move { +// let block = reader_clone.read_range(block_offset, block_offset + block_size).await.unwrap(); + +// let compressed_strings_length = u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; +// let compressed_strings = &block[8..8 + compressed_strings_length]; + +// let mut decompressor = Decoder::new(compressed_strings).unwrap(); +// let mut decompressed_strings: Vec = Vec::with_capacity(compressed_strings.len() as usize); +// decompressor.read_to_end(&mut decompressed_strings).unwrap(); + +// let compressed_plist = &block[8 + compressed_strings_length..]; +// let plist = PListChunk::from_compressed(compressed_plist).unwrap(); + +// let mut row_groups = Vec::new(); +// for (line_number, line) in String::from_utf8_lossy(&decompressed_strings).lines().enumerate() { +// if format!("\n{}\n", line).contains(&query_str_clone) { +// row_groups.extend(plist.lookup(line_number).unwrap()); +// } +// } + +// row_groups +// }); +// } + +// let mut all_row_groups = Vec::new(); +// while let Some(result) = set.join_next().await { +// let result = result.unwrap(); +// all_row_groups.extend(result); +// } + +// Ok(all_row_groups) +// } + +// const B: usize = 1024 * 1024; +// const GIVEUP: usize = 100; + +// async fn search_vfr( +// reader: &mut AsyncReader, +// wavelet_offset: u64, +// wavelet_size: u64, +// logidx_offset: u64, +// logidx_size: u64, +// query: &str, +// ) -> Result, LavaError> { +// let start_time = Instant::now(); +// let compressed_offsets_byte_offset: usize = read_and_decompress(reader, logidx_offset + logidx_size - 8, 8).await?; +// let duration = start_time.elapsed(); +// println!( +// "log_idx decompress offsets took {} milliseconds, this could choke for concurrent requests!", +// duration.as_millis() +// ); + +// let compressed_offsets: Vec = read_and_decompress( +// reader, +// logidx_offset + compressed_offsets_byte_offset as u64, +// (logidx_size - compressed_offsets_byte_offset as u64 - 8) as u64, +// ) +// .await?; + +// let chunk_offsets: Vec = bincode::deserialize(&compressed_offsets)?; + +// async fn batch_log_idx_lookup( +// chunk_offsets: &[usize], +// reader: &mut AsyncReader, +// logidx_offset: u64, +// start_idx: usize, +// end_idx: usize, +// ) -> Result, LavaError> { +// let start_chunk_offset = chunk_offsets[start_idx / B]; +// let end_chunk_offset = chunk_offsets[end_idx / B + 1]; +// let total_chunks = end_idx / B - start_idx / B + 1; + +// let compressed_chunks: Vec = read_and_decompress( +// reader, +// logidx_offset + start_chunk_offset as u64, +// (end_chunk_offset - start_chunk_offset) as u64, +// ) +// .await?; + +// let mut results = Vec::new(); +// for i in 0..total_chunks { +// let chunk_start = chunk_offsets[start_idx / B + i] - start_chunk_offset; +// let chunk_end = chunk_offsets[start_idx / B + i + 1] - start_chunk_offset; +// let log_idx: Vec = bincode::deserialize(&compressed_chunks[chunk_start..chunk_end])?; + +// let start = if i == 0 { start_idx % B } else { 0 }; +// let end = if i == total_chunks - 1 { end_idx % B } else { log_idx.len() }; + +// results.extend_from_slice(&log_idx[start..end]); +// } + +// Ok(results) +// } + +// let (start, end) = search_wavelet_tree_file(reader, wavelet_offset, wavelet_size query).await?; + +// if start == -1 || end == -1 { +// info!("no matches"); +// return Ok(vec![usize::MAX]); +// } + +// let start = start as usize; +// let end = end as usize; + +// if false { +// // (end - start > GIVEUP) { +// warn!("too many matches, giving up"); +// Ok(vec![usize::MAX]) +// } else { +// let matched_pos = batch_log_idx_lookup(&chunk_offsets, log_idx_reader, start, end).await?; + +// info!("start: {}", start); +// info!("end: {}", end); + +// Ok(matched_pos) +// } +// } + +// async fn search_hawaii( +// reader: &mut AsyncReader, +// file_size: usize, +// types: Vec, +// query: String, +// ) -> Result>, LavaError> { +// // Read metadata page size +// let metadata_page_length = reader.read_usize_from_end(1).await?[0]; + +// // Read and decompress metadata page +// let decompressed_metadata_page: Vec = read_and_decompress( +// reader, +// file_size as u64 - metadata_page_length as u64 - std::mem::size_of::() as u64, +// metadata_page_length as u64, +// ) +// .await?; + +// // Parse metadata +// let mut offset = 0; +// let num_types = u64::from_le_bytes(decompressed_metadata_page[offset..offset + 8].try_into().unwrap()) as usize; +// offset += 8; +// println!("num types: {}", num_types); + +// let num_groups = u64::from_le_bytes(decompressed_metadata_page[offset..offset + 8].try_into().unwrap()) as usize; +// offset += 8; +// println!("num groups: {}", num_groups); + +// let type_order: Vec = (0..num_types) +// .map(|i| { +// let start = offset + i * 8; +// let type_value = i32::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()); +// println!("type order: {}", type_value); +// type_value +// }) +// .collect(); +// offset += num_types * 8; + +// let chunks_in_group: Vec = (0..num_types) +// .map(|i| { +// let start = offset + i * 8; +// let chunks = usize::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()); +// println!("chunks in group: {}", chunks); +// chunks +// }) +// .collect(); +// offset += num_types * 8; + +// let type_offsets: Vec = (0..=num_types) +// .map(|i| { +// let start = offset + i * 8; +// let type_offset = usize::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()); +// println!("type offsets: {}", type_offset); +// type_offset +// }) +// .collect(); +// offset += (num_types + 1) * 8; + +// let group_offsets: Vec = (0..num_groups * 2 + 1) +// .map(|i| { +// let start = offset + i * 8; +// let group_offset = usize::from_le_bytes(decompressed_metadata_page[start..start + 8].try_into().unwrap()); +// println!("group offsets: {}", group_offset); +// group_offset +// }) +// .collect(); + +// let mut set = JoinSet::new(); +// for &type_value in &types { +// let reader_clone = reader.clone(); // Assuming AsyncReader implements Clone +// let query_clone = query.clone(); +// let type_order_clone = type_order.clone(); +// let chunks_in_group_clone = chunks_in_group.clone(); +// let type_offsets_clone = type_offsets.clone(); +// let group_offsets_clone = group_offsets.clone(); + +// set.spawn(async move { +// let type_index = type_order.iter().position(|&x| x == type_value).unwrap_or(num_types); + +// if type_index == num_types { +// return Ok((type_value, HashSet::from([usize::MAX]))); +// } + +// let chunks_in_group_for_type = chunks_in_group[type_index]; +// let type_offset = type_offsets[type_index]; +// let num_iters = type_offsets[type_index + 1] - type_offsets[type_index]; + +// println!("searching wavelet tree {} {}", type_value, num_iters); + +// let mut chunks = HashSet::new(); +// for i in (type_offset..type_offset + num_iters).step_by(2) { +// // Random delay +// sleep(Duration::from_millis(rand::thread_rng().gen_range(0..1000))).await; + +// let group_id = (i - type_offset) / 2; +// let group_chunk_offset = group_id * chunks_in_group_for_type; + +// let wavelet_offset = group_offsets[i]; +// let logidx_offset = group_offsets[i + 1]; +// let next_wavelet_offset = group_offsets[i + 2]; +// let wavelet_size = logidx_offset - wavelet_offset; +// let logidx_size = next_wavelet_offset - logidx_offset; + +// let matched_pos = search_vfr(reader, wavelet_offset, logidx_offset, &query)?; + +// for pos in matched_pos { +// chunks.insert(group_chunk_offset + pos); +// } +// } + +// Ok((type_value, chunks)) +// }); +// } + +// let mut type_chunks = HashMap::new(); +// while let Some(result) = set.join_next().await { +// match result { +// Ok(Ok((type_value, chunks))) => { +// type_chunks.insert(type_value, chunks); +// } +// Ok(Err(e)) => eprintln!("Error processing type: {:?}", e), +// Err(e) => eprintln!("Task join error: {:?}", e), +// } +// } + +// Ok(type_chunks) +// } + +// impl LogCloud { +// pub fn new(kauai: AsyncReader, oahu: AsyncReader, hawaii: AsyncReader) -> Self { +// Self { kauai, oahu, hawaii } +// } +// } diff --git a/src/lava/logcloud_plist.rs b/src/lava/logcloud_plist.rs new file mode 100644 index 0000000..0ab3a9e --- /dev/null +++ b/src/lava/logcloud_plist.rs @@ -0,0 +1,147 @@ +use std::convert::TryInto; +use std::io::{Read, Write}; +use std::mem::size_of; +use zstd::stream::read::Decoder; +use zstd::stream::write::Encoder; + +pub type PlistSize = u32; + +pub struct PListChunk { + data: Vec>, +} + +impl PListChunk { + pub fn new(data: Vec>) -> Self { + PListChunk { data } + } + + pub fn from_compressed(compressed_data: &[u8]) -> Result> { + let mut decoder = Decoder::new(compressed_data)?; + let mut data = Vec::new(); + decoder.read_to_end(&mut data)?; + + let num_posting_lists = PlistSize::from_le_bytes(data[data.len() - size_of::()..].try_into()?); + + let mut bit_array = Vec::with_capacity(num_posting_lists as usize); + for i in 0..num_posting_lists { + bit_array.push((data[i as usize / 8] & (1 << (i % 8))) != 0); + } + + let mut cursor = (num_posting_lists as usize + 7) / 8; + + let mut count_array = Vec::with_capacity(num_posting_lists as usize); + for &bit in &bit_array { + if bit { + count_array.push(PlistSize::from_le_bytes(data[cursor..cursor + size_of::()].try_into()?)); + cursor += size_of::(); + } else { + count_array.push(1); + } + } + + let mut plist_data = Vec::with_capacity(num_posting_lists as usize); + for &count in &count_array { + let mut posting_list = Vec::with_capacity(count as usize); + for _ in 0..count { + posting_list.push(PlistSize::from_le_bytes(data[cursor..cursor + size_of::()].try_into()?)); + cursor += size_of::(); + } + plist_data.push(posting_list); + } + + Ok(PListChunk { data: plist_data }) + } + + pub fn serialize(&self) -> Result, Box> { + let num_posting_lists = self.data.len(); + let bit_array_size = (num_posting_lists + 7) / 8; + let count_array_size = self.data.iter().filter(|list| list.len() > 1).count() * size_of::(); + let posting_lists_size = self.data.iter().map(|list| list.len() * size_of::()).sum::(); + let size = bit_array_size + count_array_size + posting_lists_size + size_of::(); + + let mut serialized = vec![0u8; size]; + + for (i, list) in self.data.iter().enumerate() { + if list.len() > 1 { + serialized[i / 8] |= 1 << (i % 8); + } + } + + let mut cursor = bit_array_size; + + for list in &self.data { + if list.len() > 1 { + serialized[cursor..cursor + size_of::()] + .copy_from_slice(&(list.len() as PlistSize).to_le_bytes()); + cursor += size_of::(); + } + } + + for list in &self.data { + for &item in list { + serialized[cursor..cursor + size_of::()].copy_from_slice(&item.to_le_bytes()); + cursor += size_of::(); + } + } + + serialized[cursor..].copy_from_slice(&(num_posting_lists as PlistSize).to_le_bytes()); + + let mut encoder = Encoder::new(Vec::new(), 0)?; + encoder.write_all(&serialized)?; + Ok(encoder.finish()?) + } + + pub fn data(&self) -> &Vec> { + &self.data + } + + pub fn lookup(&self, key: usize) -> Option> { + self.data.get(key).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_and_data() { + let data = vec![vec![1, 2, 3], vec![4, 5], vec![6]]; + let chunk = PListChunk::new(data.clone()); + assert_eq!(chunk.data(), &data); + } + + #[test] + fn test_lookup() { + let data = vec![vec![1, 2, 3], vec![4, 5], vec![6]]; + let chunk = PListChunk::new(data); + assert_eq!(chunk.lookup(0), Some(vec![1, 2, 3])); + assert_eq!(chunk.lookup(1), Some(vec![4, 5])); + assert_eq!(chunk.lookup(2), Some(vec![6])); + assert_eq!(chunk.lookup(3), None); + } + + #[test] + fn test_serialize_and_from_compressed() -> Result<(), Box> { + let original_data = vec![vec![1, 2, 3], vec![4], vec![5, 6, 7, 8]]; + let chunk = PListChunk::new(original_data); + + let serialized = chunk.serialize()?; + let deserialized_chunk = PListChunk::from_compressed(&serialized)?; + + assert_eq!(chunk.data(), deserialized_chunk.data()); + Ok(()) + } + + #[test] + fn test_empty_and_large_lists() -> Result<(), Box> { + let original_data = vec![vec![], vec![1], vec![2, 3], (0..1000).collect()]; + let chunk = PListChunk::new(original_data); + + let serialized = chunk.serialize()?; + let deserialized_chunk = PListChunk::from_compressed(&serialized)?; + + assert_eq!(chunk.data(), deserialized_chunk.data()); + Ok(()) + } +} diff --git a/src/lava/merge.rs b/src/lava/merge.rs index d74a7aa..9aa891b 100644 --- a/src/lava/merge.rs +++ b/src/lava/merge.rs @@ -11,9 +11,7 @@ use std::sync::{Arc, Mutex}; use zstd::stream::encode_all; use zstd::stream::read::Decoder; -use crate::formats::readers::{ - get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ReaderType, -}; +use crate::formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ReaderType}; use crate::lava::constants::*; use crate::lava::error::LavaError; use crate::lava::fm_chunk::FMChunk; @@ -21,9 +19,7 @@ use crate::lava::plist::PListChunk; use crate::lava::trie::FastTrie; use std::collections::HashMap; -use crate::vamana::{ - access::InMemoryAccessMethodF32, merge_indexes_par, EuclideanF32, IndexParams, VamanaIndex, -}; +use crate::vamana::{access::InMemoryAccessMethodF32, merge_indexes_par, EuclideanF32, IndexParams, VamanaIndex}; // @Rain chore: we need to simplify all the iterator impls @@ -37,15 +33,8 @@ struct PListIterator { impl PListIterator { // take ownership of the data structures pub async fn new(mut reader: AsyncReader, plist_offsets: Vec) -> Result { - let plist_chunk = reader - .read_range_and_decompress(plist_offsets[0], plist_offsets[1]) - .await?; - Ok(Self { - reader: reader, - plist_offsets: plist_offsets, - current_chunk_offset: 0, - current_chunk: plist_chunk, - }) + let plist_chunk = reader.read_range_and_decompress(plist_offsets[0], plist_offsets[1]).await?; + Ok(Self { reader: reader, plist_offsets: plist_offsets, current_chunk_offset: 0, current_chunk: plist_chunk }) } pub async fn advance(&mut self) -> Result<(), LavaError> { @@ -68,18 +57,13 @@ struct FMChunkIterator { reader: AsyncReader, fm_chunk_offsets: Vec, current_chunk_offset: usize, - pub current_chunk: FMChunk, + pub current_chunk: FMChunk, } impl FMChunkIterator { // take ownership of the data structures - pub async fn new( - mut reader: AsyncReader, - fm_chunk_offsets: Vec, - ) -> Result { - let buffer3 = reader - .read_range(fm_chunk_offsets[0], fm_chunk_offsets[1]) - .await?; + pub async fn new(mut reader: AsyncReader, fm_chunk_offsets: Vec) -> Result { + let buffer3 = reader.read_range(fm_chunk_offsets[0], fm_chunk_offsets[1]).await?; let current_chunk = FMChunk::new(buffer3)?; Ok(Self { @@ -109,11 +93,8 @@ impl FMChunkIterator { } pub async fn reset(&mut self) -> Result<(), LavaError> { - self.current_chunk = FMChunk::new( - self.reader - .read_range(self.fm_chunk_offsets[0], self.fm_chunk_offsets[1]) - .await?, - )?; + self.current_chunk = + FMChunk::new(self.reader.read_range(self.fm_chunk_offsets[0], self.fm_chunk_offsets[1]).await?)?; self.current_chunk_offset = 0; Ok(()) @@ -138,12 +119,9 @@ impl PListChunkIterator { ) -> Result { // read the first chunk - let buffer3 = reader - .read_range(plist_offsets[0], plist_offsets[1]) - .await?; + let buffer3 = reader.read_range(plist_offsets[0], plist_offsets[1]).await?; let result: Vec> = - PListChunk::search_compressed(buffer3.to_vec(), &(0..plist_elems[1]).collect()) - .unwrap(); + PListChunk::search_compressed(buffer3.to_vec(), &(0..plist_elems[1]).collect()).unwrap(); Ok(Self { reader: reader, @@ -179,8 +157,7 @@ impl PListChunkIterator { self.current_chunk = PListChunk::search_compressed( buffer3.to_vec(), - &(0..(self.plist_elems[self.current_chunk_offset + 1] - - self.plist_elems[self.current_chunk_offset])) + &(0..(self.plist_elems[self.current_chunk_offset + 1] - self.plist_elems[self.current_chunk_offset])) .collect(), ) .unwrap(); @@ -200,10 +177,8 @@ async fn merge_lava_uuid( assert_eq!(lava_files.len(), 2); assert_eq!(uid_offsets.len(), 2); - let (file_size1, mut reader1) = - get_file_size_and_reader(lava_files[0].clone(), reader_type.clone()).await?; - let (file_size2, mut reader2) = - get_file_size_and_reader(lava_files[1].clone(), reader_type.clone()).await?; + let (file_size1, mut reader1) = get_file_size_and_reader(lava_files[0].clone(), reader_type.clone()).await?; + let (file_size2, mut reader2) = get_file_size_and_reader(lava_files[1].clone(), reader_type.clone()).await?; // let buffer: bytes::Bytes = reader1.read_range(0, file_size1 as u64).await?; // let mut fast_trie1 = FastTrie::deserialize(buffer.to_vec()); @@ -262,13 +237,11 @@ async fn merge_lava_bm25( let num_documents = results[2]; total_num_documents += num_documents; - let compressed_token_counts = reader - .read_range(compressed_term_dict_offset, compressed_plist_offsets_offset) - .await?; + let compressed_token_counts = + reader.read_range(compressed_term_dict_offset, compressed_plist_offsets_offset).await?; let mut decompressed_token_counts: Vec = Vec::new(); - let mut decompressor: Decoder<'_, BufReader<&[u8]>> = - Decoder::new(&compressed_token_counts[..])?; + let mut decompressor: Decoder<'_, BufReader<&[u8]>> = Decoder::new(&compressed_token_counts[..])?; decompressor.read_to_end(&mut decompressed_token_counts)?; let token_counts: Vec = bincode::deserialize(&decompressed_token_counts)?; @@ -281,16 +254,12 @@ async fn merge_lava_bm25( } } - let buffer2 = reader - .read_range(compressed_plist_offsets_offset, file_size - 24) - .await?; + let buffer2 = reader.read_range(compressed_plist_offsets_offset, file_size - 24).await?; decompressor = Decoder::new(&buffer2[..])?; - let mut decompressed_serialized_plist_offsets: Vec = - Vec::with_capacity(buffer2.len() as usize); + let mut decompressed_serialized_plist_offsets: Vec = Vec::with_capacity(buffer2.len() as usize); decompressor.read_to_end(&mut decompressed_serialized_plist_offsets)?; - let this_plist_offsets: Vec = - bincode::deserialize(&decompressed_serialized_plist_offsets)?; + let this_plist_offsets: Vec = bincode::deserialize(&decompressed_serialized_plist_offsets)?; if (this_plist_offsets.len() % 2) != 0 { let err = LavaError::Parse("data corruption".to_string()); @@ -299,8 +268,7 @@ async fn merge_lava_bm25( let num_elements = this_plist_offsets.len() / 2; let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = - reader.read_range(8, 8 + compressed_tokenizer_size).await?; + let this_compressed_tokenizer: bytes::Bytes = reader.read_range(8, 8 + compressed_tokenizer_size).await?; match &compressed_tokenizer { Some(value) => assert!( @@ -380,21 +348,16 @@ async fn merge_lava_bm25( output_file.write(&compressed_token_counts)?; let serialized = bincode::serialize(&new_plist_offsets).unwrap(); - let compressed_plist_offsets = - encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); + let compressed_plist_offsets = encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); - let compressed_plist_offsets_offset = - compressed_term_dict_offset + compressed_token_counts.len() as u64; + let compressed_plist_offsets_offset = compressed_term_dict_offset + compressed_token_counts.len() as u64; output_file.write(&compressed_plist_offsets)?; output_file.write(&(compressed_term_dict_offset as u64).to_le_bytes())?; output_file.write(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; output_file.write(&(total_num_documents as u64).to_le_bytes())?; - Ok(vec![( - compressed_term_dict_offset as usize, - output_file.seek(SeekFrom::Current(0))? as usize, - )]) + Ok(vec![(compressed_term_dict_offset as usize, output_file.seek(SeekFrom::Current(0))? as usize)]) } async fn compute_interleave( @@ -491,8 +454,7 @@ async fn merge_lava_substring( // instead of bothering with wrapping this thing in Arc>. Lots of tech debt to clean up // needed for the FMChunkIterator and PListIterator let (_, mut reader) = get_file_size_and_reader(file.clone(), reader_type.clone()).await?; - let (file_size, reader1) = - get_file_size_and_reader(file.clone(), reader_type.clone()).await?; + let (file_size, reader1) = get_file_size_and_reader(file.clone(), reader_type.clone()).await?; let file_size = file_size as u64; let results = reader.read_usize_from_end(4).await?; @@ -504,8 +466,7 @@ async fn merge_lava_substring( ns.push(n); let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = - reader.read_range(8, 8 + compressed_tokenizer_size).await?; + let this_compressed_tokenizer: bytes::Bytes = reader.read_range(8, 8 + compressed_tokenizer_size).await?; match &compressed_tokenizer { Some(value) => assert!( @@ -515,15 +476,12 @@ async fn merge_lava_substring( None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), } - let fm_chunk_offsets: Vec = reader - .read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset) - .await?; - let posting_list_offsets: Vec = reader - .read_range_and_decompress(posting_list_offsets_offset, total_counts_offset) - .await?; - let cumulative_counts: Vec = reader - .read_range_and_decompress(total_counts_offset, (file_size - 32) as u64) - .await?; + let fm_chunk_offsets: Vec = + reader.read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset).await?; + let posting_list_offsets: Vec = + reader.read_range_and_decompress(posting_list_offsets_offset, total_counts_offset).await?; + let cumulative_counts: Vec = + reader.read_range_and_decompress(total_counts_offset, (file_size - 32) as u64).await?; // println!("{} {}", file, cumulative_counts.len()); @@ -623,21 +581,16 @@ async fn merge_lava_substring( for i in 0..bwt_output.len() { let current_tok = bwt_output[i]; - next_chunk_counts - .entry(current_tok) - .and_modify(|count| *count += 1) - .or_insert(1); + next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); current_chunk.push(current_tok); if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt_output.len() - 1 { let serialized_counts = bincode::serialize(¤t_chunk_counts)?; - let compressed_counts = - encode_all(&serialized_counts[..], 0).expect("Compression failed"); + let compressed_counts = encode_all(&serialized_counts[..], 0).expect("Compression failed"); output_file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; output_file.write_all(&compressed_counts)?; let serialized_chunk = bincode::serialize(¤t_chunk)?; - let compressed_chunk = - encode_all(&serialized_chunk[..], 0).expect("Compression failed"); + let compressed_chunk = encode_all(&serialized_chunk[..], 0).expect("Compression failed"); output_file.write_all(&compressed_chunk)?; fm_chunk_offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); current_chunk_counts = next_chunk_counts.clone(); @@ -645,8 +598,7 @@ async fn merge_lava_substring( } } - let mut posting_list_offsets: Vec = - vec![output_file.seek(SeekFrom::Current(0))? as usize]; + let mut posting_list_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))? as usize]; for i in (0..index_output.len()).step_by(FM_CHUNK_TOKS) { let slice = &index_output[i..std::cmp::min(index_output.len(), i + FM_CHUNK_TOKS)]; @@ -660,8 +612,7 @@ async fn merge_lava_substring( let fm_chunk_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; - let compressed_fm_chunk_offsets = - encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); output_file.write_all(&compressed_fm_chunk_offsets)?; let posting_list_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; @@ -672,8 +623,7 @@ async fn merge_lava_substring( let total_counts_offset = output_file.seek(SeekFrom::Current(0))? as usize; let serialized_total_counts = bincode::serialize(&combined_cumulative_counts)?; - let compressed_total_counts: Vec = - encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); output_file.write_all(&compressed_total_counts)?; output_file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; @@ -681,10 +631,7 @@ async fn merge_lava_substring( output_file.write_all(&(total_counts_offset as u64).to_le_bytes())?; output_file.write_all(&(bwt_output.len() as u64).to_le_bytes())?; - Ok(vec![( - cache_start, - output_file.seek(SeekFrom::Current(0))? as usize, - )]) + Ok(vec![(cache_start, output_file.seek(SeekFrom::Current(0))? as usize)]) } #[async_recursion] @@ -718,34 +665,17 @@ async fn async_parallel_merge_files( let merged_files_shared = Arc::new(Mutex::new(vec![])); let new_uid_offsets_shared = Arc::new(Mutex::new(vec![])); - let chunked_files: Vec> = files - .into_iter() - .chunks(k) - .into_iter() - .map(|chunk| chunk.collect()) - .collect(); - - let chunked_uid_offsets: Vec> = uid_offsets - .into_iter() - .chunks(k) - .into_iter() - .map(|chunk| chunk.collect()) - .collect(); - - for (file_chunk, uid_chunk) in chunked_files - .into_iter() - .zip(chunked_uid_offsets.into_iter()) - { + let chunked_files: Vec> = + files.into_iter().chunks(k).into_iter().map(|chunk| chunk.collect()).collect(); + + let chunked_uid_offsets: Vec> = + uid_offsets.into_iter().chunks(k).into_iter().map(|chunk| chunk.collect()).collect(); + + for (file_chunk, uid_chunk) in chunked_files.into_iter().zip(chunked_uid_offsets.into_iter()) { if file_chunk.len() == 1 { // If there's an odd file out, directly move it to the next level - merged_files_shared - .lock() - .unwrap() - .push(file_chunk[0].clone()); - new_uid_offsets_shared - .lock() - .unwrap() - .push(uid_chunk[0].clone()); + merged_files_shared.lock().unwrap().push(file_chunk[0].clone()); + new_uid_offsets_shared.lock().unwrap().push(uid_chunk[0].clone()); continue; } @@ -811,22 +741,15 @@ async fn async_parallel_merge_files( } // Wait for all tasks to complete, MUST BE IN ORDER due to cache_ranges! - let cache_ranges: Vec> = futures::future::join_all(tasks) - .await - .into_iter() - .collect::, _>>() - .unwrap(); + let cache_ranges: Vec> = + futures::future::join_all(tasks).await.into_iter().collect::, _>>().unwrap(); // Extract the merged files for the next level of merging - let merged_files: Vec = Arc::try_unwrap(merged_files_shared) - .expect("Lock still has multiple owners") - .into_inner() - .unwrap(); + let merged_files: Vec = + Arc::try_unwrap(merged_files_shared).expect("Lock still has multiple owners").into_inner().unwrap(); - let new_uid_offsets = Arc::try_unwrap(new_uid_offsets_shared) - .expect("Lock still has multiple owners") - .into_inner() - .unwrap(); + let new_uid_offsets = + Arc::try_unwrap(new_uid_offsets_shared).expect("Lock still has multiple owners").into_inner().unwrap(); // Recurse with the newly merged files async_parallel_merge_files( @@ -854,17 +777,9 @@ pub async fn parallel_merge_files( reader_type: ReaderType, ) -> Result, LavaError> { let do_not_delete = BTreeSet::from_iter(files.clone().into_iter()); - let result = async_parallel_merge_files( - condensed_lava_file, - files, - do_not_delete, - uid_offsets, - k, - mode, - reader_type, - None, - ) - .await?; + let result = + async_parallel_merge_files(condensed_lava_file, files, do_not_delete, uid_offsets, k, mode, reader_type, None) + .await?; Ok(result) } @@ -890,10 +805,7 @@ mod tests { pub fn test_merge_lava_substring() { let res = parallel_merge_files( "merged.lava".to_string(), - vec![ - "chinese_index/0.lava".to_string(), - "chinese_index/1.lava".to_string(), - ], + vec!["chinese_index/0.lava".to_string(), "chinese_index/1.lava".to_string()], vec![0, 1000000], 2, 1, diff --git a/src/lava/mod.rs b/src/lava/mod.rs index 45e6052..9f9bc5c 100644 --- a/src/lava/mod.rs +++ b/src/lava/mod.rs @@ -2,20 +2,21 @@ mod build; mod constants; pub mod error; mod fm_chunk; +mod logcloud; +mod logcloud_plist; mod merge; mod plist; mod search; mod trie; -pub use build::build_lava_uuid; pub use build::build_lava_bm25; -pub use build::build_lava_kmer; pub use build::build_lava_substring; +pub use build::build_lava_uuid; pub use merge::parallel_merge_files; pub use search::get_tokenizer_vocab; pub use search::search_lava_bm25; pub use search::search_lava_substring; +pub use search::search_lava_uuid; pub use search::search_lava_vector; -pub use search::search_lava_uuid; \ No newline at end of file diff --git a/src/lava/search.rs b/src/lava/search.rs index 0ef25a1..8d6e0bf 100644 --- a/src/lava/search.rs +++ b/src/lava/search.rs @@ -31,22 +31,21 @@ enum QueryParam { Substring(Vec>), Uuid(String), } - -async fn get_tokenizer_async( - mut readers: Vec, -) -> Result<(Tokenizer, Vec), LavaError> { +use std::fmt::Debug; +async fn get_tokenizer_async(mut readers: Vec) -> Result<(Tokenizer, Vec), LavaError> { let mut compressed_tokenizer: Option> = None; for i in 0..readers.len() { // now interpret this as a usize // readers[i].seek(SeekFrom::Start(0)).await?; let compressed_tokenizer_size = readers[i].read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = readers[i] - .read_range(8, 8 + compressed_tokenizer_size) - .await?; + let this_compressed_tokenizer: bytes::Bytes = readers[i].read_range(8, 8 + compressed_tokenizer_size).await?; match &compressed_tokenizer { - Some(value) => assert!(this_compressed_tokenizer == value, "detected different tokenizers between different lava files, can't search across them."), - None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()) + Some(value) => assert!( + this_compressed_tokenizer == value, + "detected different tokenizers between different lava files, can't search across them." + ), + None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), } } @@ -66,15 +65,31 @@ async fn get_tokenizer_async( Ok((tokenizer, result)) } -async fn process_substring_query( - query: Vec, +use num_traits::{AsPrimitive, PrimInt, Unsigned}; +use serde::{Deserialize, Serialize}; +use std::ops::Add; + +async fn process_substring_query( + query: Vec, n: u64, fm_chunk_offsets: &[u64], cumulative_counts: &[u64], posting_list_offsets: &[u64], reader: &mut AsyncReader, file_id: u64, -) -> Vec<(u64, u64)> { +) -> Vec<(u64, u64)> +where + T: PrimInt + + Unsigned + + Serialize + + for<'de> Deserialize<'de> + + Clone + + Eq + + std::hash::Hash + + AsPrimitive + + 'static, + usize: AsPrimitive, +{ let mut res: Vec<(u64, u64)> = vec![]; let mut start: usize = 0; let mut end: usize = n as usize; @@ -90,16 +105,10 @@ async fn process_substring_query( let end_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS + 1]; let end_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); - start = cumulative_counts[current_token as usize] as usize - + FMChunk::new(start_chunk) - .unwrap() - .search(current_token, start % FM_CHUNK_TOKS) - .unwrap() as usize; - end = cumulative_counts[current_token as usize] as usize - + FMChunk::new(end_chunk) - .unwrap() - .search(current_token, end % FM_CHUNK_TOKS) - .unwrap() as usize; + start = cumulative_counts[current_token.as_()] as usize + + FMChunk::::new(start_chunk).unwrap().search(current_token, start % FM_CHUNK_TOKS).unwrap() as usize; + end = cumulative_counts[current_token.as_()] as usize + + FMChunk::::new(end_chunk).unwrap().search(current_token, end % FM_CHUNK_TOKS).unwrap() as usize; if start >= end { return res; @@ -121,35 +130,23 @@ async fn process_substring_query( for i in 0..total_chunks { let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; - let this_chunk = plist_chunks - [(this_start - start_offset) as usize..(this_end - start_offset) as usize] - .to_vec(); + let this_chunk = + plist_chunks[(this_start - start_offset) as usize..(this_end - start_offset) as usize].to_vec(); chunk_set.spawn(async move { let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); - decompressor - .read_to_end(&mut serialized_plist_chunk) - .unwrap(); + decompressor.read_to_end(&mut serialized_plist_chunk).unwrap(); let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); let chunk_res: Vec<(u64, u64)> = if i == 0 { if total_chunks == 1 { - plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] - .iter() - .map(|&uid| (file_id, uid)) - .collect() + plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS].iter().map(|&uid| (file_id, uid)).collect() } else { - plist_chunk[start % FM_CHUNK_TOKS..] - .iter() - .map(|&uid| (file_id, uid)) - .collect() + plist_chunk[start % FM_CHUNK_TOKS..].iter().map(|&uid| (file_id, uid)).collect() } } else if i == total_chunks - 1 { - plist_chunk[..end % FM_CHUNK_TOKS] - .iter() - .map(|&uid| (file_id, uid)) - .collect() + plist_chunk[..end % FM_CHUNK_TOKS].iter().map(|&uid| (file_id, uid)).collect() } else { plist_chunk.iter().map(|&uid| (file_id, uid)).collect() }; @@ -165,14 +162,26 @@ async fn process_substring_query( res } -async fn search_substring_one_file( +async fn search_substring_one_file( file_id: u64, mut reader: AsyncReader, file_size: usize, - queries: Vec>, -) -> Result, LavaError> { - // println!("executing on thread {:?}", std::thread::current().id()); - + queries: Vec>, +) -> Result, LavaError> +where + T: PrimInt + + Unsigned + + Serialize + + for<'de> Deserialize<'de> + + Clone + + Eq + + std::hash::Hash + + AsPrimitive + + Debug + + Send + + 'static, + usize: AsPrimitive, +{ println!("{:?}", queries); let results = reader.read_usize_from_end(4).await?; @@ -181,15 +190,12 @@ async fn search_substring_one_file( let total_counts_offset = results[2]; let n = results[3]; - let fm_chunk_offsets: Vec = reader - .read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset) - .await?; - let posting_list_offsets: Vec = reader - .read_range_and_decompress(posting_list_offsets_offset, total_counts_offset) - .await?; - let cumulative_counts: Vec = reader - .read_range_and_decompress(total_counts_offset, (file_size - 32) as u64) - .await?; + let fm_chunk_offsets: Vec = + reader.read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset).await?; + let posting_list_offsets: Vec = + reader.read_range_and_decompress(posting_list_offsets_offset, total_counts_offset).await?; + let cumulative_counts: Vec = + reader.read_range_and_decompress(total_counts_offset, (file_size - 32) as u64).await?; let mut query_set = JoinSet::new(); @@ -200,7 +206,7 @@ async fn search_substring_one_file( let mut reader = reader.clone(); query_set.spawn(async move { - process_substring_query( + process_substring_query::( query, n, &fm_chunk_offsets, @@ -230,8 +236,7 @@ async fn search_uuid_one_file( let mut result: Vec<(u64, u64)> = Vec::new(); let mut start_time = Instant::now(); - let this_result: Vec = - FastTrie::query_with_reader(file_size, &mut reader, &query).await?; + let this_result: Vec = FastTrie::query_with_reader(file_size, &mut reader, &query).await?; result.extend(this_result.iter().map(|x| (file_id, *x as u64))); // println!( @@ -257,20 +262,10 @@ async fn search_generic_async( match query { QueryParam::Substring(ref value) => { - join_set.spawn(search_substring_one_file( - file_id as u64, - reader, - file_size, - value.clone(), - )); + join_set.spawn(search_substring_one_file::(file_id as u64, reader, file_size, value.clone())); } QueryParam::Uuid(ref value) => { - join_set.spawn(search_uuid_one_file( - file_id as u64, - reader, - file_size, - value.clone(), - )); + join_set.spawn(search_uuid_one_file(file_id as u64, reader, file_size, value.clone())); } _ => panic!("invalid mode"), } @@ -320,23 +315,17 @@ async fn search_bm25_async( // now read the term dictionary let token_counts = readers[i] - .read_range_and_decompress( - compressed_term_dictionary_offset, - compressed_plist_offsets_offset, - ) + .read_range_and_decompress(compressed_term_dictionary_offset, compressed_plist_offsets_offset) .await?; for query_token in query_tokens.iter() { - total_token_counts.insert( - *query_token, - total_token_counts[query_token] + token_counts[*query_token as usize] as usize, - ); + total_token_counts + .insert(*query_token, total_token_counts[query_token] + token_counts[*query_token as usize] as usize); } total_documents += num_documents as usize; - let plist_offsets = readers[i] - .read_range_and_decompress(compressed_plist_offsets_offset, file_sizes[i] as u64 - 24) - .await?; + let plist_offsets = + readers[i].read_range_and_decompress(compressed_plist_offsets_offset, file_sizes[i] as u64 - 24).await?; if plist_offsets.len() % 2 != 0 { let err = LavaError::Parse("data corruption".to_string()); @@ -353,10 +342,7 @@ async fn search_bm25_async( Err(idx) => (idx - 1, tok - term_dict_len[idx - 1]), }; - chunks_to_search - .entry((i as usize, idx)) - .or_insert_with(Vec::new) - .push((*token, offset as u64)); + chunks_to_search.entry((i as usize, idx)).or_insert_with(Vec::new).push((*token, offset as u64)); } all_plist_offsets.push(plist_offsets); @@ -370,10 +356,7 @@ async fn search_bm25_async( idf.insert( query_token, query_weight - * ((total_documents as f32 - token_count as f32 + 0.5) - / (token_count as f32 + 0.5) - + 1.0) - .ln(), + * ((total_documents as f32 - token_count as f32 + 0.5) / (token_count as f32 + 0.5) + 1.0).ln(), ); } @@ -383,12 +366,10 @@ async fn search_bm25_async( let mut join_set: JoinSet, LavaError>> = JoinSet::new(); // need to parallelize this @Rain. for (file_id, chunk_id, tokens, offsets) in - chunks_to_search - .into_iter() - .map(|((file_id, chunk_id), token_offsets)| { - let (tokens, offsets): (Vec, Vec) = token_offsets.into_iter().unzip(); - (file_id, chunk_id, Arc::new(tokens), Arc::new(offsets)) - }) + chunks_to_search.into_iter().map(|((file_id, chunk_id), token_offsets)| { + let (tokens, offsets): (Vec, Vec) = token_offsets.into_iter().unzip(); + (file_id, chunk_id, Arc::new(tokens), Arc::new(offsets)) + }) { let reader_type = match readers[file_id].reader { ClonableAsyncReader::AwsSdk(_) => ReaderType::AwsSdk, @@ -399,10 +380,7 @@ async fn search_bm25_async( let mut reader = match reader_type { ReaderType::AwsSdk | ReaderType::Http => readers[file_id].clone(), ReaderType::Local => { - get_file_size_and_reader(readers[file_id].filename.clone(), reader_type) - .await - .unwrap() - .1 + get_file_size_and_reader(readers[file_id].filename.clone(), reader_type).await.unwrap().1 } }; let start = all_plist_offsets[file_id][chunk_id]; @@ -416,8 +394,7 @@ async fn search_bm25_async( // get all the second item in the offsets into its own vector - let results: Vec> = - PListChunk::search_compressed(buffer3.to_vec(), offsets.as_ref())?; + let results: Vec> = PListChunk::search_compressed(buffer3.to_vec(), offsets.as_ref())?; let mut res = vec![]; for (i, result) in results.iter().enumerate() { @@ -494,48 +471,18 @@ pub async fn search_lava_substring( let mut skip_tokens: HashSet = HashSet::new(); for char in SKIP.chars() { let char_str = char.to_string(); - skip_tokens.extend( - tokenizer - .encode(char_str.clone(), false) - .unwrap() - .get_ids() - .to_vec(), - ); - skip_tokens.extend( - tokenizer - .encode(format!(" {}", char_str), false) - .unwrap() - .get_ids() - .to_vec(), - ); - skip_tokens.extend( - tokenizer - .encode(format!("{} ", char_str), false) - .unwrap() - .get_ids() - .to_vec(), - ); + skip_tokens.extend(tokenizer.encode(char_str.clone(), false).unwrap().get_ids().to_vec()); + skip_tokens.extend(tokenizer.encode(format!(" {}", char_str), false).unwrap().get_ids().to_vec()); + skip_tokens.extend(tokenizer.encode(format!("{} ", char_str), false).unwrap().get_ids().to_vec()); } let lower: String = query.chars().flat_map(|c| c.to_lowercase()).collect(); let encoding = tokenizer.encode(lower, false).unwrap(); - let result: Vec = encoding - .get_ids() - .iter() - .filter(|id| !skip_tokens.contains(id)) - .cloned() - .collect(); + let result: Vec = encoding.get_ids().iter().filter(|id| !skip_tokens.contains(id)).cloned().collect(); let mut query: Vec> = if let Some(sample_factor) = sample_factor { (0..sample_factor) - .map(|offset| { - result - .iter() - .skip(offset) - .step_by(sample_factor) - .cloned() - .collect::>() - }) + .map(|offset| result.iter().skip(offset).step_by(sample_factor).cloned().collect::>()) .filter(|vec| !vec.is_empty()) .collect() } else { @@ -547,13 +494,7 @@ pub async fn search_lava_substring( if let Some(token_viable_limit) = token_viable_limit { query.iter_mut().for_each(|vec| { if vec.len() > token_viable_limit { - *vec = vec - .iter() - .rev() - .take(token_viable_limit) - .rev() - .cloned() - .collect(); + *vec = vec.iter().rev().take(token_viable_limit).rev().cloned().collect(); } }); } @@ -581,10 +522,7 @@ pub fn search_lava_vector( nprobes: usize, reader_type: ReaderType, ) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); let res = rt.block_on(search_lava_vector_async(files, query, nprobes, reader_type)); rt.shutdown_background(); @@ -609,37 +547,29 @@ pub async fn search_lava_vector_async( futures.push(tokio::spawn(async move { let results = reader.read_usize_from_end(4).await.unwrap(); - let centroid_vectors_compressed_bytes = - reader.read_range(results[2], results[3]).await.unwrap(); + let centroid_vectors_compressed_bytes = reader.read_range(results[2], results[3]).await.unwrap(); // decompress them - let mut decompressor = - Decoder::new(centroid_vectors_compressed_bytes.as_ref()).unwrap(); - let mut centroid_vectors: Vec = - Vec::with_capacity(centroid_vectors_compressed_bytes.len() as usize); + let mut decompressor = Decoder::new(centroid_vectors_compressed_bytes.as_ref()).unwrap(); + let mut centroid_vectors: Vec = Vec::with_capacity(centroid_vectors_compressed_bytes.len() as usize); decompressor.read_to_end(&mut centroid_vectors).unwrap(); let centroid_vectors = bytes_to_f32_vec(¢roid_vectors); let num_vectors = centroid_vectors.len() / 128; - let array2 = - Array2::::from_shape_vec((num_vectors, 128), centroid_vectors).unwrap(); + let array2 = Array2::::from_shape_vec((num_vectors, 128), centroid_vectors).unwrap(); (num_vectors, array2) })); } - let result: Vec), tokio::task::JoinError>> = - futures::future::join_all(futures).await; + let result: Vec), tokio::task::JoinError>> = futures::future::join_all(futures).await; let end = Instant::now(); println!("Time stage 1 read: {:?}", end - start); let start = Instant::now(); - let arr_lens = result - .iter() - .map(|x| x.as_ref().unwrap().0) - .collect::>(); + let arr_lens = result.iter().map(|x| x.as_ref().unwrap().0).collect::>(); // get cumulative arr len starting from 0 let cumsum = arr_lens .iter() @@ -650,48 +580,24 @@ pub async fn search_lava_vector_async( .collect::>(); let arrays: Vec> = result.into_iter().map(|x| x.unwrap().1).collect(); - let centroids = concatenate( - Axis(0), - arrays - .iter() - .map(|array| array.view()) - .collect::>() - .as_slice(), - ) - .unwrap(); + let centroids = + concatenate(Axis(0), arrays.iter().map(|array| array.view()).collect::>().as_slice()).unwrap(); let query = Array1::::from_vec(query); let query_broadcast = query.broadcast(centroids.dim()).unwrap(); let difference = ¢roids - &query_broadcast; let norms = difference.map_axis(Axis(1), |row| row.dot(&row).sqrt()); - let mut indices_and_values: Vec<(usize, f32)> = norms - .iter() - .enumerate() - .map(|(idx, &val)| (idx, val)) - .collect(); + let mut indices_and_values: Vec<(usize, f32)> = norms.iter().enumerate().map(|(idx, &val)| (idx, val)).collect(); indices_and_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)); - let smallest_indices: Vec = indices_and_values - .iter() - .map(|&(idx, _)| idx) - .take(nprobes) - .collect(); + let smallest_indices: Vec = indices_and_values.iter().map(|&(idx, _)| idx).take(nprobes).collect(); let mut file_indices: Vec> = vec![vec![]; files.len()]; for idx in smallest_indices.iter() { // figure out which file idx based on cumsum. need to find the index of the thing that is just bigger than idx - let file_idx = cumsum - .iter() - .enumerate() - .find(|(_, &val)| val > *idx) - .unwrap() - .0; - let last_cumsum = if file_idx == 0 { - 0 - } else { - cumsum[file_idx - 1] - }; + let file_idx = cumsum.iter().enumerate().find(|(_, &val)| val > *idx).unwrap().0; + let last_cumsum = if file_idx == 0 { 0 } else { cumsum[file_idx - 1] }; let remainder = idx - last_cumsum; file_indices[file_idx].push(remainder); } @@ -719,14 +625,11 @@ pub async fn search_lava_vector_async( let pq_bytes = reader.read_range(results[0], results[1]).await.unwrap(); - let compressed_centroid_offset_bytes = - reader.read_range(results[1], results[2]).await.unwrap(); + let compressed_centroid_offset_bytes = reader.read_range(results[1], results[2]).await.unwrap(); let mut decompressor = Decoder::new(compressed_centroid_offset_bytes.as_ref()).unwrap(); let mut centroid_offsets_bytes: Vec = Vec::with_capacity(compressed_centroid_offset_bytes.len() as usize); - decompressor - .read_to_end(&mut centroid_offsets_bytes) - .unwrap(); + decompressor.read_to_end(&mut centroid_offsets_bytes).unwrap(); // now reinterpret centroid_offsets_bytes as a Vec @@ -749,8 +652,7 @@ pub async fn search_lava_vector_async( let result: Vec, Array1), tokio::task::JoinError>> = futures::future::join_all(futures).await; - let result: Vec<(Vec<(usize, u64, u64)>, Array1)> = - result.into_iter().map(|x| x.unwrap()).collect(); + let result: Vec<(Vec<(usize, u64, u64)>, Array1)> = result.into_iter().map(|x| x.unwrap()).collect(); let pq_bytes: Vec> = result.iter().map(|x| x.1.clone()).collect::>(); @@ -758,9 +660,7 @@ pub async fn search_lava_vector_async( println!("Time stage 2 read: {:?}", end - start); let start = Instant::now(); - let reader = get_reader(files[file_ids[0]].clone(), reader_type.clone()) - .await - .unwrap(); + let reader = get_reader(files[file_ids[0]].clone(), reader_type.clone()).await.unwrap(); let mut futures = FuturesUnordered::new(); for i in 0..result.len() { @@ -795,10 +695,7 @@ pub async fn search_lava_vector_async( } #[tokio::main] -pub async fn get_tokenizer_vocab( - files: Vec, - reader_type: ReaderType, -) -> Result, LavaError> { +pub async fn get_tokenizer_vocab(files: Vec, reader_type: ReaderType) -> Result, LavaError> { let (_file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type).await?; Ok(get_tokenizer_async(readers).await?.1) } @@ -814,14 +711,9 @@ mod tests { pub fn test_search_lava_one() { let file = "msmarco_index/1.lava"; - let res = search_lava_bm25( - vec![file.to_string()], - vec![6300, 15050], - vec![0.1, 0.2], - 10, - ReaderType::default(), - ) - .unwrap(); + let res = + search_lava_bm25(vec![file.to_string()], vec![6300, 15050], vec![0.1, 0.2], 10, ReaderType::default()) + .unwrap(); println!("{:?}", res); } diff --git a/src/lava/trie.rs b/src/lava/trie.rs index 7a2426c..068ac41 100644 --- a/src/lava/trie.rs +++ b/src/lava/trie.rs @@ -9,8 +9,7 @@ use std::{ use crate::{ formats::readers::{ - get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ClonableAsyncReader, - ReaderType, + get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ClonableAsyncReader, ReaderType, }, lava::error::LavaError, }; @@ -37,14 +36,7 @@ pub struct FastTrie { // Helper function to take a node and replace it with None fn take_leaf_node(node: &mut Box>) -> Box> { - std::mem::replace( - node, - Box::new(BinaryTrieNode { - data: Vec::new(), - left: None, - right: None, - }), - ) + std::mem::replace(node, Box::new(BinaryTrieNode { data: Vec::new(), left: None, right: None })) } impl FastTrie { @@ -89,11 +81,7 @@ impl FastTrie { println!("{} {:?} {:?}", k.len(), k, v.1); } - FastTrie { - root_lut, - leaf_tree_roots, - root_levels, - } + FastTrie { root_lut, leaf_tree_roots, root_levels } } // structure is serialized trie | serialized trie | ... | serialized (lut, offsets) | metadata page offset @@ -112,11 +100,8 @@ impl FastTrie { let metadata_page_offset = offsets[offsets.len() - 1]; - let metadata: ( - &BTreeMap, Option)>, - &Vec, - usize, - ) = (&self.root_lut, &offsets, self.root_levels); + let metadata: (&BTreeMap, Option)>, &Vec, usize) = + (&self.root_lut, &offsets, self.root_levels); let serialized_metadata = bincode::serialize(&metadata).unwrap(); let compressed = encode_all(&serialized_metadata[..], 10).unwrap(); @@ -156,19 +141,13 @@ impl FastTrie { }; let metadata_page_offset = reader.read_usize_from_end(1).await?[0]; - let metadata_page_bytes = reader - .read_range(metadata_page_offset, file_size as u64 - 8) - .await?; + let metadata_page_bytes = reader.read_range(metadata_page_offset, file_size as u64 - 8).await?; let mut decompressor = Decoder::new(&metadata_page_bytes[..]).unwrap(); - let mut serialized_metadata: Vec = - Vec::with_capacity(metadata_page_bytes.len() as usize); + let mut serialized_metadata: Vec = Vec::with_capacity(metadata_page_bytes.len() as usize); decompressor.read_to_end(&mut serialized_metadata).unwrap(); - let metadata: ( - BTreeMap, Option)>, - Vec, - usize, - ) = bincode::deserialize(&serialized_metadata[..]).unwrap(); + let metadata: (BTreeMap, Option)>, Vec, usize) = + bincode::deserialize(&serialized_metadata[..]).unwrap(); let lut: BTreeMap, Option)> = metadata.0; let offsets: Vec = metadata.1; let root_levels = metadata.2; @@ -188,11 +167,9 @@ impl FastTrie { let compressed_trie_bytes = reader.read_range(start as u64, end as u64).await?; let mut decompressor = Decoder::new(&compressed_trie_bytes[..]).unwrap(); - let mut serialized_trie: Vec = - Vec::with_capacity(compressed_trie_bytes.len() as usize); + let mut serialized_trie: Vec = Vec::with_capacity(compressed_trie_bytes.len() as usize); decompressor.read_to_end(&mut serialized_trie).unwrap(); - let trie: Box> = - bincode::deserialize(&serialized_trie[..]).unwrap(); + let trie: Box> = bincode::deserialize(&serialized_trie[..]).unwrap(); let result = trie.query(&query[root_levels / 8..]); return Ok(result); } @@ -205,28 +182,15 @@ impl FastTrie { async fn read_metadata( file_size: usize, reader: &mut AsyncReader, - ) -> Result< - ( - BTreeMap, Option)>, - Vec, - usize, - ), - LavaError, - > { + ) -> Result<(BTreeMap, Option)>, Vec, usize), LavaError> { let metadata_page_offset = reader.read_usize_from_end(1).await?[0]; - let metadata_page_bytes = reader - .read_range(metadata_page_offset, file_size as u64 - 8) - .await?; + let metadata_page_bytes = reader.read_range(metadata_page_offset, file_size as u64 - 8).await?; let mut decompressor = Decoder::new(&metadata_page_bytes[..]).unwrap(); - let mut serialized_metadata: Vec = - Vec::with_capacity(metadata_page_bytes.len() as usize); + let mut serialized_metadata: Vec = Vec::with_capacity(metadata_page_bytes.len() as usize); decompressor.read_to_end(&mut serialized_metadata).unwrap(); - let metadata: ( - BTreeMap, Option)>, - Vec, - usize, - ) = bincode::deserialize(&serialized_metadata[..]).unwrap(); + let metadata: (BTreeMap, Option)>, Vec, usize) = + bincode::deserialize(&serialized_metadata[..]).unwrap(); Ok(metadata) } @@ -259,10 +223,7 @@ impl FastTrie { let (mut lut1, offsets1, root_levels1) = Self::read_metadata(file_size1, reader1).await?; let (mut lut2, offsets2, root_levels2) = Self::read_metadata(file_size2, reader2).await?; - assert_eq!( - root_levels1, root_levels2, - "Root levels must be the same for both tries" - ); + assert_eq!(root_levels1, root_levels2, "Root levels must be the same for both tries"); for (_, v) in lut1.iter_mut() { v.0.iter_mut().for_each(|x| *x += uid_offset_0); @@ -284,13 +245,7 @@ impl FastTrie { // read the thing from lut1 match offset { Some(x) => { - let node = Self::read_and_adjust_node( - reader1, - offsets1[x], - offsets1[x + 1], - uid_offset_0, - ) - .await?; + let node = Self::read_and_adjust_node(reader1, offsets1[x], offsets1[x + 1], uid_offset_0).await?; let serialized_node = bincode::serialize(&node).unwrap(); let _ = output_file.write(&encode_all(&serialized_node[..], 10).unwrap()); offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); @@ -310,13 +265,7 @@ impl FastTrie { // read the thing from lut1 match offset { Some(x) => { - let node = Self::read_and_adjust_node( - reader2, - offsets2[x], - offsets2[x + 1], - uid_offset_1, - ) - .await?; + let node = Self::read_and_adjust_node(reader2, offsets2[x], offsets2[x + 1], uid_offset_1).await?; let serialized_node = bincode::serialize(&node).unwrap(); let _ = output_file.write(&encode_all(&serialized_node[..], 10).unwrap()); offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); @@ -336,20 +285,10 @@ impl FastTrie { let (values2, offset2) = lut2.get(key).unwrap(); match (*offset1, *offset2) { (Some(x1), Some(x2)) => { - let node1 = Self::read_and_adjust_node( - reader1, - offsets1[x1], - offsets1[x1 + 1], - uid_offset_0, - ) - .await?; - let node2 = Self::read_and_adjust_node( - reader2, - offsets2[x2], - offsets2[x2 + 1], - uid_offset_1, - ) - .await?; + let node1 = + Self::read_and_adjust_node(reader1, offsets1[x1], offsets1[x1 + 1], uid_offset_0).await?; + let node2 = + Self::read_and_adjust_node(reader2, offsets2[x2], offsets2[x2 + 1], uid_offset_1).await?; let mut node = node1; node.extend(*node2); @@ -363,13 +302,8 @@ impl FastTrie { } (Some(x1), None) => { - let node = Self::read_and_adjust_node( - reader1, - offsets1[x1], - offsets1[x1 + 1], - uid_offset_0, - ) - .await?; + let node = + Self::read_and_adjust_node(reader1, offsets1[x1], offsets1[x1 + 1], uid_offset_0).await?; let serialized_node = bincode::serialize(&node).unwrap(); let _ = output_file.write(&encode_all(&serialized_node[..], 10).unwrap()); offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); @@ -380,13 +314,8 @@ impl FastTrie { } (None, Some(x2)) => { - let node = Self::read_and_adjust_node( - reader2, - offsets2[x2], - offsets2[x2 + 1], - uid_offset_1, - ) - .await?; + let node = + Self::read_and_adjust_node(reader2, offsets2[x2], offsets2[x2 + 1], uid_offset_1).await?; let serialized_node = bincode::serialize(&node).unwrap(); let _ = output_file.write(&encode_all(&serialized_node[..], 10).unwrap()); offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); @@ -408,11 +337,8 @@ impl FastTrie { println!("{:?}", root_lut); - let metadata: ( - &BTreeMap, Option)>, - &Vec, - usize, - ) = (&root_lut, &offsets, root_levels1); + let metadata: (&BTreeMap, Option)>, &Vec, usize) = + (&root_lut, &offsets, root_levels1); let serialized_metadata = bincode::serialize(&metadata).unwrap(); let compressed = encode_all(&serialized_metadata[..], 10).unwrap(); @@ -428,20 +354,15 @@ impl FastTrie { } pub fn deserialize(bytes: Vec) -> Self { - let metadata_page_offset = - u64::from_le_bytes(bytes[(bytes.len() - 8)..].try_into().unwrap()); + let metadata_page_offset = u64::from_le_bytes(bytes[(bytes.len() - 8)..].try_into().unwrap()); let metadata_page_bytes = &bytes[(metadata_page_offset as usize)..bytes.len() - 8]; let mut decompressor = Decoder::new(&metadata_page_bytes[..]).unwrap(); - let mut serialized_metadata: Vec = - Vec::with_capacity(metadata_page_bytes.len() as usize); + let mut serialized_metadata: Vec = Vec::with_capacity(metadata_page_bytes.len() as usize); decompressor.read_to_end(&mut serialized_metadata).unwrap(); - let metadata: ( - BTreeMap, Option)>, - Vec, - usize, - ) = bincode::deserialize(&serialized_metadata[..]).unwrap(); + let metadata: (BTreeMap, Option)>, Vec, usize) = + bincode::deserialize(&serialized_metadata[..]).unwrap(); let lut: BTreeMap, Option)> = metadata.0; let offsets: Vec = metadata.1; @@ -459,11 +380,7 @@ impl FastTrie { leaf_tree_roots.push(node); } - FastTrie { - root_lut: lut, - leaf_tree_roots: leaf_tree_roots, - root_levels: metadata.2, - } + FastTrie { root_lut: lut, leaf_tree_roots: leaf_tree_roots, root_levels: metadata.2 } } // extend and consume the second FastTrie. In memory method. @@ -495,8 +412,7 @@ impl FastTrie { let new_offset = match offset { Some(x) => { let idx = self.leaf_tree_roots.len(); - self.leaf_tree_roots - .push(take_leaf_node(&mut t2.leaf_tree_roots[*x])); + self.leaf_tree_roots.push(take_leaf_node(&mut t2.leaf_tree_roots[*x])); Some(idx) } None => None, @@ -519,16 +435,14 @@ impl FastTrie { v.1 = match (v.1, v2.1) { (Some(x), Some(y)) => { println!("merging leaf tree roots {:?} {:?}", x, y); - let owned_t2_leaf_tree_root = - take_leaf_node(&mut t2.leaf_tree_roots[y]); + let owned_t2_leaf_tree_root = take_leaf_node(&mut t2.leaf_tree_roots[y]); self.leaf_tree_roots[x].extend(*owned_t2_leaf_tree_root); Some(x) } (Some(x), None) => Some(x), (None, Some(y)) => { let idx = self.leaf_tree_roots.len(); - self.leaf_tree_roots - .push(take_leaf_node(&mut t2.leaf_tree_roots[y])); + self.leaf_tree_roots.push(take_leaf_node(&mut t2.leaf_tree_roots[y])); Some(idx) } (None, None) => None, @@ -547,11 +461,7 @@ impl FastTrie { impl BinaryTrieNode { pub fn new() -> BinaryTrieNode { - BinaryTrieNode { - left: None, - right: None, - data: Vec::new(), - } + BinaryTrieNode { left: None, right: None, data: Vec::new() } } /// Builds a binary trie from a list of strings and their corresponding indices. @@ -562,11 +472,7 @@ impl BinaryTrieNode { } /// Build, specifying extra bits - pub fn build_extra( - strs: &[Vec], - str_data: &[Vec], - extra_bits: usize, - ) -> BinaryTrieNode { + pub fn build_extra(strs: &[Vec], str_data: &[Vec], extra_bits: usize) -> BinaryTrieNode { // big endian let get_bit = |stri: usize, i: usize| -> bool { let chr = i / 8; @@ -580,10 +486,7 @@ impl BinaryTrieNode { // lcp[0] := lcp[n] := 0 for i in 0..strs.len() - 1 { let mut j = 0; - while j < strs[i].len() * 8 - && j < strs[i + 1].len() * 8 - && get_bit(i, j) == get_bit(i + 1, j) - { + while j < strs[i].len() * 8 && j < strs[i + 1].len() * 8 && get_bit(i, j) == get_bit(i + 1, j) { j += 1; } lcp[i + 1] = j; @@ -691,10 +594,7 @@ impl Default for BinaryTrieNode { /// Merges two tries into a new trie. /// Indices are kept as-is. -pub fn merge_tries( - t1: &BinaryTrieNode, - t2: &BinaryTrieNode, -) -> BinaryTrieNode { +pub fn merge_tries(t1: &BinaryTrieNode, t2: &BinaryTrieNode) -> BinaryTrieNode { let mut output = BinaryTrieNode::new(); output.data.extend(t1.data.clone()); output.data.extend(t2.data.clone()); @@ -704,10 +604,7 @@ pub fn merge_tries( } else if t2.left.is_none() { output.left = t1.left.clone(); } else { - output.left = Some(Box::new(merge_tries( - t1.left.as_ref().unwrap(), - t2.left.as_ref().unwrap(), - ))); + output.left = Some(Box::new(merge_tries(t1.left.as_ref().unwrap(), t2.left.as_ref().unwrap()))); } if t1.right.is_none() { @@ -715,10 +612,7 @@ pub fn merge_tries( } else if t2.right.is_none() { output.right = t1.right.clone(); } else { - output.right = Some(Box::new(merge_tries( - t1.right.as_ref().unwrap(), - t2.right.as_ref().unwrap(), - ))); + output.right = Some(Box::new(merge_tries(t1.right.as_ref().unwrap(), t2.right.as_ref().unwrap()))); } output