Skip to content

Commit

Permalink
Handle joins in arithmetic expressions (#1765)
Browse files Browse the repository at this point in the history
* Move query reuse tests to test_query_reuse.py

* Add test exposing issue

* Return joins from CombinedExpression.resolve
  • Loading branch information
henadzit authored Nov 14, 2024
1 parent 0ab03db commit b75ca5f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 100 deletions.
141 changes: 41 additions & 100 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Type

from tests.testmodels import (
Author,
Book,
Event,
IntFields,
MinRelation,
Expand All @@ -21,7 +23,7 @@
ParamsError,
)
from tortoise.expressions import F, RawSQL, Subquery
from tortoise.functions import Length
from tortoise.functions import Avg

# TODO: Test the many exceptions in QuerySet
# TODO: .filter(intnum_null=None) does not work as expected
Expand Down Expand Up @@ -771,122 +773,61 @@ async def test_annotation_field_priorior_to_model_field(self):
ret = await Tournament.filter(pk=t1.pk).annotate(id=RawSQL("id + 1")).values("id")
self.assertEqual(ret, [{"id": t1.pk + 1}])


class TestNotExist(test.TestCase):
exp_cls: Type[NotExistOrMultiple] = DoesNotExist

@test.requireCapability(dialect="sqlite")
def test_does_not_exist(self):
assert str(self.exp_cls("old format")) == "old format"
assert str(self.exp_cls(Tournament)) == self.exp_cls.TEMPLATE.format(Tournament.__name__)


class TestMultiple(TestNotExist):
exp_cls = MultipleObjectsReturned


class TestQueryReuse(test.TestCase):
async def test_annotations(self):
a = await Tournament.create(name="A")

base_query = Tournament.annotate(id_plus_one=F("id") + 1)
query1 = base_query.annotate(id_plus_two=F("id") + 2)
query2 = base_query.annotate(id_plus_three=F("id") + 3)
res = await query1.first()
self.assertEqual(res.id_plus_one, a.id + 1)
self.assertEqual(res.id_plus_two, a.id + 2)
with self.assertRaises(AttributeError):
getattr(res, "id_plus_three")

res = await query2.first()
self.assertEqual(res.id_plus_one, a.id + 1)
self.assertEqual(res.id_plus_three, a.id + 3)
with self.assertRaises(AttributeError):
getattr(res, "id_plus_two")

res = await query1.first()
with self.assertRaises(AttributeError):
getattr(res, "id_plus_three")

async def test_filters(self):
a = await Tournament.create(name="A")
b = await Tournament.create(name="B")
await Tournament.create(name="C")

base_query = Tournament.exclude(name="C")
tournaments = await base_query
self.assertSetEqual(set(tournaments), {a, b})

tournaments = await base_query.exclude(name="A")
self.assertSetEqual(set(tournaments), {b})

tournaments = await base_query.exclude(name="B")
self.assertSetEqual(set(tournaments), {a})

async def test_joins(self):
tournament_a = await Tournament.create(name="A")
tournament_b = await Tournament.create(name="B")
tournament_c = await Tournament.create(name="C")
event_a = await Event.create(name="A", tournament=tournament_a)
event_b = await Event.create(name="B", tournament=tournament_b)
await Event.create(name="C", tournament=tournament_c)

base_query = Event.exclude(tournament__name="C")
events = await base_query
self.assertSetEqual(set(events), {event_a, event_b})

events = await base_query.exclude(name="A")
self.assertSetEqual(set(events), {event_b})

events = await base_query.exclude(name="B")
self.assertSetEqual(set(events), {event_a})

async def test_order_by(self):
a = await Tournament.create(name="A")
b = await Tournament.create(name="B")

base_query = Tournament.all().order_by("name")
tournaments = await base_query
self.assertEqual(tournaments, [a, b])

tournaments = await base_query.order_by("-name")
self.assertEqual(tournaments, [b, a])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_values_with_annotations(self):
await Tournament.create(name="Championship")
await Tournament.create(name="Super Bowl")

base_query = Tournament.annotate(name_length=Length("name"))
tournaments = await base_query.values_list("name")
self.assertListSortEqual(tournaments, [("Championship",), ("Super Bowl",)])

tournaments = await base_query.values_list("name_length")
self.assertListSortEqual(tournaments, [(10,), (12,)])

async def test_f_annotation_referenced_in_annotation(self):
await IntFields.create(intnum=1)
instance = await IntFields.create(intnum=1)

events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).annotate(
intnum_plus_2=F("intnum_plus_1") + 1
events = (
await IntFields.filter(id=instance.id)
.annotate(intnum_plus_1=F("intnum") + 1)
.annotate(intnum_plus_2=F("intnum_plus_1") + 1)
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)

# in a single annotate call
events = await IntFields.annotate(
events = await IntFields.filter(id=instance.id).annotate(
intnum_plus_1=F("intnum") + 1, intnum_plus_2=F("intnum_plus_1") + 1
)
self.assertEqual(len(events), 1)
self.assertEqual(events[0].intnum_plus_1, 2)
self.assertEqual(events[0].intnum_plus_2, 3)

async def test_rawsql_annotation_referenced_in_annotation(self):
await IntFields.create(intnum=1)
instance = await IntFields.create(intnum=1)

events = await IntFields.annotate(ten=RawSQL("20 / 2")).annotate(ten_plus_1=F("ten") + 1)
events = (
await IntFields.filter(id=instance.id)
.annotate(ten=RawSQL("20 / 2"))
.annotate(ten_plus_1=F("ten") + 1)
)

self.assertEqual(len(events), 1)
self.assertEqual(events[0].ten, 10)
self.assertEqual(events[0].ten_plus_1, 11)

async def test_joins_in_arithmetic_expressions(self):
author = await Author.create(name="1")
await Book.create(name="1", author=author, rating=1)
await Book.create(name="2", author=author, rating=5)

ret = await Author.annotate(rating=Avg(F("books__rating") + 1))
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0].rating, 4.0)

ret = await Author.annotate(rating=Avg(F("books__rating") * 2 - F("books__rating")))
self.assertEqual(len(ret), 1)
self.assertEqual(ret[0].rating, 3.0)


class TestNotExist(test.TestCase):
exp_cls: Type[NotExistOrMultiple] = DoesNotExist

@test.requireCapability(dialect="sqlite")
def test_does_not_exist(self):
assert str(self.exp_cls("old format")) == "old format"
assert str(self.exp_cls(Tournament)) == self.exp_cls.TEMPLATE.format(Tournament.__name__)


class TestMultiple(TestNotExist):
exp_cls = MultipleObjectsReturned
88 changes: 88 additions & 0 deletions tests/test_queryset_reuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from tests.testmodels import (
Event,
Tournament,
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import NotEQ
from tortoise.expressions import F
from tortoise.functions import Length


class TestQueryReuse(test.TestCase):
async def test_annotations(self):
a = await Tournament.create(name="A")

base_query = Tournament.annotate(id_plus_one=F("id") + 1)
query1 = base_query.annotate(id_plus_two=F("id") + 2)
query2 = base_query.annotate(id_plus_three=F("id") + 3)
res = await query1.first()
self.assertEqual(res.id_plus_one, a.id + 1)
self.assertEqual(res.id_plus_two, a.id + 2)
with self.assertRaises(AttributeError):
getattr(res, "id_plus_three")

res = await query2.first()
self.assertEqual(res.id_plus_one, a.id + 1)
self.assertEqual(res.id_plus_three, a.id + 3)
with self.assertRaises(AttributeError):
getattr(res, "id_plus_two")

res = await query1.first()
with self.assertRaises(AttributeError):
getattr(res, "id_plus_three")

async def test_filters(self):
a = await Tournament.create(name="A")
b = await Tournament.create(name="B")
await Tournament.create(name="C")

base_query = Tournament.exclude(name="C")
tournaments = await base_query
self.assertSetEqual(set(tournaments), {a, b})

tournaments = await base_query.exclude(name="A")
self.assertSetEqual(set(tournaments), {b})

tournaments = await base_query.exclude(name="B")
self.assertSetEqual(set(tournaments), {a})

async def test_joins(self):
tournament_a = await Tournament.create(name="A")
tournament_b = await Tournament.create(name="B")
tournament_c = await Tournament.create(name="C")
event_a = await Event.create(name="A", tournament=tournament_a)
event_b = await Event.create(name="B", tournament=tournament_b)
await Event.create(name="C", tournament=tournament_c)

base_query = Event.exclude(tournament__name="C")
events = await base_query
self.assertSetEqual(set(events), {event_a, event_b})

events = await base_query.exclude(name="A")
self.assertSetEqual(set(events), {event_b})

events = await base_query.exclude(name="B")
self.assertSetEqual(set(events), {event_a})

async def test_order_by(self):
a = await Tournament.create(name="A")
b = await Tournament.create(name="B")

base_query = Tournament.all().order_by("name")
tournaments = await base_query
self.assertEqual(tournaments, [a, b])

tournaments = await base_query.order_by("-name")
self.assertEqual(tournaments, [b, a])

@test.requireCapability(dialect=NotEQ("mssql"))
async def test_values_with_annotations(self):
await Tournament.create(name="Championship")
await Tournament.create(name="Super Bowl")

base_query = Tournament.annotate(name_length=Length("name"))
tournaments = await base_query.values_list("name")
self.assertListSortEqual(tournaments, [("Championship",), ("Super Bowl",)])

tournaments = await base_query.values_list("name_length")
self.assertListSortEqual(tournaments, [(10,), (12,)])
1 change: 1 addition & 0 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
operator_func = getattr(operator, self.connector.name)
return ResolveResult(
term=operator_func(left.term, right.term),
joins=list(set(left.joins + right.joins)), # dedup joins
output_field=right.output_field or left.output_field, # type: ignore
)

Expand Down

0 comments on commit b75ca5f

Please sign in to comment.