diff --git a/src/db/dtobase.py b/src/db/dtobase.py index 0abd18d..9592030 100644 --- a/src/db/dtobase.py +++ b/src/db/dtobase.py @@ -7,7 +7,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, HSTORE, INET, JSON, JSONB, MACADDR from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.mutable import Mutable -from sqlalchemy.orm import undefer +from sqlalchemy.orm import InstrumentedAttribute, undefer from sqlalchemy.sql.base import ExecutableOption from src._types import Order, QueryParams @@ -46,7 +46,7 @@ def register_table_params(table_name: str, params: InspectorTableConstraint) -> async def inspect_table(table_name: str) -> InspectorTableConstraint: """Reflect table schema to inspect unique constraints and many-to-one fks and cache in memory""" - if result := TABLE_PARAMS.get(table_name): + if result := TABLE_PARAMS.get(table_name, None): # noqa: PGH003 # type: ignore return result async with async_engine.connect() as conn: result: InspectorTableConstraint = {"unique_constraints": [], "foreign_keys": {}} @@ -84,7 +84,6 @@ def __init__(self, model: type[ModelT], undefer_load: bool = False) -> None: """ self.model = model self.undefer_load = undefer_load - self.id_attr = self.get_id_attribute_value(self.model) @overload @classmethod @@ -93,11 +92,23 @@ def get_id_attribute_value(cls, obj: ModelT, id_attribute: str | None = None) -> @overload @classmethod - def get_id_attribute_value(cls, obj: type[ModelT], id_attribute: str | None = None) -> str: + def get_id_attribute_value(cls, obj: type[ModelT], id_attribute: str | None = None) -> InstrumentedAttribute[PkIdT]: ... @classmethod - def get_id_attribute_value(cls, obj: ModelT | type[ModelT], id_attribute: str | None = None) -> str | PkIdT: + def get_id_attribute_value( + cls, obj: ModelT | type[ModelT], id_attribute: str | None = None + ) -> InstrumentedAttribute[PkIdT] | PkIdT: + """Get value of attribute named as :attr:`id_attribute` on ``obj``. + + Args: + item: Anything that should have an attribute named as :attr:`id_attribute` value. + id_attribute: Allows customization of the unique identifier to use for model fetching. + Defaults to `None`, but can reference any surrogate or candidate key for the table. + + Returns: + The value of attribute on ``obj`` named as :attr:`id_attribute `. + """ return getattr(obj, id_attribute if id_attribute is not None else cls.id_attribute) def _get_base_stmt(self) -> Select[tuple[ModelT]]: @@ -383,8 +394,9 @@ async def _check_unique_constraints( None: This function does not return anything. """ stmt = self._get_base_count_stmt() + id_str = self.get_id_attribute_value(self.model) if pk_id: - stmt = stmt.where(getattr(self.model, self.id_attr) != pk_id) + stmt = stmt.where(id_str != pk_id) for key, value in uq.items(): if isinstance(value, bool): @@ -534,6 +546,8 @@ async def create( session: AsyncSession, obj_in: CreateSchemaType, excludes: set[str] | None = None, + exclude_unset: bool = False, + exclude_none: bool = False, commit: bool | None = True, ) -> ModelT: """ @@ -554,7 +568,9 @@ async def create( insp = await inspect_table(self.model.__tablename__) await self._apply_foreign_keys_check(session, obj_in, insp) await self._apply_unique_constraints_when_create(session, obj_in, insp) - new_obj = self.model(**obj_in.model_dump(exclude_unset=True, exclude=excludes)) + new_obj = self.model( + **obj_in.model_dump(exclude_unset=exclude_unset, exclude_none=exclude_none, exclude=excludes) + ) if commit: return await self.commit(session, new_obj) return new_obj @@ -646,6 +662,8 @@ async def get_one_or_404(self, session: AsyncSession, pk_id: PkIdT, *options: Ex NotFoundError: If no instance with the given primary key (pk_id) is found in the database. """ stmt = self._get_base_stmt() + id_str = self.get_id_attribute_value(self.model) + stmt = stmt.where(id_str == pk_id) if options: stmt = self._apply_selectinload(stmt, *options) result = (await session.scalars(stmt)).one_or_none() @@ -668,6 +686,7 @@ async def get_none_or_409(self, session: AsyncSession, field: str, value: Any) - Raises: ExistError: If a record with the given value already exists in the database. """ + id_str = self.get_id_attribute_value(self.model) stmt = ( self._get_base_stmt() .where( @@ -675,7 +694,7 @@ async def get_none_or_409(self, session: AsyncSession, field: str, value: Any) - if isinstance(value, bool) or value is None else getattr(self.model, field) == value ) - .with_only_columns(getattr(self.model, self.id_attr)) + .with_only_columns(func.count(id_str), maintain_column_forms=True) ) result = (await session.execute(stmt)).one_or_none() @@ -705,6 +724,39 @@ async def get_by_filters( stmt = self._apply_selectinload(stmt, *options) return (await session.scalars(stmt)).all() + async def get_one_by_filter( + self, session: AsyncSession, filters: dict[str, Any], *options: ExecutableOption + ) -> ModelT | None: + stmt = self._get_base_stmt() + stmt = self._apply_filter(stmt=stmt, filters=filters) + stmt = self._apply_selectinload(stmt, *options) + return (await session.scalars(stmt)).one_or_none() + + async def get_multi_by_filter( + self, session: AsyncSession, filters: dict[str, Any], *options: ExecutableOption + ) -> Sequence[ModelT]: + stmt = self._get_base_stmt() + stmt = self._apply_filter(stmt=stmt, filters=filters) + stmt = self._apply_selectinload(stmt, *options) + return (await session.scalars(stmt)).all() + + async def get_multi_or_404( + self, session: AsyncSession, pk_ids: list[PkIdT], *options: ExecutableOption + ) -> Sequence[ModelT]: + stmt = self._get_base_stmt() + id_str = self.get_id_attribute_value(self.model) + stmt = stmt.where(id_str.in_(pk_ids)) + if options: + stmt = self._apply_selectinload(stmt, *options) + results = (await session.scalars(stmt)).all() + if not results: + raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids) + for r in results: + id_value = self.get_id_attribute_value(r) + if id_value not in pk_ids: + raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids) + return results + async def commit(self, session: AsyncSession, obj: ModelT) -> ModelT: """ Commits the changes made in the session and refreshes the given object.