From 0c0540580875645fba662dcd21b8fbbf0f24bd98 Mon Sep 17 00:00:00 2001 From: He Date: Fri, 1 Nov 2024 10:07:27 +0100 Subject: [PATCH] Do not call validate in to_python_value (#1750) --- tests/fields/subclass_fields.py | 4 --- tests/test_aggregation.py | 46 ++++++++++++++++++++++++++-- tests/test_filtering.py | 2 +- tests/test_validators.py | 28 +++++++++++++++++ tortoise/backends/mssql/executor.py | 1 + tortoise/backends/sqlite/executor.py | 4 +++ tortoise/expressions.py | 2 +- tortoise/fields/base.py | 1 - tortoise/fields/data.py | 10 ------ 9 files changed, 79 insertions(+), 19 deletions(-) diff --git a/tests/fields/subclass_fields.py b/tests/fields/subclass_fields.py index 8c8f6e6b6..265127737 100644 --- a/tests/fields/subclass_fields.py +++ b/tests/fields/subclass_fields.py @@ -31,8 +31,6 @@ def to_db_value(self, value, instance): return value.value def to_python_value(self, value): - self.validate(value) - if value is None or isinstance(value, self.enum_type): return value @@ -67,8 +65,6 @@ def to_db_value(self, value: Any, instance) -> Any: return value.value def to_python_value(self, value: Any) -> Any: - self.validate(value) - if value is None or isinstance(value, self.enum_type): return value diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index c4c932898..45be84cb5 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -1,8 +1,18 @@ -from tests.testmodels import Author, Book, Event, MinRelation, Team, Tournament +from decimal import Decimal + +from tests.testmodels import ( + Author, + Book, + Event, + MinRelation, + Team, + Tournament, + ValidatorModel, +) from tortoise.contrib import test from tortoise.contrib.test.condition import In from tortoise.exceptions import ConfigurationError -from tortoise.expressions import Q +from tortoise.expressions import F, Q from tortoise.functions import Avg, Coalesce, Concat, Count, Lower, Max, Min, Sum, Trim @@ -243,3 +253,35 @@ async def test_count_without_matching(self) -> None: query = Tournament.annotate(events_count=Count("events")).filter(events_count__gt=0).count() result = await query assert result == 0 + + async def test_int_sum_on_models_with_validators(self) -> None: + await ValidatorModel.create(max_value=2) + await ValidatorModel.create(max_value=2) + + query = ValidatorModel.annotate(sum=Sum("max_value")).values("sum") + result = await query + self.assertEqual(result, [{"sum": 4}]) + + async def test_int_sum_math_on_models_with_validators(self) -> None: + await ValidatorModel.create(max_value=4) + await ValidatorModel.create(max_value=4) + + query = ValidatorModel.annotate(sum=Sum(F("max_value") * F("max_value"))).values("sum") + result = await query + self.assertEqual(result, [{"sum": 32}]) + + async def test_decimal_sum_on_models_with_validators(self) -> None: + await ValidatorModel.create(min_value_decimal=2.0) + + query = ValidatorModel.annotate(sum=Sum("min_value_decimal")).values("sum") + result = await query + self.assertEqual(result, [{"sum": Decimal("2.0")}]) + + async def test_decimal_sum_with_math_on_models_with_validators(self) -> None: + await ValidatorModel.create(min_value_decimal=2.0) + + query = ValidatorModel.annotate( + sum=Sum(F("min_value_decimal") - F("min_value_decimal") * F("min_value_decimal")) + ).values("sum") + result = await query + self.assertEqual(result, [{"sum": Decimal("-2.0")}]) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index b998368b3..fff888207 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -10,7 +10,7 @@ ) from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ -from tortoise.expressions import F, Q, Case, When +from tortoise.expressions import Case, F, Q, When from tortoise.functions import Coalesce, Count, Length, Lower, Max, Trim, Upper diff --git a/tests/test_validators.py b/tests/test_validators.py index c45dda3b5..51703a71c 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -52,3 +52,31 @@ async def test_validator_comma_separated_integer_list(self): with self.assertRaises(ValidationError): await ValidatorModel.create(comma_separated_integer_list="aaaaaa") await ValidatorModel.create(comma_separated_integer_list="1,2,3") + + async def test__prevent_saving(self): + with self.assertRaises(ValidationError): + await ValidatorModel.create(min_value_decimal=Decimal("0.9")) + + self.assertEqual(await ValidatorModel.all().count(), 0) + + async def test_save(self): + with self.assertRaises(ValidationError): + record = ValidatorModel(min_value_decimal=Decimal("0.9")) + await record.save() + + record.min_value_decimal = Decimal("1.5") + await record.save() + + async def test_save_with_update_fields(self): + record = await ValidatorModel.create(min_value_decimal=Decimal("2")) + + record.min_value_decimal = Decimal("0.9") + with self.assertRaises(ValidationError): + await record.save(update_fields=["min_value_decimal"]) + + async def test_update(self): + record = await ValidatorModel.create(min_value_decimal=Decimal("2")) + + record.min_value_decimal = Decimal("0.9") + with self.assertRaises(ValidationError): + await record.save() diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index bc14c9cfa..3b18ff9f1 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -11,6 +11,7 @@ def to_db_bool( self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] ) -> Optional[int]: + self.validate(value) if value is None: return None return int(bool(value)) diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index c1af22317..86236d2b8 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -21,6 +21,7 @@ def to_db_bool( self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] ) -> Optional[int]: + self.validate(value) if value is None: return None return int(bool(value)) @@ -31,6 +32,7 @@ def to_db_decimal( value: Optional[Union[str, float, int, Decimal]], instance: Union[Type[Model], Model], ) -> Optional[str]: + self.validate(value) if value is None: return None return str(Decimal(value).quantize(self.quant).normalize()) @@ -39,6 +41,7 @@ def to_db_decimal( def to_db_datetime( self: DatetimeField, value: Optional[datetime.datetime], instance: Union[Type[Model], Model] ) -> Optional[str]: + self.validate(value) # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( self.auto_now @@ -58,6 +61,7 @@ def to_db_datetime( def to_db_time( self: TimeField, value: Optional[datetime.time], instance: Union[Type[Model], Model] ) -> Optional[str]: + self.validate(value) if hasattr(instance, "_saved_in_db") and ( self.auto_now or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None) diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 2adae5300..58783355e 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass import operator +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 1887e63ae..67972a179 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -269,7 +269,6 @@ def to_python_value(self, value: Any) -> Any: """ if value is not None and not isinstance(value, self.field_type): value = self.field_type(value) # pylint: disable=E1102 - self.validate(value) return value def validate(self, value: Any): diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index a197c6f06..10846fb29 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -286,7 +286,6 @@ def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: def to_python_value(self, value: Any) -> Optional[Decimal]: if value is not None: value = Decimal(value).quantize(self.quant).normalize() - self.validate(value) return value @property @@ -355,7 +354,6 @@ def to_python_value(self, value: Any) -> Optional[datetime.datetime]: value = timezone.make_aware(value, get_timezone()) else: value = localtime(value) - self.validate(value) return value def to_db_value( @@ -406,7 +404,6 @@ class DateField(Field[datetime.date], datetime.date): def to_python_value(self, value: Any) -> Optional[datetime.date]: if value is not None and not isinstance(value, datetime.date): value = parse_datetime(value).date() - self.validate(value) return value def to_db_value( @@ -444,7 +441,6 @@ def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime. return value if timezone.is_naive(value): value = value.replace(tzinfo=get_default_timezone()) - self.validate(value) return value def to_db_value( @@ -493,8 +489,6 @@ class _db_oracle: SQL_TYPE = "NUMBER(19)" def to_python_value(self, value: Any) -> Optional[datetime.timedelta]: - self.validate(value) - if value is None or isinstance(value, datetime.timedelta): return value return datetime.timedelta(microseconds=value) @@ -589,7 +583,6 @@ def to_python_value( f"Value {value if isinstance(value, str) else value.decode()} is invalid json value." ) - self.validate(value) return value @@ -671,7 +664,6 @@ def __init__( def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]: value = self.enum_type(value) if value is not None else None - self.validate(value) return value def to_db_value( @@ -736,8 +728,6 @@ def __init__( self.enum_type = enum_type def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]: - self.validate(value) - return self.enum_type(value) if value is not None else None def to_db_value(