Skip to content

Commit

Permalink
in middle of vector search perf eng
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Jun 9, 2024
1 parent d77fdc9 commit d071091
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 73 deletions.
62 changes: 36 additions & 26 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def get_fs_from_file_path(filepath):
def get_daft_io_config_from_file_path(filepath):

if filepath.startswith("s3://"):
fs = daft.io.IOConfig(s3 = daft.io.S3Config(force_virtual_addressing = (True if os.getenv('AWS_VIRTUAL_HOST_STYLE') else False), endpoint_override = os.getenv('AWS_ENDPOINT_URL')))
fs = daft.io.IOConfig(s3 = daft.io.S3Config(force_virtual_addressing = (True if os.getenv('AWS_VIRTUAL_HOST_STYLE') else False), endpoint_url = os.getenv('AWS_ENDPOINT_URL')))
else:
fs = daft.io.IOConfig()

return fs

def read_metadata_file(file_path: str):

table = pq.read_table(file_path.lstrip("s3://"), filesystem = get_fs_from_file_path(file_path))
table = pq.read_table(file_path.replace("s3://",''), filesystem = get_fs_from_file_path(file_path))

try:
cache_ranges = json.loads(table.schema.metadata[b'cache_ranges'].decode())
Expand All @@ -55,7 +55,7 @@ def read_metadata_file(file_path: str):
def read_columns(file_paths: list, row_groups: list, row_nr: list[list]):

def read_parquet_file(file, row_group, row_nr):
f = pq.ParquetFile(file.lstrip("s3://"), filesystem=get_fs_from_file_path(file))
f = pq.ParquetFile(file.replace("s3://",''), filesystem=get_fs_from_file_path(file))
return f.read_row_group(row_group, columns=['id']).take(row_nr)

with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: # Control the number of parallel threads
Expand Down Expand Up @@ -115,6 +115,7 @@ def index_file_bm25(file_path: str, column_name: str, name = uuid.uuid4().hex, i

cache_ranges = rottnest.build_lava_bm25(f"{name}.lava", arr, uid, tokenizer_file)

# do not attempt to manually edit the metadata. It is Parquet, but it is Varsity Parquet to ensure performance.
file_data = file_data.to_arrow()
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd')
Expand Down Expand Up @@ -331,7 +332,12 @@ def get_result_from_index_result(metadata: polars.DataFrame, index_search_result
return result, column_name, metadata.with_row_count('__metadata_key__')

def get_metadata_and_populate_cache(indices: List[str]):
metadatas = [read_metadata_file(f"{index_name}.meta") for i, index_name in enumerate(indices)]

# metadatas = [read_metadata_file(f"{index_name}.meta") for i, index_name in enumerate(indices)]

metadatas = daft.table.read_parquet_into_pyarrow_bulk([f"{index_name}.meta" for index_name in indices], io_config = get_daft_io_config_from_file_path(indices[0]))
metadatas = [(polars.from_arrow(i), json.loads(i.schema.metadata[b'cache_ranges'].decode())) for i in metadatas]

metadata = polars.concat([f[0].with_columns(polars.lit(i).alias("file_id").cast(polars.Int64)) for i, f in enumerate(metadatas)])
cache_dir = os.getenv("ROTTNEST_CACHE_DIR")
if cache_dir:
Expand All @@ -346,7 +352,7 @@ def search_index_uuid(indices: List[str], query: str, K: int):

metadata = get_metadata_and_populate_cache(indices)

index_search_results = rottnest.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K)
index_search_results = rottnest.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K, "aws")
print(index_search_results)

if len(index_search_results) == 0:
Expand All @@ -362,7 +368,7 @@ def search_index_substring(indices: List[str], query: str, K: int):

metadata = get_metadata_and_populate_cache(indices)

index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K)
index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws")
print(index_search_results)

if len(index_search_results) == 0:
Expand All @@ -387,38 +393,40 @@ def search_index_vector(indices: List[str], query: np.array, K: int, columns = [
# uids and codes are list of lists, where each sublist corresponds to an index. pq is a list of bytes
# length is the same as the list of indices
start = time.time()
results = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], query, nprobes)
valid_file_ids, pq_bytes, arrs = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], query, nprobes, "aws")

# print(results)
print("INDEX SEARCH TIME", time.time() - start)

file_ids = []
uids = []
codes = []

pqs = {}

start = time.time()
for i, result in enumerate(results):
if result is None:
continue
else:
f = open("tmp.pq", "wb")
f.write(result[1].tobytes())
pq = faiss.read_ProductQuantizer("tmp.pq")
# os.remove("tmp.pq")

for arr in result[0]:
plist_length = np.frombuffer(arr[:4], dtype = np.uint32).item()
plist = np.frombuffer(arr[4: plist_length * 4 + 4], dtype = np.uint32)
this_codes = np.frombuffer(arr[plist_length * 4 + 4:], dtype = np.uint8).reshape((plist_length, -1))

decoded = pq.decode(this_codes)
this_norms = np.linalg.norm(decoded - query, axis = 1).argsort()[:refine]
codes.append(decoded[this_norms])
uids.append(plist[this_norms])
file_ids.append(np.ones(len(this_norms)) * i)
for i, pq_bytes in zip(valid_file_ids, pq_bytes):
f = open("tmp.pq", "wb")
f.write(pq_bytes.tobytes())
pqs[i] = faiss.read_ProductQuantizer("tmp.pq")
os.remove("tmp.pq")

for (file_id, arr) in arrs:
plist_length = np.frombuffer(arr[:4], dtype = np.uint32).item()
plist = np.frombuffer(arr[4: plist_length * 4 + 4], dtype = np.uint32)
this_codes = np.frombuffer(arr[plist_length * 4 + 4:], dtype = np.uint8).reshape((plist_length, -1))

decoded = pqs[file_id].decode(this_codes)
this_norms = np.linalg.norm(decoded - query, axis = 1).argsort()[:refine]
codes.append(decoded[this_norms])
uids.append(plist[this_norms])
file_ids.append(np.ones(len(this_norms)) * file_id)

file_ids = np.hstack(file_ids).astype(np.int64)
uids = np.hstack(uids).astype(np.int64)
codes = np.vstack(codes)
fp_rerank = np.linalg.norm(query - codes, axis = 1).argsort()[:refine]

print("PQ COMPUTE TIME", time.time() - start)

file_ids = file_ids[fp_rerank]
Expand All @@ -428,6 +436,8 @@ def search_index_vector(indices: List[str], query: np.array, K: int, columns = [

index_search_results = list(set([(file_id, uid) for file_id, uid in zip(file_ids, uids)]))

print(index_search_results)

start = time.time()
result, column_name, metadata = get_result_from_index_result(metadata, index_search_results)
print("RESULT TIME", time.time() - start)
Expand Down
2 changes: 1 addition & 1 deletion src/formats/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async fn parse_metadatas(
})
.collect::<Vec<_>>()
.await;
let res = futures::future::join_all(handles).await;
let res: Vec<Result<(String, ParquetMetaData), tokio::task::JoinError>> = futures::future::join_all(handles).await;

let mut metadatas = HashMap::new();

Expand Down
37 changes: 37 additions & 0 deletions src/formats/readers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,43 @@ pub async fn get_file_sizes_and_readers(
Ok((file_sizes, readers))
}

pub async fn get_readers(
files: &[String],
reader_type: ReaderType,
) -> Result<Vec<AsyncReader>, LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone();
let reader_type = reader_type.clone();
tokio::spawn(async move { get_reader(file, reader_type).await })
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok(reader)) => {
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => {
return Err(LavaError::Parse(format!(
"Task join error: {}",
e.to_string()
)))
} // Handle join error
}
}

Ok(readers)
}

pub async fn get_file_size_and_reader(
file: String,
reader_type: ReaderType,
Expand Down
117 changes: 79 additions & 38 deletions src/lava/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ use crate::lava::plist::PListChunk;
use crate::vamana::vamana::VectorAccessMethod;
use crate::vamana::{access::ReaderAccessMethodF32, access::InMemoryAccessMethodF32, EuclideanF32, IndexParams, VamanaIndex};
use crate::{
formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ClonableAsyncReader, ReaderType},
formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, get_readers, get_reader, AsyncReader, ClonableAsyncReader, ReaderType},
lava::error::LavaError,
};
use std::time::Instant;
use tokenizers::tokenizer::Tokenizer;

use std::collections::BTreeMap;
use ordered_float::OrderedFloat;
use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};

Expand Down Expand Up @@ -496,19 +496,20 @@ pub async fn search_lava_vector(
query: &Vec<f32>,
nprobes: usize,
reader_type: ReaderType,
) -> Result<Vec<Option<(Vec<Array1<u8>>, Array1<u8>)>>, LavaError> {
) -> Result<(Vec<usize>, Vec<Array1<u8>>, Vec<(usize, Array1<u8>)>), LavaError> {

let start = Instant::now();

let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?;

let mut join_set = JoinSet::new();

for file_id in 0..readers.len() {
let mut futures = Vec::new();

for _ in 0..readers.len() {
let mut reader = readers.remove(0);

join_set.spawn(async move {
futures.push(tokio::spawn(async move {
let results = reader.read_usize_from_end(4).await.unwrap();

println!("results {:?}", results);

let centroid_vectors_compressed_bytes = reader
.read_range(results[2], results[3])
.await
Expand All @@ -523,18 +524,17 @@ pub async fn search_lava_vector(
let array2 = Array2::<f32>::from_shape_vec((1000,128), centroid_vectors).unwrap();

(results, array2)
});
}));
}

let mut result: Vec<(Vec<u64>, Array2<f32>)> = vec![];
while let Some(res) = join_set.join_next().await {
let res = res.unwrap();
result.push(res);
}
let result: Vec<Result<(Vec<u64>, Array2<f32>), tokio::task::JoinError>> = futures::future::join_all(futures).await;

join_set.shutdown().await;
let end = Instant::now();
println!("Time stage 1 read: {:?}", end - start);

let arrays: Vec<Array2<f32>> = result.into_iter().map(|(_, array)| array).collect();
let start = Instant::now();

let arrays: Vec<Array2<f32>> = result.into_iter().map(|x| x.unwrap().1).collect();
let centroids = concatenate(Axis(0), arrays.iter().map(|array| array.view()).collect::<Vec<_>>().as_slice()).unwrap();
let query = Array1::<f32>::from_vec(query.clone());
let query_broadcast = query.broadcast(centroids.dim()).unwrap();
Expand All @@ -551,19 +551,27 @@ pub async fn search_lava_vector(
file_indices[*idx / 1000].push(*idx % 1000 as usize);
}

let end = Instant::now();
println!("Time math: {:?}", end - start);


let start = Instant::now();

let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?;
let mut join_set = JoinSet::new();

let mut file_ids = vec![];
let mut futures = Vec::new();

for file_id in 0..readers.len() {

let mut reader = readers.remove(0);
if file_indices[file_id].len() == 0 {
continue;
}
let my_idx: Vec<usize> = file_indices[file_id].clone();
file_ids.push(file_id);

join_set.spawn(async move {

if my_idx.len() == 0 {
return None;
}
futures.push(tokio::spawn(async move {

let results = reader.read_usize_from_end(4).await.unwrap();

Expand All @@ -590,29 +598,62 @@ pub async fn search_lava_vector(
centroid_offsets.push(value);
}

let mut this_result: Vec<Array1<u8>> = vec![];
let mut this_result: Vec<(usize, u64, u64)> = vec![];

for idx in my_idx.iter() {
let codes_and_plist = reader
.read_range(centroid_offsets[*idx], centroid_offsets[*idx + 1])
.await
.unwrap();
let arr = Array1::<u8>::from_vec(codes_and_plist.to_vec());
this_result.push(arr);
this_result.push((file_id, centroid_offsets[*idx], centroid_offsets[*idx + 1]));
}
Some((this_result, Array1::<u8>::from_vec(pq_bytes.to_vec())))
});
(this_result, Array1::<u8>::from_vec(pq_bytes.to_vec()))
}));
}

let mut result: Vec<Option<(Vec<Array1<u8>>, Array1<u8>)>> = vec![];
while let Some(res) = join_set.join_next().await {
let res = res.unwrap();
result.push(res);
let result: Vec<Result<(Vec<(usize, u64, u64)>, Array1<u8>),tokio::task::JoinError>> = futures::future::join_all(futures).await;
let result: Vec<(Vec<(usize, u64, u64)>, Array1<u8>)> = result.into_iter().map(|x| x.unwrap()).collect();

let pq_bytes: Vec<Array1<u8>> = result
.iter()
.map(|x| x.1.clone())
.collect::<Vec<_>>();

let end = Instant::now();
println!("Time stage 2 read: {:?}", end - start);

let start = Instant::now();

let mut readers_map: BTreeMap<usize, AsyncReader> = BTreeMap::new();
for file_id in file_ids.iter() {
let reader = get_reader(files[*file_id].clone(), reader_type.clone()).await.unwrap();
readers_map.insert(*file_id, reader);
}

join_set.shutdown().await;
let mut futures = Vec::new();
for i in 0 .. result.len() {
let to_read = result[i].0.clone();
for (file_id, start, end) in to_read.into_iter() {
// let file_name = files[file_id].clone();
// let my_reader_type = reader_type.clone();
// let mut reader = get_reader(file_name, my_reader_type).await.unwrap();
let start_time = Instant::now();
let mut reader = readers_map.get_mut(&file_id).unwrap().clone();
println!("Time to get reader {:?}", Instant::now() - start_time);
futures.push(tokio::spawn(async move {

let start_time = Instant::now();
let codes_and_plist = reader.read_range(start, end).await.unwrap();
println!("Time to read {:?}", Instant::now() - start_time);
(file_id, Array1::<u8>::from_vec(codes_and_plist.to_vec()))
}));
}
}

Ok(result)
let ranges: Vec<Result<(usize, Array1<u8>), tokio::task::JoinError>> = futures::future::join_all(futures).await;
let ranges: Vec<(usize, Array1<u8>)> = ranges.into_iter().map(|x| x.unwrap()).collect();


let end = Instant::now();
println!("Time stage 3 read: {:?}", end - start);

Ok((file_ids, pq_bytes, ranges))
}

#[tokio::main]
Expand Down
Loading

0 comments on commit d071091

Please sign in to comment.