diff --git a/fastapi_pagination/ext/sqlalchemy.py b/fastapi_pagination/ext/sqlalchemy.py index 4ec5a292..e53b7c63 100644 --- a/fastapi_pagination/ext/sqlalchemy.py +++ b/fastapi_pagination/ext/sqlalchemy.py @@ -12,13 +12,13 @@ import warnings from contextlib import suppress -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, TypeVar, Union, overload from sqlalchemy import func, select, text from sqlalchemy.exc import InvalidRequestError from sqlalchemy.orm import Query, Session, noload, scoped_session from sqlalchemy.sql.elements import TextClause -from typing_extensions import TypeAlias, deprecated, no_type_check +from typing_extensions import Literal, TypeAlias, deprecated, no_type_check from ..api import apply_items_transformer, create_page from ..bases import AbstractPage, AbstractParams, is_cursor @@ -70,6 +70,13 @@ def __init__(self, *_: Any, **__: Any) -> None: AsyncConn: TypeAlias = "Union[AsyncSession, AsyncConnection, async_scoped_session]" SyncConn: TypeAlias = "Union[Session, Connection, scoped_session]" +UnwrapMode: TypeAlias = Literal[ + "auto", # default, unwrap only if select is select(model) + "legacy", # legacy mode, unwrap only when there is one column in select + "no-unwrap", # never unwrap + "unwrap", # always unwrap +] + Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]" @@ -152,6 +159,33 @@ def _maybe_unique(result: Any, unique: bool) -> Any: raise +_TSeq = TypeVar("_TSeq", bound=Sequence[Any]) + + +def _unwrap_items( + items: _TSeq, + query: Selectable, + unwrap_mode: Optional[UnwrapMode] = None, +) -> _TSeq: + # for raw queries we will use legacy mode by default + # because we can't determine if we should unwrap or not + if isinstance(query, (TextClause, FromStatement)): # noqa: SIM108 + unwrap_mode = unwrap_mode or "legacy" + else: + unwrap_mode = unwrap_mode or "auto" + + if unwrap_mode == "legacy": + items = unwrap_scalars(items) + elif unwrap_mode == "no-unwrap": + pass + elif unwrap_mode == "unwrap": + items = unwrap_scalars(items, force_unwrap=True) + elif unwrap_mode == "auto" and _should_unwrap_scalars(query): + items = unwrap_scalars(items, force_unwrap=True) + + return items + + def exec_pagination( query: Selectable, count_query: Optional[Selectable], @@ -162,6 +196,7 @@ def exec_pagination( subquery_count: bool = True, unique: bool = True, async_: bool = False, + unwrap_mode: Optional[UnwrapMode] = None, ) -> AbstractPage[Any]: raw_params = params.to_raw_params() @@ -192,8 +227,7 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: page=raw_params.cursor, # type: ignore[arg-type] ) items = [*page] - if _should_unwrap_scalars(query): - items = unwrap_scalars(items) + items = _unwrap_items(items, query, unwrap_mode) items = _apply_items_transformer(items, transformer) return create_page( @@ -209,8 +243,7 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: query = create_paginate_query(query, params) items = _maybe_unique(conn.execute(query), unique) - if _should_unwrap_scalars(query): - items = unwrap_scalars(items) + items = _unwrap_items(items, query, unwrap_mode) items = _apply_items_transformer(items, transformer) return create_page( @@ -241,6 +274,7 @@ def paginate( params: Optional[AbstractParams] = None, *, subquery_count: bool = True, + unwrap_mode: Optional[UnwrapMode] = None, transformer: Optional[SyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, unique: bool = True, @@ -256,6 +290,7 @@ def paginate( *, count_query: Optional[Selectable] = None, subquery_count: bool = True, + unwrap_mode: Optional[UnwrapMode] = None, transformer: Optional[SyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, unique: bool = True, @@ -271,6 +306,7 @@ async def paginate( *, count_query: Optional[Selectable] = None, subquery_count: bool = True, + unwrap_mode: Optional[UnwrapMode] = None, transformer: Optional[AsyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, unique: bool = True, @@ -282,12 +318,12 @@ def paginate(*args: Any, **kwargs: Any) -> Any: try: assert args assert isinstance(args[0], Query) - query, count_query, conn, params, transformer, additional_data, unique, subquery_count = _old_paginate_sign( - *args, **kwargs + query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode = ( + _old_paginate_sign(*args, **kwargs) ) except (TypeError, AssertionError): - query, count_query, conn, params, transformer, additional_data, unique, subquery_count = _new_paginate_sign( - *args, **kwargs + query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode = ( + _new_paginate_sign(*args, **kwargs) ) params, raw_params = verify_params(params, "limit-offset", "cursor") @@ -307,6 +343,7 @@ def paginate(*args: Any, **kwargs: Any) -> Any: additional_data, subquery_count, unique, + unwrap_mode=unwrap_mode, async_=True, ) @@ -319,6 +356,7 @@ def paginate(*args: Any, **kwargs: Any) -> Any: additional_data, subquery_count, unique, + unwrap_mode=unwrap_mode, async_=False, ) @@ -328,6 +366,7 @@ def _old_paginate_sign( params: Optional[AbstractParams] = None, *, subquery_count: bool = True, + unwrap_mode: Optional[UnwrapMode] = None, transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, unique: bool = True, @@ -340,6 +379,7 @@ def _old_paginate_sign( AdditionalData, bool, bool, + Optional[UnwrapMode], ]: if query.session is None: raise ValueError("query.session is None") @@ -356,7 +396,7 @@ def _old_paginate_sign( with suppress(AttributeError): query = query._statement_20() # type: ignore[attr-defined] - return query, None, session, params, transformer, additional_data, unique, subquery_count # type: ignore + return query, None, session, params, transformer, additional_data, unique, subquery_count, unwrap_mode # type: ignore def _new_paginate_sign( @@ -365,6 +405,7 @@ def _new_paginate_sign( params: Optional[AbstractParams] = None, *, subquery_count: bool = True, + unwrap_mode: Optional[UnwrapMode] = None, count_query: Optional[Selectable] = None, transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, @@ -378,5 +419,6 @@ def _new_paginate_sign( AdditionalData, bool, bool, + Optional[UnwrapMode], ]: - return query, count_query, conn, params, transformer, additional_data, unique, subquery_count + return query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode diff --git a/fastapi_pagination/ext/utils.py b/fastapi_pagination/ext/utils.py index 8e6520fc..85041058 100644 --- a/fastapi_pagination/ext/utils.py +++ b/fastapi_pagination/ext/utils.py @@ -21,8 +21,12 @@ def len_or_none(obj: Any) -> Optional[int]: @no_type_check -def unwrap_scalars(items: Sequence[Sequence[T]]) -> Union[Sequence[T], Sequence[Sequence[T]]]: - return [item[0] if len_or_none(item) == 1 else item for item in items] +def unwrap_scalars( + items: Sequence[Sequence[T]], + *, + force_unwrap: bool = False, +) -> Union[Sequence[T], Sequence[Sequence[T]]]: + return [item[0] if force_unwrap or len_or_none(item) == 1 else item for item in items] @no_type_check diff --git a/tests/ext/test_sqlalchemy.py b/tests/ext/test_sqlalchemy.py index 20e5ad8d..99834ad3 100644 --- a/tests/ext/test_sqlalchemy.py +++ b/tests/ext/test_sqlalchemy.py @@ -88,3 +88,66 @@ def test_unwrap_raw_results(self, sa_session, sa_user, query, validate): assert page.items assert validate(sa_user, page.items[0]) + + @mark.parametrize( + ("unwrap_mode", "validate"), + [ + (None, lambda item, sa_user: isinstance(item, sa_user)), + ("auto", lambda item, sa_user: isinstance(item, sa_user)), + ("legacy", lambda item, sa_user: isinstance(item, sa_user)), + ("unwrap", lambda item, sa_user: isinstance(item, sa_user)), + ("no-unwrap", lambda item, sa_user: len(item) == 1 and isinstance(item[0], sa_user)), + ], + ) + def test_unwrap_mode_select_scalar_model(self, sa_session, sa_user, unwrap_mode, validate): + with closing(sa_session()) as session, set_page(Page[Any]): + page = paginate( + session, + select(sa_user), + params=Params(page=1, size=1), + unwrap_mode=unwrap_mode, + ) + + assert validate(page.items[0], sa_user) + + @mark.parametrize( + ("unwrap_mode", "validate"), + [ + (None, lambda item, sa_user: len(item) == 1), + ("auto", lambda item, sa_user: len(item) == 1), + ("legacy", lambda item, sa_user: isinstance(item, str)), + ("unwrap", lambda item, sa_user: isinstance(item, str)), + ("no-unwrap", lambda item, sa_user: len(item) == 1), + ], + ) + def test_unwrap_mode_select_scalar_column(self, sa_session, sa_user, unwrap_mode, validate): + with closing(sa_session()) as session, set_page(Page[Any]): + page = paginate( + session, + select(sa_user.name), + params=Params(page=1, size=1), + unwrap_mode=unwrap_mode, + ) + + assert validate(page.items[0], sa_user) + + @mark.parametrize( + ("unwrap_mode", "validate"), + [ + (None, lambda item, sa_user: len(item) == 2), + ("auto", lambda item, sa_user: len(item) == 2), + ("legacy", lambda item, sa_user: len(item) == 2), + ("unwrap", lambda item, sa_user: isinstance(item, sa_user)), + ("no-unwrap", lambda item, sa_user: len(item) == 2), + ], + ) + def test_unwrap_mode_select_non_scalar(self, sa_session, sa_user, unwrap_mode, validate): + with closing(sa_session()) as session, set_page(Page[Any]): + page = paginate( + session, + select(sa_user, sa_user.name), + params=Params(page=1, size=1), + unwrap_mode=unwrap_mode, + ) + + assert validate(page.items[0], sa_user)