Skip to content

Commit

Permalink
feat(init): update project
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jan 16, 2024
1 parent 2b098c8 commit cc09194
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ fixable = ["ALL"]
"env.py" = ["INP001", "I001", "ERA001"]
"tests/*.py" = ["S101"]
"exception_handlers.py" = ["ARG001"]
"models.py" = ["RUF012"]
"api.py" = ["A002"]

[tool.ruff.flake8-bugbear]
extend-immutable-calls=[
Expand All @@ -73,6 +75,7 @@ extend-immutable-calls=[
"fastapi.params_functions.Form",
"fastapi.File",
"fastapi.Path",
"fastapi.params.Depends",
]

[tool.black]
Expand Down
7 changes: 6 additions & 1 deletion src/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Generic, Literal, ParamSpec, TypeAlias, TypeVar
from typing import Annotated, Generic, Literal, ParamSpec, TypeAlias, TypedDict, TypeVar

import pydantic
from fastapi import Query
Expand Down Expand Up @@ -77,3 +77,8 @@ class BatchUpdate(BaseModel):
class I18nField(BaseModel):
en_US: str # noqa: N815
zh_CN: str # noqa: N815


class VisibleName(TypedDict, total=True):
en_US: str
zh_CN: str
2 changes: 1 addition & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: # noqa: ARG001
)
app.include_router(router, prefix="/api")
for handler in exception_handlers:
app.add_exception_handler(exc_class_or_status_code=handler["name"], handler=handler[handler])
app.add_exception_handler(exc_class_or_status_code=handler["exception"], handler=handler["handler"])
app.add_middleware(RequestMiddleware)
app.add_middleware(ServerErrorMiddleware, handler=default_exception_handler)
app.add_middleware(
Expand Down
57 changes: 56 additions & 1 deletion src/auth/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,58 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src import errors
from src._types import BaseListResponse, BaseResponse
from src.auth import schemas
from src.auth.models import Group, Role, User
from src.auth.services import user_dto
from src.cbv import cbv
from src.deps import auth, get_session
from src.exceptions import GenerError

router = APIRouter()


@cbv(router)
class UserCBV:
user: User = Depends(auth)
session: AsyncSession = Depends(get_session)

@router.post("/users", operation_id="e0fe80d5-cbe0-4c2c-9eff-57e80ecba522")
async def create_user(self, user: schemas.UserCreate) -> BaseResponse[int]:
new_user = await user_dto.create(self.session, user)
return BaseResponse(data=new_user.id)

@router.get("/users/{id}", operation_id="8057d614-150f-42ee-984c-d0af35796da3")
async def get_user(self, id: int) -> BaseResponse[schemas.UserDetail]:
db_user = await user_dto.get_one_or_404(
self.session,
id,
options=(
selectinload(User.role).load_only(Role.id, Role.name),
selectinload(User.group).load_only(Group.id, Group.name),
),
)
return BaseResponse(data=db_user)

@router.get("/users", operation_id="c5f793b1-7adf-4b4e-a498-732b0fa7d758")
async def get_users(self, query: schemas.UserQuery) -> BaseListResponse[list[schemas.UserDetail]]:
count, results = await user_dto.list_and_count(
self.session,
query,
options=(
selectinload(User.role).load_only(Role.id, Role.name),
selectinload(User.group).load_only(Group.id, Group.name),
),
)
return BaseListResponse(count=count, results=results)

@router.put("/users/{id}", operation_id="2fda2e00-ad86-4296-a1d4-c7f02366b52e")
async def update_user(self, id: int, user: schemas.UserUpdate) -> BaseResponse[int]:
update_user = user.model_dump(exclude_unset=True)
if "password" in update_user and update_user["password"] is None:
raise GenerError(errors.ERR_10006, status_code=status.HTTP_406_NOT_ACCEPTABLE)
db_user = await user_dto.get_one_or_404(self.session, id)
await user_dto.update(self.session, db_user, user)
return BaseResponse(data=id)
10 changes: 8 additions & 2 deletions src/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import ClassVar
from uuid import UUID

from sqlalchemy import DateTime, ForeignKey
from sqlalchemy.dialects.postgresql import JSON
Expand All @@ -14,21 +15,24 @@
class RolePermission(Base):
__tablename__ = "role_permission"
role_id: Mapped[int] = mapped_column(ForeignKey("role.id"), primary_key=True)
permission_id: Mapped[int] = mapped_column(ForeignKey("permission.id"), primary_key=True)
permission_id: Mapped[UUID] = mapped_column(ForeignKey("permission.id"), primary_key=True)


class Role(Base, AuditTimeMixin):
__tablename__ = "role"
__search_fields__: ClassVar = {"name"}
__visible_name__ = {"en_US": "Role", "zh_CN": "用户角色"}
id: Mapped[_types.int_pk]
name: Mapped[str]
slug: Mapped[str]
description: Mapped[str | None]
permission: Mapped[list["Permission"]] = relationship(secondary=RolePermission, backref="role")


class Permission(Base):
__tablename__ = "permission"
id: Mapped[_types.int_pk]
__visible_name__ = {"en_US": "Permission", "zh_CN": "权限"}
id: Mapped[_types.uuid_pk]
name: Mapped[str]
url: Mapped[str]
method: Mapped[str]
Expand All @@ -38,6 +42,7 @@ class Permission(Base):
class Group(Base, AuditTimeMixin):
__tablename__ = "group"
__search_fields__: ClassVar = {"name"}
__visible_name__ = {"en_US": "Group", "zh_CN": "用户组"}
id: Mapped[_types.int_pk]
name: Mapped[str]
description: Mapped[str | None]
Expand All @@ -48,6 +53,7 @@ class Group(Base, AuditTimeMixin):
class User(Base, AuditTimeMixin):
__tablename__ = "user"
__search_fields__: ClassVar = {"email", "name", "phone"}
__visible_name__ = {"en_US": "User", "zh_CN": "用户"}
id: Mapped[_types.int_pk]
name: Mapped[str]
email: Mapped[str | None] = mapped_column(unique=True)
Expand Down
1 change: 1 addition & 0 deletions src/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GroupBrief(BaseModel):

class RoleBase(BaseModel):
name: str
slug: str
description: str


Expand Down
1 change: 1 addition & 0 deletions src/db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def process_result_value(self, value: Any, dialect: Dialect) -> UUID | None: #
return value


uuid_pk = Annotated[UUID, mapped_column(GUID, primary_key=True)]
int_pk = Annotated[int, mapped_column(Integer, primary_key=True)]
bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true())]
bool_false = Annotated[bool, mapped_column(Boolean, server_default=expression.false())]
Expand Down
6 changes: 4 additions & 2 deletions src/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import DeclarativeBase

from src._types import VisibleName
from src.db._types import GUID


class Base(DeclarativeBase):
__visible_name__: ClassVar = {}
__search_fields__: ClassVar = set()
__visible_name__: ClassVar[VisibleName] = {}
__search_fields__: ClassVar[set[str]] = set()
__i18n_files__: ClassVar[set[str]] = set()
type_annotation_map: ClassVar = {UUID: GUID}

def dict(self, exclude: set[str] | None = None, native_dict: bool = False) -> dict[str, Any]:
Expand Down
15 changes: 11 additions & 4 deletions src/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.orm import undefer
from sqlalchemy.sql.base import ExecutableOption

from src._types import Order, QueryParams
from src.context import locale_ctx
Expand Down Expand Up @@ -128,8 +129,10 @@ def _apply_filter(self, stmt: Select[tuple[ModelT]], filters: dict[str, Any]) ->
stmt = stmt.where(getattr(self.model, key).is_(value))
elif isinstance(value, list):
if value:
if key == "name" and type(getattr(self.model, key).type) is HSTORE:
stmt = stmt.where(or_(self.model.name["zh_CN"].in_(value), self.model.name["en_US"].in_(value)))
if key in self.model.__i18n_files__ and type(getattr(self.model, key).type) is HSTORE:
stmt = stmt.where(
or_(getattr(self.model)["zh_CN"].in_(value), getattr(self.model)["zh_CN"].in_(value))
)
else:
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
Expand All @@ -140,7 +143,9 @@ def _apply_filter(self, stmt: Select[tuple[ModelT]], filters: dict[str, Any]) ->
stmt = stmt.where(getattr(self.model, key) == value)
return stmt

def _apply_selectinload(self, stmt: Select[tuple[ModelT]], options: tuple | None = None) -> Select[tuple[ModelT]]:
def _apply_selectinload(
self, stmt: Select[tuple[ModelT]], options: tuple[ExecutableOption] | None = None
) -> Select[tuple[ModelT]]:
if options:
stmt = stmt.options(*options)
if self.undefer_load:
Expand Down Expand Up @@ -329,7 +334,9 @@ async def update_relationship_field( # noqa: PLR0913
getattr(obj, fk_name).append(target_obj)
return obj

async def get_one_or_404(self, session: AsyncSession, pk_id: int, options: tuple | None) -> ModelT:
async def get_one_or_404(
self, session: AsyncSession, pk_id: int, options: tuple[ExecutableOption] | None
) -> ModelT:
stmt = self._get_base_stmt()
if options:
stmt = self._apply_selectinload(options)
Expand Down
69 changes: 69 additions & 0 deletions src/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import AsyncGenerator
from datetime import UTC, datetime

import jwt
from fastapi import Request
from fastapi.security import Depends, HTTPBearer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src import exceptions
from src.auth.models import RolePermission, User
from src.config import settings
from src.context import locale_ctx
from src.db.session import async_session
from src.enums import ReservedRoleSlug
from src.security import API_WHITE_LISTS, JWT_ALGORITHM, JwtTokenPayload
from src.utils.cache import CacheNamespace, redis_client

token = HTTPBearer()


async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
yield session


async def auth(request: Request, session: AsyncSession = Depends(get_session), token: str = Depends(token)) -> User: # noqa: B008
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[JWT_ALGORITHM])
except jwt.DecodeError as e:
raise exceptions.TokenInvalidError from e
token_data = JwtTokenPayload(**payload)
if token_data.refresh:
raise exceptions.TokenInvalidError
now = datetime.now(tz=UTC)
if now < token_data.issued_at or now > token_data.expires_at:
raise exceptions.TokenExpireError
user = await session.get(User, token_data.sub, options=[selectinload(User.role)])
if not user:
raise exceptions.NotFoundError(user.__visible_name__[locale_ctx.get()], "id", id)
operation_id = request.scope.get("operation_id")
privileged = check_privileged_role(user.role.slug, operation_id)
if privileged:
return User
check_privileged_role(user.role_id, session, operation_id)
return User


def check_privileged_role(slug: str, operation_id: str) -> bool:
if slug == ReservedRoleSlug.ADMIN:
return True
if operation_id in API_WHITE_LISTS:
return True
return False


async def check_role_permissions(role_id: int, session: AsyncSession, operation_id: str) -> None:
permissions: list[str] | None = await redis_client.get_cache(name=str(role_id), namespace=CacheNamespace.ROLE_CACHE)
if not permissions:
permissions = (
await session.scalars(select(RolePermission.permission_id).where(RolePermission.role_id == role_id))
).all()
if not permissions:
raise exceptions.PermissionDenyError
permissions = [str(p) for p in permissions]
redis_client.set_nx(name=str(role_id), value=permissions, namespace=CacheNamespace.ROLE_CACHE)
if operation_id not in permissions:
raise exceptions.PermissionDenyError
6 changes: 6 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from enum import IntEnum

from src._types import AppStrEnum


class Env(IntEnum):
PRD = 0
DEV = 1


class ReservedRoleSlug(AppStrEnum):
ADMIN = "admin"
17 changes: 15 additions & 2 deletions src/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import NamedTuple
from typing import Any, NamedTuple


class ErrorCode(NamedTuple):
code: int
error: int
message: str
details: list[Any] | None = None

def dict(self): # noqa: ANN201
return self._asdict()


ERR_404 = ErrorCode(404, "app.not_found")
ERR_409 = ErrorCode(409, "app.already_exist")
ERR_500 = ErrorCode(500, "app.internal_server_error")
ERR_10001 = ErrorCode(10001, "User's password can not set as null.")
ERR_10002 = ErrorCode(10002, "Invalid bearer token.")
ERR_10003 = ErrorCode(10003, "Bearer token was expired.")
ERR_10004 = ErrorCode(10004, "Bearer token is invalid for refresh token was provided.")
ERR_10005 = ErrorCode(10005, "Permission deny, user with limited access for current API.")
ERR_10005 = ErrorCode(10005, "Permission deny, user with limited access for current API.")
ERR_10006 = ErrorCode(10006, "Update user failed, password can not be null.")
Loading

0 comments on commit cc09194

Please sign in to comment.