diff --git a/fastapi_crudrouter/core/sqlalchemy.py b/fastapi_crudrouter/core/sqlalchemy.py index 58270f34..04552cdc 100644 --- a/fastapi_crudrouter/core/sqlalchemy.py +++ b/fastapi_crudrouter/core/sqlalchemy.py @@ -4,11 +4,16 @@ from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA +import inspect try: from sqlalchemy.orm import Session from sqlalchemy.ext.declarative import DeclarativeMeta as Model - from sqlalchemy.exc import IntegrityError + from sqlalchemy.exc import IntegrityError, NoResultFound + from sqlalchemy import __version__ as sqlalchemy_version + + if sqlalchemy_version >= "1.4": + from sqlalchemy.future import select except ImportError: Model = None Session = None @@ -39,6 +44,7 @@ def __init__( update_route: Union[bool, DEPENDENCIES] = True, delete_one_route: Union[bool, DEPENDENCIES] = True, delete_all_route: Union[bool, DEPENDENCIES] = True, + use_async: Optional[bool] = None, # if not set, try autodetect **kwargs: Any ) -> None: assert ( @@ -47,6 +53,12 @@ def __init__( self.db_model = db_model self.db_func = db + if use_async == None: + self.use_async = ( + inspect.isasyncgenfunction(db) or inspect.isasyncgen(db) + ) and sqlalchemy_version >= "1.4" # autodetect async mode + else: + self.use_async = use_async self._pk: str = db_model.__table__.primary_key.columns.keys()[0] self._pk_type: type = _utils.get_pk_type(schema, self._pk) @@ -82,7 +94,32 @@ def route( ) return db_models - return route + async def async_route( + db: Session = Depends(self.db_func), + pagination: PAGINATION = self.pagination, + ) -> List[Model]: + skip, limit = pagination.get("skip"), pagination.get("limit") + + res = await db.execute( + select(self.db_model) + .order_by(getattr(self.db_model, self._pk)) + .limit(limit) + .offset(skip) + ) + res = res.all() + + model: Model + db_models: List[Model] = [] + for row in res: + (model,) = row + db_models.append(model) + + return db_models + + if self.use_async: + return async_route + else: + return route def _get_one(self, *args: Any, **kwargs: Any) -> CALLABLE: def route( @@ -95,7 +132,28 @@ def route( else: raise NOT_FOUND from None - return route + async def async_route( + item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore + ) -> Model: + model: Model + try: + (model,) = ( + await db.execute( + select(self.db_model).where(self.db_model.id == item_id) + ) + ).one() + except NoResultFound: + model = None + + if model: + return model + else: + raise NOT_FOUND from None + + if self.use_async: + return async_route + else: + return route def _create(self, *args: Any, **kwargs: Any) -> CALLABLE: def route( @@ -112,7 +170,24 @@ def route( db.rollback() raise HTTPException(422, "Key already exists") from None - return route + async def async_route( + model: self.create_schema, # type: ignore + db: Session = Depends(self.db_func), + ) -> Model: + try: + db_model: Model = self.db_model(**model.dict()) + db.add(db_model) + await db.commit() + await db.refresh(db_model) + return db_model + except IntegrityError: + await db.rollback() + raise HTTPException(422, "Key already exists") from None + + if self.use_async: + return async_route + else: + return route def _update(self, *args: Any, **kwargs: Any) -> CALLABLE: def route( @@ -135,7 +210,30 @@ def route( db.rollback() self._raise(e) - return route + async def async_route( + item_id: self._pk_type, # type: ignore + model: self.update_schema, # type: ignore + db: Session = Depends(self.db_func), + ) -> Model: + try: + db_model: Model = await self._get_one()(item_id, db) + + for key, value in model.dict(exclude={self._pk}).items(): + if hasattr(db_model, key): + setattr(db_model, key, value) + + await db.commit() + await db.refresh(db_model) + + return db_model + except IntegrityError as e: + await db.rollback() + self._raise(e) + + if self.use_async: + return async_route + else: + return route def _delete_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: def route(db: Session = Depends(self.db_func)) -> List[Model]: @@ -144,7 +242,15 @@ def route(db: Session = Depends(self.db_func)) -> List[Model]: return self._get_all()(db=db, pagination={"skip": 0, "limit": None}) - return route + async def async_route(db: Session = Depends(self.db_func)) -> List[Model]: + await db.execute("delete from " + self.db_model.__tablename__) + await db.commit() + return await self._get_all()(db=db, pagination={"skip": 0, "limit": None}) + + if self.use_async: + return async_route + else: + return route def _delete_one(self, *args: Any, **kwargs: Any) -> CALLABLE: def route( @@ -156,4 +262,16 @@ def route( return db_model - return route + async def async_route( + item_id: self._pk_type, db: Session = Depends(self.db_func) # type: ignore + ) -> Model: + db_model: Model = await self._get_one()(item_id, db) + await db.delete(db_model) + await db.commit() + + return db_model + + if self.use_async: + return async_route + else: + return route