Skip to content

Commit

Permalink
more vector index stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Jun 8, 2024
1 parent c66e217 commit d77fdc9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
42 changes: 39 additions & 3 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def index_file_vector(file_path: str, column_name: str, name = uuid.uuid4().hex,
dim = diffs.item() // dtype_size
x = np.frombuffer(buffers[2], dtype = np.float32).reshape(len(arr), dim)

kmeans = faiss.Kmeans(128, len(arr) // 10_000, niter=30, verbose=True, gpu = gpu)
num_centroids = 1000 # len(arr) // 10_000

# kmeans = faiss.Kmeans(128, len(arr) // 10_000, niter=30, verbose=True, gpu = gpu)
kmeans = faiss.Kmeans(128,num_centroids, niter=30, verbose=True, gpu = gpu)
kmeans.train(x)
centroids = kmeans.centroids

Expand All @@ -177,8 +180,8 @@ def index_file_vector(file_path: str, column_name: str, name = uuid.uuid4().hex,

batch_size = 10_000

posting_lists = [[] for _ in range(len(arr) // 10_000)]
codes_lists = [[] for _ in range(len(arr) // 10_000)]
posting_lists = [[] for _ in range(num_centroids)]
codes_lists = [[] for _ in range(num_centroids)]

for i in tqdm(range(len(arr) // batch_size)):
batch = x[i * batch_size:(i + 1) * batch_size]
Expand Down Expand Up @@ -268,8 +271,41 @@ def merge_index_substring(new_index_name: str, index_names: List[str]):

def merge_index_vector(new_index_name: str, index_names: List[str]):

try:
import faiss
import zstandard as zstd
except:
print("faiss zstandard not installed")
return

def read_range(file_handle, start, end):
file_handle.seek(start)
return file_handle.read(end - start)

offsets = merge_metadatas(new_index_name, index_names)

# assume things are on disk for now
assert len(index_names) == 2

index1 = open(f"{index_names[0]}.lava", "rb")
index2 = open(f"{index_names[1]}.lava", "rb")
output = open(f"{new_index_name}.lava", "wb")

# get the length of index1
index1_size = index1.seek(0,2)
index2_size = index2.seek(0,2)
offsets1 = read_range(index1, index1_size - 8 * 4, index1_size)
offsets2 = read_range(index2, index2_size - 8 * 4, index2_size)

output.write(read_range(index1, 0, offsets1[0]))
unprocessed = read_range(index2, 0, offsets2[0])
while unprocessed:
l = np.frombuffer(unprocessed[:4], np.uint32).item()


decompressor = zstd.ZstdDecompressor()


def get_result_from_index_result(metadata: polars.DataFrame, index_search_results: list):

uids = polars.from_dict({"file_id": [i[0] for i in index_search_results], "uid": [i[1] for i in index_search_results]})
Expand Down
9 changes: 3 additions & 6 deletions src/lava/merge.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use arrow::datatypes::ToByteSlice;
use async_recursion::async_recursion;
use bincode;
use bit_vec::BitVec;
Expand Down Expand Up @@ -189,6 +190,7 @@ impl PListChunkIterator {
}
}


async fn merge_lava_uuid(
condensed_lava_file: &str,
lava_files: Vec<String>,
Expand Down Expand Up @@ -431,13 +433,8 @@ async fn compute_interleave(

// interleave_iterations += 1;
// println!(
// "{} {} ",
// "{} ",
// interleave_iterations,
// interleave
// .iter()
// .zip(new_interleave.iter())
// .filter(|&(a_bit, b_bit)| a_bit != b_bit)
// .count()
// );

if new_interleave == interleave {
Expand Down

0 comments on commit d77fdc9

Please sign in to comment.