From d07109172aa6e953e55e4678ebab2c0fafb7a693 Mon Sep 17 00:00:00 2001 From: Ziheng Wang Date: Sat, 8 Jun 2024 22:22:48 -0700 Subject: [PATCH] in middle of vector search perf eng --- python/rottnest/pele.py | 62 +++++++++++--------- src/formats/parquet.rs | 2 +- src/formats/readers/mod.rs | 37 ++++++++++++ src/lava/search.rs | 117 +++++++++++++++++++++++++------------ src/lava_py/lava.rs | 18 +++--- 5 files changed, 163 insertions(+), 73 deletions(-) diff --git a/python/rottnest/pele.py b/python/rottnest/pele.py index 7b2d143..431594f 100644 --- a/python/rottnest/pele.py +++ b/python/rottnest/pele.py @@ -35,7 +35,7 @@ def get_fs_from_file_path(filepath): def get_daft_io_config_from_file_path(filepath): if filepath.startswith("s3://"): - fs = daft.io.IOConfig(s3 = daft.io.S3Config(force_virtual_addressing = (True if os.getenv('AWS_VIRTUAL_HOST_STYLE') else False), endpoint_override = os.getenv('AWS_ENDPOINT_URL'))) + fs = daft.io.IOConfig(s3 = daft.io.S3Config(force_virtual_addressing = (True if os.getenv('AWS_VIRTUAL_HOST_STYLE') else False), endpoint_url = os.getenv('AWS_ENDPOINT_URL'))) else: fs = daft.io.IOConfig() @@ -43,7 +43,7 @@ def get_daft_io_config_from_file_path(filepath): def read_metadata_file(file_path: str): - table = pq.read_table(file_path.lstrip("s3://"), filesystem = get_fs_from_file_path(file_path)) + table = pq.read_table(file_path.replace("s3://",''), filesystem = get_fs_from_file_path(file_path)) try: cache_ranges = json.loads(table.schema.metadata[b'cache_ranges'].decode()) @@ -55,7 +55,7 @@ def read_metadata_file(file_path: str): def read_columns(file_paths: list, row_groups: list, row_nr: list[list]): def read_parquet_file(file, row_group, row_nr): - f = pq.ParquetFile(file.lstrip("s3://"), filesystem=get_fs_from_file_path(file)) + f = pq.ParquetFile(file.replace("s3://",''), filesystem=get_fs_from_file_path(file)) return f.read_row_group(row_group, columns=['id']).take(row_nr) with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: # Control the number of parallel threads @@ -115,6 +115,7 @@ def index_file_bm25(file_path: str, column_name: str, name = uuid.uuid4().hex, i cache_ranges = rottnest.build_lava_bm25(f"{name}.lava", arr, uid, tokenizer_file) + # do not attempt to manually edit the metadata. It is Parquet, but it is Varsity Parquet to ensure performance. file_data = file_data.to_arrow() file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)}) pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd') @@ -331,7 +332,12 @@ def get_result_from_index_result(metadata: polars.DataFrame, index_search_result return result, column_name, metadata.with_row_count('__metadata_key__') def get_metadata_and_populate_cache(indices: List[str]): - metadatas = [read_metadata_file(f"{index_name}.meta") for i, index_name in enumerate(indices)] + + # metadatas = [read_metadata_file(f"{index_name}.meta") for i, index_name in enumerate(indices)] + + metadatas = daft.table.read_parquet_into_pyarrow_bulk([f"{index_name}.meta" for index_name in indices], io_config = get_daft_io_config_from_file_path(indices[0])) + metadatas = [(polars.from_arrow(i), json.loads(i.schema.metadata[b'cache_ranges'].decode())) for i in metadatas] + metadata = polars.concat([f[0].with_columns(polars.lit(i).alias("file_id").cast(polars.Int64)) for i, f in enumerate(metadatas)]) cache_dir = os.getenv("ROTTNEST_CACHE_DIR") if cache_dir: @@ -346,7 +352,7 @@ def search_index_uuid(indices: List[str], query: str, K: int): metadata = get_metadata_and_populate_cache(indices) - index_search_results = rottnest.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K) + index_search_results = rottnest.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K, "aws") print(index_search_results) if len(index_search_results) == 0: @@ -362,7 +368,7 @@ def search_index_substring(indices: List[str], query: str, K: int): metadata = get_metadata_and_populate_cache(indices) - index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K) + index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws") print(index_search_results) if len(index_search_results) == 0: @@ -387,38 +393,40 @@ def search_index_vector(indices: List[str], query: np.array, K: int, columns = [ # uids and codes are list of lists, where each sublist corresponds to an index. pq is a list of bytes # length is the same as the list of indices start = time.time() - results = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], query, nprobes) + valid_file_ids, pq_bytes, arrs = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], query, nprobes, "aws") + + # print(results) print("INDEX SEARCH TIME", time.time() - start) file_ids = [] uids = [] codes = [] + pqs = {} + start = time.time() - for i, result in enumerate(results): - if result is None: - continue - else: - f = open("tmp.pq", "wb") - f.write(result[1].tobytes()) - pq = faiss.read_ProductQuantizer("tmp.pq") - # os.remove("tmp.pq") - - for arr in result[0]: - plist_length = np.frombuffer(arr[:4], dtype = np.uint32).item() - plist = np.frombuffer(arr[4: plist_length * 4 + 4], dtype = np.uint32) - this_codes = np.frombuffer(arr[plist_length * 4 + 4:], dtype = np.uint8).reshape((plist_length, -1)) - - decoded = pq.decode(this_codes) - this_norms = np.linalg.norm(decoded - query, axis = 1).argsort()[:refine] - codes.append(decoded[this_norms]) - uids.append(plist[this_norms]) - file_ids.append(np.ones(len(this_norms)) * i) + for i, pq_bytes in zip(valid_file_ids, pq_bytes): + f = open("tmp.pq", "wb") + f.write(pq_bytes.tobytes()) + pqs[i] = faiss.read_ProductQuantizer("tmp.pq") + os.remove("tmp.pq") + + for (file_id, arr) in arrs: + plist_length = np.frombuffer(arr[:4], dtype = np.uint32).item() + plist = np.frombuffer(arr[4: plist_length * 4 + 4], dtype = np.uint32) + this_codes = np.frombuffer(arr[plist_length * 4 + 4:], dtype = np.uint8).reshape((plist_length, -1)) + + decoded = pqs[file_id].decode(this_codes) + this_norms = np.linalg.norm(decoded - query, axis = 1).argsort()[:refine] + codes.append(decoded[this_norms]) + uids.append(plist[this_norms]) + file_ids.append(np.ones(len(this_norms)) * file_id) file_ids = np.hstack(file_ids).astype(np.int64) uids = np.hstack(uids).astype(np.int64) codes = np.vstack(codes) fp_rerank = np.linalg.norm(query - codes, axis = 1).argsort()[:refine] + print("PQ COMPUTE TIME", time.time() - start) file_ids = file_ids[fp_rerank] @@ -428,6 +436,8 @@ def search_index_vector(indices: List[str], query: np.array, K: int, columns = [ index_search_results = list(set([(file_id, uid) for file_id, uid in zip(file_ids, uids)])) + print(index_search_results) + start = time.time() result, column_name, metadata = get_result_from_index_result(metadata, index_search_results) print("RESULT TIME", time.time() - start) diff --git a/src/formats/parquet.rs b/src/formats/parquet.rs index 0d24f8c..73892ba 100644 --- a/src/formats/parquet.rs +++ b/src/formats/parquet.rs @@ -211,7 +211,7 @@ async fn parse_metadatas( }) .collect::>() .await; - let res = futures::future::join_all(handles).await; + let res: Vec> = futures::future::join_all(handles).await; let mut metadatas = HashMap::new(); diff --git a/src/formats/readers/mod.rs b/src/formats/readers/mod.rs index c5aafd2..b54b9cd 100644 --- a/src/formats/readers/mod.rs +++ b/src/formats/readers/mod.rs @@ -213,6 +213,43 @@ pub async fn get_file_sizes_and_readers( Ok((file_sizes, readers)) } +pub async fn get_readers( + files: &[String], + reader_type: ReaderType, +) -> Result, LavaError> { + let tasks: Vec<_> = files + .iter() + .map(|file| { + let file = file.clone(); + let reader_type = reader_type.clone(); + tokio::spawn(async move { get_reader(file, reader_type).await }) + }) + .collect(); + + // Wait for all tasks to complete + let results = futures::future::join_all(tasks).await; + + // Process results, separating out file sizes and readers + let mut readers = Vec::new(); + + for result in results { + match result { + Ok(Ok(reader)) => { + readers.push(reader); + } + Ok(Err(e)) => return Err(e), // Handle error from inner task + Err(e) => { + return Err(LavaError::Parse(format!( + "Task join error: {}", + e.to_string() + ))) + } // Handle join error + } + } + + Ok(readers) +} + pub async fn get_file_size_and_reader( file: String, reader_type: ReaderType, diff --git a/src/lava/search.rs b/src/lava/search.rs index a15c266..ba4784e 100644 --- a/src/lava/search.rs +++ b/src/lava/search.rs @@ -16,12 +16,12 @@ use crate::lava::plist::PListChunk; use crate::vamana::vamana::VectorAccessMethod; use crate::vamana::{access::ReaderAccessMethodF32, access::InMemoryAccessMethodF32, EuclideanF32, IndexParams, VamanaIndex}; use crate::{ - formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ClonableAsyncReader, ReaderType}, + formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, get_readers, get_reader, AsyncReader, ClonableAsyncReader, ReaderType}, lava::error::LavaError, }; use std::time::Instant; use tokenizers::tokenizer::Tokenizer; - +use std::collections::BTreeMap; use ordered_float::OrderedFloat; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; @@ -496,19 +496,20 @@ pub async fn search_lava_vector( query: &Vec, nprobes: usize, reader_type: ReaderType, -) -> Result>, Array1)>>, LavaError> { +) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { + + let start = Instant::now(); + let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; - - let mut join_set = JoinSet::new(); - for file_id in 0..readers.len() { + let mut futures = Vec::new(); + + for _ in 0..readers.len() { let mut reader = readers.remove(0); - join_set.spawn(async move { + futures.push(tokio::spawn(async move { let results = reader.read_usize_from_end(4).await.unwrap(); - println!("results {:?}", results); - let centroid_vectors_compressed_bytes = reader .read_range(results[2], results[3]) .await @@ -523,18 +524,17 @@ pub async fn search_lava_vector( let array2 = Array2::::from_shape_vec((1000,128), centroid_vectors).unwrap(); (results, array2) - }); + })); } - let mut result: Vec<(Vec, Array2)> = vec![]; - while let Some(res) = join_set.join_next().await { - let res = res.unwrap(); - result.push(res); - } + let result: Vec, Array2), tokio::task::JoinError>> = futures::future::join_all(futures).await; - join_set.shutdown().await; + let end = Instant::now(); + println!("Time stage 1 read: {:?}", end - start); - let arrays: Vec> = result.into_iter().map(|(_, array)| array).collect(); + let start = Instant::now(); + + let arrays: Vec> = result.into_iter().map(|x| x.unwrap().1).collect(); let centroids = concatenate(Axis(0), arrays.iter().map(|array| array.view()).collect::>().as_slice()).unwrap(); let query = Array1::::from_vec(query.clone()); let query_broadcast = query.broadcast(centroids.dim()).unwrap(); @@ -551,19 +551,27 @@ pub async fn search_lava_vector( file_indices[*idx / 1000].push(*idx % 1000 as usize); } + let end = Instant::now(); + println!("Time math: {:?}", end - start); + + + let start = Instant::now(); + let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; - let mut join_set = JoinSet::new(); + + let mut file_ids = vec![]; + let mut futures = Vec::new(); for file_id in 0..readers.len() { let mut reader = readers.remove(0); + if file_indices[file_id].len() == 0 { + continue; + } let my_idx: Vec = file_indices[file_id].clone(); + file_ids.push(file_id); - join_set.spawn(async move { - - if my_idx.len() == 0 { - return None; - } + futures.push(tokio::spawn(async move { let results = reader.read_usize_from_end(4).await.unwrap(); @@ -590,29 +598,62 @@ pub async fn search_lava_vector( centroid_offsets.push(value); } - let mut this_result: Vec> = vec![]; + let mut this_result: Vec<(usize, u64, u64)> = vec![]; for idx in my_idx.iter() { - let codes_and_plist = reader - .read_range(centroid_offsets[*idx], centroid_offsets[*idx + 1]) - .await - .unwrap(); - let arr = Array1::::from_vec(codes_and_plist.to_vec()); - this_result.push(arr); + this_result.push((file_id, centroid_offsets[*idx], centroid_offsets[*idx + 1])); } - Some((this_result, Array1::::from_vec(pq_bytes.to_vec()))) - }); + (this_result, Array1::::from_vec(pq_bytes.to_vec())) + })); } - let mut result: Vec>, Array1)>> = vec![]; - while let Some(res) = join_set.join_next().await { - let res = res.unwrap(); - result.push(res); + let result: Vec, Array1),tokio::task::JoinError>> = futures::future::join_all(futures).await; + let result: Vec<(Vec<(usize, u64, u64)>, Array1)> = result.into_iter().map(|x| x.unwrap()).collect(); + + let pq_bytes: Vec> = result + .iter() + .map(|x| x.1.clone()) + .collect::>(); + + let end = Instant::now(); + println!("Time stage 2 read: {:?}", end - start); + + let start = Instant::now(); + + let mut readers_map: BTreeMap = BTreeMap::new(); + for file_id in file_ids.iter() { + let reader = get_reader(files[*file_id].clone(), reader_type.clone()).await.unwrap(); + readers_map.insert(*file_id, reader); } - join_set.shutdown().await; + let mut futures = Vec::new(); + for i in 0 .. result.len() { + let to_read = result[i].0.clone(); + for (file_id, start, end) in to_read.into_iter() { + // let file_name = files[file_id].clone(); + // let my_reader_type = reader_type.clone(); + // let mut reader = get_reader(file_name, my_reader_type).await.unwrap(); + let start_time = Instant::now(); + let mut reader = readers_map.get_mut(&file_id).unwrap().clone(); + println!("Time to get reader {:?}", Instant::now() - start_time); + futures.push(tokio::spawn(async move { + + let start_time = Instant::now(); + let codes_and_plist = reader.read_range(start, end).await.unwrap(); + println!("Time to read {:?}", Instant::now() - start_time); + (file_id, Array1::::from_vec(codes_and_plist.to_vec())) + })); + } + } - Ok(result) + let ranges: Vec), tokio::task::JoinError>> = futures::future::join_all(futures).await; + let ranges: Vec<(usize, Array1)> = ranges.into_iter().map(|x| x.unwrap()).collect(); + + + let end = Instant::now(); + println!("Time stage 3 read: {:?}", end - start); + + Ok((file_ids, pq_bytes, ranges)) } #[tokio::main] diff --git a/src/lava_py/lava.rs b/src/lava_py/lava.rs index 5946484..4f9e61a 100644 --- a/src/lava_py/lava.rs +++ b/src/lava_py/lava.rs @@ -59,10 +59,10 @@ pub fn search_lava_vector( query: Vec, nprobes: usize, reader_type: Option<&PyString>, -) -> Result>>, Py>)>>, LavaError> { +) -> Result<(Vec, Vec>>, Vec<(usize, Py>)>), LavaError> { let reader_type = reader_type.map(|x| x.to_string()).unwrap_or_default(); - let result: Vec>, Array1)>> = py.allow_threads(|| { + let result: (Vec, Vec>, Vec<(usize, Array1)>) = py.allow_threads(|| { lava::search_lava_vector( files, &query, @@ -71,16 +71,18 @@ pub fn search_lava_vector( ) })?; - let result = result + let x = result.1 .into_iter() - .map(|x| match x { - Some((x, y)) => Some((x.into_iter().map(|x| x.into_pyarray(py).to_owned()).collect(), y.into_pyarray(py).to_owned())), - None => None, - }) + .map(|x| x.into_pyarray(py).to_owned()) + .collect(); + + let y = result.2 + .into_iter() + .map(|(x, y)| (x, y.into_pyarray(py).to_owned())) .collect(); - Ok(result) + Ok((result.0, x, y)) }