Skip to content

Commit

Permalink
Fix ambiguous column name when grouping with joining (#1766)
Browse files Browse the repository at this point in the history
  • Loading branch information
henadzit authored Nov 14, 2024
1 parent b75ca5f commit 905daaa
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 7 deletions.
73 changes: 67 additions & 6 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ async def asyncSetUp(self) -> None:
await super(TestGroupBy, self).asyncSetUp()
self.a1 = await Author.create(name="author1")
self.a2 = await Author.create(name="author2")
for i in range(10):
await Book.create(name=f"book{i}", author=self.a1, rating=i)
for i in range(5):
await Book.create(name=f"book{i}", author=self.a2, rating=i)
self.books1 = [
await Book.create(name=f"book{i}", author=self.a1, rating=i) for i in range(10)
]
self.books2 = [
await Book.create(name=f"book{i}", author=self.a2, rating=i) for i in range(5)
]

async def test_count_group_by(self):
ret = (
Expand Down Expand Up @@ -249,10 +251,69 @@ async def test_group_by_requiring_nested_joins(self):
await team_second.events.add(event_second)
await team_third.events.add(event_third)

res = (
ret = (
await Tournament.annotate(avg=Avg("events__participants__alias"))
.group_by("desc")
.order_by("desc")
.values("desc", "avg")
)
self.assertEqual(res, [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}])
self.assertEqual(ret, [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}])

async def test_group_by_ambigious_column(self):
tournament_first = await Tournament.create(name="Tournament 1")
tournament_second = await Tournament.create(name="Tournament 2")

await Event.create(name="1", tournament=tournament_first)
await Event.create(name="2", tournament=tournament_first)
await Event.create(name="3", tournament=tournament_second)

base_query = (
Tournament.annotate(event_count=Count("events")).group_by("name").order_by("name")
)
ret = await base_query.values("name", "event_count")
self.assertEqual(
ret,
[
{"event_count": 2, "name": "Tournament 1"},
{"event_count": 1, "name": "Tournament 2"},
],
)

ret = await base_query.values_list("name", "event_count")
self.assertEqual(
ret,
[("Tournament 1", 2), ("Tournament 2", 1)],
)

async def test_group_by_nested_column(self):
tournament_first = await Tournament.create(name="A")
tournament_second = await Tournament.create(name="B")

await Event.create(name="1", tournament=tournament_first)
await Event.create(name="2", tournament=tournament_first)
await Event.create(name="3", tournament=tournament_first)
await Event.create(name="4", tournament=tournament_second)

base_query = (
Event.annotate(count=Count("event_id"))
.group_by("tournament__name")
.order_by("-tournament__name")
)
ret = await base_query.values("tournament__name", "count")
self.assertEqual(
ret,
[
{"count": 1, "tournament__name": "B"},
{"count": 3, "tournament__name": "A"},
],
)

ret = await base_query.values_list("tournament__name", "count")
self.assertEqual(
ret,
[("B", 1), ("A", 3)],
)

async def test_group_by_id_with_nested_filter(self):
ret = await Book.filter(author__name="author1").group_by("id").values_list("id")
self.assertEqual(set(ret), {(book.id,) for book in self.books1})
11 changes: 11 additions & 0 deletions tests/test_order_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ async def test_order_by_related(self):
tournaments = await Tournament.all().order_by("events__name")
self.assertEqual([t.name for t in tournaments], ["2", "1"])

async def test_order_by_ambigious_field_name(self):
tournament_first = await Tournament.create(name="Tournament 1", desc="d1")
tournament_second = await Tournament.create(name="Tournament 2", desc="d2")

event_third = await Event.create(name="3", tournament=tournament_second)
event_second = await Event.create(name="2", tournament=tournament_first)
event_first = await Event.create(name="1", tournament=tournament_first)

res = await Event.all().order_by("tournament__name", "name")
self.assertEqual(res, [event_first, event_second, event_third])

async def test_order_by_related_reversed(self):
tournament_first = await Tournament.create(name="1")
tournament_second = await Tournament.create(name="2")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_source_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,21 @@ async def test_filter_with_field_f_annotation(self):
)
self.assertEqual(obj, ret_obj)

async def test_group_by(self):
await self.model.create(chars="aaa", blip="a")
await self.model.create(chars="aaa", blip="b")
await self.model.create(chars="bbb")

objs = (
await self.model.annotate(chars_count=Count("chars"))
.group_by("chars")
.order_by("chars")
.values("chars", "chars_count")
)
self.assertEqual(
objs, [{"chars": "aaa", "chars_count": 2}, {"chars": "bbb", "chars_count": 1}]
)


class SourceFieldTests(StraightFieldTests):
def setUp(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,9 @@ def _resolve_group_bys(self, *field_names: str):
field=field,
forwarded_fields=forwarded_fields,
)
field = related_table[related_db_field].as_(field_name)
field = related_table[related_db_field].as_(
f"{related_table.get_table_name()}__{field_name}"
)
group_bys.append(field)
return group_bys

Expand Down

0 comments on commit 905daaa

Please sign in to comment.