From 5a66aa60bc189073c48f06aa29c9382168f04509 Mon Sep 17 00:00:00 2001 From: Vladimir Ulupov Date: Thu, 4 Feb 2021 14:58:55 +0300 Subject: [PATCH] fix save with F expression and field with source_field (#630) * fix save with F expression and field with source_field * fix style for executor.py file * add some tests for "fix save with F expression and field with source_field" --- tests/test_source_field.py | 17 ++++++++++++++++- tests/testmodels.py | 4 ++++ tortoise/backends/base/executor.py | 8 ++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/test_source_field.py b/tests/test_source_field.py index 761606a47..7825038a6 100644 --- a/tests/test_source_field.py +++ b/tests/test_source_field.py @@ -4,7 +4,7 @@ This is to test that behaviour doesn't change when one defined source_field parameters. """ -from tests.testmodels import SourceFields, StraightFields +from tests.testmodels import NumberSourceField, SourceFields, StraightFields from tortoise.contrib import test from tortoise.expressions import F from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper @@ -259,3 +259,18 @@ async def test_values_by_fk(self): class SourceFieldTests(StraightFieldTests): def setUp(self) -> None: self.model = SourceFields # type: ignore + + +class NumberSourceFieldTests(test.TestCase): + def setUp(self) -> None: + self.model = NumberSourceField + + async def test_f_expression_save(self): + obj1 = await self.model.create() + obj1.number = F("number") + 1 + await obj1.save() + + async def test_f_expression_save_update_fields(self): + obj1 = await self.model.create() + obj1.number = F("number") + 1 + await obj1.save(update_fields=["number"]) diff --git a/tests/testmodels.py b/tests/testmodels.py index 74c6848c5..659180849 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -725,3 +725,7 @@ class ValidatorModel(Model): comma_separated_integer_list = fields.CharField( max_length=100, null=True, validators=[CommaSeparatedIntegerListValidator()] ) + + +class NumberSourceField(Model): + number = fields.IntField(source_field="counter", default=0) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 17099017d..cdd58e1ab 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -23,6 +23,7 @@ from pypika.terms import ArithmeticExpression, Function from tortoise.exceptions import OperationalError +from tortoise.expressions import F from tortoise.fields.base import Field from tortoise.fields.relational import ( BackwardFKRelation, @@ -267,11 +268,14 @@ def get_update_sql( db_column = self.model._meta.fields_db_projection[field] field_object = self.model._meta.fields_map[field] if not field_object.pk: - if db_column not in arithmetic_or_function.keys(): + if field not in arithmetic_or_function.keys(): query = query.set(db_column, self.parameter(count)) count += 1 else: - query = query.set(db_column, arithmetic_or_function.get(db_column)) + value = F.resolver_arithmetic_expression( + self.model, arithmetic_or_function.get(field) + )[0] + query = query.set(db_column, value) query = query.where(table[self.model._meta.db_pk_column] == self.parameter(count))