Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend for async SQLAlchemy ORM #121

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
132 changes: 125 additions & 7 deletions fastapi_crudrouter/core/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
import inspect

try:
from sqlalchemy.orm import Session
from sqlalchemy.ext.declarative import DeclarativeMeta as Model
from sqlalchemy.exc import IntegrityError
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy import __version__ as sqlalchemy_version

if sqlalchemy_version >= "1.4":
from sqlalchemy.future import select
except ImportError:
Model = None
Session = None
Expand Down Expand Up @@ -39,6 +44,7 @@ def __init__(
update_route: Union[bool, DEPENDENCIES] = True,
delete_one_route: Union[bool, DEPENDENCIES] = True,
delete_all_route: Union[bool, DEPENDENCIES] = True,
use_async: Optional[bool] = None, # if not set, try autodetect
**kwargs: Any
) -> None:
assert (
Expand All @@ -47,6 +53,12 @@ def __init__(

self.db_model = db_model
self.db_func = db
if use_async == None:
self.use_async = (
inspect.isasyncgenfunction(db) or inspect.isasyncgen(db)
) and sqlalchemy_version >= "1.4" # autodetect async mode
else:
self.use_async = use_async
self._pk: str = db_model.__table__.primary_key.columns.keys()[0]
self._pk_type: type = _utils.get_pk_type(schema, self._pk)

Expand Down Expand Up @@ -82,7 +94,32 @@ def route(
)
return db_models

return route
async def async_route(
db: Session = Depends(self.db_func),
pagination: PAGINATION = self.pagination,
) -> List[Model]:
skip, limit = pagination.get("skip"), pagination.get("limit")

res = await db.execute(
select(self.db_model)
.order_by(getattr(self.db_model, self._pk))
.limit(limit)
.offset(skip)
)
res = res.all()

model: Model
db_models: List[Model] = []
for row in res:
(model,) = row
db_models.append(model)

return db_models

if self.use_async:
return async_route
else:
return route

def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(
Expand All @@ -95,7 +132,28 @@ def route(
else:
raise NOT_FOUND from None

return route
async def async_route(
item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore
) -> Model:
model: Model
try:
(model,) = (
await db.execute(
select(self.db_model).where(self.db_model.id == item_id)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fyi... my team noticed this doesn't work with PK's not named "id". I think this code might work better:

Suggested change
select(self.db_model).where(self.db_model.id == item_id)
select(self.db_model).where(getattr(self.db_model, self._pk) == item_id)

)
).one()
except NoResultFound:
model = None

if model:
return model
else:
raise NOT_FOUND from None

if self.use_async:
return async_route
else:
return route

def _create(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(
Expand All @@ -112,7 +170,24 @@ def route(
db.rollback()
raise HTTPException(422, "Key already exists") from None

return route
async def async_route(
model: self.create_schema, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
try:
db_model: Model = self.db_model(**model.dict())
db.add(db_model)
await db.commit()
await db.refresh(db_model)
return db_model
except IntegrityError:
await db.rollback()
raise HTTPException(422, "Key already exists") from None

if self.use_async:
return async_route
else:
return route

def _update(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(
Expand All @@ -135,7 +210,30 @@ def route(
db.rollback()
self._raise(e)

return route
async def async_route(
item_id: self._pk_type, # type: ignore
model: self.update_schema, # type: ignore
db: Session = Depends(self.db_func),
) -> Model:
try:
db_model: Model = await self._get_one()(item_id, db)

for key, value in model.dict(exclude={self._pk}).items():
if hasattr(db_model, key):
setattr(db_model, key, value)

await db.commit()
await db.refresh(db_model)

return db_model
except IntegrityError as e:
await db.rollback()
self._raise(e)

if self.use_async:
return async_route
else:
return route

def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
def route(db: Session = Depends(self.db_func)) -> List[Model]:
Expand All @@ -144,7 +242,15 @@ def route(db: Session = Depends(self.db_func)) -> List[Model]:

return self._get_all()(db=db, pagination={"skip": 0, "limit": None})

return route
async def async_route(db: Session = Depends(self.db_func)) -> List[Model]:
await db.execute("delete from " + self.db_model.__tablename__)
await db.commit()
return await self._get_all()(db=db, pagination={"skip": 0, "limit": None})

if self.use_async:
return async_route
else:
return route

def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE:
def route(
Expand All @@ -156,4 +262,16 @@ def route(

return db_model

return route
async def async_route(
item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore
) -> Model:
db_model: Model = await self._get_one()(item_id, db)
await db.delete(db_model)
await db.commit()

return db_model

if self.use_async:
return async_route
else:
return route