Skip to content

Commit

Permalink
Add support for CompoundSelect (#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo authored Nov 16, 2024
1 parent c702d83 commit c7fd9e4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
14 changes: 10 additions & 4 deletions fastapi_pagination/ext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from sqlalchemy import func, select, text
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Query, Session, noload, scoped_session
from sqlalchemy.sql import CompoundSelect, Select
from sqlalchemy.sql.elements import TextClause
from typing_extensions import Literal, TypeAlias, deprecated, no_type_check
from typing_extensions import Literal, TypeAlias, deprecated, get_args, no_type_check

from ..api import apply_items_transformer, create_page
from ..bases import AbstractPage, AbstractParams, is_cursor
Expand All @@ -29,7 +30,6 @@
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.sql import Select


try:
Expand Down Expand Up @@ -77,11 +77,17 @@ def __init__(self, *_: Any, **__: Any) -> None:
"unwrap", # always unwrap
]

Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]"
Selectable: TypeAlias = Union[Select, TextClause, FromStatement, CompoundSelect]


@no_type_check
def _should_unwrap_scalars(query: Selectable) -> bool:
if not isinstance(query, get_args(Selectable)):
return False

if isinstance(query, CompoundSelect):
return False

try:
cols_desc = query.column_descriptions
all_cols = query._all_selected_columns
Expand Down Expand Up @@ -157,7 +163,7 @@ def create_count_query(query: Selectable, *, use_subquery: bool = True) -> Selec
if use_subquery:
return select(func.count()).select_from(query.subquery())

return query.with_only_columns( # type: ignore[call-arg] # noqa: PIE804
return query.with_only_columns( # type: ignore[call-arg,union-attr] # noqa: PIE804
func.count(),
maintain_column_froms=True,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/ext/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def test_non_scalar_not_unwrapped(self, sa_session, sa_user, entities):
lambda sa_user: select(sa_user).from_statement(select(sa_user)),
lambda sa_user, item: isinstance(item, sa_user),
),
(
lambda sa_user: select(sa_user.id).union_all(select(sa_user.id)),
lambda sa_user, item: len(item) == 1,
),
],
)
def test_unwrap_raw_results(self, sa_session, sa_user, query, validate):
Expand Down

0 comments on commit c7fd9e4

Please sign in to comment.