-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
182 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING | ||
|
||
from sqlalchemy import event | ||
from sqlalchemy import Engine | ||
from sqlmodel import StaticPool | ||
from sqlmodel import create_engine | ||
from sqlmodel import Session | ||
|
||
from common import tables | ||
|
||
|
||
if TYPE_CHECKING: | ||
from typing import TypeVar | ||
T = TypeVar("T", bound=tables.SQLModel) | ||
|
||
|
||
def collation(string1, string2): | ||
if string1 == string2: | ||
return 0 | ||
elif string1 > string2: | ||
return 1 | ||
else: | ||
return -1 | ||
|
||
@event.listens_for(Engine, "connect") | ||
def set_sqlite_pragma(dbapi_connection, dummy_connection_record): | ||
dbapi_connection.create_collation("Hebrew_100_CI_AI_SC_UTF8", collation) | ||
dbapi_connection.create_collation("Hebrew_CI_AI", collation) | ||
cursor = dbapi_connection.cursor() | ||
cursor.execute("PRAGMA foreign_keys=ON") | ||
cursor.close() | ||
|
||
class MockDb: | ||
def __init__(self): | ||
self.db_uri = "sqlite:///:memory:?cache=shared" | ||
self.engine = create_engine( | ||
self.db_uri, | ||
connect_args={"check_same_thread": False}, | ||
poolclass=StaticPool, | ||
) | ||
self.session = Session( | ||
bind=self.engine, | ||
expire_on_commit=False, | ||
autoflush=True, | ||
) | ||
tables.SQLModel.metadata.create_all(self.engine) | ||
|
||
def add(self, entity: T) -> T: | ||
self.session.begin() | ||
self.session.add(entity) | ||
self.session.commit() | ||
return entity | ||
|
||
def add_many(self, entities: list[T]) -> None: | ||
self.session.begin() | ||
for entity in entities: | ||
self.session.add(entity) | ||
self.session.commit() | ||
for entity in entities: | ||
self.session.refresh(entity) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import unittest | ||
import datetime | ||
import pytest | ||
from sqlmodel import Session | ||
|
||
from common import tables | ||
from common.error import HSError | ||
from logic.game_logic import SecretLogic | ||
from mock.mock_db import MockDb | ||
|
||
|
||
class TestGameLogic(unittest.IsolatedAsyncioTestCase): | ||
async def asyncSetUp(self): | ||
self.db = MockDb() | ||
self.date = datetime.date(2021, 1, 1) | ||
self.testee = SecretLogic(session=self.db.session, dt=self.date) | ||
|
||
async def test_no_secret(self) -> None: | ||
# act & assert | ||
with pytest.raises(HSError): | ||
await self.testee.get_secret() | ||
|
||
async def test_get_secret(self) -> None: | ||
# arrange | ||
db_secret = tables.SecretWord(word="test", game_date=self.date) | ||
self.db.add(db_secret) | ||
|
||
# act | ||
secret = await self.testee.get_secret() | ||
|
||
# assert | ||
self.assertEqual(db_secret.word, secret) | ||
|
||
async def test_get_secret__cache(self): | ||
# arrange | ||
cached = self.db.add(tables.SecretWord(word="cached", game_date=self.date)) | ||
await self.testee.get_secret() | ||
with Session(self.db.engine) as session: | ||
db_secret = session.get(tables.SecretWord, cached.id) | ||
db_secret.word = "not_cached" | ||
session.add(db_secret) | ||
session.commit() | ||
|
||
# act | ||
secret = await self.testee.get_secret() | ||
|
||
# assert | ||
self.assertEqual("cached", secret) | ||
|
||
async def test_get_secret__dont_cache_if_no_secret(self): | ||
# arrange | ||
try: | ||
await self.testee.get_secret() | ||
except HSError: | ||
pass | ||
self.db.add(tables.SecretWord(word="not_cached", game_date=self.date)) | ||
|
||
# act | ||
secret = await self.testee.get_secret() | ||
|
||
# assert | ||
self.assertEqual("not_cached", secret) |