Skip to content

Commit

Permalink
chore: improve type hints (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng authored Nov 24, 2024
1 parent 7f077c1 commit 1e6089d
Showing 1 changed file with 47 additions and 64 deletions.
111 changes: 47 additions & 64 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import importlib
import importlib.metadata as importlib_metadata
Expand All @@ -7,20 +9,9 @@
from copy import deepcopy
from inspect import isclass
from types import ModuleType
from typing import (
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from typing import Any, Callable, Coroutine, Iterable, Type, cast

from pypika import Table
from pypika import Query, Table

from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.backends.base.config_generator import expand_db_url, generate_config
Expand All @@ -40,8 +31,8 @@


class Tortoise:
apps: Dict[str, Dict[str, Type["Model"]]] = {}
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None
apps: dict[str, dict[str, Type["Model"]]] = {}
table_name_generator: Callable[[Type["Model"]], str] | None = None
_inited: bool = False

@classmethod
Expand All @@ -60,7 +51,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:
@classmethod
def describe_model(
cls, model: Type["Model"], serializable: bool = True
) -> dict: # pragma: nocoverage
) -> dict[str, Any]: # pragma: nocoverage
"""
Describes the given list of models or ALL registered models.
Expand All @@ -85,8 +76,8 @@ def describe_model(

@classmethod
def describe_models(
cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True
) -> Dict[str, dict]:
cls, models: list[Type["Model"]] | None = None, serializable: bool = True
) -> dict[str, dict[str, Any]]:
"""
Describes the given list of models or ALL registered models.
Expand Down Expand Up @@ -142,7 +133,7 @@ def get_related_model(related_app_name: str, related_model_name: str) -> Type["M
f" app '{related_app_name}'."
)

def split_reference(reference: str) -> Tuple[str, str]:
def split_reference(reference: str) -> tuple[str, str]:
"""
Validate, if reference follow the official naming conventions. Throws a
ConfigurationError with a hopefully helpful message. If successful,
Expand All @@ -158,12 +149,9 @@ def split_reference(reference: str) -> Tuple[str, str]:
return items[0], items[1]

def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
if is_o2o:
fk_object: Union[OneToOneFieldInstance, ForeignKeyFieldInstance] = cast(
OneToOneFieldInstance, model._meta.fields_map[field]
)
else:
fk_object = cast(ForeignKeyFieldInstance, model._meta.fields_map[field])
fk_object = cast(
"OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field]
)
related_app_name, related_model_name = split_reference(fk_object.model_name)
related_model = get_related_model(related_app_name, related_model_name)

Expand Down Expand Up @@ -206,24 +194,24 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
f'backward relation "{backward_relation_name}" duplicates in'
f" model {related_model_name}"
)
if is_o2o:
fk_relation: Union[BackwardOneToOneRelation, BackwardFKRelation] = (
BackwardOneToOneRelation(
model,
key_field,
key_fk_object.source_field,
null=True,
description=fk_object.description,
)

fk_relation = (
BackwardOneToOneRelation(
model,
key_field,
key_fk_object.source_field,
null=True,
description=fk_object.description,
)
else:
fk_relation = BackwardFKRelation(
if is_o2o
else BackwardFKRelation(
model,
key_field,
key_fk_object.source_field,
null=fk_object.null,
description=fk_object.description,
)
)
fk_relation.to_field_instance = fk_object.to_field_instance # type:ignore
related_model._meta.add_field(backward_relation_name, fk_relation)
if is_o2o and fk_object.pk:
Expand Down Expand Up @@ -251,8 +239,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
m2m_object = cast(ManyToManyFieldInstance, model._meta.fields_map[field])
if m2m_object._generated:
continue
backward_key = m2m_object.backward_key
if not backward_key:
if not (backward_key := m2m_object.backward_key):
backward_key = f"{model._meta.db_table}_id"
if backward_key == m2m_object.forward_key:
backward_key = f"{model._meta.db_table}_rel_id"
Expand All @@ -264,8 +251,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:

m2m_object.related_model = related_model

backward_relation_name = m2m_object.related_name
if not backward_relation_name:
if not (backward_relation_name := m2m_object.related_name):
backward_relation_name = m2m_object.related_name = (
f"{model._meta.db_table}s"
)
Expand Down Expand Up @@ -295,9 +281,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
related_model._meta.add_field(backward_relation_name, m2m_relation)

@classmethod
def _discover_models(
cls, models_path: Union[ModuleType, str], app_label: str
) -> List[Type["Model"]]:
def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]:
if isinstance(models_path, ModuleType):
module = models_path
else:
Expand All @@ -306,11 +290,11 @@ def _discover_models(
except ImportError:
raise ConfigurationError(f'Module "{models_path}" not found')
discovered_models = []
possible_models = getattr(module, "__models__", None)
try:
possible_models = [*possible_models] # type:ignore
except TypeError:
possible_models = None
if possible_models := getattr(module, "__models__", None):
try:
possible_models = [*possible_models]
except TypeError:
possible_models = None
if not possible_models:
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in possible_models:
Expand All @@ -326,7 +310,7 @@ def _discover_models(
@classmethod
def init_models(
cls,
models_paths: Iterable[Union[ModuleType, str]],
models_paths: Iterable[ModuleType | str],
app_label: str,
_init_relations: bool = True,
) -> None:
Expand All @@ -342,7 +326,7 @@ def init_models(
:raises ConfigurationError: If models are invalid.
"""
app_models: List[Type[Model]] = []
app_models: list[Type[Model]] = []
for models_path in models_paths:
app_models += cls._discover_models(models_path, app_label)

Expand All @@ -352,7 +336,7 @@ def init_models(
cls._init_relations()

@classmethod
def _init_apps(cls, apps_config: dict) -> None:
def _init_apps(cls, apps_config: dict[str, dict[str, Any]]) -> None:
for name, info in apps_config.items():
try:
connections.get(info.get("default_connection", "default"))
Expand Down Expand Up @@ -396,23 +380,23 @@ def _build_initial_querysets(cls) -> None:
model._meta.finalise_model()
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery = basequery # type:ignore[assignment]
model._meta.basequery_all_fields = basequery.select(
*model._meta.db_fields
) # type:ignore[assignment]
model._meta.basequery = cast(Query, basequery)
model._meta.basequery_all_fields = cast(
Query, basequery.select(*model._meta.db_fields)
)

@classmethod
async def init(
cls,
config: Optional[dict] = None,
config_file: Optional[str] = None,
config: dict[str, Any] | None = None,
config_file: str | None = None,
_create_db: bool = False,
db_url: Optional[str] = None,
modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None,
db_url: str | None = None,
modules: dict[str, Iterable[str | ModuleType]] | None = None,
use_tz: bool = False,
timezone: str = "UTC",
routers: Optional[List[Union[str, Type]]] = None,
table_name_generator: Optional[Callable[[Type["Model"]], str]] = None,
routers: list[str | type] | None = None,
table_name_generator: Callable[[Type["Model"]], str] | None = None,
) -> None:
"""
Sets up Tortoise-ORM.
Expand Down Expand Up @@ -516,8 +500,7 @@ async def init(
for name, info in connections_config.items():
if isinstance(info, str):
info = expand_db_url(info)
password = info.get("credentials", {}).get("password")
if password:
if password := info.get("credentials", {}).get("password"):
passwords.append(password)

str_connection_config = str(connections_config)
Expand All @@ -542,7 +525,7 @@ async def init(
cls._inited = True

@classmethod
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
def _init_routers(cls, routers: list[str | type] | None = None) -> None:
from tortoise.router import router

routers = routers or []
Expand Down

0 comments on commit 1e6089d

Please sign in to comment.