Skip to content

Commit

Permalink
Merge pull request #34 from neo4j-field/develop-backend-tests
Browse files Browse the repository at this point in the history
Develop backend tests
  • Loading branch information
a-s-g93 authored May 10, 2024
2 parents de1ea6d + 142a89f commit a6bcc69
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/main/app/backend/database/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Communicator:
Base class for graph reader and writer.
"""

def __init__(self, secret_manager: Optional[SecretManager] = None) -> None:
def __init__(self, secret_manager: Optional[SecretManager]) -> None:

if secret_manager is not None:
print("Grabbing secrets from GCP.")
Expand Down Expand Up @@ -159,7 +159,7 @@ def log_assistant(
message: AssistantMessage,
previous_message_id: str,
context_ids: List[str],
):
) -> None:
"""
This method logs a new assistant message to the neo4j database and
creates appropriate relationships.
Expand Down
8 changes: 8 additions & 0 deletions src/main/app/backend/example.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
OPENAI_API_KEY="sk-..."
OPENAI_VERSION="2023-03-15-preview"
NEO4J_URI="neo4j+s://abc123.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="password"
NEO4J_DATABASE="neo4j"
GCP_PROJECT_ID="proj-123"
GCP_REGION="us-central1"
78 changes: 53 additions & 25 deletions src/main/app/backend/routers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,60 @@
from typing import List, Dict, Union, Tuple
from uuid import uuid4

from fastapi import APIRouter, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, Depends

from database.communicator import GraphReader, GraphWriter
from objects.question import Question
from objects.response import Response
from objects.nodes import UserMessage, AssistantMessage
from resources.prompts.prompts import prompt_no_context_template, prompt_template
from tools.embedding import TextEmbeddingService
from tools.embedding import TextEmbeddingService, EmbeddingServiceProtocol
from tools.llm import LLM
from tools.secret_manager import SecretManager

PUBLIC = True

sm = SecretManager()
router = APIRouter()
reader = GraphReader(secret_manager=sm)
writer = GraphWriter(secret_manager=sm)


def get_reader():
reader = GraphReader(secret_manager=sm)
try:
yield reader
finally:
reader.close_driver()


def get_writer():
writer = GraphWriter(secret_manager=sm)
try:
yield writer
finally:
writer.close_driver()


def get_embedding_service() -> EmbeddingServiceProtocol:
return TextEmbeddingService()


def get_llm(question: Question) -> LLM:
return LLM(llm_type=question.llm_type, temperature=question.temperature)


@router.get("/")
def get_default() -> str:
return "Agent Neo backend is live."


# Todo: Implement bearer tokens in the backend?
# right now anyone with the url and the endpoints can hit them
@router.post("/llm_dummy", response_model=Response)
async def get_response(
question: Question, background_tasks: BackgroundTasks
) -> Response:
async def get_response(question: Question) -> Response:
"""
Dummy test.
"""
print("dummy test")

question_embedding = [0.321, 0.123]
llm_response = "This call works!"

Expand All @@ -45,18 +70,14 @@ async def get_response(
assistant_message = AssistantMessage(
session_id=question.session_id,
conversation_id=question.conversation_id,
prompt="",
prompt=prompt_no_context_template,
content=llm_response,
public=PUBLIC,
vectorIndexSearch=True,
number_of_documents=question.number_of_documents,
temperature=question.temperature,
)

background_tasks.add_task(
write_notification, question.question, message="some notification"
)

return Response(
session_id=question.session_id,
conversation_id=question.conversation_id,
Expand All @@ -68,22 +89,26 @@ async def get_response(

@router.post("/llm", response_model=Response)
async def get_response(
question: Question, background_tasks: BackgroundTasks
question: Question,
background_tasks: BackgroundTasks,
reader: GraphReader = Depends(get_reader),
writer: GraphWriter = Depends(get_writer),
embedding_service: EmbeddingServiceProtocol = Depends(get_embedding_service),
llm: LLM = Depends(get_llm),
) -> Response:
"""
Gather context from the graph and retrieve a response from the designated LLM endpoint.
"""

print("real call")
question_embedding = TextEmbeddingService().get_embedding(text=question.question)
question_embedding = embedding_service.get_embedding(text=question.question)
print("got embedding...")
context = reader.retrieve_context_documents(
question_embedding=question_embedding,
number_of_context_documents=question.number_of_documents,
)
# print(context)
print("context retrieved...")
llm = LLM(llm_type=question.llm_type, temperature=question.temperature)
# llm = LLM(llm_type=question.llm_type, temperature=question.temperature)
print("llm initialized...")
user_id: str = "user-" + str(uuid4())
assistant_id: str = "llm-" + str(uuid4())
Expand Down Expand Up @@ -119,12 +144,14 @@ async def get_response(
question.message_history,
question.llm_type,
question.temperature,
writer,
)
background_tasks.add_task(
log_assistant_message,
assistant_message,
user_message.message_id,
list(context["index"]),
writer,
)
print("returning...")
return Response(
Expand All @@ -137,7 +164,11 @@ async def get_response(


def log_user_message(
message: UserMessage, message_history: List[str], llm_type: str, temperature: float
message: UserMessage,
message_history: List[str],
llm_type: str,
temperature: float,
writer: GraphWriter,
) -> None:
"""
Log a user message in the graph. If this is the first message, then also log the conversation and session.
Expand All @@ -153,7 +184,10 @@ def log_user_message(


def log_assistant_message(
message: AssistantMessage, previous_message_id: str, context_ids: List[str]
message: AssistantMessage,
previous_message_id: str,
context_ids: List[str],
writer: GraphWriter,
) -> None:
"""
Log an assistant message in the graph.
Expand All @@ -172,9 +206,3 @@ def get_prompt(context: List[str]) -> str:
"""

return prompt_no_context_template if len(context) < 1 else prompt_template


def write_notification(email: str, message=""):
with open("dummy_log.txt", mode="w") as email_file:
content = f"notification for {email}: {message}"
email_file.write(content)
16 changes: 13 additions & 3 deletions src/main/app/backend/routers/rating.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends

from database.communicator import GraphWriter
from objects.rating import Rating
from tools.secret_manager import SecretManager

sm = SecretManager()
router = APIRouter()
writer = GraphWriter(secret_manager=sm)
# writer = GraphWriter(secret_manager=sm)


def get_writer():
writer = GraphWriter(secret_manager=sm)
try:
yield writer
finally:
writer.close_driver()


@router.post("/rating")
async def rate_message(rating: Rating) -> None:
async def rate_message(
rating: Rating, writer: GraphWriter = Depends(get_writer)
) -> None:
"""
Write a message rating to the database.
"""
Expand Down
1 change: 1 addition & 0 deletions src/main/app/backend/run_all_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.runner
1 change: 1 addition & 0 deletions src/main/app/backend/run_database_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.database_runner
1 change: 1 addition & 0 deletions src/main/app/backend/run_route_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.route_runner
14 changes: 14 additions & 0 deletions src/main/app/backend/tests/database_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

mods = ["test_graph_reader", "test_graph_writer"]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
15 changes: 15 additions & 0 deletions src/main/app/backend/tests/route_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import unittest


mods = ["test_llm_route", "test_rating_route"]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
19 changes: 19 additions & 0 deletions src/main/app/backend/tests/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# tests/runner.py
import os
import unittest

# import your test modules
from . import *

mods = [x[:-3] for x in os.listdir("tests/") if x.startswith("test_")]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
10 changes: 6 additions & 4 deletions src/main/app/backend/tests/test_graph_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from langchain_community.embeddings import FakeEmbeddings
from database.communicator import GraphReader
from tools.secret_manager import SecretManager


class TestGraphReader(unittest.TestCase):
Expand All @@ -13,13 +14,14 @@ def setUpClass(cls) -> None:
os.environ.get("DATABASE_TYPE") == "dev"
), f"Current db is {os.environ.get('DATABASE_TYPE')}. Please change to dev for testing."
cls.embedder = FakeEmbeddings(size=768)
cls.sm = SecretManager()

def test_init(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)
gr.close_driver()

def test_standard_context_retrieval(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)

context = gr.retrieve_context_documents(
question_embedding=self.embedder.embed_query("What is gds?"),
Expand All @@ -35,7 +37,7 @@ def test_standard_context_retrieval(self) -> None:
gr.close_driver()

def test_topics_context_retrieval(self) -> None:
# gr = GraphReader()
# gr = GraphReader(secret_manager=self.sm)

# context = gr.retrieve_context_documents_by_topic(
# question_embedding=self.embedder.embed_query("What is gds?"),
Expand All @@ -54,7 +56,7 @@ def test_topics_context_retrieval(self) -> None:
pass # not implemented

def test_match_by_id(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)
ids = [
"conv-20aa11bb-d65b-4c77-a6f3-58a39d8d0205",
"conv-692266c4-33ba-4f6f-bf41-fcea75fd2579",
Expand Down
Loading

0 comments on commit a6bcc69

Please sign in to comment.