Skip to content

Commit

Permalink
working vector search
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Apr 4, 2024
1 parent 0d5a256 commit 46c1d57
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ bitvector = "0.1.5"
ndarray = { version = "0.15.6", features = ["rayon", "serde"] }
numpy = "0.20.0"
num-traits = "0.2.18"

ordered-float = "4.2.0"

[profile.release]
lto = false
Expand Down
8 changes: 4 additions & 4 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,17 @@ def search_index_vector(indices: List[str], query: np.array, K: int):
assert len(metadata["column_name"].unique()) == 1, "index is not allowed to span multiple column names"
column_name = metadata["column_name"].unique()[0]

index_search_results = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], column_name, data_page_rows, uid_to_metadata, query, K)
print(index_search_results)
index_search_results, vectors = rottnest.search_lava_vector([f"{index_name}.lava" for index_name in indices], column_name, data_page_rows, uid_to_metadata, query, K)

import pdb; pdb.set_trace()
print(index_search_results)
print(vectors)

if len(index_search_results) == 0:
return None

uids = polars.from_dict({"file_id": [i[0] for i in index_search_results], "uid": [i[1] for i in index_search_results]})



metadata = metadata.join(uids, on = ["file_id", "uid"])

result = pyarrow.chunked_array(rottnest.read_indexed_pages(column_name, metadata["file_path"].to_list(), metadata["row_groups"].to_list(),
Expand Down
90 changes: 70 additions & 20 deletions src/formats/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use opendal::services::{Fs, S3};
use opendal::Operator;
use opendal::Reader;
use std::env;
use std::io::{Read,SeekFrom};
use std::io::{Read, SeekFrom};
use std::ops::{Deref, DerefMut};
use zstd::stream::read::Decoder;
use tokio::pin;
use zstd::stream::read::Decoder;

use tokio::io::{AsyncReadExt, AsyncSeekExt};

Expand Down Expand Up @@ -72,10 +72,15 @@ impl AsyncReader {
}

// theoretically we should try to return different types here, but Vec<u64> is def. the most common
pub async fn read_range_and_decompress(&mut self, from: u64, to: u64) -> Result<Vec<u64>, LavaError> {
pub async fn read_range_and_decompress(
&mut self,
from: u64,
to: u64,
) -> Result<Vec<u64>, LavaError> {
let compressed_posting_list_offsets = self.read_range(from, to).await?;
let mut decompressor = Decoder::new(&compressed_posting_list_offsets[..])?;
let mut serialized_posting_list_offsets: Vec<u8> = Vec::with_capacity(compressed_posting_list_offsets.len() as usize);
let mut serialized_posting_list_offsets: Vec<u8> =
Vec::with_capacity(compressed_posting_list_offsets.len() as usize);
decompressor.read_to_end(&mut serialized_posting_list_offsets)?;
let result: Vec<u64> = bincode::deserialize(&serialized_posting_list_offsets)?;
Ok(result)
Expand Down Expand Up @@ -112,7 +117,7 @@ impl From<&str> for S3Builder {
if let Ok(_value) = env::var("AWS_VIRTUAL_HOST_STYLE") {
builder.enable_virtual_host_style();
}

S3Builder(builder)
}
}
Expand Down Expand Up @@ -156,18 +161,63 @@ impl From<FsBuilder> for Operators {
}
}

pub fn get_operator_and_filename_from_file(file: String) -> (Operator, String) {
let mut operator = if file.starts_with("s3://") {
Operators::from(S3Builder::from(file.as_str())).into_inner()
} else {
let current_path = env::current_dir().unwrap();
Operators::from(FsBuilder::from(current_path.to_str().expect("no path"))).into_inner()
};

let filename = if file.starts_with("s3://") {
file[5..].split("/").collect::<Vec<&str>>().join("/")
} else {
file.to_string()
};
(operator, filename)
}
pub(crate) async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncReader>), LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone(); // Clone file name to move into the async block
tokio::spawn(async move {
// Determine the operator based on the file scheme
let operator = if file.starts_with("s3://") {
Operators::from(S3Builder::from(file.as_str())).into_inner()
} else {
let current_path = env::current_dir()?;
Operators::from(FsBuilder::from(current_path.to_str().expect("no path")))
.into_inner()
};

// Extract filename
let filename = if file.starts_with("s3://") {
file[5..].split('/').collect::<Vec<_>>()[1..].join("/")
} else {
file.clone()
};

// Create the reader
let reader: AsyncReader = operator
.clone()
.reader_with(&filename)
.buffer(READER_BUFFER_SIZE)
.await?
.into();

// Get the file size
let file_size: u64 = operator.stat(&filename).await?.content_length();

Ok::<_, LavaError>((file_size as usize, reader))
})
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse("Task join error: {}".to_string())), // Handle join error
}
}

Ok((file_sizes, readers))
}
103 changes: 31 additions & 72 deletions src/lava/search.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,27 @@
use bytes::Bytes;
use core::num;
use futures::{FutureExt, SinkExt};
use itertools::Itertools;
use ndarray::Array2;
use std::collections::BTreeSet;
use std::env;
use std::{
collections::{HashMap, HashSet},
io::{BufRead, BufReader, Cursor, Read, SeekFrom},
};
use tokio::task::JoinSet;
use zstd::stream::read::Decoder;

use crate::formats::parquet::read_indexed_pages;
use crate::lava::constants::*;
use crate::lava::fm_chunk::FMChunk;
use crate::vamana::vamana::{Distance, VectorAccessMethod};
use crate::vamana::{access::ReaderAccessMethodF32, EuclideanF32, IndexParams, VamanaIndex};
use crate::{formats::io::READER_BUFFER_SIZE, lava::plist::PListChunk};
use crate::{
formats::io::{AsyncReader, FsBuilder, Operators, S3Builder},
formats::io::{get_file_sizes_and_readers, AsyncReader, FsBuilder, Operators, S3Builder},
lava::error::LavaError,
};

use futures::future::{AbortHandle, Abortable, Aborted, Join};
use std::sync::Arc;
use std::time::Duration;
use tokenizers::tokenizer::Tokenizer;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use tokio::sync::{self, mpsc, Mutex};

use ordered_float::OrderedFloat;

async fn get_tokenizer_async(
mut readers: Vec<AsyncReader>,
Expand Down Expand Up @@ -343,8 +337,9 @@ async fn search_vector_async(
uid_to_metadatas: &Vec<Vec<(String, usize, usize, usize, usize)>>,
query: &Vec<f32>,
k: usize,
) -> Result<Vec<usize>, LavaError> {
let mut results: Vec<usize> = vec![];
) -> Result<(Vec<(usize, usize)>, Array2<f32>), LavaError> {
let mut results: BTreeSet<(OrderedFloat<f32>, usize, usize)> = BTreeSet::new();
let mut reader_access_methods: Vec<ReaderAccessMethodF32> = vec![];

for i in 0..readers.len() {
let num_points = readers[i].read_u64_le().await?;
Expand All @@ -365,6 +360,7 @@ async fn search_vector_async(
// uid to (file_path, row_group, page_offset, page_size, dict_page_size)
uid_to_metadata: &uid_to_metadatas[i],
};
reader_access_methods.push(reader_access_method.clone());

// we probably want to serialize and deserialize the indexparams too
// upon merging if they are not the same throw an error
Expand All @@ -382,73 +378,36 @@ async fn search_vector_async(

let mut ctx = index.get_search_context();
let _ = index.search(&mut ctx, query.as_slice()).await;
let mut local_results: Vec<usize> = ctx.frontier.iter().map(|(v, _d)| *v).collect();
let local_results: Vec<(OrderedFloat<f32>, usize, usize)> = ctx
.frontier
.iter()
.map(|(v, d)| (OrderedFloat(*d as f32), i, *v))
.collect();

results.append(&mut local_results);
results.extend(local_results);
}

Ok(results)
}

async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncReader>), LavaError> {
let tasks: Vec<_> = files
let results: Vec<(usize, usize)> = results
.iter()
.map(|file| {
let file = file.clone(); // Clone file name to move into the async block
tokio::spawn(async move {
// Determine the operator based on the file scheme
let operator = if file.starts_with("s3://") {
Operators::from(S3Builder::from(file.as_str())).into_inner()
} else {
let current_path = env::current_dir()?;
Operators::from(FsBuilder::from(current_path.to_str().expect("no path")))
.into_inner()
};

// Extract filename
let filename = if file.starts_with("s3://") {
file[5..].split('/').collect::<Vec<_>>()[1..].join("/")
} else {
file.clone()
};

// Create the reader
let reader: AsyncReader = operator
.clone()
.reader_with(&filename)
.buffer(READER_BUFFER_SIZE)
.await?
.into();

// Get the file size
let file_size: u64 = operator.stat(&filename).await?.content_length();

Ok::<_, LavaError>((file_size as usize, reader))
})
})
.take(k)
.cloned()
.map(|(_v, i, d)| (i, d))
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();
let futures: Vec<_> = results
.iter()
.map(|(file_id, n)| reader_access_methods[*file_id].get_vec_async(*n))
.collect();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse("Task join error: {}".to_string())), // Handle join error
}
}
let vectors: Vec<Result<Vec<f32>, LavaError>> = futures::future::join_all(futures).await;
let vectors: Result<Vec<Vec<f32>>, LavaError> = vectors.into_iter().collect();
let vectors: Vec<Vec<f32>> = vectors?;
let rows = vectors.len();
let cols = vectors[0].len();
let vectors: Vec<f32> = vectors.into_iter().flatten().collect();
let vectors = Array2::from_shape_vec((rows, cols), vectors).unwrap();

Ok((file_sizes, readers))
Ok((results, vectors))
}

#[tokio::main]
Expand Down Expand Up @@ -520,7 +479,7 @@ pub async fn search_lava_vector(
uid_to_metadatas: &Vec<Vec<(String, usize, usize, usize, usize)>>,
query: &Vec<f32>,
k: usize,
) -> Result<Vec<usize>, LavaError> {
) -> Result<(Vec<(usize, usize)>, Array2<f32>), LavaError> {
let (file_sizes, readers) = get_file_sizes_and_readers(&files).await?;
search_vector_async(
column_name,
Expand Down
9 changes: 6 additions & 3 deletions src/lava_py/lava.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ndarray::{Array2, ArrayD, Ix2};
use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::types::PyBytes;
use pyo3::IntoPy;
use pyo3::Py;

#[pyfunction]
pub fn search_lava_bm25(
Expand Down Expand Up @@ -40,10 +41,12 @@ pub fn search_lava_vector(
uid_to_metadatas: Vec<Vec<(String, usize, usize, usize, usize)>>,
query: Vec<f32>,
k: usize,
) -> Result<Vec<usize>, LavaError> {
py.allow_threads(|| {
) -> Result<(Vec<(usize, usize)>, Py<PyArray2<f32>>), LavaError> {
let (metadata, array) = py.allow_threads(|| {
lava::search_lava_vector(files, column_name, &uid_nrows, &uid_to_metadatas, &query, k)
})
})?;

Ok((metadata, array.into_pyarray(py).to_owned()))
}

#[pyfunction]
Expand Down
2 changes: 1 addition & 1 deletion src/vamana/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct InMemoryAccessMethodF32 {
pub data: Array2<f32>,
}

#[derive(Clone)]
pub struct ReaderAccessMethodF32<'a> {
pub dim: usize,
pub num_points: usize,
Expand All @@ -49,7 +50,6 @@ pub struct ReaderAccessMethodF32<'a> {
}

impl VectorAccessMethod<f32> for ReaderAccessMethodF32<'_> {

fn get_vec<'a>(&'a self, idx: usize) -> &'a [f32] {
unimplemented!("get_vec not implemented for ReaderAccessMethodF32")
}
Expand Down

0 comments on commit 46c1d57

Please sign in to comment.