Skip to content

Commit

Permalink
Merge pull request #16 from henadzit/parameterization-changes
Browse files Browse the repository at this point in the history
Parameterization support
  • Loading branch information
henadzit authored Nov 21, 2024
2 parents def1617 + d196b33 commit c9b8d17
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 247 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# ChangeLog

## 0.3

## 0.3.0
- Add `Parameterizer`
- Uppdate `Parameter` to be dialect-aware
- Remove `ListParameter`, `DictParameter`, `QmarkParameter`, etc.
- Wrap query's offset and limit with ValueWrapper so they can be parametrized
- Fix a missing whitespace for MSSQL when pagination without ordering is used

## 0.2

### 0.2.2
Expand Down
7 changes: 2 additions & 5 deletions pypika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@
CustomFunction,
EmptyCriterion,
Field,
FormatParameter,
Index,
Interval,
NamedParameter,
Not,
NullValue,
NumericParameter,
Parameter,
PyformatParameter,
QmarkParameter,
Parameterizer,
Rollup,
SystemTimeValue,
Tuple,
ValueWrapper,
)

NULL = NullValue()
Expand Down
25 changes: 15 additions & 10 deletions pypika/dialects/mssql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from typing import Any
from typing import Any, cast

from pypika.enums import Dialects
from pypika.exceptions import QueryException
from pypika.queries import Query, QueryBuilder
from pypika.terms import ValueWrapper
from pypika.utils import builder


Expand Down Expand Up @@ -42,25 +43,29 @@ def top(self, value: str | int) -> MSSQLQueryBuilder: # type:ignore[return]
@builder
def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return]
# Overridden to provide a more domain-specific API for T-SQL users
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

def _offset_sql(self) -> str:
def _offset_sql(self, **kwargs) -> str:
order_by = ""
if not self._orderbys:
order_by = "ORDER BY (SELECT 0)"
return order_by + " OFFSET {offset} ROWS".format(offset=self._offset or 0)
order_by = " ORDER BY (SELECT 0)"
return order_by + " OFFSET {offset} ROWS".format(
offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0
)

def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))

def _apply_pagination(self, querystring: str) -> str:
def _apply_pagination(self, querystring: str, **kwargs) -> str:
# Note: Overridden as MSSQL specifies offset before the fetch next limit
if self._limit is not None or self._offset:
# Offset has to be present if fetch next is specified in a MSSQL query
querystring += self._offset_sql()
querystring += self._offset_sql(**kwargs)

if self._limit is not None:
querystring += self._limit_sql()
querystring += self._limit_sql(**kwargs)

return querystring

Expand Down
14 changes: 9 additions & 5 deletions pypika/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str:
kwargs["groupby_alias"] = False
return super().get_sql(*args, **kwargs)

def _offset_sql(self) -> str:
return " OFFSET {offset} ROWS".format(offset=self._offset)

def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))
68 changes: 36 additions & 32 deletions pypika/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def __init__(
self._set_operation = [(set_operation, set_operation_query)]
self._orderbys: list[tuple[Field, Order | None]] = []

self._limit: int | None = None
self._offset: int | None = None
self._limit: ValueWrapper | None = None
self._offset: ValueWrapper | None = None

self._wrapper_cls = wrapper_cls

Expand All @@ -553,11 +553,11 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "Self": # type:ignore[retur

@builder
def limit(self, limit: int) -> "Self": # type:ignore[return]
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

@builder
def offset(self, offset: int) -> "Self": # type:ignore[return]
self._offset = offset
self._offset = cast(ValueWrapper, self.wrap_constant(offset))

@builder
def union(self, other: Selectable) -> "Self": # type:ignore[return]
Expand Down Expand Up @@ -624,11 +624,8 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An
if self._orderbys:
querystring += self._orderby_sql(**kwargs)

if self._limit is not None:
querystring += self._limit_sql()

if self._offset:
querystring += self._offset_sql()
querystring += self._limit_sql(**kwargs)
querystring += self._offset_sql(**kwargs)

if subquery:
querystring = "({query})".format(query=querystring, **kwargs)
Expand Down Expand Up @@ -668,11 +665,15 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str:

return " ORDER BY {orderby}".format(orderby=",".join(clauses))

def _offset_sql(self) -> str:
return " OFFSET {offset}".format(offset=self._offset)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self) -> str:
return " LIMIT {limit}".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs))


class QueryBuilder(Selectable, Term): # type:ignore[misc]
Expand Down Expand Up @@ -725,8 +726,8 @@ def __init__(
self._joins: list[Join] = []
self._unions: list = []

self._limit: int | None = None
self._offset: int | None = None
self._limit: ValueWrapper | None = None
self._offset: ValueWrapper | None = None

self._updates: list[tuple] = []

Expand Down Expand Up @@ -1223,11 +1224,11 @@ def hash_join(self, item: Table | "QueryBuilder" | AliasedQuery) -> "Joiner":

@builder
def limit(self, limit: int) -> "Self": # type:ignore[return]
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

@builder
def offset(self, offset: int) -> "Self": # type:ignore[return]
self._offset = offset
self._offset = cast(ValueWrapper, self.wrap_constant(offset))

@builder
def union(self, other: Self) -> _SetOperation:
Expand All @@ -1252,7 +1253,8 @@ def minus(self, other: Self) -> _SetOperation:
@builder
def set(self, field: Field | str, value: Any) -> "Self": # type:ignore[return]
field = Field(field) if not isinstance(field, Field) else field
self._updates.append((field, self._wrapper_cls(value)))
value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls)
self._updates.append((field, value))

def __add__(self, other: Self) -> _SetOperation: # type:ignore[override]
return self.union(other)
Expand All @@ -1265,8 +1267,10 @@ def __sub__(self, other: Self) -> _SetOperation: # type:ignore[override]

@builder
def slice(self, slice: slice) -> "Self": # type:ignore[return]
self._offset = slice.start
self._limit = slice.stop
if slice.start is not None:
self._offset = cast(ValueWrapper, self.wrap_constant(slice.start))
if slice.stop is not None:
self._limit = cast(ValueWrapper, self.wrap_constant(slice.stop))

def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override]
if not isinstance(item, slice):
Expand Down Expand Up @@ -1512,7 +1516,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An
if self._orderbys:
querystring += self._orderby_sql(**kwargs)

querystring = self._apply_pagination(querystring)
querystring = self._apply_pagination(querystring, **kwargs)

if self._for_update:
querystring += self._for_update_sql(**kwargs)
Expand All @@ -1532,13 +1536,9 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An

return querystring

def _apply_pagination(self, querystring: str) -> str:
if self._limit is not None:
querystring += self._limit_sql()

if self._offset:
querystring += self._offset_sql()

def _apply_pagination(self, querystring: str, **kwargs) -> str:
querystring += self._limit_sql(**kwargs)
querystring += self._offset_sql(**kwargs)
return querystring

def _with_sql(self, **kwargs: Any) -> str:
Expand Down Expand Up @@ -1750,11 +1750,15 @@ def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str:
having = self._havings.get_sql(quote_char=quote_char, **kwargs) # type:ignore[union-attr]
return f" HAVING {having}"

def _offset_sql(self) -> str:
return " OFFSET {offset}".format(offset=self._offset)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self) -> str:
return " LIMIT {limit}".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs))

def _set_sql(self, **kwargs: Any) -> str:
return " SET {set}".format(
Expand Down
Loading

0 comments on commit c9b8d17

Please sign in to comment.