Skip to content

Commit

Permalink
create test runners for all, database and routes
Browse files Browse the repository at this point in the history
  • Loading branch information
a-s-g93 committed May 8, 2024
1 parent a86263e commit ff28f5d
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/main/app/backend/run_all_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.runner
1 change: 1 addition & 0 deletions src/main/app/backend/run_database_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.database_runner
1 change: 1 addition & 0 deletions src/main/app/backend/run_route_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -m unittest tests.route_runner
14 changes: 14 additions & 0 deletions src/main/app/backend/tests/database_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

mods = ["test_graph_reader", "test_graph_writer"]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
15 changes: 15 additions & 0 deletions src/main/app/backend/tests/route_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import unittest


mods = ["test_llm_route", "test_rating_route"]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
19 changes: 19 additions & 0 deletions src/main/app/backend/tests/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# tests/runner.py
import os
import unittest

# import your test modules
from . import *

mods = [x[:-3] for x in os.listdir("tests/") if x.startswith("test_")]
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()

# add tests to the test suite
for mod in mods:
suite.addTests(loader.loadTestsFromName(f"tests.{mod}"))

# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
10 changes: 6 additions & 4 deletions src/main/app/backend/tests/test_graph_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from langchain_community.embeddings import FakeEmbeddings
from database.communicator import GraphReader
from tools.secret_manager import SecretManager


class TestGraphReader(unittest.TestCase):
Expand All @@ -13,13 +14,14 @@ def setUpClass(cls) -> None:
os.environ.get("DATABASE_TYPE") == "dev"
), f"Current db is {os.environ.get('DATABASE_TYPE')}. Please change to dev for testing."
cls.embedder = FakeEmbeddings(size=768)
cls.sm = SecretManager()

def test_init(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)
gr.close_driver()

def test_standard_context_retrieval(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)

context = gr.retrieve_context_documents(
question_embedding=self.embedder.embed_query("What is gds?"),
Expand All @@ -35,7 +37,7 @@ def test_standard_context_retrieval(self) -> None:
gr.close_driver()

def test_topics_context_retrieval(self) -> None:
# gr = GraphReader()
# gr = GraphReader(secret_manager=self.sm)

# context = gr.retrieve_context_documents_by_topic(
# question_embedding=self.embedder.embed_query("What is gds?"),
Expand All @@ -54,7 +56,7 @@ def test_topics_context_retrieval(self) -> None:
pass # not implemented

def test_match_by_id(self) -> None:
gr = GraphReader()
gr = GraphReader(secret_manager=self.sm)
ids = [
"conv-20aa11bb-d65b-4c77-a6f3-58a39d8d0205",
"conv-692266c4-33ba-4f6f-bf41-fcea75fd2579",
Expand Down
5 changes: 4 additions & 1 deletion src/main/app/backend/tests/test_llm_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tools.embedding import FakeEmbeddingService
from tools.llm import LLM
from objects.nodes import UserMessage, AssistantMessage
from objects.rating import Rating
from routers.llm import get_embedding_service, get_llm, get_reader, get_writer

client = TestClient(app)
Expand All @@ -30,6 +31,8 @@ def log_assistant(
) -> None:
pass

def rate_message(rating: Rating) -> None:
pass

class GraphReaderMock:
def retrieve_context_documents(
Expand Down Expand Up @@ -94,4 +97,4 @@ def test_llm_dummy_route(self) -> None:
def test_llm_route(self) -> None:
resp = client.post("/llm", json=self.question)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json()["content"], "GDS is cool.")
self.assertEqual(resp.json()["content"], "GDS is cool.")
10 changes: 2 additions & 8 deletions src/main/app/backend/tests/test_rating_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@

from main import app
from routers.llm import get_writer
from objects.rating import Rating
from tests.test_llm_route import GraphWriterMock

client = TestClient(app)


class GraphWriterMock:
def rate_message(rating: Rating) -> None:
pass


def override_get_writer():
return GraphWriterMock()

Expand All @@ -36,4 +30,4 @@ def setUpClass(cls) -> None:

def test_rating_route(self) -> None:
resp = client.post("/rating", json=self.rating)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.status_code, 200)
4 changes: 2 additions & 2 deletions src/main/app/backend/tools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _init_llm(self):
case _:
raise ValueError("Please provide a valid LLM type.")

def _format_llm_input(self, question: str, context: Optional[pd.DataFrame]) -> str:
def _format_llm_input(self, question: str, context: Optional[pd.DataFrame] = None) -> str:
"""
Format the LLM input and return the input along with the context IDs if they exist.
"""
Expand All @@ -120,7 +120,7 @@ def get_response(
question: Question,
user_id: str,
assistant_id: str,
context: Optional[pd.DataFrame],
context: Optional[pd.DataFrame] = None,
) -> str:
"""
Get a response from the LLM.
Expand Down

0 comments on commit ff28f5d

Please sign in to comment.