Elastic as db instead of default SQLite #1996
Unanswered
arunvenkat1911
asked this question in
Q&A
Replies: 1 comment
-
I used pg to rewrite sqlite import uuid
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extras import DictCursor, RealDictCursor
from contextlib import contextmanager
from psycopg2.extensions import register_adapter, AsIs
import json
logger = logging.getLogger(__name__)
class PostgresManager:
def __init__(self, connection_string: str, min_conn: int = 2, max_conn: int = 20):
"""初始化PostgreSQL连接池
Args:
connection_string: 数据库连接字符串
min_conn: 最小连接数,默认2
max_conn: 最大连接数,默认20
"""
self.pool = ThreadedConnectionPool(
minconn=min_conn,
maxconn=max_conn,
dsn=connection_string,
cursor_factory=RealDictCursor
)
self._create_history_table()
@contextmanager
def get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = self.pool.getconn()
yield conn
conn.commit()
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database error: {str(e)}")
raise
finally:
if conn:
self.pool.putconn(conn)
def _create_history_table(self) -> None:
"""创建历史记录表,使用PostgreSQL特有的数据类型"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS history (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
memory_id TEXT NOT NULL,
old_memory JSONB,
new_memory JSONB,
event TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE,
is_deleted INTEGER DEFAULT 0,
CONSTRAINT event_type CHECK (event IN ('ADD', 'UPDATE', 'DELETE'))
);
CREATE INDEX IF NOT EXISTS idx_history_memory_id
ON history(memory_id);
CREATE INDEX IF NOT EXISTS idx_history_created_at
ON history(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_history_event
ON history(event);
""")
def add_history(
self,
memory_id: str,
old_memory: Optional[Any],
new_memory: Optional[Any],
event: str,
created_at: Optional[datetime] = None,
updated_at: Optional[datetime] = None,
is_deleted: bool = False
) -> str:
"""添加历史记录
Returns:
str: 新创建记录的ID
"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
record_id = str(uuid.uuid4())
cursor.execute(
"""
INSERT INTO history
(id, memory_id, old_memory, new_memory, event, created_at, updated_at, is_deleted)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
record_id,
memory_id,
json.dumps(old_memory) if old_memory is not None else None,
json.dumps(new_memory) if new_memory is not None else None,
event,
created_at or datetime.now(),
updated_at,
1 if is_deleted else 0
)
)
return cursor.fetchone()['id']
def get_history(self, memory_id: str, limit: int = 100) -> List[Dict]:
"""获取特定memory_id的历史记录
Args:
memory_id: 内存ID
limit: 返回记录的最大数量,默认100条
"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT
id,
memory_id,
old_memory as prev_value,
new_memory as new_value,
event,
created_at as timestamp,
is_deleted
FROM history
WHERE memory_id = %s
AND NOT is_deleted
ORDER BY created_at DESC
LIMIT %s
""",
(memory_id, limit)
)
return cursor.fetchall()
def reset(self) -> None:
"""重置数据库表"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS history CASCADE;
DROP INDEX IF EXISTS idx_history_memory_id;
DROP INDEX IF EXISTS idx_history_created_at;
DROP INDEX IF EXISTS idx_history_event;
""")
self._create_history_table()
def __del__(self):
"""关闭连接池"""
if hasattr(self, 'pool'):
try:
self.pool.closeall()
except Exception as e:
logger.error(f"Error closing connection pool: {str(e)}") import concurrent
import hashlib
import json
import logging
import os
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
from mem1.storage import PostgresManager
logger = logging.getLogger(__name__)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.custom_prompt = self.config.custom_prompt
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = PostgresManager(os.getenv("DATABASE_URL"))
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.enable_graph = False
if self.api_version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem1.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
def add(
self,
messages,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
prompt=None,
):
"""
Create a new memory.
Args:
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
'memories': affected memories
'graph': affected graph memories
'memories' and 'graph' is a dict, each with following subkeys:
'add': added memory
'update': updated memory
'delete': deleted memory
"""
if metadata is None:
metadata = {}
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
concurrent.futures.wait([future1, future2])
vector_store_result = future1.result()
graph_result = future2.result()
if self.api_version == "v1.1":
return {
"results": vector_store_result,
"relations": graph_result,
}
else:
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return vector_store_result
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
if self.custom_prompt:
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
try:
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem)
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=messages_embeddings,
limit=5,
filters=filters,
)
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
# mapping UUIDs with integers for handling UUID hallucinations
temp_uuid_mapping = {}
for idx, item in enumerate(retrieved_old_memory):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)
new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
new_memories_with_actions = json.loads(new_memories_with_actions)
returned_memories = []
try:
for resp in new_memories_with_actions["memory"]:
logging.info(resp)
try:
if resp["event"] == "ADD":
memory_id = self._create_memory(
data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata
)
returned_memories.append(
{
"id": memory_id,
"memory": resp["text"],
"event": resp["event"],
}
)
elif resp["event"] == "UPDATE":
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=resp["text"],
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": temp_uuid_mapping[resp["id"]],
"memory": resp["text"],
"event": resp["event"],
"previous_memory": resp["old_memory"],
}
)
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp["id"]])
returned_memories.append(
{
"id": temp_uuid_mapping[resp["id"]],
"memory": resp["text"],
"event": resp["event"],
}
)
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("mem1.add", self, {"version": self.api_version, "keys": list(filters.keys())})
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.api_version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
elif filters["agent_id"]:
self.graph.agent_id = filters["agent_id"]
elif filters["run_id"]:
self.graph.run_id = filters["run_id"]
else:
self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("mem1.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem1.get_all", self, {"limit": limit, "keys": list(filters.keys())})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit)
if self.api_version == "v1.1" and self.enable_graph
else None
)
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
else:
return {"results": all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return all_memories
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
all_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
return all_memories
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event(
"mem1.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit)
if self.api_version == "v1.1" and self.enable_graph
else None
)
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
return {"results": original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return original_memories
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
original_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("mem1.update", self, {"memory_id": memory_id})
existing_embeddings = {data: self.embedding_model.embed(data)}
self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem1.delete", self, {"memory_id": memory_id})
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem1.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.api_version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("mem1.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, existing_embeddings, metadata=None):
logging.info(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem1._create_memory", self, {"memory_id": memory_id})
return memory_id
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = self.vector_store.get(vector_id=memory_id)
except Exception:
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem1._update_memory", self, {"memory_id": memory_id})
return memory_id
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem1._delete_memory", self, {"memory_id": memory_id})
return memory_id
def reset(self):
"""
Reset the memory store.
"""
logger.warning("Resetting all memories")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.db.reset()
capture_event("mem1.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.") |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
Can we use the use the ElasticSearch as db instead of the default sqlite?
Arun.
Beta Was this translation helpful? Give feedback.
All reactions