Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve type hints #1779

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,25 +488,26 @@ async def init(

if config_file:
config = cls._get_config_from_config_file(config_file)

if db_url:
elif db_url:
if not modules:
raise ConfigurationError('You must specify "db_url" and "modules" together')
config = generate_config(db_url, modules)
else:
henadzit marked this conversation as resolved.
Show resolved Hide resolved
assert config is not None # To improve type hints

try:
connections_config = config["connections"] # type: ignore
connections_config = config["connections"]
except KeyError:
raise ConfigurationError('Config must define "connections" section')

try:
apps_config = config["apps"] # type: ignore
apps_config = config["apps"]
except KeyError:
raise ConfigurationError('Config must define "apps" section')

use_tz = config.get("use_tz", use_tz) # type: ignore
timezone = config.get("timezone", timezone) # type: ignore
routers = config.get("routers", routers) # type: ignore
use_tz = config.get("use_tz", use_tz)
timezone = config.get("timezone", timezone)
routers = config.get("routers", routers)

cls.table_name_generator = table_name_generator

Expand Down
2 changes: 1 addition & 1 deletion tortoise/contrib/mysql/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
fields: Optional[Tuple[str, ...]] = None,
name: Optional[str] = None,
parser_name: Optional[str] = None,
):
) -> None:
super().__init__(*expressions, fields=fields, name=name)
if parser_name:
self.extra = f" WITH PARSER {parser_name}"
Expand Down
2 changes: 1 addition & 1 deletion tortoise/contrib/postgres/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
fields: Optional[Tuple[str, ...]] = None,
name: Optional[str] = None,
condition: Optional[dict] = None,
):
) -> None:
super().__init__(*expressions, fields=fields, name=name)
if condition:
cond = " WHERE "
Expand Down
105 changes: 46 additions & 59 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,7 @@
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Iterator, Type, cast

from pypika import Case as PypikaCase
from pypika import Field as PypikaField
Expand Down Expand Up @@ -48,15 +37,15 @@
class ResolveContext:
model: Type["Model"]
table: Table
annotations: Dict[str, Any]
custom_filters: Dict[str, FilterInfoDict]
annotations: dict[str, Any]
custom_filters: dict[str, FilterInfoDict]


@dataclass
class ResolveResult:
term: Term
joins: List[TableCriterionTuple] = dataclass_field(default_factory=list)
output_field: Optional[Field] = None
joins: list[TableCriterionTuple] = dataclass_field(default_factory=list)
output_field: Field | None = None


class Expression:
Expand Down Expand Up @@ -93,25 +82,25 @@ class CombinedExpression(Expression):
def __init__(self, left: Expression, connector: Connector, right: Any) -> None:
self.left = left
self.connector = connector
self.right: Expression
if isinstance(right, Expression):
self.right = right
else:
self.right = Value(right)
self.right = right if isinstance(right, Expression) else Value(right)

def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
left = self.left.resolve(resolve_context)
right = self.right.resolve(resolve_context)
left_output_field, right_output_field = left.output_field, right.output_field # type: ignore

if left.output_field and right.output_field: # type: ignore
if type(left.output_field) is not type(right.output_field): # type: ignore
raise FieldError("Cannot use arithmetic expression between different field type")
if (
left_output_field
and right_output_field
and type(left_output_field) is not type(right_output_field)
):
raise FieldError("Cannot use arithmetic expression between different field type")

operator_func = getattr(operator, self.connector.name)
return ResolveResult(
term=operator_func(left.term, right.term),
joins=list(set(left.joins + right.joins)), # dedup joins
output_field=right.output_field or left.output_field, # type: ignore
output_field=right_output_field or left_output_field,
)


Expand All @@ -129,7 +118,7 @@ def __init__(self, name: str) -> None:

def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
term: Term = PypikaField(self.name)
joins: List[TableCriterionTuple] = []
joins: list[TableCriterionTuple] = []
output_field = None
if self.name.split("__")[0] in resolve_context.model._meta.fetch_fields:
# field in the format of "related_field__field" or "related_field__another_rel_field__field"
Expand Down Expand Up @@ -158,7 +147,7 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
except KeyError:
raise FieldError(
f"There is no non-virtual field {self.name} on Model {resolve_context.model.__name__}"
)
) from None
henadzit marked this conversation as resolved.
Show resolved Hide resolved
return ResolveResult(term=term, output_field=output_field, joins=joins)

def _combine(self, other: Any, connector: Connector, right_hand: bool) -> CombinedExpression:
Expand Down Expand Up @@ -260,9 +249,9 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None:
if not all(isinstance(node, Q) for node in args):
raise OperationalError("All ordered arguments must be Q nodes")
#: Contains the sub-Q's that this Q is made up of
self.children: Tuple[Q, ...] = args
self.children: tuple[Q, ...] = args
#: Contains the filters applied to this Q
self.filters: Dict[str, FilterInfoDict] = kwargs
self.filters: dict[str, FilterInfoDict] = kwargs
if join_type not in {self.AND, self.OR}:
raise OperationalError("join_type must be AND or OR")
#: Specifies if this Q does an AND or OR on its children
Expand Down Expand Up @@ -357,7 +346,7 @@ def _resolve_custom_kwarg(

def _process_filter_kwarg(
self, model: "Type[Model]", key: str, value: Any, table: Table
) -> Tuple[Criterion, Optional[Tuple[Table, Criterion]]]:
) -> tuple[Criterion, tuple[Table, Criterion] | None]:
join = None

if value is None and f"{key}__isnull" in model._meta.filters:
Expand Down Expand Up @@ -408,7 +397,7 @@ def _resolve_regular_kwarg(

def _get_actual_filter_params(
self, resolve_context: ResolveContext, key: str, value: Table | FilterInfoDict
) -> Tuple[str, Any]:
) -> tuple[str, Any]:
filter_key = key
if (
key in resolve_context.model._meta.fk_fields
Expand Down Expand Up @@ -513,13 +502,13 @@ class Function(Expression):
populate_field_object = False

def __init__(
self, field: Union[str, F, CombinedExpression, "Function"], *default_values: Any
self, field: str | F | CombinedExpression | "Function", *default_values: Any
) -> None:
self.field = field
self.field_object: "Optional[Field]" = None
self.field_object: "Field | None" = None
self.default_values = default_values

def _get_function_field(self, field: Union[Term, str], *default_values) -> PypikaFunction:
def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction:
return self.database_func(field, *default_values) # type:ignore[arg-type]

def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult:
Expand Down Expand Up @@ -549,26 +538,22 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:

default_values = self._resolve_default_values(resolve_context)

res = None
if isinstance(self.field, str):
function_arg = self._resolve_nested_field(resolve_context, self.field)
term = self._get_function_field(function_arg.term, *default_values)
res = ResolveResult(
term=term,
joins=function_arg.joins,
output_field=function_arg.output_field, # type: ignore
)
else:
function_arg = self.field.resolve(resolve_context)
term = self._get_function_field(function_arg.term, *default_values)
res = ResolveResult(
term=term,
joins=function_arg.joins,
output_field=function_arg.output_field, # type: ignore
)
function_arg = (
self._resolve_nested_field(resolve_context, self.field)
if isinstance(self.field, str)
else self.field.resolve(resolve_context)
)
term = self._get_function_field(function_arg.term, *default_values)
res = ResolveResult(
term=term,
joins=function_arg.joins,
output_field=function_arg.output_field, # type:ignore[call-overload]
)

if self.populate_field_object and res.output_field: # type: ignore
self.field_object = res.output_field # type: ignore
if self.populate_field_object and (
res_output_field := res.output_field # type:ignore[call-overload]
):
self.field_object = res_output_field

return res

Expand All @@ -586,17 +571,17 @@ class Aggregate(Function):

def __init__(
self,
field: Union[str, F, CombinedExpression],
field: str | F | CombinedExpression,
*default_values: Any,
distinct: bool = False,
_filter: Optional[Q] = None,
_filter: Q | None = None,
) -> None:
super().__init__(field, *default_values)
self.distinct = distinct
self.filter = _filter

def _get_function_field( # type:ignore[override]
self, field: Union[ArithmeticExpression, PypikaField, str], *default_values
self, field: ArithmeticExpression | PypikaField | str, *default_values
) -> DistinctOptionFunction:
function = cast(DistinctOptionFunction, self.database_func(field, *default_values))
if self.distinct:
Expand Down Expand Up @@ -634,7 +619,7 @@ class When(Expression):
def __init__(
self,
*args: Q,
then: Union[str, F, CombinedExpression, Function],
then: str | F | CombinedExpression | Function,
negate: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -643,7 +628,7 @@ def __init__(
self.negate = negate
self.kwargs = kwargs

def _resolve_q_objects(self) -> List[Q]:
def _resolve_q_objects(self) -> list[Q]:
q_objects = []
for arg in self.args:
if not isinstance(arg, Q):
Expand Down Expand Up @@ -684,7 +669,9 @@ class Case(Expression):
"""

def __init__(
self, *args: When, default: Union[str, F, CombinedExpression, Function, None] = None
self,
*args: When,
default: str | F | CombinedExpression | Function | None = None,
) -> None:
self.args = args
self.default = default
Expand Down
8 changes: 5 additions & 3 deletions tortoise/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
*expressions: Term,
fields: Optional[Tuple[str, ...]] = None,
name: Optional[str] = None,
):
) -> None:
"""
All kinds of index parent class, default is BTreeIndex.

Expand All @@ -38,7 +38,9 @@ def __init__(
self.expressions = expressions
self.extra = ""

def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str:
def get_sql(
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
) -> str:
if self.fields:
fields = ", ".join(schema_generator.quote(f) for f in self.fields)
else:
Expand All @@ -65,7 +67,7 @@ def __init__(
fields: Optional[Tuple[str, ...]] = None,
name: Optional[str] = None,
condition: Optional[dict] = None,
):
) -> None:
super().__init__(*expressions, fields=fields, name=name)
if condition:
cond = " WHERE "
Expand Down