From 1e6089d04a8673b83e73ae0975409b32d7445e87 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Sun, 24 Nov 2024 15:46:52 +0800 Subject: [PATCH] chore: improve type hints (#1784) --- tortoise/__init__.py | 111 ++++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 64 deletions(-) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 8752510f9..2aace9a16 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import importlib import importlib.metadata as importlib_metadata @@ -7,20 +9,9 @@ from copy import deepcopy from inspect import isclass from types import ModuleType -from typing import ( - Callable, - Coroutine, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, Coroutine, Iterable, Type, cast -from pypika import Table +from pypika import Query, Table from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.backends.base.config_generator import expand_db_url, generate_config @@ -40,8 +31,8 @@ class Tortoise: - apps: Dict[str, Dict[str, Type["Model"]]] = {} - table_name_generator: Optional[Callable[[Type["Model"]], str]] = None + apps: dict[str, dict[str, Type["Model"]]] = {} + table_name_generator: Callable[[Type["Model"]], str] | None = None _inited: bool = False @classmethod @@ -60,7 +51,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: @classmethod def describe_model( cls, model: Type["Model"], serializable: bool = True - ) -> dict: # pragma: nocoverage + ) -> dict[str, Any]: # pragma: nocoverage """ Describes the given list of models or ALL registered models. @@ -85,8 +76,8 @@ def describe_model( @classmethod def describe_models( - cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True - ) -> Dict[str, dict]: + cls, models: list[Type["Model"]] | None = None, serializable: bool = True + ) -> dict[str, dict[str, Any]]: """ Describes the given list of models or ALL registered models. @@ -142,7 +133,7 @@ def get_related_model(related_app_name: str, related_model_name: str) -> Type["M f" app '{related_app_name}'." ) - def split_reference(reference: str) -> Tuple[str, str]: + def split_reference(reference: str) -> tuple[str, str]: """ Validate, if reference follow the official naming conventions. Throws a ConfigurationError with a hopefully helpful message. If successful, @@ -158,12 +149,9 @@ def split_reference(reference: str) -> Tuple[str, str]: return items[0], items[1] def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: - if is_o2o: - fk_object: Union[OneToOneFieldInstance, ForeignKeyFieldInstance] = cast( - OneToOneFieldInstance, model._meta.fields_map[field] - ) - else: - fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field]) + fk_object = cast( + "OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field] + ) related_app_name, related_model_name = split_reference(fk_object.model_name) related_model = get_related_model(related_app_name, related_model_name) @@ -206,24 +194,24 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: f'backward relation "{backward_relation_name}" duplicates in' f" model {related_model_name}" ) - if is_o2o: - fk_relation: Union[BackwardOneToOneRelation, BackwardFKRelation] = ( - BackwardOneToOneRelation( - model, - key_field, - key_fk_object.source_field, - null=True, - description=fk_object.description, - ) + + fk_relation = ( + BackwardOneToOneRelation( + model, + key_field, + key_fk_object.source_field, + null=True, + description=fk_object.description, ) - else: - fk_relation = BackwardFKRelation( + if is_o2o + else BackwardFKRelation( model, key_field, key_fk_object.source_field, null=fk_object.null, description=fk_object.description, ) + ) fk_relation.to_field_instance = fk_object.to_field_instance # type:ignore related_model._meta.add_field(backward_relation_name, fk_relation) if is_o2o and fk_object.pk: @@ -251,8 +239,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field]) if m2m_object._generated: continue - backward_key = m2m_object.backward_key - if not backward_key: + if not (backward_key := m2m_object.backward_key): backward_key = f"{model._meta.db_table}_id" if backward_key == m2m_object.forward_key: backward_key = f"{model._meta.db_table}_rel_id" @@ -264,8 +251,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: m2m_object.related_model = related_model - backward_relation_name = m2m_object.related_name - if not backward_relation_name: + if not (backward_relation_name := m2m_object.related_name): backward_relation_name = m2m_object.related_name = ( f"{model._meta.db_table}s" ) @@ -295,9 +281,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: related_model._meta.add_field(backward_relation_name, m2m_relation) @classmethod - def _discover_models( - cls, models_path: Union[ModuleType, str], app_label: str - ) -> List[Type["Model"]]: + def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]: if isinstance(models_path, ModuleType): module = models_path else: @@ -306,11 +290,11 @@ def _discover_models( except ImportError: raise ConfigurationError(f'Module "{models_path}" not found') discovered_models = [] - possible_models = getattr(module, "__models__", None) - try: - possible_models = [*possible_models] # type:ignore - except TypeError: - possible_models = None + if possible_models := getattr(module, "__models__", None): + try: + possible_models = [*possible_models] + except TypeError: + possible_models = None if not possible_models: possible_models = [getattr(module, attr_name) for attr_name in dir(module)] for attr in possible_models: @@ -326,7 +310,7 @@ def _discover_models( @classmethod def init_models( cls, - models_paths: Iterable[Union[ModuleType, str]], + models_paths: Iterable[ModuleType | str], app_label: str, _init_relations: bool = True, ) -> None: @@ -342,7 +326,7 @@ def init_models( :raises ConfigurationError: If models are invalid. """ - app_models: List[Type[Model]] = [] + app_models: list[Type[Model]] = [] for models_path in models_paths: app_models += cls._discover_models(models_path, app_label) @@ -352,7 +336,7 @@ def init_models( cls._init_relations() @classmethod - def _init_apps(cls, apps_config: dict) -> None: + def _init_apps(cls, apps_config: dict[str, dict[str, Any]]) -> None: for name, info in apps_config.items(): try: connections.get(info.get("default_connection", "default")) @@ -396,23 +380,23 @@ def _build_initial_querysets(cls) -> None: model._meta.finalise_model() model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema) basequery = model._meta.db.query_class.from_(model._meta.basetable) - model._meta.basequery = basequery # type:ignore[assignment] - model._meta.basequery_all_fields = basequery.select( - *model._meta.db_fields - ) # type:ignore[assignment] + model._meta.basequery = cast(Query, basequery) + model._meta.basequery_all_fields = cast( + Query, basequery.select(*model._meta.db_fields) + ) @classmethod async def init( cls, - config: Optional[dict] = None, - config_file: Optional[str] = None, + config: dict[str, Any] | None = None, + config_file: str | None = None, _create_db: bool = False, - db_url: Optional[str] = None, - modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, use_tz: bool = False, timezone: str = "UTC", - routers: Optional[List[Union[str, Type]]] = None, - table_name_generator: Optional[Callable[[Type["Model"]], str]] = None, + routers: list[str | type] | None = None, + table_name_generator: Callable[[Type["Model"]], str] | None = None, ) -> None: """ Sets up Tortoise-ORM. @@ -516,8 +500,7 @@ async def init( for name, info in connections_config.items(): if isinstance(info, str): info = expand_db_url(info) - password = info.get("credentials", {}).get("password") - if password: + if password := info.get("credentials", {}).get("password"): passwords.append(password) str_connection_config = str(connections_config) @@ -542,7 +525,7 @@ async def init( cls._inited = True @classmethod - def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None: + def _init_routers(cls, routers: list[str | type] | None = None) -> None: from tortoise.router import router routers = routers or []