Skip to content

Commit

Permalink
about done with natural language index
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Feb 14, 2024
1 parent c46fc46 commit 0607bc1
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 26 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Rust Build with Maturin

on: [push, pull_request]

jobs:
build:
name: Build Rust Project
runs-on: ubuntu-latest

container:
image: ghcr.io/pyo3/maturin:latest

steps:
- uses: actions/checkout@v2
name: Checkout code

- name: Build with Maturin
run: |
maturin build --release --features py
4 changes: 3 additions & 1 deletion src/formats/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ pub async fn get_parquet_layout(
let mut column_chunk_pages: Vec<parquet::column::page::Page> = Vec::new();

let end = end - start;
let column_chunk_offset = start;
start = 0;


while start != end {
// this takes a slice of the entire thing for each page, granted it won't read the entire thing,
Expand Down Expand Up @@ -403,7 +405,7 @@ pub async fn get_parquet_layout(
parquet_layout
.data_page_sizes
.push(compressed_page_size as usize + header_len);
parquet_layout.data_page_offsets.push(start as usize);
parquet_layout.data_page_offsets.push((column_chunk_offset + start) as usize);

parquet_layout
.dictionary_page_sizes
Expand Down
13 changes: 8 additions & 5 deletions src/lava/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,22 @@ pub fn build_lava_natural_language(

// Handle the compressed data (for example, saving to a file or sending over a network)
println!(
"Compressed term dictionary length: {}",
compressed_term_dictionary.len()
"Compressed term dictionary size: {} len: {}",
compressed_term_dictionary.len(),
inverted_index.len()
);

let mut plist_offsets: Vec<u64> = vec![0];
let mut plist_elems: Vec<u64> = vec![0];
let mut plist = PList::new()?;
let mut counter: u64 = 0;

for (_, value) in inverted_index.iter() {
for (key, value) in inverted_index.iter() {
// this usually saves around 20% of the space. Don't remember things that happen more than 1/4 of the time.
// but let's not do this because it makes everything else more complicated

let mut value_all = BTreeSet::new();
value_all.insert(u64::MAX);
value_all.insert(0u64);

let value_vec = if value.len() <= (num_unique_uids / 4) as usize {
//@Rain can we get rid of this clone
Expand All @@ -184,7 +186,6 @@ pub fn build_lava_natural_language(
counter += 1;

// value_vec.sort();
// println!("{}", key);
let written = plist.add_plist(value_vec)?;
if written > 1024 * 1024 || counter == inverted_index.len() as u64 {
let bytes = plist.finalize_compression()?;
Expand All @@ -196,6 +197,8 @@ pub fn build_lava_natural_language(
}




plist_offsets.append(&mut plist_elems);

let compressed_term_dict_offset = file.seek(SeekFrom::Current(0))?;
Expand Down
14 changes: 9 additions & 5 deletions src/lava/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ impl PListChunkIterator {
}
let mut buffer3: Vec<u8> = vec![0u8; (self.plist_offsets[self.current_chunk_offset + 1] - self.plist_offsets[self.current_chunk_offset]) as usize];
self.reader.read(&mut buffer3).await?;
self.current_chunk = PList::search_compressed(buffer3, (0..self.plist_elems[self.current_chunk_offset + 1]).collect()).unwrap();
self.current_chunk = PList::search_compressed(buffer3,
(0..(self.plist_elems[self.current_chunk_offset + 1] - self.plist_elems[self.current_chunk_offset])).collect()).unwrap();
}

Ok(())
Expand All @@ -73,6 +74,7 @@ async fn hoa(
condensed_lava_file: &str,
operator: &mut Operator,
lava_files: Vec<Cow<str>>,
uid_offsets: Vec<u64>
) -> Result<(), LavaError> // hawaiian for lava condensation
{
// instantiate a list of readers from lava_files
Expand Down Expand Up @@ -134,7 +136,6 @@ async fn hoa(

let mut term_dictionary = String::new();
let mut current_lines: Vec<Option<String>> = vec![None; decompressed_term_dictionaries.len()];
let mut plist_cursor: Vec<u64> = vec![0; decompressed_term_dictionaries.len()];

// Initialize the current line for each reader
for (i, reader) in decompressed_term_dictionaries.iter_mut().enumerate() {
Expand Down Expand Up @@ -172,9 +173,9 @@ async fn hoa(
// we need to read and decompress the plists

let this_plist: Vec<u64> = plist_chunk_iterators[i].get_current();

// println!("{:?} {:?}", this_plist, uid_offsets[i]);
for item in this_plist {
plist.insert(item);
plist.insert(item + uid_offsets[i]);
}

let _ = plist_chunk_iterators[i].increase_cursor().await;
Expand Down Expand Up @@ -209,6 +210,8 @@ async fn hoa(
let compressed_term_dictionary = encode_all(bytes, 0).expect("Compression failed");
let compressed_term_dict_offset = output_file.seek(SeekFrom::Current(0))?;

println!("merged compress dict size {:?} len {:?}", compressed_term_dictionary.len(), counter);

output_file.write_all(&compressed_term_dictionary)?;

let compressed_plist_offsets_offset = output_file.seek(SeekFrom::Current(0))?;
Expand All @@ -226,6 +229,7 @@ async fn hoa(
pub fn merge_lava(
condensed_lava_file: Cow<str>,
lava_files: Vec<Cow<str>>,
uid_offsets: Vec<u64>
) -> Result<(), LavaError> {
// you should only merge them on local disk. It's not worth random accessing S3 for this because of the request costs.
// worry about running out of disk later. Assume you have a fast SSD for now.
Expand All @@ -234,5 +238,5 @@ pub fn merge_lava(
builder.root(current_path.to_str().expect("no path"));
let mut operator = Operator::new(builder)?.finish();

hoa(condensed_lava_file.as_ref(), &mut operator, lava_files)
hoa(condensed_lava_file.as_ref(), &mut operator, lava_files, uid_offsets)
}
3 changes: 3 additions & 0 deletions src/lava_py/lava.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ pub fn search_lava(file: &str, query: &str) -> Result<Vec<u64>, LavaError> {
pub fn merge_lava(
condensed_lava_file: &PyString,
lava_files: Vec<&PyString>,
uid_offsets: Vec<u64>
) -> Result<(), LavaError> {
println!("{:?}", uid_offsets);
lava::merge_lava(
condensed_lava_file.to_string_lossy(),
lava_files.iter().map(|x| x.to_string_lossy()).collect(),
uid_offsets
)
}

Expand Down
62 changes: 47 additions & 15 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import polars
import rottnest_rs
import pyarrow.parquet as pq
import numpy as np

def basic_test():
a = polars.from_dict({"a":["你是一只小猪","hello you are happy", "hello, et tu, brutes?"]}).to_arrow()
Expand Down Expand Up @@ -37,43 +38,72 @@ def merge_test():

print("search result", rottnest_rs.search_lava("1.lava", "d"))

rottnest_rs.merge_lava("merged.lava", ["1.lava", "2.lava", "3.lava"])
rottnest_rs.merge_lava("merged.lava", ["1.lava", "2.lava", "3.lava"], [0,10,20])

assert rottnest_rs.search_lava("merged.lava", "d") == [2]
assert rottnest_rs.search_lava("merged.lava", "f") == [5,20] # the second one will be 20 because short uid list

# basic_test()
# merge_test()

from typing import List, Optional
import numpy as np
import uuid

def index_file_natural_language(file_path: List[str], column_name: str, name: Optional[str]):

arr, layout = rottnest_rs.get_parquet_layout(column_name, file_path)
data_page_num_rows = np.array(layout.data_page_num_rows)
uid = np.repeat(np.arange(len(data_page_num_rows)), data_page_num_rows)
uid = np.repeat(np.arange(len(data_page_num_rows)), data_page_num_rows) + 1

file_data = polars.from_dict({
"uid": np.arange(len(data_page_num_rows)),
"data_page_offsets": layout.data_page_offsets,
"data_page_sizes": layout.data_page_sizes,
"dictionary_page_sizes": layout.dictionary_page_sizes,
"row_groups": np.repeat(np.arange(layout.num_row_groups), layout.row_group_data_pages),
"uid": np.arange(len(data_page_num_rows) + 1),
"file_path": [file_path] * (len(data_page_num_rows) + 1),
"column_name": [column_name] * (len(data_page_num_rows) + 1),
"data_page_offsets": [-1] + layout.data_page_offsets,
"data_page_sizes": [-1] + layout.data_page_sizes,
"dictionary_page_sizes": [-1] + layout.dictionary_page_sizes,
"row_groups": np.hstack([[-1] , np.repeat(np.arange(layout.num_row_groups), layout.row_group_data_pages)]),
}
)

name = uuid.uuid4().hex if name is None else name

file_data.write_parquet(f"{name}.parquet")
file_data.write_parquet(f"{name}.meta")
print(rottnest_rs.build_lava_natural_language(f"{name}.lava", arr, pyarrow.array(uid.astype(np.uint64))))

def search_index_natural_language(metadata_path: str, index_path: str, query):
def merge_index_natural_language(new_index_name: str, index_names: List[str]):
assert len(index_names) > 1

# first read the metadata files and merge those
metadatas = [polars.read_parquet(f"{name}.meta")for name in index_names]
metadata_lens = [len(metadata) for metadata in metadatas]
offsets = np.cumsum([0] + metadata_lens)[:-1]
metadatas = [metadata.with_columns(polars.col("uid") + offsets[i]) for i, metadata in enumerate(metadatas)]

rottnest_rs.merge_lava(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets)
polars.concat(metadatas).write_parquet(f"{new_index_name}.meta")

uids = rottnest_rs.search_lava(index_path, query)
def search_index_natural_language(index_name, query, mode = "exact"):

assert mode in {"exact", "substring"}

metadata_file = f"{index_name}.meta"
index_file = f"{index_name}.lava"
uids = polars.from_dict({"uid":rottnest_rs.search_lava(index_file, query if mode == "substring" else f"^{query}$")})

print(uids)
# rottnest_rs.search_indexed_pages(query, )
if len(uids) == 0:
return

metadata = polars.read_parquet(metadata_file).join(uids, on = "uid")
assert len(metadata["column_name"].unique()) == 1, "index is not allowed to span multiple column names"

# now we need to do something special about -1 values that indicate we have to search the entire file

column_name = metadata["column_name"].unique()[0]
result = rottnest_rs.search_indexed_pages(query, column_name, metadata["file_path"].to_list(), metadata["row_groups"].to_list(),
metadata["data_page_offsets"].to_list(), metadata["data_page_sizes"].to_list(), metadata["dictionary_page_sizes"].to_list())
print([item.matched for item in result])



Expand All @@ -94,9 +124,11 @@ def search_index_natural_language(metadata_path: str, index_path: str, query):

# index_name = "bump1"
# index_file_natural_language("data/part-03060-21668627-949b-4858-97ce-a4b0f4fc2df4-c000.gz.parquet","text", name = index_name)
# search_index_natural_language(f"{index_name}.parquet", f"{index_name}.lava", "C1X")
# search_index_natural_language(index_name, "Publish")
# index_name = "bump2"
# index_file_natural_language("data/part-03062-21668627-949b-4858-97ce-a4b0f4fc2df4-c000.gz.parquet","text", name = index_name)
# search_index_natural_language(f"{index_name}.parquet", f"{index_name}.lava", "C1X")
# search_index_natural_language(index_name, "C1X")

rottnest_rs.merge_lava("merged.lava", ["bump1.lava", "bump2.lava"])
# merge_index_natural_language("merged", ["bump1", "bump2"])
search_index_natural_language("merged", "disparity")
# search_index_natural_language("bump2", "Eric")

0 comments on commit 0607bc1

Please sign in to comment.