Skip to content

Commit

Permalink
feat(dto): update dto
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jan 19, 2024
1 parent 64f57eb commit 6fa76b7
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/db/dtobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,15 @@ def _apply_list(
stmt = self._apply_filter(stmt, filters)
return stmt

def _check_not_found(self, instance: ModelT | Row[Any] | None, column: str, value: Any) -> None:
@overload
def _check_not_found(self, instance: ModelT | None, column: str, value: Any) -> ModelT:
...

@overload
def _check_not_found(self, instance: Row[Any] | None, column: str, value: Any) -> Row[Any]:
...

def _check_not_found(self, instance: ModelT | Row[Any] | None, column: str, value: Any) -> ModelT | Row[Any]:
"""
Check if the given instance is not found in the specified table.
Expand All @@ -297,6 +305,7 @@ def _check_not_found(self, instance: ModelT | Row[Any] | None, column: str, valu
"""
if not instance:
raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], column, value)
return instance

def _check_exist(self, instance: ModelT | Row[Any] | None, column: str, value: Any) -> None:
"""
Expand Down Expand Up @@ -757,6 +766,24 @@ async def get_multi_or_404(
raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids)
return results

async def get_one_and_delete(self, session: AsyncSession, pk_id: PkIdT) -> None:
stmt = self._get_base_stmt()
id_str = self.get_id_attribute_value(self.model)
result = (await session.scalars(stmt.where(id_str == pk_id))).one_or_none()
result = self._check_not_found(result, self.id_attribute, pk_id)
await self.delete(session, result)

async def get_multi_and_delete(self, session: AsyncSession, pk_ids: list[PkIdT]) -> None:
stmt = self._get_base_stmt()
id_str = self.get_id_attribute_value(self.model)
results = (await session.scalars(stmt.where(id_str.in_(pk_ids)))).all()
for r in results:
id_value = self.get_id_attribute_value(r)
if id_value not in pk_ids:
raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, id_value)
await session.delete(r)
await session.commit()

async def commit(self, session: AsyncSession, obj: ModelT) -> ModelT:
"""
Commits the changes made in the session and refreshes the given object.
Expand Down

0 comments on commit 6fa76b7

Please sign in to comment.