diff --git a/src/main/app/backend/run_all_tests.sh b/src/main/app/backend/run_all_tests.sh new file mode 100755 index 00000000..70f7c1df --- /dev/null +++ b/src/main/app/backend/run_all_tests.sh @@ -0,0 +1 @@ + python3 -m unittest tests.runner \ No newline at end of file diff --git a/src/main/app/backend/run_database_tests.sh b/src/main/app/backend/run_database_tests.sh new file mode 100755 index 00000000..6c9c35b3 --- /dev/null +++ b/src/main/app/backend/run_database_tests.sh @@ -0,0 +1 @@ + python3 -m unittest tests.database_runner \ No newline at end of file diff --git a/src/main/app/backend/run_route_tests.sh b/src/main/app/backend/run_route_tests.sh new file mode 100755 index 00000000..6c6ee552 --- /dev/null +++ b/src/main/app/backend/run_route_tests.sh @@ -0,0 +1 @@ + python3 -m unittest tests.route_runner \ No newline at end of file diff --git a/src/main/app/backend/tests/database_runner.py b/src/main/app/backend/tests/database_runner.py new file mode 100644 index 00000000..92817e4c --- /dev/null +++ b/src/main/app/backend/tests/database_runner.py @@ -0,0 +1,14 @@ +import unittest + +mods = ["test_graph_reader", "test_graph_writer"] +# initialize the test suite +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# add tests to the test suite +for mod in mods: + suite.addTests(loader.loadTestsFromName(f"tests.{mod}")) + +# initialize a runner, pass it your suite and run it +runner = unittest.TextTestRunner(verbosity=3) +result = runner.run(suite) \ No newline at end of file diff --git a/src/main/app/backend/tests/route_runner.py b/src/main/app/backend/tests/route_runner.py new file mode 100644 index 00000000..31588a79 --- /dev/null +++ b/src/main/app/backend/tests/route_runner.py @@ -0,0 +1,15 @@ +import unittest + + +mods = ["test_llm_route", "test_rating_route"] +# initialize the test suite +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# add tests to the test suite +for mod in mods: + suite.addTests(loader.loadTestsFromName(f"tests.{mod}")) + +# initialize a runner, pass it your suite and run it +runner = unittest.TextTestRunner(verbosity=3) +result = runner.run(suite) \ No newline at end of file diff --git a/src/main/app/backend/tests/runner.py b/src/main/app/backend/tests/runner.py new file mode 100644 index 00000000..8594e5c3 --- /dev/null +++ b/src/main/app/backend/tests/runner.py @@ -0,0 +1,19 @@ +# tests/runner.py +import os +import unittest + +# import your test modules +from . import * + +mods = [x[:-3] for x in os.listdir("tests/") if x.startswith("test_")] +# initialize the test suite +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# add tests to the test suite +for mod in mods: + suite.addTests(loader.loadTestsFromName(f"tests.{mod}")) + +# initialize a runner, pass it your suite and run it +runner = unittest.TextTestRunner(verbosity=3) +result = runner.run(suite) \ No newline at end of file diff --git a/src/main/app/backend/tests/test_graph_reader.py b/src/main/app/backend/tests/test_graph_reader.py index 8b004f8c..1b762b57 100644 --- a/src/main/app/backend/tests/test_graph_reader.py +++ b/src/main/app/backend/tests/test_graph_reader.py @@ -3,6 +3,7 @@ from langchain_community.embeddings import FakeEmbeddings from database.communicator import GraphReader +from tools.secret_manager import SecretManager class TestGraphReader(unittest.TestCase): @@ -13,13 +14,14 @@ def setUpClass(cls) -> None: os.environ.get("DATABASE_TYPE") == "dev" ), f"Current db is {os.environ.get('DATABASE_TYPE')}. Please change to dev for testing." cls.embedder = FakeEmbeddings(size=768) + cls.sm = SecretManager() def test_init(self) -> None: - gr = GraphReader() + gr = GraphReader(secret_manager=self.sm) gr.close_driver() def test_standard_context_retrieval(self) -> None: - gr = GraphReader() + gr = GraphReader(secret_manager=self.sm) context = gr.retrieve_context_documents( question_embedding=self.embedder.embed_query("What is gds?"), @@ -35,7 +37,7 @@ def test_standard_context_retrieval(self) -> None: gr.close_driver() def test_topics_context_retrieval(self) -> None: - # gr = GraphReader() + # gr = GraphReader(secret_manager=self.sm) # context = gr.retrieve_context_documents_by_topic( # question_embedding=self.embedder.embed_query("What is gds?"), @@ -54,7 +56,7 @@ def test_topics_context_retrieval(self) -> None: pass # not implemented def test_match_by_id(self) -> None: - gr = GraphReader() + gr = GraphReader(secret_manager=self.sm) ids = [ "conv-20aa11bb-d65b-4c77-a6f3-58a39d8d0205", "conv-692266c4-33ba-4f6f-bf41-fcea75fd2579", diff --git a/src/main/app/backend/tests/test_llm_route.py b/src/main/app/backend/tests/test_llm_route.py index d2a3a9db..d61c4514 100644 --- a/src/main/app/backend/tests/test_llm_route.py +++ b/src/main/app/backend/tests/test_llm_route.py @@ -8,6 +8,7 @@ from tools.embedding import FakeEmbeddingService from tools.llm import LLM from objects.nodes import UserMessage, AssistantMessage +from objects.rating import Rating from routers.llm import get_embedding_service, get_llm, get_reader, get_writer client = TestClient(app) @@ -30,6 +31,8 @@ def log_assistant( ) -> None: pass + def rate_message(rating: Rating) -> None: + pass class GraphReaderMock: def retrieve_context_documents( @@ -94,4 +97,4 @@ def test_llm_dummy_route(self) -> None: def test_llm_route(self) -> None: resp = client.post("/llm", json=self.question) self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json()["content"], "GDS is cool.") + self.assertEqual(resp.json()["content"], "GDS is cool.") \ No newline at end of file diff --git a/src/main/app/backend/tests/test_rating_route.py b/src/main/app/backend/tests/test_rating_route.py index 42b0f556..a860316b 100644 --- a/src/main/app/backend/tests/test_rating_route.py +++ b/src/main/app/backend/tests/test_rating_route.py @@ -5,16 +5,10 @@ from main import app from routers.llm import get_writer -from objects.rating import Rating +from tests.test_llm_route import GraphWriterMock client = TestClient(app) - -class GraphWriterMock: - def rate_message(rating: Rating) -> None: - pass - - def override_get_writer(): return GraphWriterMock() @@ -36,4 +30,4 @@ def setUpClass(cls) -> None: def test_rating_route(self) -> None: resp = client.post("/rating", json=self.rating) - self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.status_code, 200) \ No newline at end of file diff --git a/src/main/app/backend/tools/llm.py b/src/main/app/backend/tools/llm.py index 2db3763b..ea86feed 100644 --- a/src/main/app/backend/tools/llm.py +++ b/src/main/app/backend/tools/llm.py @@ -103,7 +103,7 @@ def _init_llm(self): case _: raise ValueError("Please provide a valid LLM type.") - def _format_llm_input(self, question: str, context: Optional[pd.DataFrame]) -> str: + def _format_llm_input(self, question: str, context: Optional[pd.DataFrame] = None) -> str: """ Format the LLM input and return the input along with the context IDs if they exist. """ @@ -120,7 +120,7 @@ def get_response( question: Question, user_id: str, assistant_id: str, - context: Optional[pd.DataFrame], + context: Optional[pd.DataFrame] = None, ) -> str: """ Get a response from the LLM.