Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enums not quoted #1776

Merged
merged 10 commits into from
Nov 19, 2024
13 changes: 6 additions & 7 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ Changelog
------
Fixed
^^^^^
- Fix enums not quoted (#1776)
- Primary key field should not be nullable (#1778)

Added
^^^^^
- JSONField adds optional generic support, and supports OpenAPI document generation by specifying `field_type` as a pydantic BaseModel (#1763)

Changed
^^^^^^^
- Change old pydantic docs link to new one (#1775).


0.21.7
0.21.7 <../0.21.7>`_ - 2024-10-14
------
Fixed
^^^^^
Expand All @@ -36,11 +39,7 @@ Added
- Add POSIX Regex support for PostgreSQL and MySQL (#1714)
- support app=None for tortoise.contrib.fastapi.RegisterTortoise (#1733)

Changed
^^^^^^^
- Change old pydantic docs link to new one (#1775).

0.21.6
0.21.6 <../0.21.6>`_ - 2024-08-17
------
Fixed
^^^^^
Expand Down
284 changes: 142 additions & 142 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.2.1"
pypika-tortoise = "^0.2.2"
iso8601 = "^2.1.0"
aiosqlite = ">=0.16.0, <0.21.0"
pytz = "*"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from decimal import Decimal
from enum import Enum

from tests.testmodels import (
BooleanFields,
Expand All @@ -9,6 +10,15 @@
)
from tortoise.contrib import test
from tortoise.exceptions import FieldError
from tortoise.fields.base import StrEnum


class MyEnum(str, Enum):
moo = "moo"


class MyStrEnum(StrEnum):
moo = "moo"


class TestCharFieldFilters(test.TestCase):
Expand All @@ -29,6 +39,14 @@ async def test_equal(self):
set(await CharFields.filter(char="moo").values_list("char", flat=True)), {"moo"}
)

async def test_enum(self):
self.assertEqual(
set(await CharFields.filter(char=MyEnum.moo).values_list("char", flat=True)), {"moo"}
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
)
self.assertEqual(
set(await CharFields.filter(char=MyStrEnum.moo).values_list("char", flat=True)), {"moo"}
)

async def test_not(self):
self.assertEqual(
set(await CharFields.filter(char__not="moo").values_list("char", flat=True)),
Expand Down
10 changes: 8 additions & 2 deletions tests/test_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ class TestQCall(TestCase):
def setUp(self) -> None:
super().setUp()
self.int_fields_context = ResolveContext(
model=IntFields, table=IntFields._meta.basequery, annotations={}, custom_filters={}
model=IntFields,
table=IntFields._meta.basequery, # type:ignore[arg-type]
annotations={},
custom_filters={},
)
self.char_fields_context = ResolveContext(
model=CharFields, table=CharFields._meta.basequery, annotations={}, custom_filters={}
model=CharFields,
table=CharFields._meta.basequery, # type:ignore[arg-type]
henadzit marked this conversation as resolved.
Show resolved Hide resolved
annotations={},
custom_filters={},
)

def test_q_basic(self):
Expand Down
9 changes: 5 additions & 4 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,11 @@ def _build_initial_querysets(cls) -> None:
for model in app.values():
model._meta.finalise_model()
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
model._meta.basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery_all_fields = model._meta.basequery.select(
basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery = basequery # type:ignore[assignment]
model._meta.basequery_all_fields = basequery.select(
*model._meta.db_fields
)
) # type:ignore[assignment]

@classmethod
async def init(
Expand Down Expand Up @@ -517,7 +518,7 @@ async def init(
cls._inited = True

@classmethod
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None):
def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None:
from tortoise.router import router

routers = routers or []
Expand Down
13 changes: 7 additions & 6 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def __init__(
self.column_map[column] = field_object.to_db_value

table = self.model._meta.basetable
basequery = cast(QueryBuilder, self.model._meta.basequery)
self.delete_query = str(
self.model._meta.basequery.where(
table[self.model._meta.db_pk_column] == self.parameter(0)
).delete()
basequery.where(table[self.model._meta.db_pk_column] == self.parameter(0)).delete()
)
self.update_cache: Dict[str, str] = {}

Expand All @@ -121,13 +120,13 @@ def __init__(
) = EXECUTOR_CACHE[key]

async def execute_explain(self, query: Query) -> Any:
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql()))
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql())) # type:ignore[attr-defined]
return (await self.db.execute_query(sql))[1]

async def execute_select(
self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None
) -> list:
_, raw_results = await self.db.execute_query(query.get_sql())
_, raw_results = await self.db.execute_query(query.get_sql()) # type:ignore[union-attr]
instance_list = []
for row in raw_results:
if self.select_related_idx:
Expand Down Expand Up @@ -543,7 +542,9 @@ def _make_prefetch_queries(self) -> None:
relation_field = self.model._meta.fields_map[field_name]
related_model: "Type[Model]" = relation_field.related_model # type: ignore
related_query = related_model.all().using_db(self.db)
related_query.query = copy(related_query.model._meta.basequery)
related_query.query = copy(
related_query.model._meta.basequery
) # type:ignore[assignment]
if forwarded_prefetches:
related_query = related_query.prefetch_related(*forwarded_prefetches)
self._prefetch_queries.setdefault(field_name, []).append((to_attr, related_query))
Expand Down
17 changes: 6 additions & 11 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Optional, Sequence
from typing import Optional, Sequence, cast

from pypika import Parameter
from pypika.dialects import PostgreSQLQueryBuilder
Expand All @@ -23,7 +23,7 @@
)


def postgres_search(field: Term, value: Term):
def postgres_search(field: Term, value: Term) -> SearchCriterion:
return SearchCriterion(field, expr=value)


Expand All @@ -44,15 +44,10 @@ def parameter(self, pos: int) -> Parameter:
def _prepare_insert_statement(
self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False
) -> PostgreSQLQueryBuilder:
query = (
self.db.query_class.into(self.model._meta.basetable)
.columns(*columns)
.insert(*[self.parameter(i) for i in range(len(columns))])
)
if has_generated:
generated_fields = self.model._meta.generated_db_fields
if generated_fields:
query = query.returning(*generated_fields)
builder = cast(PostgreSQLQueryBuilder, self.db.query_class.into(self.model._meta.basetable))
query = builder.columns(*columns).insert(*[self.parameter(i) for i in range(len(columns))])
if has_generated and (generated_fields := self.model._meta.generated_db_fields):
query = query.returning(*generated_fields)
if ignore_conflicts:
query = query.on_conflict().do_nothing()
return query
Expand Down
10 changes: 5 additions & 5 deletions tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
)


class StrWrapper(ValueWrapper): # type: ignore
class StrWrapper(ValueWrapper):
"""
Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL
"""

def get_value_sql(self, **kwargs):
def get_value_sql(self, **kwargs) -> str:
quote_char = kwargs.get("secondary_quote_char") or ""
value = self.value.replace(quote_char, quote_char * 2)
return format_quotes(value, quote_char)
Expand Down Expand Up @@ -92,12 +92,12 @@ def mysql_insensitive_ends_with(field: Term, value: str) -> Criterion:
)


def mysql_search(field: Term, value: str):
def mysql_search(field: Term, value: str) -> SearchCriterion:
return SearchCriterion(field, expr=StrWrapper(value))


def mysql_posix_regex(field: Term, value: str):
return BasicCriterion(" REGEXP ", field, StrWrapper(value))
def mysql_posix_regex(field: Term, value: str) -> BasicCriterion:
return BasicCriterion(" REGEXP ", field, StrWrapper(value)) # type:ignore[arg-type]


class MySQLExecutor(BaseExecutor):
Expand Down
2 changes: 1 addition & 1 deletion tortoise/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self) -> None:
self._db_config: Optional["DBConfigType"] = None
self._create_db: bool = False

async def _init(self, db_config: "DBConfigType", create_db: bool):
async def _init(self, db_config: "DBConfigType", create_db: bool) -> None:
if self._db_config is None:
self._db_config = db_config
else:
Expand Down
6 changes: 3 additions & 3 deletions tortoise/contrib/mysql/functions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Union
from __future__ import annotations

from pypika.terms import Function, Parameter


class Rand(Function): # type: ignore
class Rand(Function):
"""
Generate random number, with optional seed.

:samp:`Rand()`
"""

def __init__(self, seed: Union[int, None] = None, alias=None) -> None:
def __init__(self, seed: int | None = None, alias=None) -> None:
super().__init__("RAND", seed, alias=alias)
self.args = [self.wrap_constant(seed) if seed is not None else Parameter("")]
43 changes: 17 additions & 26 deletions tortoise/contrib/mysql/json_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import operator
from typing import Any, Dict, List
Expand All @@ -7,27 +9,28 @@
from pypika.terms import Function as PypikaFunction
from pypika.terms import Term, ValueWrapper

from tortoise.filters import not_equal
from tortoise.filters import get_json_filter_operator, not_equal


class JSONContains(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, target_list: Term):
super(JSONContains, self).__init__("JSON_CONTAINS", column_name, target_list)
class JSONContains(PypikaFunction):
def __init__(self, column_name: Term, target_list: Term) -> None:
super().__init__("JSON_CONTAINS", column_name, target_list)


class JSONExtract(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, query_list: List[Term]):
class JSONExtract(PypikaFunction):
def __init__(self, column_name: Term, query_list: List[int | str | Term]) -> None:
query = self.make_query(query_list)
super(JSONExtract, self).__init__("JSON_EXTRACT", column_name, query)
super().__init__("JSON_EXTRACT", column_name, query)

@classmethod
def serialize_value(cls, value: Any):
def serialize_value(cls, value: Any) -> str:
if isinstance(value, int):
return f"[{value}]"
if isinstance(value, str):
return f".{value}"
return str(value)

def make_query(self, query_list: List[Term]):
def make_query(self, query_list: List[Term | int | str]) -> str:
query = ["$"]
for value in query_list:
query.append(self.serialize_value(value))
Expand All @@ -39,7 +42,7 @@ def mysql_json_contains(field: Term, value: str) -> Criterion:
return JSONContains(field, ValueWrapper(value))


def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
def mysql_json_contained_by(field: Term, value_str: str) -> JSONContains | None:
values = json.loads(value_str)
contained_by = None
for value in values:
Expand All @@ -50,14 +53,14 @@ def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
return contained_by


def _mysql_json_is_null(left: Term, is_null: bool):
def _mysql_json_is_null(left: Term, is_null: bool) -> Criterion:
if is_null:
return operator.eq(left, Cast("null", "JSON"))
else:
return not_equal(left, Cast("null", "JSON"))


def _mysql_json_not_is_null(left: Term, is_null: bool):
def _mysql_json_not_is_null(left: Term, is_null: bool) -> Criterion:
return _mysql_json_is_null(left, not is_null)


Expand All @@ -68,18 +71,6 @@ def _mysql_json_not_is_null(left: Term, is_null: bool):
}


def _serialize_value(value: Any):
if type(value) in [dict, list]:
return json.dumps(value)
return value


def mysql_json_filter(field: Term, value: Dict) -> Criterion:
((key, filter_value),) = value.items()
filter_value = _serialize_value(filter_value)
key_parts = [int(item) if item.isdigit() else str(item) for item in key.split("__")]
operator_ = operator.eq
if key_parts[-1] in operator_keywords:
operator_ = operator_keywords[str(key_parts.pop(-1))] # type: ignore

return operator_(JSONExtract(field, key_parts), filter_value)
key_parts, filter_value, operator_ = get_json_filter_operator(value, operator_keywords)
return operator_(JSONExtract(field, key_parts), filter_value) # type:ignore[arg-type]
14 changes: 7 additions & 7 deletions tortoise/contrib/mysql/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pypika.terms import Term


class Comp(Comparator): # type: ignore
class Comp(Comparator):
search = " "


Expand All @@ -18,13 +18,13 @@ class Mode(Enum):
WITH_QUERY_EXPRESSION = "WITH QUERY EXPANSION"


class Match(PypikaFunction): # type: ignore
def __init__(self, *columns: Term):
class Match(PypikaFunction):
def __init__(self, *columns: Term) -> None:
super(Match, self).__init__("MATCH", *columns)


class Against(PypikaFunction): # type: ignore
def __init__(self, expr: Term, mode: Optional[Mode] = None):
class Against(PypikaFunction):
def __init__(self, expr: Term, mode: Optional[Mode] = None) -> None:
super(Against, self).__init__("AGAINST", expr)
self.mode = mode

Expand All @@ -34,10 +34,10 @@ def get_special_params_sql(self, **kwargs: Any) -> Any:
return self.mode.value


class SearchCriterion(BasicCriterion): # type: ignore
class SearchCriterion(BasicCriterion):
"""
Only support for CharField, TextField with full search indexes.
"""

def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None):
def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None) -> None:
super().__init__(Comp.search, Match(*columns), Against(expr, mode))
Loading