Skip to content

Commit

Permalink
chore: rollback version and add more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Nov 19, 2024
1 parent e1f398b commit 33e2c4b
Show file tree
Hide file tree
Showing 16 changed files with 97 additions and 86 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tortoise-orm"
version = "0.21.8"
version = "0.21.7"
description = "Easy async ORM for python, built with relations in mind"
authors = ["Andrey Bondar <andrey@bondar.ru>", "Nickolas Grigoriadis <nagrigoriadis@gmail.com>", "long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ async def init(
cls._inited = True

@classmethod
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None):
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
from tortoise.router import router

routers = routers or []
Expand Down
2 changes: 1 addition & 1 deletion tortoise/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self) -> None:
self._db_config: Optional["DBConfigType"] = None
self._create_db: bool = False

async def _init(self, db_config: "DBConfigType", create_db: bool):
async def _init(self, db_config: "DBConfigType", create_db: bool) -> None:
if self._db_config is None:
self._db_config = db_config
else:
Expand Down
2 changes: 1 addition & 1 deletion tortoise/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_escape_table[ord("'")] = "\\'"


def _escape_unicode(value: str, mapping=None):
def _escape_unicode(value: str, mapping=None) -> str:
"""escapes *value* without adding quote.
Value should be unicode
Expand Down
4 changes: 2 additions & 2 deletions tortoise/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ class ObjectDoesNotExistError(OperationalError, KeyError):
The DoesNotExist exception is raised when an item with the passed primary key does not exist
"""

def __init__(self, model: "Type[Model]", pk_name: str, pk_val: Any):
def __init__(self, model: "Type[Model]", pk_name: str, pk_val: Any) -> None:
self.model: "Type[Model]" = model
self.pk_name: str = pk_name
self.pk_val: Any = pk_val

def __str__(self):
def __str__(self) -> str:
return f"{self.model.__name__} has no object with {self.pk_name}={self.pk_val}"


Expand Down
34 changes: 17 additions & 17 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Value(Expression):
Wrapper for a value that should be used as a term in a query.
"""

def __init__(self, value: Any):
def __init__(self, value: Any) -> None:
self.value = value

def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
Expand All @@ -90,7 +90,7 @@ class Connector(Enum):


class CombinedExpression(Expression):
def __init__(self, left: Expression, connector: Connector, right: Any):
def __init__(self, left: Expression, connector: Connector, right: Any) -> None:
self.left = left
self.connector = connector
self.right: Expression
Expand Down Expand Up @@ -124,7 +124,7 @@ class F(Expression):
:param name: The name of the field to reference.
"""

def __init__(self, name: str):
def __init__(self, name: str) -> None:
self.name = name

def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
Expand Down Expand Up @@ -169,43 +169,43 @@ def _combine(self, other: Any, connector: Connector, right_hand: bool) -> Combin
return CombinedExpression(other, connector, self)
return CombinedExpression(self, connector, other)

def __neg__(self):
def __neg__(self) -> CombinedExpression:
return self._combine(-1, Connector.mul, False)

def __add__(self, other):
def __add__(self, other) -> CombinedExpression:
return self._combine(other, Connector.add, False)

def __sub__(self, other):
def __sub__(self, other) -> CombinedExpression:
return self._combine(other, Connector.sub, False)

def __mul__(self, other):
def __mul__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mul, False)

def __truediv__(self, other):
def __truediv__(self, other) -> CombinedExpression:
return self._combine(other, Connector.div, False)

def __mod__(self, other):
def __mod__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mod, False)

def __pow__(self, other):
def __pow__(self, other) -> CombinedExpression:
return self._combine(other, Connector.pow, False)

def __radd__(self, other):
def __radd__(self, other) -> CombinedExpression:
return self._combine(other, Connector.add, True)

def __rsub__(self, other):
def __rsub__(self, other) -> CombinedExpression:
return self._combine(other, Connector.sub, True)

def __rmul__(self, other):
def __rmul__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mul, True)

def __rtruediv__(self, other):
def __rtruediv__(self, other) -> CombinedExpression:
return self._combine(other, Connector.div, True)

def __rmod__(self, other):
def __rmod__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mod, True)

def __rpow__(self, other):
def __rpow__(self, other) -> CombinedExpression:
return self._combine(other, Connector.pow, True)


Expand Down Expand Up @@ -519,7 +519,7 @@ def __init__(
self.field_object: "Optional[Field]" = None
self.default_values = default_values

def _get_function_field(self, field: Union[Term, str], *default_values):
def _get_function_field(self, field: Union[Term, str], *default_values) -> PypikaFunction:
return self.database_func(field, *default_values) # type:ignore[arg-type]

def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult:
Expand Down
4 changes: 2 additions & 2 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OnDelete(StrEnum):

class _FieldMeta(type):
# TODO: Require functions to return field instances instead of this hack
def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict):
def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict) -> type:
if len(bases) > 1 and bases[0] is Field:
# Instantiate class with only the 1st base class (should be Field)
cls = type.__new__(mcs, name, (bases[0],), attrs)
Expand Down Expand Up @@ -271,7 +271,7 @@ def to_python_value(self, value: Any) -> Any:
value = self.field_type(value) # pylint: disable=E1102
return value

def validate(self, value: Any):
def validate(self, value: Any) -> None:
"""
Validate whether given value is valid
Expand Down
4 changes: 2 additions & 2 deletions tortoise/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def contains(field: Term, value: str) -> Criterion:
return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}%"))


def search(field: Term, value: str):
def search(field: Term, value: str) -> Any:
# will be override in each executor
pass


def posix_regex(field: Term, value: str):
def posix_regex(field: Term, value: str) -> Any:
# Will be overridden in each executor
raise NotImplementedError(
"The postgres_posix_regex filter operator is not supported by your database backend"
Expand Down
4 changes: 2 additions & 2 deletions tortoise/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.expressions = expressions
self.extra = ""

def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool):
def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str:
if self.fields:
fields = ", ".join(schema_generator.quote(f) for f in self.fields)
else:
Expand All @@ -54,7 +54,7 @@ def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]",
extra=self.extra,
)

def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]"):
def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
return self.name or schema_generator._generate_index_name("idx", model, self.fields)


Expand Down
4 changes: 3 additions & 1 deletion tortoise/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from tortoise.queryset import QuerySet


Expand All @@ -14,5 +16,5 @@ def __init__(self, model=None) -> None:
def get_queryset(self) -> QuerySet:
return QuerySet(self._model)

def __getattr__(self, item):
def __getattr__(self, item: str) -> Any:
return getattr(self.get_queryset(), item)
10 changes: 5 additions & 5 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _generate_filters(self) -> None:
class ModelMeta(type):
__slots__ = ()

def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict):
def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict) -> "ModelMeta":
fields_db_projection: Dict[str, str] = {}
fields_map: Dict[str, Field] = {}
filters: Dict[str, FilterInfoDict] = {}
Expand Down Expand Up @@ -676,7 +676,7 @@ def __init__(self, **kwargs: Any) -> None:
else:
setattr(self, key, deepcopy(field_object.default))

def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
# set field value override async default function
if hasattr(self, "_await_when_save"):
self._await_when_save.pop(key, None)
Expand Down Expand Up @@ -782,7 +782,7 @@ def __hash__(self) -> int:
raise TypeError("Model instances without id are unhashable")
return hash(self.pk)

def __iter__(self):
def __iter__(self) -> Iterable[Tuple]:
for field in self._meta.db_fields:
yield field, getattr(self, field)

Expand Down Expand Up @@ -850,7 +850,7 @@ def update_from_dict(self: MODEL, data: dict) -> MODEL:
return self

@classmethod
def register_listener(cls, signal: Signals, listener: Callable):
def register_listener(cls, signal: Signals, listener: Callable) -> None:
"""
Register listener to current model class for special Signal.
Expand Down Expand Up @@ -1020,7 +1020,7 @@ async def refresh_from_db(
setattr(self, field, getattr(obj, field, None))

@classmethod
def _choose_db(cls, for_write: bool = False):
def _choose_db(cls, for_write: bool = False) -> BaseDBAsyncClient:
"""
Return the connection that will be used if this query is executed now.
Expand Down
6 changes: 3 additions & 3 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _join_table_by_field(
self._join_table(join)
return joins[-1][0]

def _join_table(self, table_criterio_tuple: TableCriterionTuple):
def _join_table(self, table_criterio_tuple: TableCriterionTuple) -> None:
if table_criterio_tuple[0] not in self._joined_tables:
self.query = self.query.join(table_criterio_tuple[0], how=JoinType.left_outer).on(
table_criterio_tuple[1]
Expand Down Expand Up @@ -1491,7 +1491,7 @@ def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable:

raise FieldError(f'Unknown field "{field}" for model "{model}"')

def _resolve_group_bys(self, *field_names: str):
def _resolve_group_bys(self, *field_names: str) -> List:
group_bys = []
for field_name in field_names:
if field_name in self._annotations:
Expand Down Expand Up @@ -1777,7 +1777,7 @@ async def _execute(self) -> Union[List[dict], Dict]:
class RawSQLQuery(AwaitableQuery):
__slots__ = ("_sql", "_db")

def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str):
def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None:
super().__init__(model)
self._sql = sql
self._db = db
Expand Down
16 changes: 9 additions & 7 deletions tortoise/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, List, Optional, Type
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Type

from tortoise.connection import connections
from tortoise.exceptions import ConfigurationError
Expand All @@ -9,12 +11,12 @@

class ConnectionRouter:
def __init__(self) -> None:
self._routers: List[type] = None # type: ignore
self._routers: list[type] = None # type: ignore

def init_routers(self, routers: List[type]):
def init_routers(self, routers: list[Callable]) -> None:
self._routers = [r() for r in routers]

def _router_func(self, model: Type["Model"], action: str):
def _router_func(self, model: Type["Model"], action: str) -> Any:
for r in self._routers:
try:
method = getattr(r, action)
Expand All @@ -26,16 +28,16 @@ def _router_func(self, model: Type["Model"], action: str):
if chosen_db:
return chosen_db

def _db_route(self, model: Type["Model"], action: str):
def _db_route(self, model: Type["Model"], action: str) -> "BaseDBAsyncClient" | None:
try:
return connections.get(self._router_func(model, action))
except ConfigurationError:
return None

def db_for_read(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]:
def db_for_read(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None:
return self._db_route(model, "db_for_read")

def db_for_write(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]:
def db_for_write(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None:
return self._db_route(model, "db_for_write")


Expand Down
20 changes: 11 additions & 9 deletions tortoise/signals.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,64 @@
from enum import Enum
from typing import Callable
from typing import Callable, TypeVar

T = TypeVar("T")
FuncType = Callable[[T], T]
Signals = Enum("Signals", ["pre_save", "post_save", "pre_delete", "post_delete"])


def post_save(*senders) -> Callable:
def post_save(*senders) -> FuncType:
"""
Register given models post_save signal.
:param senders: Model class
"""

def decorator(f):
def decorator(f: T) -> T:
for sender in senders:
sender.register_listener(Signals.post_save, f)
return f

return decorator


def pre_save(*senders) -> Callable:
def pre_save(*senders) -> FuncType:
"""
Register given models pre_save signal.
:param senders: Model class
"""

def decorator(f):
def decorator(f: T) -> T:
for sender in senders:
sender.register_listener(Signals.pre_save, f)
return f

return decorator


def pre_delete(*senders) -> Callable:
def pre_delete(*senders) -> FuncType:
"""
Register given models pre_delete signal.
:param senders: Model class
"""

def decorator(f):
def decorator(f: T) -> T:
for sender in senders:
sender.register_listener(Signals.pre_delete, f)
return f

return decorator


def post_delete(*senders) -> Callable:
def post_delete(*senders) -> FuncType:
"""
Register given models post_delete signal.
:param senders: Model class
"""

def decorator(f):
def decorator(f: T) -> T:
for sender in senders:
sender.register_listener(Signals.post_delete, f)
return f
Expand Down
Loading

0 comments on commit 33e2c4b

Please sign in to comment.