Skip to content

Commit

Permalink
0.1.8.1 support for batch input queries
Browse files Browse the repository at this point in the history
  • Loading branch information
vprelovac committed Oct 30, 2023
1 parent 69a32bc commit 5f91db6
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 90 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ Save content to memory. Metadata will be automatically optimized to use less res
- `metdata`: *Optional.* Metadata or list of metadata associated with the texts.
- `memory_file`: *Optional.* Path to persist the memory file. By default

**Memory.search(query, top_n=5, unique=False)**
**Memory.search(query, top_n=5, unique=False, batch_results="flatten")**

Search inside memory.

- `query`: *Required.* Query text.
- `query`: *Required.* Query text or list of queries (see `batch_results` option below for handling results for a list).
- `top_n`: *Optional.* Number of most similar chunks to return (default: 5).
- `unique`: *Optional.* Return only items chunks from unique original texts (additional chunks coming from the same text will be ignored). Note this may return less chhunks than requested (default: False).
- `batch_results`: *Optional.* When input is a list of queries, output algorithm can be "flatten" or "diverse". Flatten returns true nearest neighbours across all input queries, meaning all results could come from just one query. "diverse" attempts to spread out the results, so that each query's nearest neighbours are equally added (neareast first across all queries, than 2nd nearest and so on). (default: "flatten")

**Memory.clear()**

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="vectordb2",
version="0.1.7",
version="0.1.8.1",
packages=find_packages(),
install_requires=[
"torch>=1.9.0",
Expand Down
2 changes: 1 addition & 1 deletion vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals
# pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals

from .memory import Memory
4 changes: 2 additions & 2 deletions vectordb/chunking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals
# pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals

from typing import List
import re
Expand Down Expand Up @@ -39,7 +39,7 @@ def clean_text(text: str) -> str:
def __call__(self, text: str) -> List[str]:
if self.strategy == "paragraph":
return self.paragraph_chunking(text)

return self.sliding_window_chunking(text)

def paragraph_chunking(self, text: str) -> List[str]:
Expand Down
13 changes: 5 additions & 8 deletions vectordb/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module provides classes for generating text embeddings using various pre-trained models.
"""

#pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals
# pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals

from abc import ABC, abstractmethod
from typing import List
Expand All @@ -13,12 +13,12 @@

class BaseEmbedder(ABC):
"""Base class for Embedder."""

@abstractmethod
def embed_text(self, chunks: List[str]) -> List[List[float]]:
"""Generates embeddings for a list of text chunks."""



class Embedder(BaseEmbedder):
"""
This class provides a way to generate embeddings for given text chunks using a specified
Expand All @@ -39,19 +39,16 @@ def __init__(self, model_name: str = "normal"):
"https://tfhub.dev/google/universal-sentence-encoder/4"
)
self.sbert = False
elif model_name == "multilingual" :
self.model = hub.load(
"universal-sentence-encoder-multilingual-large/3"
)
elif model_name == "multilingual":
self.model = hub.load("universal-sentence-encoder-multilingual-large/3")
self.sbert = False
else:
#if model_name == "normal":
# if model_name == "normal":
# model_name = "sentence-transformers/all-MiniLM-L6-v2"
if model_name == "normal":
model_name = "BAAI/bge-small-en-v1.5"
elif model_name == "best":
model_name = "BAAI/bge-base-en-v1.5"


self.model = SentenceTransformer(model_name)

Expand Down
73 changes: 29 additions & 44 deletions vectordb/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
for text and associated metadata, with functionality for saving, searching, and
managing memory entries.
"""
#pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals
# pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals

from typing import List, Dict, Any, Union
import itertools
Expand All @@ -25,7 +25,6 @@ def __init__(
memory_file: str = None,
chunking_strategy: dict = None,
embeddings: Union[BaseEmbedder, str] = "normal",
embed_on_save: bool = True
):
"""
Initializes the Memory class.
Expand All @@ -35,7 +34,6 @@ def __init__(
:param embedding_model: a string containing the name of the pre-trained model to be used for embeddings (default: "sentence-transformers/all-MiniLM-L6-v2").
"""
self.memory_file = memory_file
self.embed_on_save = embed_on_save

self.memory = (
[] if memory_file is None else Storage(memory_file).load_from_disk()
Expand All @@ -46,7 +44,7 @@ def __init__(

self.metadata_memory = []
self.metadata_index_counter = 0
self.text_index_counter = 0
self.text_index_counter = 0

if isinstance(embeddings, str):
self.embedder = Embedder(embeddings)
Expand Down Expand Up @@ -74,7 +72,6 @@ def save(
if not isinstance(texts, list):
texts = [texts]


if metadata is None:
metadata = []
elif not isinstance(metadata, list):
Expand All @@ -86,12 +83,13 @@ def save(
for meta in metadata:
self.metadata_memory.append(meta)

meta_index_start = (
self.metadata_index_counter
) # Starting index for this save operation
self.metadata_index_counter += len(
metadata
) # Update the counter for future save operations



meta_index_start = self.metadata_index_counter # Starting index for this save operation
self.metadata_index_counter += len(metadata) # Update the counter for future save operations

if memory_file is None:
memory_file = self.memory_file

Expand All @@ -100,16 +98,13 @@ def save(

flatten_chunks = list(itertools.chain.from_iterable(text_chunks))

if self.embed_on_save:
embeddings = self.embedder.embed_text(flatten_chunks)
else:
embeddings = [None] * len(flatten_chunks) # Placeholder for future embedding

embeddings = self.embedder.embed_text(flatten_chunks)

text_index_start = self.text_index_counter # Starting index for this save operation
text_index_start = (
self.text_index_counter
) # Starting index for this save operation
self.text_index_counter += len(texts)


# accumulated size is end_index of each chunk
for size, end_index, chunks, meta_index, text_index in zip(
chunks_size,
Expand All @@ -118,57 +113,48 @@ def save(
range(meta_index_start, self.metadata_index_counter),
range(text_index_start, self.text_index_counter),
):

start_index = end_index - size
chunks_embedding = embeddings[start_index:end_index]

for chunk, embedding in zip(chunks, chunks_embedding):

entry = {
"chunk": chunk,
"embedding": embedding,
"metadata_index": meta_index,
"text_index": text_index,
}
self.memory.append(entry)

if memory_file is not None:
Storage(memory_file).save_to_disk(self.memory)

def search(self, query: str, top_n: int = 5, unique: bool = False) -> List[Dict[str, Any]]:
def search(
self, query: str, top_n: int = 5, unique: bool = False, batch_results: str = "flatten"
) -> List[Dict[str, Any]]:
"""
Searches for the most similar chunks to the given query in memory.
:param query: a string containing the query text.
:param top_n: the number of most similar chunks to return. (default: 5)
:param unique: chunks are filtered out to unique texts (default: False)
:param batch_results: if input is list of queries, results can use "flatten" or "diverse" algorithm
:return: a list of dictionaries containing the top_n most similar chunks and their associated metadata.
"""

if not self.embed_on_save: # We need to dynamically create
all_chunks = [entry['chunk'] for entry in self.memory] # Gather all stored chunks

all_chunks.append(query) # Add the query for simultaneous embedding

all_embeddings = self.embedder.embed_text(all_chunks) # Embed all at once

query_embedding = all_embeddings[-1] # Last is the query
embeddings = all_embeddings[:-1] # All but last are the stored chunks

# # Update stored embeddings
# for i, entry in enumerate(self.memory):
# entry['embedding'] = all_embeddings[i]
if isinstance(query, list):
query_embedding = self.embedder.embed_text(query)
else:
query_embedding = self.embedder.embed_text([query])[0]
embeddings = [entry["embedding"] for entry in self.memory]

indices = self.vector_search.search_vectors(query_embedding, embeddings, top_n)
print(indices)

embeddings = [entry["embedding"] for entry in self.memory]

indices = self.vector_search.search_vectors(query_embedding, embeddings, top_n, batch_results)
if unique:
unique_indices = []
seen_text_indices = set() # Change the variable name
for i in indices:
text_index = self.memory[i][
text_index = self.memory[i[0]][
"text_index"
] # Use text_index instead of metadata_index
if (
Expand All @@ -184,10 +170,11 @@ def search(self, query: str, top_n: int = 5, unique: bool = False) -> List[Dict[
{
"chunk": self.memory[i[0]]["chunk"],
"metadata": self.metadata_memory[self.memory[i[0]]["metadata_index"]],
"distance": i[1]
"distance": i[1],
}
for i in indices
]

return results

def clear(self):
Expand All @@ -197,7 +184,7 @@ def clear(self):
self.memory = []
self.metadata_memory = []
self.metadata_index_counter = 0
self.text_index_counter = 0
self.text_index_counter = 0

if self.memory_file is not None:
Storage(self.memory_file).save_to_disk(self.memory)
Expand All @@ -211,8 +198,6 @@ def dump(self):
print("Embedding Length:", len(entry["embedding"]))
print("Metadata:", self.metadata_memory[entry["metadata_index"]])
print("-" * 40)

print("Total entries: ", len(self.memory))
print("Total metadata: ", len(self.metadata_memory))


2 changes: 1 addition & 1 deletion vectordb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module provides the Storage class for saving and loading data to and from a disk.
"""

#pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals
# pylint: disable = line-too-long, trailing-whitespace, trailing-newlines, line-too-long, missing-module-docstring, import-error, too-few-public-methods, too-many-instance-attributes, too-many-locals

from typing import List, Dict, Any
import pickle
Expand Down
Loading

0 comments on commit 5f91db6

Please sign in to comment.