Skip to content

Commit

Permalink
fix(dto): fix dto id issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jan 19, 2024
1 parent a7de740 commit 79ff3a0
Showing 1 changed file with 60 additions and 8 deletions.
68 changes: 60 additions & 8 deletions src/db/dtobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.dialects.postgresql import ARRAY, HSTORE, INET, JSON, JSONB, MACADDR
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.orm import undefer
from sqlalchemy.orm import InstrumentedAttribute, undefer
from sqlalchemy.sql.base import ExecutableOption

from src._types import Order, QueryParams
Expand Down Expand Up @@ -46,7 +46,7 @@ def register_table_params(table_name: str, params: InspectorTableConstraint) ->

async def inspect_table(table_name: str) -> InspectorTableConstraint:
"""Reflect table schema to inspect unique constraints and many-to-one fks and cache in memory"""
if result := TABLE_PARAMS.get(table_name):
if result := TABLE_PARAMS.get(table_name, None): # noqa: PGH003 # type: ignore
return result
async with async_engine.connect() as conn:
result: InspectorTableConstraint = {"unique_constraints": [], "foreign_keys": {}}
Expand Down Expand Up @@ -84,7 +84,6 @@ def __init__(self, model: type[ModelT], undefer_load: bool = False) -> None:
"""
self.model = model
self.undefer_load = undefer_load
self.id_attr = self.get_id_attribute_value(self.model)

@overload
@classmethod
Expand All @@ -93,11 +92,23 @@ def get_id_attribute_value(cls, obj: ModelT, id_attribute: str | None = None) ->

@overload
@classmethod
def get_id_attribute_value(cls, obj: type[ModelT], id_attribute: str | None = None) -> str:
def get_id_attribute_value(cls, obj: type[ModelT], id_attribute: str | None = None) -> InstrumentedAttribute[PkIdT]:
...

@classmethod
def get_id_attribute_value(cls, obj: ModelT | type[ModelT], id_attribute: str | None = None) -> str | PkIdT:
def get_id_attribute_value(
cls, obj: ModelT | type[ModelT], id_attribute: str | None = None
) -> InstrumentedAttribute[PkIdT] | PkIdT:
"""Get value of attribute named as :attr:`id_attribute` on ``obj``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``obj`` named as :attr:`id_attribute <AbstractAsyncRepository.id_attribute>`.
"""
return getattr(obj, id_attribute if id_attribute is not None else cls.id_attribute)

def _get_base_stmt(self) -> Select[tuple[ModelT]]:
Expand Down Expand Up @@ -383,8 +394,9 @@ async def _check_unique_constraints(
None: This function does not return anything.
"""
stmt = self._get_base_count_stmt()
id_str = self.get_id_attribute_value(self.model)
if pk_id:
stmt = stmt.where(getattr(self.model, self.id_attr) != pk_id)
stmt = stmt.where(id_str != pk_id)

for key, value in uq.items():
if isinstance(value, bool):
Expand Down Expand Up @@ -534,6 +546,8 @@ async def create(
session: AsyncSession,
obj_in: CreateSchemaType,
excludes: set[str] | None = None,
exclude_unset: bool = False,
exclude_none: bool = False,
commit: bool | None = True,
) -> ModelT:
"""
Expand All @@ -554,7 +568,9 @@ async def create(
insp = await inspect_table(self.model.__tablename__)
await self._apply_foreign_keys_check(session, obj_in, insp)
await self._apply_unique_constraints_when_create(session, obj_in, insp)
new_obj = self.model(**obj_in.model_dump(exclude_unset=True, exclude=excludes))
new_obj = self.model(
**obj_in.model_dump(exclude_unset=exclude_unset, exclude_none=exclude_none, exclude=excludes)
)
if commit:
return await self.commit(session, new_obj)
return new_obj
Expand Down Expand Up @@ -646,6 +662,8 @@ async def get_one_or_404(self, session: AsyncSession, pk_id: PkIdT, *options: Ex
NotFoundError: If no instance with the given primary key (pk_id) is found in the database.
"""
stmt = self._get_base_stmt()
id_str = self.get_id_attribute_value(self.model)
stmt = stmt.where(id_str == pk_id)
if options:
stmt = self._apply_selectinload(stmt, *options)
result = (await session.scalars(stmt)).one_or_none()
Expand All @@ -668,14 +686,15 @@ async def get_none_or_409(self, session: AsyncSession, field: str, value: Any) -
Raises:
ExistError: If a record with the given value already exists in the database.
"""
id_str = self.get_id_attribute_value(self.model)
stmt = (
self._get_base_stmt()
.where(
getattr(self.model, field).is_(value)
if isinstance(value, bool) or value is None
else getattr(self.model, field) == value
)
.with_only_columns(getattr(self.model, self.id_attr))
.with_only_columns(func.count(id_str), maintain_column_forms=True)
)

result = (await session.execute(stmt)).one_or_none()
Expand Down Expand Up @@ -705,6 +724,39 @@ async def get_by_filters(
stmt = self._apply_selectinload(stmt, *options)
return (await session.scalars(stmt)).all()

async def get_one_by_filter(
self, session: AsyncSession, filters: dict[str, Any], *options: ExecutableOption
) -> ModelT | None:
stmt = self._get_base_stmt()
stmt = self._apply_filter(stmt=stmt, filters=filters)
stmt = self._apply_selectinload(stmt, *options)
return (await session.scalars(stmt)).one_or_none()

async def get_multi_by_filter(
self, session: AsyncSession, filters: dict[str, Any], *options: ExecutableOption
) -> Sequence[ModelT]:
stmt = self._get_base_stmt()
stmt = self._apply_filter(stmt=stmt, filters=filters)
stmt = self._apply_selectinload(stmt, *options)
return (await session.scalars(stmt)).all()

async def get_multi_or_404(
self, session: AsyncSession, pk_ids: list[PkIdT], *options: ExecutableOption
) -> Sequence[ModelT]:
stmt = self._get_base_stmt()
id_str = self.get_id_attribute_value(self.model)
stmt = stmt.where(id_str.in_(pk_ids))
if options:
stmt = self._apply_selectinload(stmt, *options)
results = (await session.scalars(stmt)).all()
if not results:
raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids)
for r in results:
id_value = self.get_id_attribute_value(r)
if id_value not in pk_ids:
raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids)
return results

async def commit(self, session: AsyncSession, obj: ModelT) -> ModelT:
"""
Commits the changes made in the session and refreshes the given object.
Expand Down

0 comments on commit 79ff3a0

Please sign in to comment.