Skip to content

Commit

Permalink
attempt to implement merge and more async cancer
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Apr 9, 2024
1 parent 99fe3f6 commit 4dcfecd
Show file tree
Hide file tree
Showing 12 changed files with 1,011 additions and 710 deletions.
2 changes: 1 addition & 1 deletion python/rottnest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import rottnest as rottnest
from .pele import search_index_bm25, search_index_substring, search_index_vector, \
merge_index_bm25, merge_index_substring, \
merge_index_bm25, merge_index_substring, merge_index_vector, \
index_file_bm25, index_file_substring, index_file_vector

__doc__ = rottnest.__doc__
Expand Down
16 changes: 16 additions & 0 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ def query_expansion_keyword(tokenizer_vocab: List[str], query: str):
print("Expanded tokens: ", tokens)
return tokens, token_ids, weights

def merge_index_vector(new_index_name: str, index_names: List[str]):

assert len(index_names) > 1

# first read the metadata files and merge those
metadatas = [read_metadata_file(f"{index_name}.meta").with_columns(polars.lit(i).alias("file_id").cast(polars.Int64)) for i, index_name in enumerate(index_names)]
data_page_rows = [np.cumsum(np.hstack([[0] , np.array(metadata["data_page_rows"])])) for metadata in metadatas]
uid_to_metadata = [[(a,b,c,d,e) for a,b,c,d,e in zip(metadata["file_path"], metadata["row_groups"], metadata["data_page_offsets"],
metadata["data_page_sizes"], metadata["dictionary_page_sizes"])] for metadata in metadatas]

metadata = polars.concat(metadatas)
assert len(metadata["column_name"].unique()) == 1, "index is not allowed to span multiple column names"
column_name = metadata["column_name"].unique()[0]

rottnest.merge_lava_vector(new_index_name, [f"{name}.lava" for name in index_names], column_name, data_page_rows, uid_to_metadata)

def search_index_vector(indices: List[str], query: np.array, K: int):

metadatas = [read_metadata_file(f"{index_name}.meta").with_columns(polars.lit(i).alias("file_id").cast(polars.Int64)) for i, index_name in enumerate(indices)]
Expand Down
269 changes: 262 additions & 7 deletions src/lava/merge.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use async_recursion::async_recursion;
use bincode;
use bit_vec::BitVec;
use itertools::Itertools;
use ndarray::Array2;
use opendal::raw::oio::ReadExt;
use opendal::services::Fs;
use opendal::{Operator, Writer};
use std::collections::BTreeSet;
use std::env;
use std::fs::File;
use std::io::{BufRead, BufReader, Cursor, Read, Seek, SeekFrom, Write};
use std::sync::{Arc, Mutex};
use tokio::io::AsyncReadExt;
use zstd::bulk::compress;
use zstd::stream::encode_all;
use zstd::stream::read::Decoder;

use async_recursion::async_recursion;
use itertools::Itertools;
use opendal::{Operator, Writer};
use std::env;
use std::sync::{Arc, Mutex};

use crate::formats::io::{AsyncReader, READER_BUFFER_SIZE, WRITER_BUFFER_SIZE};
use crate::formats::io::{
get_file_sizes_and_readers, AsyncReader, READER_BUFFER_SIZE, WRITER_BUFFER_SIZE,
};
use crate::lava::constants::*;
use crate::lava::error::LavaError;
use crate::lava::fm_chunk::FMChunk;
use crate::lava::plist::PListChunk;
use std::collections::HashMap;

use crate::vamana::{
access::ReaderAccessMethodF32, merge_indexes_par, EuclideanF32, IndexParams, VamanaIndex,
};

// @Rain chore: we need to simplify all the iterator impls

struct PListIterator {
Expand Down Expand Up @@ -655,6 +661,81 @@ async fn merge_lava_substring(
Ok(())
}

async fn merge_lava_vector(
condensed_lava_file: &str,
lava_files: Vec<String>,
column_name: &str,
uid_nrows: &Vec<Vec<usize>>,
uid_to_metadatas: &Vec<Vec<(String, usize, usize, usize, usize)>>,
) -> Result<(), LavaError> {
assert_eq!(lava_files.len(), 2);
assert_eq!(uid_nrows.len(), 2);
assert_eq!(uid_to_metadatas.len(), 2);

let (file_sizes, mut readers) = get_file_sizes_and_readers(&lava_files).await?;

let mut indices: Vec<VamanaIndex<f32, EuclideanF32, ReaderAccessMethodF32>> = vec![];

let mut all_dim: Option<u64> = None;

for i in 0..2 {
let mut reader = readers.remove(0);
let num_points = reader.read_u64_le().await?;
let dim = reader.read_u64_le().await?;
match all_dim {
Some(d) => assert_eq!(dim, d),
None => {
all_dim.replace(dim);
}
}
let start = reader.read_u64_le().await?;

let compressed_nlist = reader.read_range(24, file_sizes[i] as u64).await?;
let mut decompressor = Decoder::new(&compressed_nlist[..])?;
let mut serialized_nlist: Vec<u8> = Vec::with_capacity(compressed_nlist.len() as usize);
decompressor.read_to_end(&mut serialized_nlist)?;
let nlist: Array2<usize> = bincode::deserialize(&serialized_nlist)?;

let reader_access_method = ReaderAccessMethodF32 {
dim: dim as usize,
num_points: num_points as usize,
column_name: column_name.to_string(),
uid_nrows: &uid_nrows[i],
// uid to (file_path, row_group, page_offset, page_size, dict_page_size)
uid_to_metadata: &uid_to_metadatas[i],
};
let index: VamanaIndex<f32, EuclideanF32, _> = VamanaIndex::hydrate(
reader_access_method,
IndexParams {
num_neighbors: 32,
search_frontier_size: 32,
pruning_threshold: 2.0,
},
nlist,
start as usize,
);
indices.push(index);
}

let index0 = indices.remove(0);
let index1 = indices.remove(0);
let index = merge_indexes_par(index0, index1);

let num_points = index.num_points();
let start = index.start;
let nlist = index.neighbors;
let bytes = bincode::serialize(&nlist)?;
let compressed_nlist: Vec<u8> = encode_all(&bytes[..], 0).expect("Compression failed");

let mut file = File::create(condensed_lava_file)?;
file.write_all(&(num_points as u64).to_le_bytes())?;
file.write_all(&(all_dim.unwrap() as u64).to_le_bytes())?;
file.write_all(&(start as u64).to_le_bytes())?;
file.write_all(&compressed_nlist)?;

Ok(())
}

#[async_recursion]
async fn async_parallel_merge_files(
condensed_lava_file: String,
Expand Down Expand Up @@ -745,6 +826,7 @@ async fn async_parallel_merge_files(
}
}

// no race condition since everybody pushes the same value to new_uid_offsets_clone
merged_files_clone.lock().unwrap().push(merged_filename);
new_uid_offsets_clone.lock().unwrap().push(0);
Result::<(), LavaError>::Ok(())
Expand Down Expand Up @@ -785,6 +867,157 @@ async fn async_parallel_merge_files(
}
}

#[async_recursion]
async fn async_parallel_merge_vector_files(
condensed_lava_file: String,
files: Vec<String>,
do_not_delete: BTreeSet<String>,
column_name: &str,
uid_nrows: Vec<Vec<usize>>,
uid_to_metadatas: Vec<Vec<(String, usize, usize, usize, usize)>>,
K: usize,
) -> Result<(), LavaError> {
assert_eq!(K, 2);

match files.len() {
0 => Err(LavaError::Parse("out of chunks".to_string())), // Assuming LavaError can be constructed like this
1 => {
// the recursion will end here in this case. rename the files[0] to the supposed output name
std::fs::rename(files[0].clone(), condensed_lava_file).unwrap();
Ok(())
}
_ => {
// More than one file, need to merge
let mut tasks = vec![];
let merged_files = Arc::new(Mutex::new(vec![]));
let merged_uid_nrows = Arc::new(Mutex::new(vec![]));
let merged_uid_to_metadata = Arc::new(Mutex::new(vec![]));

let chunked_files: Vec<Vec<String>> = files
.into_iter()
.chunks(K)
.into_iter()
.map(|chunk| chunk.collect())
.collect();

let chunked_uid_nrows: Vec<Vec<Vec<usize>>> = uid_nrows
.into_iter() // Use iter() to get an iterator over references
.chunks(K) // This comes from itertools
.into_iter()
.map(|chunk| chunk.collect::<Vec<_>>()) // Collect each chunk into a Vec
.collect();

let chunked_uid_metadata: Vec<Vec<Vec<(String, usize, usize, usize, usize)>>> =
uid_to_metadatas
.into_iter()
.chunks(K)
.into_iter()
.map(|chunk| chunk.collect::<Vec<_>>())
.collect();

for ((file_chunk, uid_nrows_chunk), uid_metadata_chunk) in chunked_files
.into_iter()
.zip(chunked_uid_nrows.into_iter())
.zip(chunked_uid_metadata.into_iter())
{
if file_chunk.len() == 1 {
// If there's an odd file out, directly move it to the next level
merged_files.lock().unwrap().push(file_chunk[0].clone());
merged_uid_nrows
.lock()
.unwrap()
.push(uid_nrows_chunk[0].clone());
merged_uid_to_metadata
.lock()
.unwrap()
.push(uid_metadata_chunk[0].clone());
continue;
}

let merged_files_clone = Arc::clone(&merged_files);
let new_uid_nrows_clone = Arc::clone(&merged_uid_nrows);
let new_uid_metadata_clone = Arc::clone(&merged_uid_to_metadata);
let do_not_delete_clone = do_not_delete.clone();
let column_name = column_name.to_string(); // Convert &str to String

let task = tokio::spawn(async move {
let my_uuid = uuid::Uuid::new_v4();
let merged_filename = my_uuid.to_string(); // Define this function based on your requirements

println!("mergin {:?}", file_chunk);

merge_lava_vector(
&merged_filename,
file_chunk.clone(),
&column_name,
&uid_nrows_chunk,
&uid_metadata_chunk,
)
.await
.unwrap();

// now go delete the input filesx

for file in file_chunk {
if !do_not_delete_clone.contains(&file) {
println!("deleting {}", file);
std::fs::remove_file(file).unwrap();
}
}

merged_files_clone.lock().unwrap().push(merged_filename);
new_uid_nrows_clone
.lock()
.unwrap()
.push(uid_nrows_chunk.concat());
new_uid_metadata_clone
.lock()
.unwrap()
.push(uid_metadata_chunk.concat());
Result::<(), LavaError>::Ok(())
});

tasks.push(task);
}

// Wait for all tasks to complete
let _: Vec<_> = futures::future::join_all(tasks)
.await
.into_iter()
.collect::<Result<_, _>>()
.unwrap();

// Extract the merged files for the next level of merging
let merged_files: Vec<String> = Arc::try_unwrap(merged_files)
.expect("Lock still has multiple owners")
.into_inner()
.unwrap();

let new_uid_nrows = Arc::try_unwrap(merged_uid_nrows)
.expect("Lock still has multiple owners")
.into_inner()
.unwrap();

let new_uid_metadatas = Arc::try_unwrap(merged_uid_to_metadata)
.expect("Lock still has multiple owners")
.into_inner()
.unwrap();

// Recurse with the newly merged files
async_parallel_merge_vector_files(
condensed_lava_file,
merged_files,
do_not_delete,
column_name,
new_uid_nrows,
new_uid_metadatas,
K,
)
.await
}
}
}

#[tokio::main]
pub async fn parallel_merge_files(
condensed_lava_file: String,
Expand All @@ -806,6 +1039,28 @@ pub async fn parallel_merge_files(
Ok(result)
}

#[tokio::main]
pub async fn parallel_merge_vector_files(
condensed_lava_file: String,
files: Vec<String>,
column_name: &str,
uid_nrows: Vec<Vec<usize>>,
uid_to_metadatas: Vec<Vec<(String, usize, usize, usize, usize)>>,
) -> Result<(), LavaError> {
let do_not_delete = BTreeSet::from_iter(files.clone().into_iter());
let result = async_parallel_merge_vector_files(
condensed_lava_file,
files,
do_not_delete,
column_name,
uid_nrows,
uid_to_metadatas,
2,
)
.await?;
Ok(result)
}

#[cfg(test)]
mod tests {
use crate::lava::merge::parallel_merge_files;
Expand Down
2 changes: 2 additions & 0 deletions src/lava/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub use build::build_lava_substring;
pub use build::build_lava_vector;

pub use merge::parallel_merge_files;
pub use merge::parallel_merge_vector_files;

pub use search::get_tokenizer_vocab;
pub use search::search_lava_bm25;
pub use search::search_lava_substring;
Expand Down
6 changes: 2 additions & 4 deletions src/lava/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,10 @@ async fn search_vector_async(

let futures: Vec<_> = results
.iter()
.map(|(file_id, n)| reader_access_methods[*file_id].get_vec_async(*n))
.map(|(file_id, n)| reader_access_methods[*file_id].get_vec(*n))
.collect();

let vectors: Vec<Result<Vec<f32>, LavaError>> = futures::future::join_all(futures).await;
let vectors: Result<Vec<Vec<f32>>, LavaError> = vectors.into_iter().collect();
let vectors: Vec<Vec<f32>> = vectors?;
let vectors: Vec<Vec<f32>> = futures::future::join_all(futures).await;
let rows = vectors.len();
let cols = vectors[0].len();
let vectors: Vec<f32> = vectors.into_iter().flatten().collect();
Expand Down
Loading

0 comments on commit 4dcfecd

Please sign in to comment.