From c7fd9e41ce35493191200581b71072026319777b Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Sat, 16 Nov 2024 11:26:39 +0100 Subject: [PATCH] Add support for CompoundSelect (#1361) --- fastapi_pagination/ext/sqlalchemy.py | 14 ++++++++++---- tests/ext/test_sqlalchemy.py | 4 ++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/fastapi_pagination/ext/sqlalchemy.py b/fastapi_pagination/ext/sqlalchemy.py index 984f221f..f0b0116e 100644 --- a/fastapi_pagination/ext/sqlalchemy.py +++ b/fastapi_pagination/ext/sqlalchemy.py @@ -17,8 +17,9 @@ from sqlalchemy import func, select, text from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import Query, Session, noload, scoped_session +from sqlalchemy.sql import CompoundSelect, Select from sqlalchemy.sql.elements import TextClause -from typing_extensions import Literal, TypeAlias, deprecated, no_type_check +from typing_extensions import Literal, TypeAlias, deprecated, get_args, no_type_check from ..api import apply_items_transformer, create_page from ..bases import AbstractPage, AbstractParams, is_cursor @@ -29,7 +30,6 @@ if TYPE_CHECKING: from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession - from sqlalchemy.sql import Select try: @@ -77,11 +77,17 @@ def __init__(self, *_: Any, **__: Any) -> None: "unwrap", # always unwrap ] -Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]" +Selectable: TypeAlias = Union[Select, TextClause, FromStatement, CompoundSelect] @no_type_check def _should_unwrap_scalars(query: Selectable) -> bool: + if not isinstance(query, get_args(Selectable)): + return False + + if isinstance(query, CompoundSelect): + return False + try: cols_desc = query.column_descriptions all_cols = query._all_selected_columns @@ -157,7 +163,7 @@ def create_count_query(query: Selectable, *, use_subquery: bool = True) -> Selec if use_subquery: return select(func.count()).select_from(query.subquery()) - return query.with_only_columns( # type: ignore[call-arg] # noqa: PIE804 + return query.with_only_columns( # type: ignore[call-arg,union-attr] # noqa: PIE804 func.count(), maintain_column_froms=True, ) diff --git a/tests/ext/test_sqlalchemy.py b/tests/ext/test_sqlalchemy.py index 99834ad3..a4754ed7 100644 --- a/tests/ext/test_sqlalchemy.py +++ b/tests/ext/test_sqlalchemy.py @@ -78,6 +78,10 @@ def test_non_scalar_not_unwrapped(self, sa_session, sa_user, entities): lambda sa_user: select(sa_user).from_statement(select(sa_user)), lambda sa_user, item: isinstance(item, sa_user), ), + ( + lambda sa_user: select(sa_user.id).union_all(select(sa_user.id)), + lambda sa_user, item: len(item) == 1, + ), ], ) def test_unwrap_raw_results(self, sa_session, sa_user, query, validate):