Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enums not quoted #1776

Merged
merged 10 commits into from
Nov 19, 2024
286 changes: 145 additions & 141 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.2.1"
pypika-tortoise = { git = "https://github.com/waketzheng/pypika-tortoise.git", branch = "fix-enums-not-quoted" }
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
iso8601 = "^2.1.0"
aiosqlite = ">=0.16.0, <0.21.0"
pytz = "*"
Expand Down
178 changes: 56 additions & 122 deletions tests/contrib/test_pydantic.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from decimal import Decimal
from enum import Enum

from tests.testmodels import (
BooleanFields,
Expand All @@ -9,6 +10,15 @@
)
from tortoise.contrib import test
from tortoise.exceptions import FieldError
from tortoise.fields.base import StrEnum


class MyEnum(str, Enum):
moo = "moo"


class MyStrEnum(StrEnum):
moo = "moo"


class TestCharFieldFilters(test.TestCase):
Expand All @@ -29,6 +39,14 @@ async def test_equal(self):
set(await CharFields.filter(char="moo").values_list("char", flat=True)), {"moo"}
)

async def test_enum(self):
self.assertEqual(
set(await CharFields.filter(char=MyEnum.moo).values_list("char", flat=True)), {"moo"}
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
)
self.assertEqual(
set(await CharFields.filter(char=MyStrEnum.moo).values_list("char", flat=True)), {"moo"}
)

async def test_not(self):
self.assertEqual(
set(await CharFields.filter(char__not="moo").values_list("char", flat=True)),
Expand Down
10 changes: 8 additions & 2 deletions tests/test_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,16 @@ class TestQCall(TestCase):
def setUp(self) -> None:
super().setUp()
self.int_fields_context = ResolveContext(
model=IntFields, table=IntFields._meta.basequery, annotations={}, custom_filters={}
model=IntFields,
table=IntFields._meta.basequery, # type:ignore[arg-type]
annotations={},
custom_filters={},
)
self.char_fields_context = ResolveContext(
model=CharFields, table=CharFields._meta.basequery, annotations={}, custom_filters={}
model=CharFields,
table=CharFields._meta.basequery, # type:ignore[arg-type]
henadzit marked this conversation as resolved.
Show resolved Hide resolved
annotations={},
custom_filters={},
)

def test_q_basic(self):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_queryset_reuse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from tests.testmodels import (
Event,
Tournament,
)
from tests.testmodels import Event, Tournament
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.expressions import F
Expand Down
7 changes: 4 additions & 3 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,11 @@ def _build_initial_querysets(cls) -> None:
for model in app.values():
model._meta.finalise_model()
model._meta.basetable = Table(name=model._meta.db_table, schema=model._meta.schema)
model._meta.basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery_all_fields = model._meta.basequery.select(
basequery = model._meta.db.query_class.from_(model._meta.basetable)
model._meta.basequery = basequery # type:ignore[assignment]
model._meta.basequery_all_fields = basequery.select(
*model._meta.db_fields
)
) # type:ignore[assignment]

@classmethod
async def init(
Expand Down
13 changes: 7 additions & 6 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def __init__(
self.column_map[column] = field_object.to_db_value

table = self.model._meta.basetable
basequery = cast(QueryBuilder, self.model._meta.basequery)
self.delete_query = str(
self.model._meta.basequery.where(
table[self.model._meta.db_pk_column] == self.parameter(0)
).delete()
basequery.where(table[self.model._meta.db_pk_column] == self.parameter(0)).delete()
)
self.update_cache: Dict[str, str] = {}

Expand All @@ -121,13 +120,13 @@ def __init__(
) = EXECUTOR_CACHE[key]

async def execute_explain(self, query: Query) -> Any:
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql()))
sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql())) # type:ignore[attr-defined]
return (await self.db.execute_query(sql))[1]

async def execute_select(
self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None
) -> list:
_, raw_results = await self.db.execute_query(query.get_sql())
_, raw_results = await self.db.execute_query(query.get_sql()) # type:ignore[union-attr]
instance_list = []
for row in raw_results:
if self.select_related_idx:
Expand Down Expand Up @@ -543,7 +542,9 @@ def _make_prefetch_queries(self) -> None:
relation_field = self.model._meta.fields_map[field_name]
related_model: "Type[Model]" = relation_field.related_model # type: ignore
related_query = related_model.all().using_db(self.db)
related_query.query = copy(related_query.model._meta.basequery)
related_query.query = copy(
related_query.model._meta.basequery
) # type:ignore[assignment]
if forwarded_prefetches:
related_query = related_query.prefetch_related(*forwarded_prefetches)
self._prefetch_queries.setdefault(field_name, []).append((to_attr, related_query))
Expand Down
17 changes: 6 additions & 11 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Optional, Sequence
from typing import Optional, Sequence, cast

from pypika import Parameter
from pypika.dialects import PostgreSQLQueryBuilder
Expand All @@ -23,7 +23,7 @@
)


def postgres_search(field: Term, value: Term):
def postgres_search(field: Term, value: Term) -> SearchCriterion:
return SearchCriterion(field, expr=value)


Expand All @@ -44,15 +44,10 @@ def parameter(self, pos: int) -> Parameter:
def _prepare_insert_statement(
self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False
) -> PostgreSQLQueryBuilder:
query = (
self.db.query_class.into(self.model._meta.basetable)
.columns(*columns)
.insert(*[self.parameter(i) for i in range(len(columns))])
)
if has_generated:
generated_fields = self.model._meta.generated_db_fields
if generated_fields:
query = query.returning(*generated_fields)
builder = cast(PostgreSQLQueryBuilder, self.db.query_class.into(self.model._meta.basetable))
query = builder.columns(*columns).insert(*[self.parameter(i) for i in range(len(columns))])
if has_generated and (generated_fields := self.model._meta.generated_db_fields):
query = query.returning(*generated_fields)
if ignore_conflicts:
query = query.on_conflict().do_nothing()
return query
Expand Down
10 changes: 5 additions & 5 deletions tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
)


class StrWrapper(ValueWrapper): # type: ignore
class StrWrapper(ValueWrapper):
"""
Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL
"""

def get_value_sql(self, **kwargs):
def get_value_sql(self, **kwargs) -> str:
quote_char = kwargs.get("secondary_quote_char") or ""
value = self.value.replace(quote_char, quote_char * 2)
return format_quotes(value, quote_char)
Expand Down Expand Up @@ -92,12 +92,12 @@ def mysql_insensitive_ends_with(field: Term, value: str) -> Criterion:
)


def mysql_search(field: Term, value: str):
def mysql_search(field: Term, value: str) -> SearchCriterion:
return SearchCriterion(field, expr=StrWrapper(value))


def mysql_posix_regex(field: Term, value: str):
return BasicCriterion(" REGEXP ", field, StrWrapper(value))
def mysql_posix_regex(field: Term, value: str) -> BasicCriterion:
return BasicCriterion(" REGEXP ", field, StrWrapper(value)) # type:ignore[arg-type]


class MySQLExecutor(BaseExecutor):
Expand Down
6 changes: 3 additions & 3 deletions tortoise/contrib/mysql/functions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Union
from __future__ import annotations

from pypika.terms import Function, Parameter


class Rand(Function): # type: ignore
class Rand(Function):
"""
Generate random number, with optional seed.

:samp:`Rand()`
"""

def __init__(self, seed: Union[int, None] = None, alias=None) -> None:
def __init__(self, seed: int | None = None, alias=None) -> None:
super().__init__("RAND", seed, alias=alias)
self.args = [self.wrap_constant(seed) if seed is not None else Parameter("")]
38 changes: 21 additions & 17 deletions tortoise/contrib/mysql/json_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import operator
from typing import Any, Dict, List
Expand All @@ -10,24 +12,25 @@
from tortoise.filters import not_equal


class JSONContains(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, target_list: Term):
super(JSONContains, self).__init__("JSON_CONTAINS", column_name, target_list)
class JSONContains(PypikaFunction):
def __init__(self, column_name: Term, target_list: Term) -> None:
super().__init__("JSON_CONTAINS", column_name, target_list)


class JSONExtract(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, query_list: List[Term]):
class JSONExtract(PypikaFunction):
def __init__(self, column_name: Term, query_list: List[int | str | Term]) -> None:
query = self.make_query(query_list)
super(JSONExtract, self).__init__("JSON_EXTRACT", column_name, query)
super().__init__("JSON_EXTRACT", column_name, query)

@classmethod
def serialize_value(cls, value: Any):
def serialize_value(cls, value: Any) -> str:
if isinstance(value, int):
return f"[{value}]"
if isinstance(value, str):
return f".{value}"
return str(value)

def make_query(self, query_list: List[Term]):
def make_query(self, query_list: List[Term | int | str]) -> str:
query = ["$"]
for value in query_list:
query.append(self.serialize_value(value))
Expand All @@ -39,7 +42,7 @@ def mysql_json_contains(field: Term, value: str) -> Criterion:
return JSONContains(field, ValueWrapper(value))


def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
def mysql_json_contained_by(field: Term, value_str: str) -> JSONContains | None:
values = json.loads(value_str)
contained_by = None
for value in values:
Expand All @@ -50,14 +53,14 @@ def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
return contained_by


def _mysql_json_is_null(left: Term, is_null: bool):
def _mysql_json_is_null(left: Term, is_null: bool) -> Criterion:
if is_null:
return operator.eq(left, Cast("null", "JSON"))
else:
return not_equal(left, Cast("null", "JSON"))


def _mysql_json_not_is_null(left: Term, is_null: bool):
def _mysql_json_not_is_null(left: Term, is_null: bool) -> Criterion:
return _mysql_json_is_null(left, not is_null)


Expand All @@ -68,7 +71,7 @@ def _mysql_json_not_is_null(left: Term, is_null: bool):
}


def _serialize_value(value: Any):
def _serialize_value(value: Any) -> str | Any:
if type(value) in [dict, list]:
return json.dumps(value)
return value
Expand All @@ -78,8 +81,9 @@ def mysql_json_filter(field: Term, value: Dict) -> Criterion:
((key, filter_value),) = value.items()
filter_value = _serialize_value(filter_value)
key_parts = [int(item) if item.isdigit() else str(item) for item in key.split("__")]
operator_ = operator.eq
if key_parts[-1] in operator_keywords:
operator_ = operator_keywords[str(key_parts.pop(-1))] # type: ignore

return operator_(JSONExtract(field, key_parts), filter_value)
operator_ = (
operator_keywords[str(key_parts.pop(-1))]
if key_parts[-1] in operator_keywords
else operator.eq
)
return operator_(JSONExtract(field, key_parts), filter_value) # type:ignore[arg-type]
14 changes: 7 additions & 7 deletions tortoise/contrib/mysql/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pypika.terms import Term


class Comp(Comparator): # type: ignore
class Comp(Comparator):
search = " "


Expand All @@ -18,13 +18,13 @@ class Mode(Enum):
WITH_QUERY_EXPRESSION = "WITH QUERY EXPANSION"


class Match(PypikaFunction): # type: ignore
def __init__(self, *columns: Term):
class Match(PypikaFunction):
def __init__(self, *columns: Term) -> None:
super(Match, self).__init__("MATCH", *columns)


class Against(PypikaFunction): # type: ignore
def __init__(self, expr: Term, mode: Optional[Mode] = None):
class Against(PypikaFunction):
def __init__(self, expr: Term, mode: Optional[Mode] = None) -> None:
super(Against, self).__init__("AGAINST", expr)
self.mode = mode

Expand All @@ -34,10 +34,10 @@ def get_special_params_sql(self, **kwargs: Any) -> Any:
return self.mode.value


class SearchCriterion(BasicCriterion): # type: ignore
class SearchCriterion(BasicCriterion):
"""
Only support for CharField, TextField with full search indexes.
"""

def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None):
def __init__(self, *columns: Term, expr: Term, mode: Optional[Mode] = None) -> None:
super().__init__(Comp.search, Match(*columns), Against(expr, mode))
20 changes: 10 additions & 10 deletions tortoise/contrib/postgres/functions.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
from pypika.terms import Function, Term


class ToTsVector(Function): # type: ignore
class ToTsVector(Function):
"""
to to_tsvector function
"""

def __init__(self, field: Term):
super(ToTsVector, self).__init__("TO_TSVECTOR", field)
def __init__(self, field: Term) -> None:
super().__init__("TO_TSVECTOR", field)


class ToTsQuery(Function): # type: ignore
class ToTsQuery(Function):
"""
to_tsquery function
"""

def __init__(self, field: Term):
super(ToTsQuery, self).__init__("TO_TSQUERY", field)
def __init__(self, field: Term) -> None:
super().__init__("TO_TSQUERY", field)


class PlainToTsQuery(Function): # type: ignore
class PlainToTsQuery(Function):
"""
plainto_tsquery function
"""

def __init__(self, field: Term):
super(PlainToTsQuery, self).__init__("PLAINTO_TSQUERY", field)
def __init__(self, field: Term) -> None:
super().__init__("PLAINTO_TSQUERY", field)


class Random(Function): # type: ignore
class Random(Function):
"""
Generate random number.

Expand Down
Loading