Skip to content

Commit

Permalink
Add unwrap mode for sqlalchemy ext (#1316)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo authored Oct 1, 2024
1 parent 40ddbe5 commit 954f1e8
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 14 deletions.
66 changes: 54 additions & 12 deletions fastapi_pagination/ext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

import warnings
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, overload
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, TypeVar, Union, overload

from sqlalchemy import func, select, text
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Query, Session, noload, scoped_session
from sqlalchemy.sql.elements import TextClause
from typing_extensions import TypeAlias, deprecated, no_type_check
from typing_extensions import Literal, TypeAlias, deprecated, no_type_check

from ..api import apply_items_transformer, create_page
from ..bases import AbstractPage, AbstractParams, is_cursor
Expand Down Expand Up @@ -70,6 +70,13 @@ def __init__(self, *_: Any, **__: Any) -> None:
AsyncConn: TypeAlias = "Union[AsyncSession, AsyncConnection, async_scoped_session]"
SyncConn: TypeAlias = "Union[Session, Connection, scoped_session]"

UnwrapMode: TypeAlias = Literal[
"auto", # default, unwrap only if select is select(model)
"legacy", # legacy mode, unwrap only when there is one column in select
"no-unwrap", # never unwrap
"unwrap", # always unwrap
]

Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]"


Expand Down Expand Up @@ -152,6 +159,33 @@ def _maybe_unique(result: Any, unique: bool) -> Any:
raise


_TSeq = TypeVar("_TSeq", bound=Sequence[Any])


def _unwrap_items(
items: _TSeq,
query: Selectable,
unwrap_mode: Optional[UnwrapMode] = None,
) -> _TSeq:
# for raw queries we will use legacy mode by default
# because we can't determine if we should unwrap or not
if isinstance(query, (TextClause, FromStatement)): # noqa: SIM108
unwrap_mode = unwrap_mode or "legacy"
else:
unwrap_mode = unwrap_mode or "auto"

if unwrap_mode == "legacy":
items = unwrap_scalars(items)
elif unwrap_mode == "no-unwrap":
pass
elif unwrap_mode == "unwrap":
items = unwrap_scalars(items, force_unwrap=True)
elif unwrap_mode == "auto" and _should_unwrap_scalars(query):
items = unwrap_scalars(items, force_unwrap=True)

return items


def exec_pagination(
query: Selectable,
count_query: Optional[Selectable],
Expand All @@ -162,6 +196,7 @@ def exec_pagination(
subquery_count: bool = True,
unique: bool = True,
async_: bool = False,
unwrap_mode: Optional[UnwrapMode] = None,
) -> AbstractPage[Any]:
raw_params = params.to_raw_params()

Expand Down Expand Up @@ -192,8 +227,7 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:
page=raw_params.cursor, # type: ignore[arg-type]
)
items = [*page]
if _should_unwrap_scalars(query):
items = unwrap_scalars(items)
items = _unwrap_items(items, query, unwrap_mode)
items = _apply_items_transformer(items, transformer)

return create_page(
Expand All @@ -209,8 +243,7 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:

query = create_paginate_query(query, params)
items = _maybe_unique(conn.execute(query), unique)
if _should_unwrap_scalars(query):
items = unwrap_scalars(items)
items = _unwrap_items(items, query, unwrap_mode)
items = _apply_items_transformer(items, transformer)

return create_page(
Expand Down Expand Up @@ -241,6 +274,7 @@ def paginate(
params: Optional[AbstractParams] = None,
*,
subquery_count: bool = True,
unwrap_mode: Optional[UnwrapMode] = None,
transformer: Optional[SyncItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
unique: bool = True,
Expand All @@ -256,6 +290,7 @@ def paginate(
*,
count_query: Optional[Selectable] = None,
subquery_count: bool = True,
unwrap_mode: Optional[UnwrapMode] = None,
transformer: Optional[SyncItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
unique: bool = True,
Expand All @@ -271,6 +306,7 @@ async def paginate(
*,
count_query: Optional[Selectable] = None,
subquery_count: bool = True,
unwrap_mode: Optional[UnwrapMode] = None,
transformer: Optional[AsyncItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
unique: bool = True,
Expand All @@ -282,12 +318,12 @@ def paginate(*args: Any, **kwargs: Any) -> Any:
try:
assert args
assert isinstance(args[0], Query)
query, count_query, conn, params, transformer, additional_data, unique, subquery_count = _old_paginate_sign(
*args, **kwargs
query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode = (
_old_paginate_sign(*args, **kwargs)
)
except (TypeError, AssertionError):
query, count_query, conn, params, transformer, additional_data, unique, subquery_count = _new_paginate_sign(
*args, **kwargs
query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode = (
_new_paginate_sign(*args, **kwargs)
)

params, raw_params = verify_params(params, "limit-offset", "cursor")
Expand All @@ -307,6 +343,7 @@ def paginate(*args: Any, **kwargs: Any) -> Any:
additional_data,
subquery_count,
unique,
unwrap_mode=unwrap_mode,
async_=True,
)

Expand All @@ -319,6 +356,7 @@ def paginate(*args: Any, **kwargs: Any) -> Any:
additional_data,
subquery_count,
unique,
unwrap_mode=unwrap_mode,
async_=False,
)

Expand All @@ -328,6 +366,7 @@ def _old_paginate_sign(
params: Optional[AbstractParams] = None,
*,
subquery_count: bool = True,
unwrap_mode: Optional[UnwrapMode] = None,
transformer: Optional[ItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
unique: bool = True,
Expand All @@ -340,6 +379,7 @@ def _old_paginate_sign(
AdditionalData,
bool,
bool,
Optional[UnwrapMode],
]:
if query.session is None:
raise ValueError("query.session is None")
Expand All @@ -356,7 +396,7 @@ def _old_paginate_sign(
with suppress(AttributeError):
query = query._statement_20() # type: ignore[attr-defined]

return query, None, session, params, transformer, additional_data, unique, subquery_count # type: ignore
return query, None, session, params, transformer, additional_data, unique, subquery_count, unwrap_mode # type: ignore


def _new_paginate_sign(
Expand All @@ -365,6 +405,7 @@ def _new_paginate_sign(
params: Optional[AbstractParams] = None,
*,
subquery_count: bool = True,
unwrap_mode: Optional[UnwrapMode] = None,
count_query: Optional[Selectable] = None,
transformer: Optional[ItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
Expand All @@ -378,5 +419,6 @@ def _new_paginate_sign(
AdditionalData,
bool,
bool,
Optional[UnwrapMode],
]:
return query, count_query, conn, params, transformer, additional_data, unique, subquery_count
return query, count_query, conn, params, transformer, additional_data, unique, subquery_count, unwrap_mode
8 changes: 6 additions & 2 deletions fastapi_pagination/ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def len_or_none(obj: Any) -> Optional[int]:


@no_type_check
def unwrap_scalars(items: Sequence[Sequence[T]]) -> Union[Sequence[T], Sequence[Sequence[T]]]:
return [item[0] if len_or_none(item) == 1 else item for item in items]
def unwrap_scalars(
items: Sequence[Sequence[T]],
*,
force_unwrap: bool = False,
) -> Union[Sequence[T], Sequence[Sequence[T]]]:
return [item[0] if force_unwrap or len_or_none(item) == 1 else item for item in items]


@no_type_check
Expand Down
63 changes: 63 additions & 0 deletions tests/ext/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,66 @@ def test_unwrap_raw_results(self, sa_session, sa_user, query, validate):

assert page.items
assert validate(sa_user, page.items[0])

@mark.parametrize(
("unwrap_mode", "validate"),
[
(None, lambda item, sa_user: isinstance(item, sa_user)),
("auto", lambda item, sa_user: isinstance(item, sa_user)),
("legacy", lambda item, sa_user: isinstance(item, sa_user)),
("unwrap", lambda item, sa_user: isinstance(item, sa_user)),
("no-unwrap", lambda item, sa_user: len(item) == 1 and isinstance(item[0], sa_user)),
],
)
def test_unwrap_mode_select_scalar_model(self, sa_session, sa_user, unwrap_mode, validate):
with closing(sa_session()) as session, set_page(Page[Any]):
page = paginate(
session,
select(sa_user),
params=Params(page=1, size=1),
unwrap_mode=unwrap_mode,
)

assert validate(page.items[0], sa_user)

@mark.parametrize(
("unwrap_mode", "validate"),
[
(None, lambda item, sa_user: len(item) == 1),
("auto", lambda item, sa_user: len(item) == 1),
("legacy", lambda item, sa_user: isinstance(item, str)),
("unwrap", lambda item, sa_user: isinstance(item, str)),
("no-unwrap", lambda item, sa_user: len(item) == 1),
],
)
def test_unwrap_mode_select_scalar_column(self, sa_session, sa_user, unwrap_mode, validate):
with closing(sa_session()) as session, set_page(Page[Any]):
page = paginate(
session,
select(sa_user.name),
params=Params(page=1, size=1),
unwrap_mode=unwrap_mode,
)

assert validate(page.items[0], sa_user)

@mark.parametrize(
("unwrap_mode", "validate"),
[
(None, lambda item, sa_user: len(item) == 2),
("auto", lambda item, sa_user: len(item) == 2),
("legacy", lambda item, sa_user: len(item) == 2),
("unwrap", lambda item, sa_user: isinstance(item, sa_user)),
("no-unwrap", lambda item, sa_user: len(item) == 2),
],
)
def test_unwrap_mode_select_non_scalar(self, sa_session, sa_user, unwrap_mode, validate):
with closing(sa_session()) as session, set_page(Page[Any]):
page = paginate(
session,
select(sa_user, sa_user.name),
params=Params(page=1, size=1),
unwrap_mode=unwrap_mode,
)

assert validate(page.items[0], sa_user)

0 comments on commit 954f1e8

Please sign in to comment.