diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 9b96b6e0f..8752510f9 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -488,25 +488,26 @@ async def init( if config_file: config = cls._get_config_from_config_file(config_file) - - if db_url: + elif db_url: if not modules: raise ConfigurationError('You must specify "db_url" and "modules" together') config = generate_config(db_url, modules) + else: + assert config is not None # To improve type hints try: - connections_config = config["connections"] # type: ignore + connections_config = config["connections"] except KeyError: raise ConfigurationError('Config must define "connections" section') try: - apps_config = config["apps"] # type: ignore + apps_config = config["apps"] except KeyError: raise ConfigurationError('Config must define "apps" section') - use_tz = config.get("use_tz", use_tz) # type: ignore - timezone = config.get("timezone", timezone) # type: ignore - routers = config.get("routers", routers) # type: ignore + use_tz = config.get("use_tz", use_tz) + timezone = config.get("timezone", timezone) + routers = config.get("routers", routers) cls.table_name_generator = table_name_generator diff --git a/tortoise/contrib/mysql/indexes.py b/tortoise/contrib/mysql/indexes.py index 17b612a46..4a76298d9 100644 --- a/tortoise/contrib/mysql/indexes.py +++ b/tortoise/contrib/mysql/indexes.py @@ -14,7 +14,7 @@ def __init__( fields: Optional[Tuple[str, ...]] = None, name: Optional[str] = None, parser_name: Optional[str] = None, - ): + ) -> None: super().__init__(*expressions, fields=fields, name=name) if parser_name: self.extra = f" WITH PARSER {parser_name}" diff --git a/tortoise/contrib/postgres/indexes.py b/tortoise/contrib/postgres/indexes.py index f746c3dd2..c02f4f2e0 100644 --- a/tortoise/contrib/postgres/indexes.py +++ b/tortoise/contrib/postgres/indexes.py @@ -16,7 +16,7 @@ def __init__( fields: Optional[Tuple[str, ...]] = None, name: Optional[str] = None, condition: Optional[dict] = None, - ): + ) -> None: super().__init__(*expressions, fields=fields, name=name) if condition: cond = " WHERE " diff --git a/tortoise/expressions.py b/tortoise/expressions.py index f7049e4be..51893b581 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -4,18 +4,7 @@ from dataclasses import dataclass from dataclasses import field as dataclass_field from enum import Enum, auto -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterator, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Iterator, Type, cast from pypika import Case as PypikaCase from pypika import Field as PypikaField @@ -48,15 +37,15 @@ class ResolveContext: model: Type["Model"] table: Table - annotations: Dict[str, Any] - custom_filters: Dict[str, FilterInfoDict] + annotations: dict[str, Any] + custom_filters: dict[str, FilterInfoDict] @dataclass class ResolveResult: term: Term - joins: List[TableCriterionTuple] = dataclass_field(default_factory=list) - output_field: Optional[Field] = None + joins: list[TableCriterionTuple] = dataclass_field(default_factory=list) + output_field: Field | None = None class Expression: @@ -93,25 +82,25 @@ class CombinedExpression(Expression): def __init__(self, left: Expression, connector: Connector, right: Any) -> None: self.left = left self.connector = connector - self.right: Expression - if isinstance(right, Expression): - self.right = right - else: - self.right = Value(right) + self.right = right if isinstance(right, Expression) else Value(right) def resolve(self, resolve_context: ResolveContext) -> ResolveResult: left = self.left.resolve(resolve_context) right = self.right.resolve(resolve_context) + left_output_field, right_output_field = left.output_field, right.output_field # type: ignore - if left.output_field and right.output_field: # type: ignore - if type(left.output_field) is not type(right.output_field): # type: ignore - raise FieldError("Cannot use arithmetic expression between different field type") + if ( + left_output_field + and right_output_field + and type(left_output_field) is not type(right_output_field) + ): + raise FieldError("Cannot use arithmetic expression between different field type") operator_func = getattr(operator, self.connector.name) return ResolveResult( term=operator_func(left.term, right.term), joins=list(set(left.joins + right.joins)), # dedup joins - output_field=right.output_field or left.output_field, # type: ignore + output_field=right_output_field or left_output_field, ) @@ -129,7 +118,7 @@ def __init__(self, name: str) -> None: def resolve(self, resolve_context: ResolveContext) -> ResolveResult: term: Term = PypikaField(self.name) - joins: List[TableCriterionTuple] = [] + joins: list[TableCriterionTuple] = [] output_field = None if self.name.split("__")[0] in resolve_context.model._meta.fetch_fields: # field in the format of "related_field__field" or "related_field__another_rel_field__field" @@ -158,7 +147,7 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult: except KeyError: raise FieldError( f"There is no non-virtual field {self.name} on Model {resolve_context.model.__name__}" - ) + ) from None return ResolveResult(term=term, output_field=output_field, joins=joins) def _combine(self, other: Any, connector: Connector, right_hand: bool) -> CombinedExpression: @@ -260,9 +249,9 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None: if not all(isinstance(node, Q) for node in args): raise OperationalError("All ordered arguments must be Q nodes") #: Contains the sub-Q's that this Q is made up of - self.children: Tuple[Q, ...] = args + self.children: tuple[Q, ...] = args #: Contains the filters applied to this Q - self.filters: Dict[str, FilterInfoDict] = kwargs + self.filters: dict[str, FilterInfoDict] = kwargs if join_type not in {self.AND, self.OR}: raise OperationalError("join_type must be AND or OR") #: Specifies if this Q does an AND or OR on its children @@ -357,7 +346,7 @@ def _resolve_custom_kwarg( def _process_filter_kwarg( self, model: "Type[Model]", key: str, value: Any, table: Table - ) -> Tuple[Criterion, Optional[Tuple[Table, Criterion]]]: + ) -> tuple[Criterion, tuple[Table, Criterion] | None]: join = None if value is None and f"{key}__isnull" in model._meta.filters: @@ -408,7 +397,7 @@ def _resolve_regular_kwarg( def _get_actual_filter_params( self, resolve_context: ResolveContext, key: str, value: Table | FilterInfoDict - ) -> Tuple[str, Any]: + ) -> tuple[str, Any]: filter_key = key if ( key in resolve_context.model._meta.fk_fields @@ -513,13 +502,13 @@ class Function(Expression): populate_field_object = False def __init__( - self, field: Union[str, F, CombinedExpression, "Function"], *default_values: Any + self, field: str | F | CombinedExpression | "Function", *default_values: Any ) -> None: self.field = field - self.field_object: "Optional[Field]" = None + self.field_object: "Field | None" = None self.default_values = default_values - def _get_function_field(self, field: Union[Term, str], *default_values) -> PypikaFunction: + def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction: return self.database_func(field, *default_values) # type:ignore[arg-type] def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult: @@ -549,26 +538,22 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult: default_values = self._resolve_default_values(resolve_context) - res = None - if isinstance(self.field, str): - function_arg = self._resolve_nested_field(resolve_context, self.field) - term = self._get_function_field(function_arg.term, *default_values) - res = ResolveResult( - term=term, - joins=function_arg.joins, - output_field=function_arg.output_field, # type: ignore - ) - else: - function_arg = self.field.resolve(resolve_context) - term = self._get_function_field(function_arg.term, *default_values) - res = ResolveResult( - term=term, - joins=function_arg.joins, - output_field=function_arg.output_field, # type: ignore - ) + function_arg = ( + self._resolve_nested_field(resolve_context, self.field) + if isinstance(self.field, str) + else self.field.resolve(resolve_context) + ) + term = self._get_function_field(function_arg.term, *default_values) + res = ResolveResult( + term=term, + joins=function_arg.joins, + output_field=function_arg.output_field, # type:ignore[call-overload] + ) - if self.populate_field_object and res.output_field: # type: ignore - self.field_object = res.output_field # type: ignore + if self.populate_field_object and ( + res_output_field := res.output_field # type:ignore[call-overload] + ): + self.field_object = res_output_field return res @@ -586,17 +571,17 @@ class Aggregate(Function): def __init__( self, - field: Union[str, F, CombinedExpression], + field: str | F | CombinedExpression, *default_values: Any, distinct: bool = False, - _filter: Optional[Q] = None, + _filter: Q | None = None, ) -> None: super().__init__(field, *default_values) self.distinct = distinct self.filter = _filter def _get_function_field( # type:ignore[override] - self, field: Union[ArithmeticExpression, PypikaField, str], *default_values + self, field: ArithmeticExpression | PypikaField | str, *default_values ) -> DistinctOptionFunction: function = cast(DistinctOptionFunction, self.database_func(field, *default_values)) if self.distinct: @@ -634,7 +619,7 @@ class When(Expression): def __init__( self, *args: Q, - then: Union[str, F, CombinedExpression, Function], + then: str | F | CombinedExpression | Function, negate: bool = False, **kwargs: Any, ) -> None: @@ -643,7 +628,7 @@ def __init__( self.negate = negate self.kwargs = kwargs - def _resolve_q_objects(self) -> List[Q]: + def _resolve_q_objects(self) -> list[Q]: q_objects = [] for arg in self.args: if not isinstance(arg, Q): @@ -684,7 +669,9 @@ class Case(Expression): """ def __init__( - self, *args: When, default: Union[str, F, CombinedExpression, Function, None] = None + self, + *args: When, + default: str | F | CombinedExpression | Function | None = None, ) -> None: self.args = args self.default = default diff --git a/tortoise/indexes.py b/tortoise/indexes.py index ffab65da2..6561b76f2 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -18,7 +18,7 @@ def __init__( *expressions: Term, fields: Optional[Tuple[str, ...]] = None, name: Optional[str] = None, - ): + ) -> None: """ All kinds of index parent class, default is BTreeIndex. @@ -38,7 +38,9 @@ def __init__( self.expressions = expressions self.extra = "" - def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str: + def get_sql( + self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + ) -> str: if self.fields: fields = ", ".join(schema_generator.quote(f) for f in self.fields) else: @@ -65,7 +67,7 @@ def __init__( fields: Optional[Tuple[str, ...]] = None, name: Optional[str] = None, condition: Optional[dict] = None, - ): + ) -> None: super().__init__(*expressions, fields=fields, name=name) if condition: cond = " WHERE "