From 3075b22744af3bbbaf048f8c9433e4895db5c156 Mon Sep 17 00:00:00 2001 From: Itamar Shefi Date: Wed, 17 Apr 2024 16:39:27 +0300 Subject: [PATCH] Add tests --- logic/game_logic.py | 2 +- mock/mock_db.py | 26 +++++++++++++-------- poetry.lock | 35 +++++++++++++++++++++++++++- pyproject.toml | 1 + routers/admin_routes.py | 47 +++++++++++++++++++++++++++----------- routers/base.py | 8 ++++--- scripts/semantle.py | 3 +-- tests/test_secret_logic.py | 10 ++++---- 8 files changed, 98 insertions(+), 34 deletions(-) diff --git a/logic/game_logic.py b/logic/game_logic.py index ec0c624..5e09def 100644 --- a/logic/game_logic.py +++ b/logic/game_logic.py @@ -37,7 +37,7 @@ async def get_secret(self) -> str: @staticmethod @lru_cache - def _get_cached_secret(session: Session, date: datetime.date) -> str | None: + def _get_cached_secret(session: Session, date: datetime.date) -> str: # TODO: this function is accessing db but is NOT ASYNC, which might be # problematic if we choose to do async stuff with sql in the future. # the reason for that is `@lru_cache` does not support async. diff --git a/mock/mock_db.py b/mock/mock_db.py index 1f65a82..6dd6d1f 100644 --- a/mock/mock_db.py +++ b/mock/mock_db.py @@ -1,21 +1,23 @@ from __future__ import annotations + +import sqlite3 from typing import TYPE_CHECKING -from sqlalchemy import event from sqlalchemy import Engine +from sqlalchemy import event +from sqlmodel import Session +from sqlmodel import SQLModel from sqlmodel import StaticPool from sqlmodel import create_engine -from sqlmodel import Session - -from common import tables - if TYPE_CHECKING: + from typing import Any from typing import TypeVar - T = TypeVar("T", bound=tables.SQLModel) + + T = TypeVar("T", bound=SQLModel) -def collation(string1, string2): +def collation(string1: str, string2: str) -> int: if string1 == string2: return 0 elif string1 > string2: @@ -23,16 +25,20 @@ def collation(string1, string2): else: return -1 + @event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection, dummy_connection_record): +def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, dummy_connection_record: Any +) -> None: dbapi_connection.create_collation("Hebrew_100_CI_AI_SC_UTF8", collation) dbapi_connection.create_collation("Hebrew_CI_AI", collation) cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + class MockDb: - def __init__(self): + def __init__(self) -> None: self.db_uri = "sqlite:///:memory:?cache=shared" self.engine = create_engine( self.db_uri, @@ -44,7 +50,7 @@ def __init__(self): expire_on_commit=False, autoflush=True, ) - tables.SQLModel.metadata.create_all(self.engine) + SQLModel.metadata.create_all(self.engine) def add(self, entity: T) -> T: self.session.begin() diff --git a/poetry.lock b/poetry.lock index b3a3ad2..b8e2669 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1626,6 +1626,25 @@ pluggy = ">=1.4,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-sugar" +version = "1.0.0" +description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)." +optional = false +python-versions = "*" +files = [ + {file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"}, + {file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"}, +] + +[package.dependencies] +packaging = ">=21.3" +pytest = ">=6.2.0" +termcolor = ">=2.1.0" + +[package.extras] +dev = ["black", "flake8", "pre-commit"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2061,6 +2080,20 @@ files = [ ply = ">=3.4" six = ">=1.12.0" +[[package]] +name = "termcolor" +version = "2.4.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.8" +files = [ + {file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"}, + {file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "types-pyopenssl" version = "24.0.0.20240311" @@ -2585,4 +2618,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ab76e08263ed9cbe83f7231e777ee072c1d19958a79ea395bf58ee12b165646a" +content-hash = "f22b27dbc406d1f1e30627897959ab3b040ea5270469bbcc89602a5e1712d04b" diff --git a/pyproject.toml b/pyproject.toml index 510df73..d6bd7a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ types-redis = "^4.6.0.11" ruff = "0.3.4" alembic = "^1.13.1" pytest = "^8.1.1" +pytest-sugar = "^1.0.0" [tool.ruff] fix = true diff --git a/routers/admin_routes.py b/routers/admin_routes.py index 36b4924..629ef53 100644 --- a/routers/admin_routes.py +++ b/routers/admin_routes.py @@ -7,14 +7,15 @@ from fastapi.responses import HTMLResponse from pydantic import BaseModel from sqlmodel import Session +from sqlmodel import select +from common import tables from common.session import hs_transaction -from logic.game_logic import SecretLogic, CacheSecretLogic +from logic.game_logic import CacheSecretLogic +from logic.game_logic import SecretLogic from model import GensimModel from routers.base import render from routers.base import super_admin -from sqlmodel import select -from common import tables TOP_SAMPLE = 10000 @@ -25,22 +26,34 @@ async def index(request: Request) -> HTMLResponse: model = request.app.state.model secret_logic = SecretLogic(request.app.state.session) - all_secrets = await secret_logic.get_all_secrets(with_future=True) - potential_secrets = [] + all_secrets = [ + secret[0] for secret in await secret_logic.get_all_secrets(with_future=True) + ] + potential_secrets: list[str] = [] while len(potential_secrets) < 45: secret = await get_random_word(model) # todo: in batches if secret not in all_secrets: potential_secrets.append(secret) - return render(name="set_secret.html", request=request, potential_secrets=potential_secrets) + return render( + name="set_secret.html", request=request, potential_secrets=potential_secrets + ) @admin_router.get("/model", include_in_schema=False) -async def get_word_data(request: Request, word: str) -> dict[str, list[str] | datetime.date]: +async def get_word_data( + request: Request, word: str +) -> dict[str, list[str] | datetime.date]: session = request.app.state.session redis = request.app.state.redis model = request.app.state.model - logic = CacheSecretLogic(session=session, redis=redis, secret=word, dt=await get_date(session), model=model) + logic = CacheSecretLogic( + session=session, + redis=redis, + secret=word, + dt=await get_date(session), + model=model, + ) await logic.simulate_set_secret(force=False) cache = await logic.cache return { @@ -48,26 +61,34 @@ async def get_word_data(request: Request, word: str) -> dict[str, list[str] | da "data": cache[::-1], } + class SetSecretRequest(BaseModel): secret: str clues: list[str] + @admin_router.post("/set-secret", include_in_schema=False) -async def set_new_secret(request: Request, set_secret: SetSecretRequest): +async def set_new_secret(request: Request, set_secret: SetSecretRequest) -> str: session = request.app.state.session redis = request.app.state.redis model = request.app.state.model - logic = CacheSecretLogic(session=session, redis=redis, secret=set_secret.secret, dt=await get_date(session), model=model) + logic = CacheSecretLogic( + session=session, + redis=redis, + secret=set_secret.secret, + dt=await get_date(session), + model=model, + ) await logic.simulate_set_secret(force=False) await logic.do_populate(set_secret.clues) return f"Set '{set_secret.secret}' with clues '{set_secret.clues}' on {logic.date_}" - # TODO: everything below here should be in a separate file, and set_secret script should be updated to use it async def get_random_word(model: GensimModel) -> str: rand_index = random.randint(0, TOP_SAMPLE) - return model.model.index_to_key[rand_index] + word: str = model.model.index_to_key[rand_index] + return word async def get_date(session: Session) -> datetime.date: @@ -77,4 +98,4 @@ async def get_date(session: Session) -> datetime.date: latest: datetime.date = s.exec(query).first() dt = latest + datetime.timedelta(days=1) - return dt \ No newline at end of file + return dt diff --git a/routers/base.py b/routers/base.py index ae02f96..4d83a57 100644 --- a/routers/base.py +++ b/routers/base.py @@ -6,11 +6,11 @@ from fastapi import FastAPI from fastapi import HTTPException +from fastapi import Request from fastapi import status -from fastapi.templating import Jinja2Templates from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates -from fastapi import Request from logic.game_logic import CacheSecretLogic from logic.game_logic import VectorLogic from logic.user_logic import UserLogic @@ -32,7 +32,9 @@ async def get_logics( delta += app.state.days_delta date = get_date(delta) logic = VectorLogic(app.state.session, dt=date, model=app.state.model) - secret = await logic.secret_logic.get_secret() # TODO: raise a user-friendly exception + secret = ( + await logic.secret_logic.get_secret() + ) # TODO: raise a user-friendly exception cache_logic = CacheSecretLogic( app.state.session, app.state.redis, diff --git a/scripts/semantle.py b/scripts/semantle.py index 49d0729..9d7d4e7 100644 --- a/scripts/semantle.py +++ b/scripts/semantle.py @@ -3,13 +3,12 @@ import sys from datetime import datetime - base = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.extend([base]) -from common.session import get_session # noqa: E402 from common.session import get_model # noqa: E402 from common.session import get_redis # noqa: E402 +from common.session import get_session # noqa: E402 from logic.game_logic import CacheSecretLogic # noqa: E402 from logic.game_logic import VectorLogic # noqa: E402 diff --git a/tests/test_secret_logic.py b/tests/test_secret_logic.py index e598ce4..5093b68 100644 --- a/tests/test_secret_logic.py +++ b/tests/test_secret_logic.py @@ -1,5 +1,6 @@ -import unittest import datetime +import unittest + import pytest from sqlmodel import Session @@ -10,7 +11,7 @@ class TestGameLogic(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): + async def asyncSetUp(self) -> None: self.db = MockDb() self.date = datetime.date(2021, 1, 1) self.testee = SecretLogic(session=self.db.session, dt=self.date) @@ -31,12 +32,13 @@ async def test_get_secret(self) -> None: # assert self.assertEqual(db_secret.word, secret) - async def test_get_secret__cache(self): + async def test_get_secret__cache(self) -> None: # arrange cached = self.db.add(tables.SecretWord(word="cached", game_date=self.date)) await self.testee.get_secret() with Session(self.db.engine) as session: db_secret = session.get(tables.SecretWord, cached.id) + assert db_secret is not None db_secret.word = "not_cached" session.add(db_secret) session.commit() @@ -47,7 +49,7 @@ async def test_get_secret__cache(self): # assert self.assertEqual("cached", secret) - async def test_get_secret__dont_cache_if_no_secret(self): + async def test_get_secret__dont_cache_if_no_secret(self) -> None: # arrange try: await self.testee.get_secret()