Skip to content

Commit

Permalink
merge cache ranges, refactor python
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Jun 12, 2024
1 parent 2cacc12 commit 48bac29
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 232 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ numpy = "0.21.0"
num-traits = "0.2.18"
ordered-float = "4.2.0"
reqwest = "0.12.4"

[profile.release]
lto = false
bit-vec = "0.6.3"
Expand Down
22 changes: 11 additions & 11 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
# result = rottnest.search_index_bm25(["merged_index"], "cell phones", K = 10,query_expansion = "openai", reader_type = "aws")
# print(result)

# rottnest.index_file_substring("example_data/real.parquet", "text", "index0", token_skip_factor = 3)
rottnest.index_file_substring("example_data/real.parquet", "text", "index0", token_skip_factor = 10)
# rottnest.index_file_substring("example_data/1.parquet", "body", "index1")
# rottnest.merge_index_substring("merged_index", ["index0", "index1"])
# result = rottnest.search_index_substring(["index0"], "did fake", K = 10)
# print(result)

# rottnest.index_file_uuid("a.parquet", "hashes", "index0")
# rottnest.index_file_uuid("b.parquet", "hashes", "index1")
# rottnest.index_file_uuid("uuid_data/a.parquet", "hashes", "index0")
# rottnest.index_file_uuid("uuid_data/b.parquet", "hashes", "index1")
# rottnest.merge_index_uuid("merged_index", ["index0", "index1"])
result = rottnest.search_index_uuid(["index1"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10)
print(result)
result = rottnest.search_index_uuid(["merged_index"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10)
print(result)
result = rottnest.search_index_uuid(["index0", "index1"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10)
print(result)
result = rottnest.search_index_uuid(["merged_index"], "32b8fd4d808300b97b2dff451cba4185faee842a1248c84c1ab544632957eb8904dccb5880f0d4a9a7317c3a4490b0222e4deb5047abc1788665a46176009a07", K = 10)
print(result)
# result = rottnest.search_index_uuid(["index1"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10)
# print(result)
# result = rottnest.search_index_uuid(["merged_index"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10)
# print(result)
# result = rottnest.search_index_uuid(["index0", "index1"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10)
# print(result)
# result = rottnest.search_index_uuid(["merged_index"], "32b8fd4d808300b97b2dff451cba4185faee842a1248c84c1ab544632957eb8904dccb5880f0d4a9a7317c3a4490b0222e4deb5047abc1788665a46176009a07", K = 10)
# print(result)
205 changes: 37 additions & 168 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,108 +6,15 @@
import uuid
import polars
import numpy as np
import boto3
from botocore.config import Config
import os
import pyarrow.compute as pac
from pyarrow.fs import S3FileSystem, LocalFileSystem
from .nlp import query_expansion_keyword, query_expansion_llm
import json
import time
import daft
from concurrent.futures import ThreadPoolExecutor

def get_fs_from_file_path(filepath):

if filepath.startswith("s3://"):
if os.getenv('AWS_VIRTUAL_HOST_STYLE'):
try:
s3fs = S3FileSystem(endpoint_override = os.getenv('AWS_ENDPOINT_URL'), force_virtual_addressing = True )
except:
raise ValueError("Requires pyarrow >= 16.0.0 for virtual addressing.")
else:
s3fs = S3FileSystem(endpoint_override = os.getenv('AWS_ENDPOINT_URL'))
else:
s3fs = LocalFileSystem()

return s3fs

def get_daft_io_config_from_file_path(filepath):

if filepath.startswith("s3://"):
fs = daft.io.IOConfig(s3 = daft.io.S3Config(force_virtual_addressing = (True if os.getenv('AWS_VIRTUAL_HOST_STYLE') else False), endpoint_url = os.getenv('AWS_ENDPOINT_URL')))
else:
fs = daft.io.IOConfig()

return fs

def read_metadata_file(file_path: str):

table = pq.read_table(file_path.replace("s3://",''), filesystem = get_fs_from_file_path(file_path))

try:
cache_ranges = json.loads(table.schema.metadata[b'cache_ranges'].decode())
except:
cache_ranges = []

return polars.from_arrow(table), cache_ranges

def read_columns(file_paths: list, row_groups: list, row_nr: list[list]):

def read_parquet_file(file, row_group, row_nr):
f = pq.ParquetFile(file.replace("s3://",''), filesystem=get_fs_from_file_path(file))
return f.read_row_group(row_group, columns=['id']).take(row_nr)

with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: # Control the number of parallel threads
results = list(executor.map(read_parquet_file, file_paths, row_groups, row_nr))

return pyarrow.concat_tables(results)


def get_physical_layout(file_path: str, column_name: str, type = "str"):

assert type in {"str", "binary"}
arrs, layout = rottnest.get_parquet_layout(column_name, file_path)
arr = pyarrow.concat_arrays([i.cast(pyarrow.large_string() if type == 'str' else pyarrow.large_binary()) for i in arrs])
data_page_num_rows = np.array(layout.data_page_num_rows)
uid = np.repeat(np.arange(len(data_page_num_rows)), data_page_num_rows) + 1

# Code tries to compute the starting row offset of each page in its row group.
# The following three lines are definitely easier to read than to write.

x = np.cumsum(np.hstack([[0],layout.data_page_num_rows[:-1]]))
y = np.repeat(x[np.cumsum(np.hstack([[0],layout.row_group_data_pages[:-1]])).astype(np.uint64)], layout.row_group_data_pages)
page_row_offsets_in_row_group = x - y

file_data = polars.from_dict({
"uid": np.arange(len(data_page_num_rows) + 1),
"file_path": [file_path] * (len(data_page_num_rows) + 1),
"column_name": [column_name] * (len(data_page_num_rows) + 1),
# TODO: figure out a better way to handle this. Currently this is definitely not a bottleneck. Write ampl factor is almost 10x
# writing just one row followed by a bunch of Nones don't help, likely because it's already smart enough to do dict encoding.
# but we should probably still do this to save memory once loaded in!
"metadata_bytes": [layout.metadata_bytes] + [None] * (len(data_page_num_rows)),
"data_page_offsets": [-1] + layout.data_page_offsets,
"data_page_sizes": [-1] + layout.data_page_sizes,
"dictionary_page_sizes": [-1] + layout.dictionary_page_sizes,
"row_groups": np.hstack([[-1] , np.repeat(np.arange(layout.num_row_groups), layout.row_group_data_pages)]),
"page_row_offset_in_row_group": np.hstack([[-1], page_row_offsets_in_row_group])
}
)

return arr, pyarrow.array(uid.astype(np.uint64)), file_data

def get_virtual_layout(file_path: str, column_name: str, key_column_name: str, type = "str", stride = 500):

fs = get_fs_from_file_path(file_path)
table = pq.read_table(file_path, filesystem=fs, columns = [key_column_name, column_name])
table = table.with_row_count('__row_count__').with_columns((polars.col('__row_count__') // stride).alias('__uid__'))

arr = table[column_name].to_arrow().cast(pyarrow.large_string() if type == 'str' else pyarrow.large_binary())
uid = table['__uid__'].to_arrow().cast(pyarrow.uint64())
from .nlp import query_expansion_keyword, query_expansion_llm
from .utils import get_daft_io_config_from_file_path, get_fs_from_file_path, get_physical_layout, get_virtual_layout, read_columns, read_metadata_file,\
get_metadata_and_populate_cache, get_result_from_index_result, return_full_result

metadata = table.groupby("__uid__").agg([polars.col(key_column_name).min().alias("min"), polars.col(key_column_name).max().alias("max")]).sort("__uid__")
return arr, uid, metadata

def index_file_bm25(file_path: str, column_name: str, name = uuid.uuid4().hex, index_mode = "physical", tokenizer_file = None):

Expand All @@ -124,7 +31,7 @@ def index_file_substring(file_path: str, column_name: str, name = uuid.uuid4().h

arr, uid, file_data = get_physical_layout(file_path, column_name) if index_mode == "physical" else get_virtual_layout(file_path, column_name, "uid")

cache_ranges = rottnest.build_lava_substring(f"{name}.lava", arr, pyarrow.array(uid.astype(np.uint64)), tokenizer_file, token_skip_factor)
cache_ranges = rottnest.build_lava_substring(f"{name}.lava", arr, uid, tokenizer_file, token_skip_factor)

file_data = file_data.to_arrow()
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
Expand Down Expand Up @@ -239,36 +146,46 @@ def index_file_vector(file_path: str, column_name: str, name = uuid.uuid4().hex,

cache_end = f.tell()

file_data = file_data.to_arrow()
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps([(cache_start, cache_end)])})
file_data = file_data.to_arrow().replace_schema_metadata({"cache_ranges": json.dumps([(cache_start, cache_end)])})
pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd')


def merge_metadatas(new_index_name: str, index_names: List[str]):
assert len(index_names) > 1
# first read the metadata files and merge those
# discard the cache ranges
metadatas = [read_metadata_file(f"{name}.meta")[0] for name in index_names]
metadatas = daft.table.read_parquet_into_pyarrow_bulk([f"{index_name}.meta" for index_name in index_names], io_config = get_daft_io_config_from_file_path(index_names[0]))
# discard cache ranges in metadata, don't need them
metadatas = [polars.from_arrow(i) for i in metadatas]
metadata_lens = [len(metadata) for metadata in metadatas]
offsets = np.cumsum([0] + metadata_lens)[:-1]
metadatas = [metadata.with_columns(polars.col("uid") + offsets[i]) for i, metadata in enumerate(metadatas)]
polars.concat(metadatas).write_parquet(f"{new_index_name}.meta", statistics=False)
return offsets
return offsets, polars.concat(metadatas)

def merge_index_bm25(new_index_name: str, index_names: List[str]):

offsets = merge_metadatas(new_index_name, index_names)
rottnest.merge_lava_bm25(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets)

def merge_index_uuid(new_index_name: str, index_names: List[str]):
offsets, file_data = merge_metadatas(new_index_name, index_names)

offsets = merge_metadatas(new_index_name, index_names)
rottnest.merge_lava_uuid(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets)
cache_ranges = rottnest.merge_lava_generic(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets, 0)

file_data = file_data.to_arrow().replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{new_index_name}.meta", write_statistics = False, compression = 'zstd')

def merge_index_substring(new_index_name: str, index_names: List[str]):

offsets = merge_metadatas(new_index_name, index_names)
rottnest.merge_lava_substring(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets)
offsets, file_data = merge_metadatas(new_index_name, index_names)

cache_ranges = rottnest.merge_lava_generic(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets, 1)

file_data = file_data.to_arrow().replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{new_index_name}.meta", write_statistics = False, compression = 'zstd')

def merge_index_uuid(new_index_name: str, index_names: List[str]):

offsets, file_data = merge_metadatas(new_index_name, index_names)

cache_ranges = rottnest.merge_lava_generic(f"{new_index_name}.lava", [f"{name}.lava" for name in index_names], offsets, 2)

file_data = file_data.to_arrow().replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{new_index_name}.meta", write_statistics = False, compression = 'zstd')

def merge_index_vector(new_index_name: str, index_names: List[str]):

Expand All @@ -279,6 +196,8 @@ def merge_index_vector(new_index_name: str, index_names: List[str]):
print("faiss zstandard not installed")
return

raise NotImplementedError

def read_range(file_handle, start, end):
file_handle.seek(start)
return file_handle.read(end - start)
Expand Down Expand Up @@ -306,49 +225,7 @@ def read_range(file_handle, start, end):

decompressor = zstd.ZstdDecompressor()


def get_result_from_index_result(metadata: polars.DataFrame, index_search_results: list):

uids = polars.from_dict({"file_id": [i[0] for i in index_search_results], "uid": [i[1] for i in index_search_results]})
file_metadatas = metadata.filter(polars.col("metadata_bytes").is_not_null()).group_by("file_path").first().select(["file_path", "metadata_bytes"])
metadata = metadata.join(uids, on = ["file_id", "uid"])

assert len(metadata["column_name"].unique()) == 1, "index is not allowed to span multiple column names"
column_name = metadata["column_name"].unique()[0]

file_metadatas = {d["file_path"]: d["metadata_bytes"] for d in file_metadatas.to_dicts()}

result = rottnest.read_indexed_pages(column_name, metadata["file_path"].to_list(), metadata["row_groups"].to_list(),
metadata["data_page_offsets"].to_list(), metadata["data_page_sizes"].to_list(), metadata["dictionary_page_sizes"].to_list(),
"aws", file_metadatas)

row_group_rownr = [pyarrow.array(np.arange(metadata['page_row_offset_in_row_group'][i], metadata['page_row_offset_in_row_group'][i] + len(arr))) for i, arr in enumerate(result)]

metadata_key = [pyarrow.array(np.ones(len(arr)).astype(np.uint32) * i) for i, arr in enumerate(result)]

result = pyarrow.table([pyarrow.chunked_array(result), pyarrow.chunked_array(row_group_rownr),
pyarrow.chunked_array(metadata_key)], names = [column_name, '__row_group_rownr__', '__metadata_key__'])

return result, column_name, metadata.with_row_count('__metadata_key__')

def get_metadata_and_populate_cache(indices: List[str]):

# metadatas = [read_metadata_file(f"{index_name}.meta") for i, index_name in enumerate(indices)]

metadatas = daft.table.read_parquet_into_pyarrow_bulk([f"{index_name}.meta" for index_name in indices], io_config = get_daft_io_config_from_file_path(indices[0]))
metadatas = [(polars.from_arrow(i), json.loads(i.schema.metadata[b'cache_ranges'].decode())) for i in metadatas]

metadata = polars.concat([f[0].with_columns(polars.lit(i).alias("file_id").cast(polars.Int64)) for i, f in enumerate(metadatas)])
cache_dir = os.getenv("ROTTNEST_CACHE_DIR")
if cache_dir:
cache_ranges = {f"{indices[i]}.lava": f[1] for i, f in enumerate(metadatas) if len(f[1]) > 0}
cached_files = list(cache_ranges.keys())
ranges = [[tuple(k) for k in cache_ranges[f]] for f in cached_files]
rottnest.populate_cache(cached_files, ranges, cache_dir, "aws")

return metadata

def search_index_uuid(indices: List[str], query: str, K: int):
def search_index_uuid(indices: List[str], query: str, K: int, columns = []):

metadata = get_metadata_and_populate_cache(indices)

Expand All @@ -361,10 +238,10 @@ def search_index_uuid(indices: List[str], query: str, K: int):
result, column_name, metadata = get_result_from_index_result(metadata, index_search_results)
result = polars.from_arrow(result).filter(polars.col(column_name) == query)

return result
return return_full_result(result, metadata, column_name, columns)


def search_index_substring(indices: List[str], query: str, K: int):
def search_index_substring(indices: List[str], query: str, K: int, columns = []):

metadata = get_metadata_and_populate_cache(indices)

Expand All @@ -377,7 +254,7 @@ def search_index_substring(indices: List[str], query: str, K: int):
result, column_name, metadata = get_result_from_index_result(metadata, index_search_results)
result = polars.from_arrow(result).filter(polars.col(column_name).str.to_lowercase().str.contains(query.lower()))

return result
return return_full_result(result, metadata, column_name, columns)

def search_index_vector(indices: List[str], query: np.array, K: int, columns = [], nprobes = 500, refine = 500):

Expand Down Expand Up @@ -451,17 +328,9 @@ def search_index_vector(indices: List[str], query: np.array, K: int, columns = [
dim = diffs.item() // 4
vecs = np.frombuffer(buffers[2], dtype = np.float32).reshape(len(result), dim)
results = np.linalg.norm(query - vecs, axis = 1).argsort()[:K]
result = polars.from_arrow(result)[results].join(metadata.select(["__metadata_key__", "file_path", "row_groups"]), on = "__metadata_key__", how = "left")

if columns != []:
grouped = result.groupby(["file_path", "row_groups"]).agg([polars.col('__metadata_key__'), polars.col('__row_group_rownr__')])
collected_results = polars.from_arrow(read_columns(grouped["file_path"].to_list(), grouped["row_groups"].to_list(), grouped["__row_group_rownr__"].to_list()))
unnested_metadata_key = grouped['__metadata_key__'].explode()
unnested_row_group_rownr = grouped['__row_group_rownr__'].explode()
collected_results = collected_results.with_columns([unnested_metadata_key, unnested_row_group_rownr])
result = result.join(collected_results, on = ["__metadata_key__", "__row_group_rownr__"], how = "left")
result = polars.from_arrow(result)[results]

return result.select(columns + [column_name])
return return_full_result(result, metadata, column_name, columns)


def search_index_bm25(indices: List[str], query: str, K: int, query_expansion = "bge", quality_factor = 0.2, expansion_tokens = 20, cache_dir = None, reader_type = None):
Expand Down
Loading

0 comments on commit 48bac29

Please sign in to comment.