Skip to content

Commit

Permalink
expose cache dir option, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Apr 29, 2024
1 parent 65900c7 commit b30a2c0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rottnest"
version = "1.3.0"
version = "1.3.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "rottnest"
version = '1.3.0'
version = '1.3.1'
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
Expand All @@ -20,4 +20,4 @@ features = ["pyo3/extension-module"]

# Defining optional dependencies for feature flags
[project.optional-dependencies]
bm25 = ["duckdb"]
bm25 = ["duckdb", ""]
5 changes: 2 additions & 3 deletions python/rottnest/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def embed_batch_openai(tokens: List[str], model = "text-embedding-3-large"):

return np.vstack(all_vecs)

def query_expansion_llm(tokenizer_vocab: List[str], query: str, method = "bge", expansion_tokens = 20):
def query_expansion_llm(tokenizer_vocab: List[str], query: str, method = "bge", expansion_tokens = 20, cache_dir = None):

assert type(query) == str, "query must be string. If you have a list of keywords, concatenate them with spaces."

cache_dir = os.path.expanduser('~/.cache')
cache_dir = os.path.join(os.path.expanduser('~/.cache'), 'rottnest') if cache_dir is None else cache_dir
# make a subdirectory rottnest under cache_dir if it's not there already
cache_dir = os.path.join(cache_dir, 'rottnest')
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

Expand Down
4 changes: 2 additions & 2 deletions python/rottnest/pele.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,14 @@ def search_index_substring(indices: List[str], query: str, K: int):

return polars.from_arrow(result).filter(polars.col("text").str.to_lowercase().str.contains(query.lower()))

def search_index_bm25(indices: List[str], query: str, K: int, query_expansion = "bge", quality_factor = 0.2, expansion_tokens = 20):
def search_index_bm25(indices: List[str], query: str, K: int, query_expansion = "bge", quality_factor = 0.2, expansion_tokens = 20, cache_dir = None):

assert query_expansion in {"bge", "openai", "keyword", "none"}

tokenizer_vocab = rottnest.get_tokenizer_vocab([f"{index_name}.lava" for index_name in indices])

if query_expansion in {"bge","openai"}:
tokens, token_ids, weights = query_expansion_llm(tokenizer_vocab, query, method = query_expansion, expansion_tokens=expansion_tokens)
tokens, token_ids, weights = query_expansion_llm(tokenizer_vocab, query, method = query_expansion, expansion_tokens=expansion_tokens, cache_dir = cache_dir)
elif query_expansion == "keyword":
tokens, token_ids, weights = query_expansion_keyword(tokenizer_vocab, query)
else:
Expand Down

0 comments on commit b30a2c0

Please sign in to comment.