Skip to content

Commit

Permalink
Merge branch 'DOC-4345-json-intro' of github.com:andy-stark-redis/red…
Browse files Browse the repository at this point in the history
…is-py into DOC-4345-json-intro
  • Loading branch information
andy-stark-redis committed Nov 5, 2024
2 parents c639228 + 976063d commit 1ffe56c
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packaging>=20.4
pytest
pytest-asyncio>=0.23.0,<0.24.0
pytest-cov
pytest-profiling
pytest-profiling==1.7.0
pytest-timeout
ujson>=4.2.0
uvloop
Expand Down
124 changes: 124 additions & 0 deletions doctests/query_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# EXAMPLE: query_combined
# HIDE_START
import json
import numpy as np
import redis
import warnings
from redis.commands.json.path import Path
from redis.commands.search.field import NumericField, TagField, TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from sentence_transformers import SentenceTransformer


def embed_text(model, text):
return np.array(model.encode(text)).astype(np.float32).tobytes()

warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces.*")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
query = "Bike for small kids"
query_vector = embed_text(model, query)

r = redis.Redis(decode_responses=True)

# create index
schema = (
TextField("$.description", no_stem=True, as_name="model"),
TagField("$.condition", as_name="condition"),
NumericField("$.price", as_name="price"),
VectorField(
"$.description_embeddings",
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": 384,
"DISTANCE_METRIC": "COSINE",
},
as_name="vector",
),
)

index = r.ft("idx:bicycle")
index.create_index(
schema,
definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON),
)

# load data
with open("data/query_vector.json") as f:
bicycles = json.load(f)

pipeline = r.pipeline(transaction=False)
for bid, bicycle in enumerate(bicycles):
pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle)
pipeline.execute()
# HIDE_END

# STEP_START combined1
q = Query("@price:[500 1000] @condition:{new}")
res = index.search(q)
print(res.total) # >>> 1
# REMOVE_START
assert res.total == 1
# REMOVE_END
# STEP_END

# STEP_START combined2
q = Query("kids @price:[500 1000] @condition:{used}")
res = index.search(q)
print(res.total) # >>> 1
# REMOVE_START
assert res.total == 1
# REMOVE_END
# STEP_END

# STEP_START combined3
q = Query("(kids | small) @condition:{used}")
res = index.search(q)
print(res.total) # >>> 2
# REMOVE_START
assert res.total == 2
# REMOVE_END
# STEP_END

# STEP_START combined4
q = Query("@description:(kids | small) @condition:{used}")
res = index.search(q)
print(res.total) # >>> 0
# REMOVE_START
assert res.total == 0
# REMOVE_END
# STEP_END

# STEP_START combined5
q = Query("@description:(kids | small) @condition:{new | used}")
res = index.search(q)
print(res.total) # >>> 0
# REMOVE_START
assert res.total == 0
# REMOVE_END
# STEP_END

# STEP_START combined6
q = Query("@price:[500 1000] -@condition:{new}")
res = index.search(q)
print(res.total) # >>> 2
# REMOVE_START
assert res.total == 2
# REMOVE_END
# STEP_END

# STEP_START combined7
q = Query("(@price:[500 1000] -@condition:{new})=>[KNN 3 @vector $query_vector]").dialect(2)
# put query string here
res = index.search(q,{ 'query_vector': query_vector })
print(res.total) # >>> 2
# REMOVE_START
assert res.total == 2
# REMOVE_END
# STEP_END

# REMOVE_START
# destroy index and data
r.ft("idx:bicycle").dropindex(delete_documents=True)
# REMOVE_END
16 changes: 16 additions & 0 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
self._cursor = []
self._dialect = None
self._add_scores = False
self._scorer = "TFIDF"

def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Expand Down Expand Up @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
self._add_scores = True
return self

def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self

def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
Expand All @@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
if self._verbatim:
ret.append("VERBATIM")

if self._scorer:
ret.extend(["SCORER", self._scorer])

if self._add_scores:
ret.append("ADDSCORES")

Expand All @@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
if self._loadall:
ret.append("LOAD")
ret.append("*")

elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
Expand Down
55 changes: 55 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
assert await decoded_r.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

assert await decoded_r.hset(
"doc1",
mapping={
"name": "cat book",
"description": "an animal book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
assert await decoded_r.hset(
"doc2",
mapping={
"name": "dog book",
"description": "an animal book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = await decoded_r.ft().aggregate(
req,
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
)

if isinstance(res, dict):
assert len(res["results"]) == 2
else:
assert len(res.rows) == 2
for row in res.rows:
len(row) == 6


@pytest.mark.redismod
@skip_if_redis_enterprise()
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
Expand Down
55 changes: 55 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(client):
client.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

client.hset(
"doc1",
mapping={
"name": "cat book",
"description": "an animal book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
client.hset(
"doc2",
mapping={
"name": "dog book",
"description": "an animal book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = client.ft().aggregate(
req,
query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()},
)

if isinstance(res, dict):
assert len(res["results"]) == 2
else:
assert len(res.rows) == 2
for row in res.rows:
len(row) == 6


@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_index_definition(client):
Expand Down

0 comments on commit 1ffe56c

Please sign in to comment.