diff --git a/pre-commit.sh b/pre-commit.sh index 30653a5..0993843 100755 --- a/pre-commit.sh +++ b/pre-commit.sh @@ -3,4 +3,4 @@ set -ex rye run black src rye run black tests rye run ruff src --fix -rye run ruff src --fix +rye run ruff tests --fix diff --git a/pyproject.toml b/pyproject.toml index 67da5fc..cf0e50e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dev-dependencies = [ "pytest-cov>=4.1.0", "pytest-asyncio>=0.23.3", "black>=23.12.1", + "mypy>=1.8.0", ] [tool.hatch.metadata] @@ -52,7 +53,7 @@ target-version = "py311" [tool.ruff.lint] select = ["ALL"] -ignore = ["D", "G002", "DTZ003", "ANN401", "ANN101", "ANN102", "EM101", "PD901", "COM812", "ISC001", "FBT001"] +ignore = ["D", "G002", "DTZ003", "ANN401", "ANN101", "ANN102", "EM101", "PD901", "COM812", "ISC001", "FBT"] fixable = ["ALL"] @@ -60,6 +61,7 @@ fixable = ["ALL"] "env.py" = ["INP001", "I001", "ERA001"] "tests/*.py" = ["S101"] "models.py" = ["A003"] +"mixins.py" = ["A003"] "exception_handlers.py" = ["ARG001"] [tool.ruff.flake8-bugbear] diff --git a/requirements-dev.lock b/requirements-dev.lock index c567660..b6366cf 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,7 +10,6 @@ alembic==1.13.1 annotated-types==0.6.0 anyio==4.2.0 -babel==2.14.0 black==23.12.1 certifi==2023.11.17 cfgv==3.4.0 @@ -27,6 +26,7 @@ idna==3.6 iniconfig==2.0.0 mako==1.3.0 markupsafe==2.1.3 +mypy==1.8.0 mypy-extensions==1.0.0 nodeenv==1.8.0 numpy==1.26.3 diff --git a/requirements.lock b/requirements.lock index 49ffef4..54d5ee2 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,7 +10,6 @@ alembic==1.13.1 annotated-types==0.6.0 anyio==4.2.0 -babel==2.14.0 certifi==2023.11.17 click==8.1.7 fastapi==0.108.0 diff --git a/src/_types.py b/src/_types.py index 011a409..4fd4dfd 100644 --- a/src/_types.py +++ b/src/_types.py @@ -1,6 +1,5 @@ from enum import Enum -from typing import Annotated, Generic, Literal, ParamSpec, TypeVar -from uuid import UUID +from typing import Annotated, Generic, Literal, ParamSpec, TypeAlias, TypeVar import pydantic from fastapi import Query @@ -12,6 +11,8 @@ T = TypeVar("T") P = ParamSpec("P") +Order: TypeAlias = Literal["descend", "ascend"] + StrList = Annotated[str | list[str], BeforeValidator(items_to_list)] IntList = Annotated[int | list[int], BeforeValidator(items_to_list)] MacAddress = Annotated[str, BeforeValidator(mac_address_validator)] @@ -53,17 +54,17 @@ class QueryParams(BaseModel): limit: int | None = Query(default=20, ge=0, le=1000, description="Number of results to return per request.") offset: int | None = Query(default=0, ge=0, description="The initial index from which return the results.") q: str | None = Query(default=None, description="Search for results.") - id: list[UUID] | None = Field(Query(default=[], description="request object unique ID")) + id: list[int] | None = Field(Query(default=[], description="request object unique ID")) order_by: str | None = Query(default=None, description="Which field to use when order the results") - order: Literal["descend", "ascend"] | None = Query(default="ascend", description="Order by dscend or ascend") + order: Order | None = Query(default="ascend", description="Order by dscend or ascend") class BatchDelete(BaseModel): - ids: list[UUID] + ids: list[int] class BatchUpdate(BaseModel): - ids: list[UUID] + ids: list[int] class I18nField(BaseModel): diff --git a/src/auth/models.py b/src/auth/models.py index b816163..2cfea7b 100644 --- a/src/auth/models.py +++ b/src/auth/models.py @@ -1,6 +1,5 @@ from datetime import datetime from typing import ClassVar -from uuid import UUID from sqlalchemy import DateTime, ForeignKey from sqlalchemy.dialects.postgresql import JSON @@ -13,22 +12,21 @@ class RolePermission(Base): __tablename__ = "role_permission" - __multi_tenant__: ClassVar = False - role_id: Mapped[UUID] = mapped_column(ForeignKey("role.id"), primary_key=True) - permission_id: Mapped[UUID] = mapped_column(ForeignKey("permission.id"), primary_key=True) + role_id: Mapped[int] = mapped_column(ForeignKey("role.id"), primary_key=True) + permission_id: Mapped[int] = mapped_column(ForeignKey("permission.id"), primary_key=True) class Role(Base): __tablename__ = "role" __search_fields__: ClassVar = {"name"} - id: Mapped[_types.uuid_pk] + id: Mapped[_types.int_pk] name: Mapped[str] permission: Mapped[list["Permission"]] = relationship(secondary=RolePermission, backref="role") class Permission(Base): __tablename__ = "permission" - id: Mapped[_types.uuid_pk] + id: Mapped[_types.int_pk] name: Mapped[str] url: Mapped[str] method: Mapped[str] @@ -38,24 +36,24 @@ class Permission(Base): class Group(Base): __tablename__ = "group" __search_fields__: ClassVar = {"name"} - id: Mapped[_types.uuid_pk] + id: Mapped[_types.int_pk] name: Mapped[str] - role_id: Mapped[UUID] = mapped_column(ForeignKey(Role.id, ondelete="CASCADE")) + role_id: Mapped[int] = mapped_column(ForeignKey(Role.id, ondelete="CASCADE")) role: Mapped["Role"] = relationship(backref="group", passive_deletes=True) class User(Base): __tablename__ = "user" __search_fields__: ClassVar = {"email", "name", "phone"} - id: Mapped[_types.uuid_pk] + id: Mapped[_types.int_pk] name: Mapped[str] email: Mapped[str | None] = mapped_column(unique=True) phone: Mapped[str | None] = mapped_column(unique=True) password: Mapped[str] avatar: Mapped[str | None] last_login: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) - group_id: Mapped[UUID] = mapped_column(ForeignKey(Group.id, ondelete="CASCADE")) + group_id: Mapped[int] = mapped_column(ForeignKey(Group.id, ondelete="CASCADE")) group: Mapped["Group"] = relationship(backref="user", passive_deletes=True) - role_id: Mapped[UUID] = mapped_column(ForeignKey(Role.id, ondelete="CASCADE")) + role_id: Mapped[int] = mapped_column(ForeignKey(Role.id, ondelete="CASCADE")) role: Mapped["Role"] = relationship(backref="user", passive_deletes=True) auth_info: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) diff --git a/src/config.py b/src/config.py index 90c3356..c648e56 100644 --- a/src/config.py +++ b/src/config.py @@ -39,4 +39,4 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=f"{PROJECT_DIR}/.env", case_sensitive=True, extra="allow") -settings = Settings() +settings = Settings() # type: ignore # noqa: PGH003 diff --git a/src/context.py b/src/context.py index fe21bdd..bc0f20b 100644 --- a/src/context.py +++ b/src/context.py @@ -2,7 +2,6 @@ from uuid import UUID request_id_ctx = ContextVar[str | None] = ContextVar("x-request-id", default=None) -tenant_ctx = ContextVar[UUID | None] = ContextVar("x-tenant-id", default=None) auth_user_ctx = ContextVar[UUID | None] = ContextVar("x-auth-user", default=None) locale_ctx = ContextVar[UUID | None] = ContextVar("Accept-Language", default="en") orm_diff_ctx = ContextVar[dict | None] = ContextVar("x-orm-diff", default=None) diff --git a/src/db/__init__.py b/src/db/__init__.py index e69de29..6f501d4 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -0,0 +1,9 @@ +from src.auth.models import Group, Permission, Role, RolePermission, User # noqa: F401 +from src.db.base import Base + + +def orm_by_table_name(table_name: str) -> type[Base] | None: + for m in Base.registry.mappers: + if getattr(m.class_, "__tablename__", None) == table_name: + return m.class_ + return None diff --git a/src/db/_types.py b/src/db/_types.py index bdae666..bd367db 100644 --- a/src/db/_types.py +++ b/src/db/_types.py @@ -2,7 +2,7 @@ from datetime import date, datetime from typing import Annotated, Any -from sqlalchemy import CHAR, Boolean, Date, DateTime, Dialect, String, func, type_coerce +from sqlalchemy import CHAR, Boolean, Date, DateTime, Dialect, Integer, String, func, type_coerce from sqlalchemy.dialects.postgresql import BYTEA, UUID from sqlalchemy.orm import mapped_column from sqlalchemy.sql import expression @@ -63,7 +63,7 @@ def process_result_value(self, value: Any, dialect: Dialect) -> UUID | None: # return value -uuid_pk = Annotated[uuid.UUID, mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)] +int_pk = Annotated[int, mapped_column(Integer, primary_key=True)] bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true())] bool_false = Annotated[bool, mapped_column(Boolean, server_default=expression.false())] datetime_required = Annotated[datetime, mapped_column(DateTime(timezone=True))] diff --git a/src/db/base.py b/src/db/base.py index 4eb4357..0195b68 100644 --- a/src/db/base.py +++ b/src/db/base.py @@ -1,21 +1,19 @@ -from typing import ClassVar +from typing import Any, ClassVar from uuid import UUID -from sqlalchemy import ForeignKey -from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column +from fastapi.encoders import jsonable_encoder +from sqlalchemy.orm import DeclarativeBase -from src.context import tenant_ctx from src.db._types import GUID class Base(DeclarativeBase): - __multi_tenant__: bool = True + __visible_name__: ClassVar = {} __search_fields__: ClassVar = set() type_annotation_map: ClassVar = {UUID: GUID} - @declared_attr - @classmethod - def tenant_id(cls) -> Mapped[UUID]: - if not cls.__multi_tenant__: - return None - return mapped_column(UUID, ForeignKey("tenant.id", ondelete="CASCADE"), index=True, default=tenant_ctx.get) + def dict(self, exclude: set[str] | None = None, native_dict: bool = False) -> dict[str, Any]: + """Return dict representation of model.""" + if not native_dict: + return jsonable_encoder(self, exclude=exclude) + return {c.name: getattr(self, c.name) for c in self.__table__.columns if c.name not in exclude} diff --git a/src/db/crud.py b/src/db/crud.py index 68c8403..322ba19 100644 --- a/src/db/crud.py +++ b/src/db/crud.py @@ -1,5 +1,350 @@ -from typing import TypeVar +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypedDict, TypeVar +from pydantic import BaseModel +from sqlalchemy import Row, Select, Text, cast, desc, func, inspect, not_, or_, select, text +from sqlalchemy.dialects.postgresql import ARRAY, HSTORE, INET, JSON, JSONB, MACADDR +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.mutable import Mutable +from sqlalchemy.orm import undefer + +from src._types import Order, QueryParams +from src.context import locale_ctx from src.db.base import Base +from src.db.session import async_engine +from src.exceptions import ExistError, NotFoundError + +if TYPE_CHECKING: + from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint, ReflectedUniqueConstraint + +ModelT = TypeVar("ModelT", bound=Base) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +QuerySchemaType = TypeVar("QuerySchemaType", bound=QueryParams) + +TABLE_PARAMS: dict[str, "InspectorTableConstraint"] = {} + + +class OrmField(NamedTuple): + field: str + value: Any + + +class InspectorTableConstraint(TypedDict, total=False): + foreign_keys: dict[str, tuple[str, str]] + unique_constraints: list[list[str]] + + +def register_table_params(table_name: str, params: InspectorTableConstraint) -> None: + if not TABLE_PARAMS.get(table_name): + TABLE_PARAMS[table_name] = params + + +async def inspect_table(table_name: str) -> dict[str, InspectorTableConstraint]: + """Reflect table schema to inspect unique constraints and many-to-one fks and cache in memory""" + if result := TABLE_PARAMS.get(table_name): + return result + async with async_engine.connect() as conn: + result: InspectorTableConstraint = {"unique_constraints": [], "foreign_keys": {}} + uq: list[ReflectedUniqueConstraint] = await conn.run_sync( + lambda sync_conn: inspect(sync_conn).get_unique_constraints(table_name=table_name) + ) + if uq: + result["unique_constraints"] = [_uq["column_names"] for _uq in uq] + fk: list[ReflectedForeignKeyConstraint] = await conn.run_sync( + lambda sync_conn: inspect(sync_conn).get_foreign_keys(table_name=table_name) + ) + if fk: + for _fk in fk: + fk_name = _fk["constrained_columns"][0] + referred_table = _fk["referred_table"] + referred_column = _fk["referred_columns"][0] + result["foreign_keys"][fk_name] = (referred_table, referred_column) + register_table_params(table_name=table_name, params=result) + return result + + +class DtoBase(Generic[ModelT, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): + def __init__(self, model: type[ModelT], undefer_load: bool = False) -> None: + self.model = model + self.undefer_load = undefer_load + + def _get_base_stmt(self) -> Select[tuple[ModelT]]: + """Get base select statement of query""" + return select(self.model) + + def _get_base_count_stmt(self) -> Select[tuple[ModelT]]: + return select(func.count()).select_from(self.model) + + def _apply_search(self, stmt: Select[tuple[ModelT]], value: str, ignore_case: bool) -> Select[tuple[ModelT]]: + where_clauses = [] + search_text = f"%{value}%" + if self.model.__search_fields__: + for field in self.model.__search_fields__: + _t = getattr(self.model, field).type + if type(_t) in (HSTORE, JSON, JSONB, INET, MACADDR, ARRAY): + if ignore_case: + where_clauses.append(cast(getattr(self.model, field), Text).ilike(search_text)) + else: + where_clauses.append(cast(getattr(self.model, field), Text).like(search_text)) + if where_clauses: + return stmt.where(or_[False, *where_clauses]) + return stmt + + def _apply_order_by(self, stmt: Select[tuple[ModelT]], order_by: str, order: Order) -> Select[tuple[ModelT]]: + if order == "ascend": + return stmt.order_by(desc(getattr(self.model, order_by))) + return stmt.order_by(getattr(self.model, order_by)) + + def _apply_pagination( + self, stmt: Select[tuple[ModelT]], limit: int | None = 20, offset: int | None = 0 + ) -> Select[tuple[ModelT]]: + return stmt.slice(offset, limit + offset) + + def _apply_operator_filter(self, stmt: Select[tuple[ModelT]], key: str, value: Any) -> Select[tuple[ModelT]]: + operators = { + "eq": lambda col, value: col.in_(value if isinstance(value, list) else [value]), + "ne": lambda col, value: ~col.in_(value if isinstance(value, list) else [value]), + "ic": lambda col, value: col.ilike(f"%{value}%"), + "nic": lambda col, value: not_(col.ilike(f"%{value}%")), + "le": lambda col, value: col < value, + "ge": lambda col, value: col > value, + "lte": lambda col, value: col <= value, + "gte": lambda col, value: col >= value, + } + filed_name, operator = key.split("__") + if not hasattr(self.model, filed_name): + return stmt + if operator_func := operators.get(operator, None): + col = getattr(self.model, filed_name) + return stmt.filter(operator_func(col, value)) + return stmt + + def _apply_filter(self, stmt: Select[tuple[ModelT]], filters: dict[str, Any]) -> Select[tuple[ModelT]]: + for key, value in filters.items(): + if "__" in key: + stmt = self._apply_operator_filter(stmt, key, value) + elif isinstance(value, bool): + stmt = stmt.where(getattr(self.model, key).is_(value)) + elif isinstance(value, list): + if value: + if key == "name" and type(getattr(self.model, key).type) is HSTORE: + stmt = stmt.where(or_(self.model.name["zh_CN"].in_(value), self.model.name["en_US"].in_(value))) + else: + stmt = stmt.where(getattr(self.model, key).in_(value)) + else: + stmt = stmt.where(getattr(self.model, key).in_(value)) + elif not value: + stmt = stmt.where(getattr(self.model, key).is_(None)) + else: + stmt = stmt.where(getattr(self.model, key) == value) + return stmt + + def _apply_selectinload(self, stmt: Select[tuple[ModelT]], options: tuple | None = None) -> Select[tuple[ModelT]]: + if options: + stmt = stmt.options(*options) + if self.undefer_load: + stmt = stmt.options(undefer("*")) + return stmt + + def _apply_list( + self, stmt: Select[tuple[ModelT]], query: QuerySchemaType, excludes: set[str] | None = None + ) -> Select[tuple[ModelT]]: + _excludes = {"limit", "offset", "q", "order", "order_by"} + if excludes: + _excludes.update(excludes) + filters = query.model_dump(exclude=_excludes, exclude_unset=True) + if filters: + stmt = self._apply_filter(stmt, filters) + return stmt + + @staticmethod + def _check_not_found(instance: ModelT | Row[Any] | None, table_name: str, column: str, value: Any) -> None: + if not instance: + raise NotFoundError(table_name, column, value) + + @staticmethod + def _check_exist(instance: ModelT | None, table_name: str, column: str, value: Any) -> None: + if instance: + raise ExistError(table_name, column, value) + + @staticmethod + def _update_mutable_tracking(update_schema: UpdateSchemaType, obj: ModelT, excludes: set[str]) -> ModelT: + for key, value in update_schema.model_dump(exclude_unset=True, exclude=excludes).items(): + if issubclass(type(getattr(obj, key)), Mutable): + field_value = getattr(obj, key).copy() + if isinstance(value, list | dict): + if isinstance(value, list): + setattr(obj, key, value) + else: + for k, v in value.items(): + field_value[k] = v + setattr(obj, key, field_value) + else: + setattr(obj, key, value) + return obj + + async def _check_unique_constraints( + self, + session: AsyncSession, + uq: dict[str, Any], + pk_id: int | None = None, + ) -> None: + stmt = self._get_base_count_stmt() + if pk_id: + stmt = stmt.where(self.model.id != pk_id) + for key, value in uq.items(): + if isinstance(value, bool): + stmt = stmt.where(getattr(self.model, key).is_(value)) + else: + stmt.where(getattr(self.model, key) == value) + result = await session.scalar(stmt) + if result > 0: + keys = ",".join(list[uq.keys()]) + values = ",".join([f"{key}-{value}" for key, value in uq.items()]) + raise ExistError(self.model.__visible_name__[locale_ctx.get()], keys, values) + + async def _apply_unique_constraints_when_create( + self, + session: AsyncSession, + record: CreateSchemaType, + inspections: InspectorTableConstraint, + ) -> None: + """Apply unique constraints of given object in database. + + Args: + session (AsyncSession): sqla session + record (CreateSchemaType) + inspections (InspectorTableConstraint) + """ + uniq_args = inspections.get("unique_constraints") + if not uniq_args: + return + record_dict = record.model_dump(exclude_unset=True) + for arg in uniq_args: + uq: dict[str, Any] = {} + for column in arg: + if column in record_dict: + if record_dict.get(column): + uq[column] = record_dict[column] + else: + uq = {} + break + else: + uq = {} + break + if uq: + await self._check_unique_constraints(session, uq) + + async def _apply_unique_constraints_when_update( + self, session: AsyncSession, record: UpdateSchemaType, inspections: InspectorTableConstraint, obj: ModelT + ) -> None: + uniq_args = inspections.get("unique_constraints") + if uniq_args: + record_dict = record.model_dump(exclude_unset=True) + for arg in uniq_args: + uq: dict[str, Any] = {} + for column in arg: + if column in record_dict: + if any([value := record_dict.get(column), value := getattr(obj, column)]): + uq[column] = value + else: + uq = {} + break + elif value := getattr(obj, column): + uq[column] = value + else: + uq = {} + break + if uq: + await self._check_unique_constraints(session, uq, obj.id) + + async def _apply_foreign_keys_check( + self, session: AsyncSession, record: CreateSchemaType | UpdateSchemaType, inspections: InspectorTableConstraint + ) -> None: + fk_args = inspections.get("foreign_keys") + if not fk_args: + return + record_dict = record.model_dump() + for fk_name, relation in fk_args.items(): + if value := record_dict.get(fk_name): + table_name, column = relation + stmt_text = f"SELECT 1 FROM {table_name} WHERE {column}='{value}'" # noqa: S608 + fk_result = (await session.execute(text(stmt_text))).one_or_none() + self._check_not_found(fk_result, table_name, column, value) + + async def list_and_count( + self, session: AsyncSession, query: QuerySchemaType, options: tuple | None = None + ) -> tuple[int, Sequence[ModelT]]: + stmt = self._get_base_stmt() + c_stmt = self._get_base_count_stmt() + stmt = self._apply_list(stmt, query) + c_stmt = self._apply_list(c_stmt, query) + if query.q: + stmt = self._apply_search(stmt, query.q) + c_stmt = self._apply_search(c_stmt, query.q) + if query.limit is not None and query.offset is not None: + stmt = self._apply_pagination(stmt, query.limit, query.offset) + if query.order_by and query.order: + stmt = self._apply_order_by(stmt, query.order_by, query.order) + stmt = self._apply_selectinload(stmt, options, True) + count: int = await session.scalar(c_stmt) # type: ignore # noqa: PGH003 + results = (await session.scalars(stmt)).all() + return count, results + + async def create(self, session: AsyncSession, obj_in: CreateSchemaType, excludes: set[str] | None = None) -> ModelT: + insp = await inspect_table(self.model.__tablename__) + await self._apply_foreign_keys_check(session, obj_in, insp) + await self._apply_unique_constraints_when_create(session, obj_in, insp) + new_obj = self.model(**obj_in.model_dump(exclude_unset=True, exclude=excludes)) + session.add(new_obj) + await session.commit() + await session.flush() + return new_obj + + async def update( + self, session: AsyncSession, db_obj: ModelT, obj_in: UpdateSchemaType, excludes: set | None = None + ) -> ModelT: + insp = await inspect_table(self.model.__tablename__) + await self._apply_foreign_keys_check(session, obj_in, insp) + await self._apply_unique_constraints_when_update(session, obj_in, insp, db_obj) + db_obj = self._update_mutable_tracking(obj_in, db_obj, excludes) + session.add(db_obj) + await session.commit() + return db_obj + + async def update_relationship_field( # noqa: PLR0913 + self, session: AsyncSession, obj: ModelT, m2m_model: type[ModelT], fk_name: str, fk_values: list[int] + ) -> ModelT: + local_fk_values = getattr(obj, fk_name) + local_fk_value_ids = [v.id for v in local_fk_values] + for fk_value in local_fk_values[::-1]: + if fk_value.id not in fk_values: + getattr(obj, fk_name).remove(fk_value) + for fk_value in fk_values: + if fk_value not in local_fk_value_ids: + target_obj = await session.get(m2m_model, fk_value) + if not target_obj: + raise NotFoundError(m2m_model.__visible_name__[locale_ctx.get()], "id", fk_value) + getattr(obj, fk_name).append(target_obj) + return obj + + async def get_one_or_404(self, session: AsyncSession, pk_id: int, options: tuple | None) -> ModelT: + stmt = self._get_base_stmt() + if options: + stmt = self._apply_selectinload(options) + result = (await session.scalars(stmt)).one_or_none() + if not result: + raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], "id", pk_id) + return result -ModelType = TypeVar("ModelType", bound=Base) + async def get_none_or_409(self, session: AsyncSession, field: str, value: Any) -> None: + stmt = self._get_base_stmt() + if isinstance(value, bool) or value is None: + stmt = stmt.where(getattr(self.model).is_(value)) + else: + stmt = stmt.where(getattr(self.model) == value) + stmt.with_only_columns(self.model.id) + result = (await session.execute(stmt)).one_or_none() + if result: + raise ExistError(self.model.__visible_name__[locale_ctx.get()], field, value) diff --git a/src/exception_handlers.py b/src/exception_handlers.py index b59d1f7..7812b04 100644 --- a/src/exception_handlers.py +++ b/src/exception_handlers.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def log_exception(exc: _E, logger_trace_info: bool) -> None: +def log_exception(exc: _E | Exception, logger_trace_info: bool) -> None: ex_type, _tmp, ex_traceback = sys.exc_info trace_back = traceback.format_list(traceback.extract_tb(ex_traceback)[-1:])[-1] logger.warning("ErrorMessage: %s" % str(exc)) @@ -32,7 +32,7 @@ def default_exception_handler(request: Request, exc: Exception) -> JSONResponse: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=BaseResponse( - ERR_500.code, data=jsonable_encoder(str(exc)), message=_(ERR_500.message, request_id_ctx.get()) + code=ERR_500.code, data=jsonable_encoder(str(exc)), message=_(ERR_500.message, request_id_ctx.get()) ), ) diff --git a/src/exceptions.py b/src/exceptions.py index 5e725ed..06c9e1b 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -1,3 +1,19 @@ +from ipaddress import UUID, IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from typing import Any + +from src.context import locale_ctx + + +def error_message_value_handler(value: Any) -> Any: + if isinstance(value, dict) and "en_US" in value: + return value[locale_ctx.get()] + if isinstance(value, IPv4Address | IPv6Address | IPv4Network | IPv6Network | IPv4Interface | IPv6Interface | UUID): + return str(value) + if isinstance(value, list): + return [str(_v) for _v in value] + return value + + class TokenNotProvideError(Exception): ... @@ -15,11 +31,23 @@ class PermissionDenyError(Exception): class NotFoundError(Exception): - ... + def __init__(self, name: str, field: str, value: Any) -> None: + self.name = name + self.field = field + self.value = error_message_value_handler(value) + + def __repr__(self) -> str: + return f"Object:{self.name} with field:{self.field}-value:{self.value} not found." class ExistError(Exception): - ... + def __init__(self, name: str, field: str, value: Any) -> None: + self.name = name + self.field = field + self.value = error_message_value_handler(value) + + def __repr__(self) -> str: + return f"Object:{self.name} with field:{self.field}-value:{self.value} already exist." sentry_ignore_errors = [ diff --git a/src/loggers.py b/src/loggers.py index 4e7fcc0..8e473ef 100644 --- a/src/loggers.py +++ b/src/loggers.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Hashable from logging import LogRecord, setLogRecordFactory from logging.config import dictConfig @@ -48,7 +47,7 @@ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: } -def configure_logger(config: Hashable | None = None) -> None: +def configure_logger(config: dict | None = None) -> None: if config is None: config = LOGGING dictConfig(config)