-
Notifications
You must be signed in to change notification settings - Fork 361
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #295 from AnimeKaizoku/KigyoDev
Add CachingQuery class to implement caching in SQL queries
- Loading branch information
Showing
1 changed file
with
52 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,73 @@ | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import sessionmaker, scoped_session | ||
|
||
from sqlalchemy.orm import sessionmaker, scoped_session, Query | ||
from tg_bot import DB_URI, KInit, log | ||
|
||
|
||
class CachingQuery(Query): | ||
""" | ||
A subclass of Query that implements caching using the cache-aside caching pattern. | ||
Attributes: | ||
cache (dict): A dictionary used for caching query results. | ||
Methods: | ||
__iter__(): Overrides the __iter__ method of the parent class to implement caching. | ||
cache_key(): Generates a cache key based on the query's SQL statement and parameters. | ||
""" | ||
|
||
def __init__(self, *args, cache=None, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.cache = cache or {} | ||
|
||
def __iter__(self): | ||
""" | ||
Overrides the __iter__ method of the parent class to implement caching. | ||
Returns: | ||
iter: An iterator over the cached query results. | ||
""" | ||
cache_key = self.cache_key() | ||
result = self.cache.get(cache_key) | ||
|
||
if result is None: | ||
result = list(super().__iter__()) | ||
self.cache[cache_key] = result | ||
|
||
return iter(result) | ||
|
||
def cache_key(self): | ||
""" | ||
Generates a cache key based on the query's SQL statement and parameters. | ||
Returns: | ||
str: The cache key. | ||
""" | ||
stmt = self.with_labels().statement | ||
compiled = stmt.compile() | ||
params = compiled.params | ||
return " ".join([str(compiled)] + [str(params[k]) for k in sorted(params)]) | ||
|
||
|
||
if DB_URI and DB_URI.startswith("postgres://"): | ||
DB_URI = DB_URI.replace("postgres://", "postgresql://", 1) | ||
|
||
|
||
def start() -> scoped_session: | ||
engine = create_engine(DB_URI, client_encoding="utf8", echo=KInit.DEBUG) | ||
log.info("[PostgreSQL] Connecting to database......") | ||
BASE.metadata.bind = engine | ||
BASE.metadata.create_all(engine) | ||
return scoped_session(sessionmaker(bind=engine, autoflush=False)) | ||
return scoped_session( | ||
sessionmaker(bind=engine, autoflush=False, query_cls=CachingQuery) | ||
) | ||
|
||
|
||
BASE = declarative_base() | ||
try: | ||
SESSION: scoped_session = start() | ||
except Exception as e: | ||
log.exception(f'[PostgreSQL] Failed to connect due to {e}') | ||
log.exception(f"[PostgreSQL] Failed to connect due to {e}") | ||
exit() | ||
|
||
log.info("[PostgreSQL] Connection successful, session started.") |