From c07f87898f0a4a8989ad4bcdf0c0a2180f6b13e9 Mon Sep 17 00:00:00 2001 From: Lars Schwegmann Date: Wed, 2 Oct 2024 15:20:19 +0200 Subject: [PATCH 1/8] fix functions.Function not passing down kwargs during sql generation of function --- pypika/terms.py | 2 +- tests/test_parameter.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index c0eccb0..d1e2605 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1510,7 +1510,7 @@ def get_sql(self, **kwargs: Any) -> str: # FIXME escape function_sql = self.get_function_sql( - with_namespace=with_namespace, quote_char=quote_char, dialect=dialect + with_namespace=with_namespace, quote_char=quote_char, dialect=dialect, **kwargs ) if self.schema is not None: diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 920803a..1e6d0aa 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -11,7 +11,8 @@ Query, Tables, ) -from pypika.terms import ListParameter, ParameterValueWrapper +from pypika.functions import Upper +from pypika.terms import ListParameter, ParameterValueWrapper, ValueWrapper class ParametrizedTests(unittest.TestCase): @@ -212,3 +213,15 @@ def test_pyformat_parameter(self): 'INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql ) self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) + + def test_function_parameter(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == Upper(ValueWrapper("foobar"))) + ) + p = ListParameter("%s") + sql = q.get_sql(parameter=p) + self.assertEqual('SELECT * FROM "abc" WHERE "category"=UPPER(%s)', sql) + + self.assertEqual(["foobar"], p.get_parameters()) From db6c42e5948198013d05309f79404c3dab8bbc23 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 14 Nov 2024 00:26:20 +0100 Subject: [PATCH 2/8] Replace ListParameter with Parameterizer * Do not escape parameter values --- pypika/__init__.py | 7 +- pypika/terms.py | 184 +++++++++++++--------------------------- tests/test_parameter.py | 139 ++++++++++++++---------------- 3 files changed, 125 insertions(+), 205 deletions(-) diff --git a/pypika/__init__.py b/pypika/__init__.py index f4ef90c..b1ef3e5 100644 --- a/pypika/__init__.py +++ b/pypika/__init__.py @@ -21,19 +21,16 @@ CustomFunction, EmptyCriterion, Field, - FormatParameter, Index, Interval, - NamedParameter, Not, NullValue, - NumericParameter, Parameter, - PyformatParameter, - QmarkParameter, + Parameterizer, Rollup, SystemTimeValue, Tuple, + ValueWrapper, ) NULL = NullValue() diff --git a/pypika/terms.py b/pypika/terms.py index d1e2605..7b40b21 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -5,7 +5,7 @@ import re import sys import uuid -from datetime import date, datetime, time +from datetime import date, time from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, Type, TypeVar, cast @@ -316,114 +316,71 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() -def idx_placeholder_gen(idx: int) -> str: - return str(idx + 1) - - -def named_placeholder_gen(idx: int) -> str: - return f"param{idx + 1}" - - class Parameter(Term): - is_aggregate = None - - def __init__(self, placeholder: str | int) -> None: - super().__init__() - self._placeholder = placeholder - - @property - def placeholder(self) -> str | int: - return self._placeholder - - def get_sql(self, **kwargs: Any) -> str: - return str(self.placeholder) - - def update_parameters(self, param_key: Any, param_value: Any, **kwargs) -> None: - pass - - def get_param_key(self, placeholder: Any, **kwargs) -> str | int: - return placeholder - - -class ListParameter(Parameter): - def __init__(self, placeholder: str | int | Callable[[int], str] = idx_placeholder_gen) -> None: - super().__init__(placeholder=placeholder) # type:ignore[arg-type] - self._parameters: list = [] - - @property - def placeholder(self) -> str: - if callable(self._placeholder): - return self._placeholder(len(self._parameters)) - - return str(self._placeholder) - - def get_parameters(self, **kwargs) -> list: - return self._parameters - - def update_parameters(self, value: Any, **kwargs) -> None: # type:ignore[override] - self._parameters.append(value) - - -class DictParameter(Parameter): - def __init__( - self, placeholder: str | int | Callable[[int], str] = named_placeholder_gen - ) -> None: - super().__init__(placeholder=placeholder) # type:ignore[arg-type] - self._parameters: dict = {} - - @property - def placeholder(self) -> str: - if callable(self._placeholder): - return self._placeholder(len(self._parameters)) - - return str(self._placeholder) - - def get_parameters(self, **kwargs) -> dict: - return self._parameters - - def get_param_key(self, placeholder: Any, **kwargs) -> str: - return placeholder[1:] - - def update_parameters( # type:ignore[override] - self, param_key: Any, value: Any, **kwargs - ) -> None: - self._parameters[param_key] = value - - -class QmarkParameter(ListParameter): - def get_sql(self, **kwargs) -> str: - return "?" - + """ + Represents a parameter in a query. The placeholder can be specified with the `placeholder` argument or + will be determined based on the dialect if not provided. + """ -class NumericParameter(ListParameter): - """Numeric, positional style, e.g. ...WHERE name=:1""" + IDX_PLACEHOLDERS = { + Dialects.ORACLE: lambda _: "?", + Dialects.MSSQL: lambda _: "?", + Dialects.MYSQL: lambda _: "%s", + Dialects.POSTGRESQL: lambda idx: f"${idx}", + Dialects.SQLITE: lambda _: "?", + } + DEFAULT_PLACEHOLDER = "?" + is_aggregate = None - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) + def __init__(self, placeholder: str | None = None, idx: int | None = None) -> None: + if not placeholder and not idx: + raise ValueError("Must provide either a placeholder or an idx") + if placeholder and idx: + raise ValueError("Cannot provide both a placeholder and an idx") -class FormatParameter(ListParameter): - """ANSI C printf format codes, e.g. ...WHERE name=%s""" + self._placeholder = placeholder + self._idx = idx def get_sql(self, **kwargs: Any) -> str: - return "%s" + if self._placeholder: + return self._placeholder + dialect = kwargs.get("dialect", None) + return self.IDX_PLACEHOLDERS.get(dialect, lambda _: self.DEFAULT_PLACEHOLDER)(self._idx) -class NamedParameter(DictParameter): - """Named style, e.g. ...WHERE name=:name""" - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) +class Parameterizer: + """ + Parameterizer can be used to replace values with parameters in a query: +>>>>>>> 94fab2d (Replace ListParameter with Parameterizer) + + >>> parameterizer = Parameterizer() + >>> customers = Table("customers") + >>> sql = Query.from_(customers).select(customers.id)\ + ... .where(customers.lname == "Mustermann")\ + ... .get_sql(parameterizer=parameterizer, dialect=Dialects.SQLITE) + >>> sql, parameterizer.values + ('SELECT "id" FROM "customers" WHERE "lname"=?', ['Mustermann']) + + Parameterizer remembers the values it has seen and replaces them with parameters. The values can + be accessed via the `values` attribute. + """ + def __init__(self) -> None: + self.values = [] -class PyformatParameter(DictParameter): - """Python extended format codes, e.g. ...WHERE name=%(name)s""" + def should_parameterize(self, value: Any) -> bool: + if isinstance(value, Enum): + return False - def get_sql(self, **kwargs: Any) -> str: - return "%({placeholder})s".format(placeholder=self.placeholder) + if isinstance(value, str) and value == "*": + return False + return True - def get_param_key(self, placeholder: T, **kwargs) -> T: - return placeholder[2:-2] # type:ignore[index] + def create_param(self, value: Any) -> Parameter: + self.values.append(value) + return Parameter(idx=len(self.values)) class Negative(Term): @@ -475,48 +432,23 @@ def get_formatted_value(cls, value: Any, **kwargs) -> str: return "null" return str(value) - def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: - param_sql = parameter.get_sql(**kwargs) - param_key = parameter.get_param_key(placeholder=param_sql) - - return param_sql, param_key # type:ignore[return-value] - def get_sql( self, quote_char: str | None = None, secondary_quote_char: str = "'", - parameter: Parameter | None = None, + parameterizer: Parameterizer | None = None, **kwargs: Any, ) -> str: - if parameter is None: + if parameterizer is None or not parameterizer.should_parameterize(self.value): sql = self.get_value_sql( quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs ) return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) - # Don't stringify number or date values when using a parameter - if isinstance(self.value, (int, float, date, time, datetime)): - value_sql = self.value - else: - value_sql = self.get_value_sql( - quote_char=quote_char, **kwargs - ) # type:ignore[assignment] - param_sql, param_key = self._get_param_data(parameter, **kwargs) - parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) - - return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) - - -class ParameterValueWrapper(ValueWrapper): - def __init__(self, parameter: Parameter, value: Any, alias: str | None = None) -> None: - super().__init__(value, alias) - self._parameter = parameter - - def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: - param_sql = self._parameter.get_sql(**kwargs) - param_key = self._parameter.get_param_key(placeholder=param_sql) - - return param_sql, param_key # type:ignore[return-value] + param = parameterizer.create_param(self.value) + return format_alias_sql( + param.get_sql(**kwargs), self.alias, quote_char=quote_char, **kwargs + ) class JSON(Term): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 1e6d0aa..0851ac2 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,18 +1,10 @@ import unittest from datetime import date -from pypika import ( - FormatParameter, - NamedParameter, - NumericParameter, - Parameter, - PyformatParameter, - QmarkParameter, - Query, - Tables, -) +from pypika import Parameter, Query, Tables, ValueWrapper +from pypika.enums import Dialects from pypika.functions import Upper -from pypika.terms import ListParameter, ParameterValueWrapper, ValueWrapper +from pypika.terms import Case, Parameterizer class ParametrizedTests(unittest.TestCase): @@ -87,34 +79,41 @@ def test_join(self): ) def test_qmark_parameter(self): - self.assertEqual("?", QmarkParameter().get_sql()) + self.assertEqual("?", Parameter("?").get_sql()) - def test_numeric_parameter(self): - self.assertEqual(":14", NumericParameter("14").get_sql()) - self.assertEqual(":15", NumericParameter(15).get_sql()) + def test_oracle(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.ORACLE)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.ORACLE)) - def test_named_parameter(self): - self.assertEqual(":buz", NamedParameter("buz").get_sql()) + def test_mssql(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.MSSQL)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.MSSQL)) - def test_format_parameter(self): - self.assertEqual("%s", FormatParameter().get_sql()) + def test_mysql(self): + self.assertEqual("%s", Parameter(idx=1).get_sql(dialect=Dialects.MYSQL)) + self.assertEqual("%s", Parameter(idx=2).get_sql(dialect=Dialects.MYSQL)) - def test_pyformat_parameter(self): - self.assertEqual("%(buz)s", PyformatParameter("buz").get_sql()) + def test_postgres(self): + self.assertEqual("$1", Parameter(idx=1).get_sql(dialect=Dialects.POSTGRESQL)) + self.assertEqual("$2", Parameter(idx=2).get_sql(dialect=Dialects.POSTGRESQL)) + def test_sqlite(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.SQLITE)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.SQLITE)) -class ParametrizedTestsWithValues(unittest.TestCase): + +class ParameterizerTests(unittest.TestCase): table_abc, table_efg = Tables("abc", "efg") def test_param_insert(self): q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") - parameter = QmarkParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) - self.assertEqual([1, 2.2, "foo"], parameter.get_parameters()) + self.assertEqual([1, 2.2, "foo"], parameterizer.values) - def test_param_select_join(self): + def test_select_join_in_mysql(self): q = ( Query.from_(self.table_abc) .select("*") @@ -125,15 +124,15 @@ def test_param_select_join(self): .limit(10) ) - parameter = FormatParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.MYSQL) self.assertEqual( 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameter.get_parameters()) + self.assertEqual(["foobar", date(2024, 2, 22)], parameterizer.values) - def test_param_select_subquery(self): + def test_select_subquery_in_postgres(self): q = ( Query.from_(self.table_abc) .select("*") @@ -148,15 +147,15 @@ def test_param_select_subquery(self): .limit(10) ) - parameter = ListParameter(placeholder=lambda idx: f"&{idx+1}") - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) self.assertEqual( - 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + 'SELECT * FROM "abc" WHERE "category"=$1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=$2) LIMIT 10', sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameter.get_parameters()) + self.assertEqual(["foobar", date(2024, 2, 22)], parameterizer.values) - def test_join(self): + def test_join_in_postgres(self): subquery = ( Query.from_(self.table_efg) .select(self.table_efg.fiz, self.table_efg.buz) @@ -171,57 +170,49 @@ def test_join(self): .where(self.table_abc.bar == "bar") ) - parameter = NamedParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) self.assertEqual( - 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' - ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=$1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=$2', sql, ) - self.assertEqual({"param1": "buz", "param2": "bar"}, parameter.get_parameters()) - - def test_join_with_parameter_value_wrapper(self): - subquery = ( - Query.from_(self.table_efg) - .select(self.table_efg.fiz, self.table_efg.buz) - .where(self.table_efg.buz == ParameterValueWrapper(Parameter(":buz"), "buz")) - ) + self.assertEqual(["buz", "bar"], parameterizer.values) + def test_function_parameter(self): q = ( Query.from_(self.table_abc) - .join(subquery) - .on(self.table_abc.bar == subquery.buz) - .select(self.table_abc.foo, subquery.fiz) - .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter("bar"), "bar")) - ) - - parameter = NamedParameter() - sql = q.get_sql(parameter=parameter) - self.assertEqual( - 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' - ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', - sql, + .select("*") + .where(self.table_abc.category == Upper(ValueWrapper("foobar"))) ) - self.assertEqual({":buz": "buz", "bar": "bar"}, parameter.get_parameters()) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT * FROM "abc" WHERE "category"=UPPER(?)', sql) - def test_pyformat_parameter(self): - q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") + self.assertEqual(["foobar"], parameterizer.values) - parameter = PyformatParameter() - sql = q.get_sql(parameter=parameter) - self.assertEqual( - 'INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql + def test_case_when_in_select(self): + q = Query.from_(self.table_abc).select( + Case().when(self.table_abc.category == "foobar", 1).else_(2) ) - self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT CASE WHEN "category"=? THEN ? ELSE ? END FROM "abc"', sql) + self.assertEqual(["foobar", 1, 2], parameterizer.values) - def test_function_parameter(self): + def test_case_when_in_where(self): q = ( Query.from_(self.table_abc) .select("*") - .where(self.table_abc.category == Upper(ValueWrapper("foobar"))) + .where( + self.table_abc.category_int + > Case().when(self.table_abc.category == "foobar", 1).else_(2) + ) ) - p = ListParameter("%s") - sql = q.get_sql(parameter=p) - self.assertEqual('SELECT * FROM "abc" WHERE "category"=UPPER(%s)', sql) - - self.assertEqual(["foobar"], p.get_parameters()) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual( + 'SELECT * FROM "abc" WHERE "category_int">CASE WHEN "category"=? THEN ? ELSE ? END', + sql, + ) + self.assertEqual(["foobar", 1, 2], parameterizer.values) From 7c6e6fc909cb154f50abeb98183ff6e6395487d2 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 14 Nov 2024 00:27:18 +0100 Subject: [PATCH 3/8] Add missing whitespace for MSSQL when pagination without ordering --- pypika/dialects/mssql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/dialects/mssql.py b/pypika/dialects/mssql.py index 610c918..a633711 100644 --- a/pypika/dialects/mssql.py +++ b/pypika/dialects/mssql.py @@ -47,7 +47,7 @@ def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return] def _offset_sql(self) -> str: order_by = "" if not self._orderbys: - order_by = "ORDER BY (SELECT 0)" + order_by = " ORDER BY (SELECT 0)" return order_by + " OFFSET {offset} ROWS".format(offset=self._offset or 0) def _limit_sql(self) -> str: From 8b1701be4ce0a94a8b19d37fd4946b59af46e75d Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 14 Nov 2024 01:25:16 +0100 Subject: [PATCH 4/8] Wrap limit and offset so they can be parametrized --- pypika/dialects/mssql.py | 16 ++++++------ pypika/dialects/oracle.py | 8 +++--- pypika/queries.py | 52 ++++++++++++++++++++------------------- tests/test_parameter.py | 33 +++++++++++++++++++------ 4 files changed, 66 insertions(+), 43 deletions(-) diff --git a/pypika/dialects/mssql.py b/pypika/dialects/mssql.py index a633711..76bcfa3 100644 --- a/pypika/dialects/mssql.py +++ b/pypika/dialects/mssql.py @@ -44,23 +44,25 @@ def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return] # Overridden to provide a more domain-specific API for T-SQL users self._limit = limit - def _offset_sql(self) -> str: + def _offset_sql(self, **kwargs) -> str: order_by = "" if not self._orderbys: order_by = " ORDER BY (SELECT 0)" - return order_by + " OFFSET {offset} ROWS".format(offset=self._offset or 0) + return order_by + " OFFSET {offset} ROWS".format( + offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0 + ) - def _limit_sql(self) -> str: - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) - def _apply_pagination(self, querystring: str) -> str: + def _apply_pagination(self, querystring: str, **kwargs) -> str: # Note: Overridden as MSSQL specifies offset before the fetch next limit if self._limit is not None or self._offset: # Offset has to be present if fetch next is specified in a MSSQL query - querystring += self._offset_sql() + querystring += self._offset_sql(**kwargs) if self._limit is not None: - querystring += self._limit_sql() + querystring += self._limit_sql(**kwargs) return querystring diff --git a/pypika/dialects/oracle.py b/pypika/dialects/oracle.py index 5ab3867..6830812 100644 --- a/pypika/dialects/oracle.py +++ b/pypika/dialects/oracle.py @@ -28,8 +28,8 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: kwargs["groupby_alias"] = False return super().get_sql(*args, **kwargs) - def _offset_sql(self) -> str: - return " OFFSET {offset} ROWS".format(offset=self._offset) + def _offset_sql(self, **kwargs) -> str: + return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs)) - def _limit_sql(self) -> str: - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) diff --git a/pypika/queries.py b/pypika/queries.py index f64f7a1..2a140fa 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -552,12 +552,12 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "Self": # type:ignore[retur self._orderbys.append((field, kwargs.get("order"))) @builder - def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = limit + def limit(self, limit: int) -> "Self": # type:ignore[return] + self._limit = self.wrap_constant(limit) @builder - def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = offset + def offset(self, offset: int) -> "Self": # type:ignore[return] + self._offset = self.wrap_constant(offset) @builder def union(self, other: Selectable) -> "Self": # type:ignore[return] @@ -625,10 +625,10 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring += self._orderby_sql(**kwargs) if self._limit is not None: - querystring += self._limit_sql() + querystring += self._limit_sql(**kwargs) if self._offset: - querystring += self._offset_sql() + querystring += self._offset_sql(**kwargs) if subquery: querystring = "({query})".format(query=querystring, **kwargs) @@ -668,11 +668,11 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " ORDER BY {orderby}".format(orderby=",".join(clauses)) - def _offset_sql(self) -> str: - return " OFFSET {offset}".format(offset=self._offset) + def _offset_sql(self, **kwargs) -> str: + return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) - def _limit_sql(self) -> str: - return " LIMIT {limit}".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) class QueryBuilder(Selectable, Term): # type:ignore[misc] @@ -1222,12 +1222,12 @@ def hash_join(self, item: Table | "QueryBuilder" | AliasedQuery) -> "Joiner": return self.join(item, JoinType.hash) @builder - def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = limit + def limit(self, limit: int) -> "Self": # type:ignore[return] + self._limit = self.wrap_constant(limit) @builder - def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = offset + def offset(self, offset: int) -> "Self": # type:ignore[return] + self._offset = self.wrap_constant(offset) @builder def union(self, other: Self) -> _SetOperation: @@ -1265,8 +1265,10 @@ def __sub__(self, other: Self) -> _SetOperation: # type:ignore[override] @builder def slice(self, slice: slice) -> "Self": # type:ignore[return] - self._offset = slice.start - self._limit = slice.stop + if slice.start is not None: + self._offset = self.wrap_constant(slice.start) + if slice.stop is not None: + self._limit = self.wrap_constant(slice.stop) def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override] if not isinstance(item, slice): @@ -1512,7 +1514,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An if self._orderbys: querystring += self._orderby_sql(**kwargs) - querystring = self._apply_pagination(querystring) + querystring = self._apply_pagination(querystring, **kwargs) if self._for_update: querystring += self._for_update_sql(**kwargs) @@ -1532,12 +1534,12 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring - def _apply_pagination(self, querystring: str) -> str: + def _apply_pagination(self, querystring: str, **kwargs) -> str: if self._limit is not None: - querystring += self._limit_sql() + querystring += self._limit_sql(**kwargs) - if self._offset: - querystring += self._offset_sql() + if self._offset is not None: + querystring += self._offset_sql(**kwargs) return querystring @@ -1750,11 +1752,11 @@ def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: having = self._havings.get_sql(quote_char=quote_char, **kwargs) # type:ignore[union-attr] return f" HAVING {having}" - def _offset_sql(self) -> str: - return " OFFSET {offset}".format(offset=self._offset) + def _offset_sql(self, **kwargs) -> str: + return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) - def _limit_sql(self) -> str: - return " LIMIT {limit}".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) def _set_sql(self, **kwargs: Any) -> str: return " SET {set}".format( diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 0851ac2..e832a90 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -2,6 +2,9 @@ from datetime import date from pypika import Parameter, Query, Tables, ValueWrapper +from pypika.dialects.mssql import MSSQLQuery +from pypika.dialects.mysql import MySQLQuery +from pypika.dialects.postgresql import PostgreSQLQuery from pypika.enums import Dialects from pypika.functions import Upper from pypika.terms import Case, Parameterizer @@ -115,7 +118,7 @@ def test_param_insert(self): def test_select_join_in_mysql(self): q = ( - Query.from_(self.table_abc) + MySQLQuery.from_(self.table_abc) .select("*") .where(self.table_abc.category == "foobar") .join(self.table_efg) @@ -127,14 +130,14 @@ def test_select_join_in_mysql(self): parameterizer = Parameterizer() sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.MYSQL) self.assertEqual( - 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + "SELECT * FROM `abc` JOIN `efg` ON `abc`.`id`=`efg`.`abc_id` WHERE `abc`.`category`=%s AND `efg`.`date`>=%s LIMIT %s", sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameterizer.values) + self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) def test_select_subquery_in_postgres(self): q = ( - Query.from_(self.table_abc) + PostgreSQLQuery.from_(self.table_abc) .select("*") .where(self.table_abc.category == "foobar") .where( @@ -150,10 +153,10 @@ def test_select_subquery_in_postgres(self): parameterizer = Parameterizer() sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) self.assertEqual( - 'SELECT * FROM "abc" WHERE "category"=$1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=$2) LIMIT 10', + 'SELECT * FROM "abc" WHERE "category"=$1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=$2) LIMIT $3', sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameterizer.values) + self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) def test_join_in_postgres(self): subquery = ( @@ -163,7 +166,7 @@ def test_join_in_postgres(self): ) q = ( - Query.from_(self.table_abc) + PostgreSQLQuery.from_(self.table_abc) .join(subquery) .on(self.table_abc.bar == subquery.buz) .select(self.table_abc.foo, subquery.fiz) @@ -216,3 +219,19 @@ def test_case_when_in_where(self): sql, ) self.assertEqual(["foobar", 1, 2], parameterizer.values) + + def test_limit_and_offest(self): + q = Query.from_(self.table_abc).select("*").limit(10).offset(5) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT * FROM "abc" LIMIT ? OFFSET ?', sql) + self.assertEqual([10, 5], parameterizer.values) + + def test_limit_and_offest_in_mssql(self): + q = MSSQLQuery.from_(self.table_abc).select("*").limit(10).offset(5) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual( + 'SELECT * FROM "abc" ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY', sql + ) + self.assertEqual([5, 10], parameterizer.values) From 485b8548de4ed76de7e3fad77a9de6c10064c717 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 18 Nov 2024 10:17:44 +0100 Subject: [PATCH 5/8] Add placeholder_factory to Parameterizer --- pypika/queries.py | 3 ++- pypika/terms.py | 8 ++++++-- tests/test_parameter.py | 5 +++++ tests/test_updates.py | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 2a140fa..0845268 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1252,7 +1252,8 @@ def minus(self, other: Self) -> _SetOperation: @builder def set(self, field: Field | str, value: Any) -> "Self": # type:ignore[return] field = Field(field) if not isinstance(field, Field) else field - self._updates.append((field, self._wrapper_cls(value))) + value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls) + self._updates.append((field, value)) def __add__(self, other: Self) -> _SetOperation: # type:ignore[override] return self.union(other) diff --git a/pypika/terms.py b/pypika/terms.py index 7b40b21..3f7e5bd 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -367,7 +367,8 @@ class Parameterizer: be accessed via the `values` attribute. """ - def __init__(self) -> None: + def __init__(self, placeholder_factory: Optional[Callable[[int], str]] = None) -> None: + self.placeholder_factory = placeholder_factory self.values = [] def should_parameterize(self, value: Any) -> bool: @@ -380,7 +381,10 @@ def should_parameterize(self, value: Any) -> bool: def create_param(self, value: Any) -> Parameter: self.values.append(value) - return Parameter(idx=len(self.values)) + if self.placeholder_factory: + return Parameter(self.placeholder_factory(len(self.values))) + else: + return Parameter(idx=len(self.values)) class Negative(Term): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index e832a90..e3bdc9f 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -235,3 +235,8 @@ def test_limit_and_offest_in_mssql(self): 'SELECT * FROM "abc" ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY', sql ) self.assertEqual([5, 10], parameterizer.values) + + def test_placeholder_factory(self): + parameterizer = Parameterizer(placeholder_factory=lambda _: "%s") + param = parameterizer.create_param(1) + self.assertEqual("%s", param.get_sql()) diff --git a/tests/test_updates.py b/tests/test_updates.py index 175a256..a722554 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -43,7 +43,7 @@ def test_update__table_schema(self): def test_update_with_none(self): q = Query.update("abc").set("foo", None) - self.assertEqual('UPDATE "abc" SET "foo"=null', str(q)) + self.assertEqual('UPDATE "abc" SET "foo"=NULL', str(q)) def test_update_from(self): from_table = Table("from_table") From 61b08e4f6c06aa275b18751a78f652ee80dee3c9 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 18 Nov 2024 12:29:51 +0100 Subject: [PATCH 6/8] Introduce allow_parametrize to ValueWrapper --- pypika/terms.py | 27 ++++++++++++++++++++++++--- tests/test_terms.py | 14 +++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 3f7e5bd..43aca83 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -333,9 +333,12 @@ class Parameter(Term): is_aggregate = None def __init__(self, placeholder: str | None = None, idx: int | None = None) -> None: - if not placeholder and not idx: + if not placeholder and idx is None: raise ValueError("Must provide either a placeholder or an idx") + if idx is not None and idx < 1: + raise ValueError("idx must start at 1") + if placeholder and idx: raise ValueError("Cannot provide both a placeholder and an idx") @@ -403,9 +406,23 @@ def get_sql(self, **kwargs: Any) -> str: class ValueWrapper(Term): is_aggregate = None - def __init__(self, value: Any, alias: str | None = None) -> None: + def __init__( + self, value: Any, alias: str | None = None, allow_parametrize: bool = True + ) -> None: + """ + A wrapper for a constant value such as a string or number. + + :param value: + The value to be wrapped. + :param alias: + An optional alias for the value. + :param allow_parametrize: + Whether the value should be replaced with a parameter in the query if parameterizer + is used. + """ super().__init__(alias) self.value = value + self.allow_parametrize = allow_parametrize def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @@ -443,7 +460,11 @@ def get_sql( parameterizer: Parameterizer | None = None, **kwargs: Any, ) -> str: - if parameterizer is None or not parameterizer.should_parameterize(self.value): + if ( + parameterizer is None + or not parameterizer.should_parameterize(self.value) + or not self.allow_parametrize + ): sql = self.get_value_sql( quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs ) diff --git a/tests/test_terms.py b/tests/test_terms.py index a5dea35..295e39d 100644 --- a/tests/test_terms.py +++ b/tests/test_terms.py @@ -1,7 +1,7 @@ from unittest import TestCase from pypika import Field, Query, Table -from pypika.terms import AtTimezone +from pypika.terms import AtTimezone, Parameterizer, ValueWrapper class FieldAliasTests(TestCase): @@ -49,3 +49,15 @@ def test_passes_kwargs_to_field_get_sql(self): 'FROM "customers" JOIN "accounts" ON "customers"."account_id"="accounts"."account_id"', query.get_sql(with_namespace=True), ) + + +class ValueWrapperTests(TestCase): + def test_allow_parametrize(self): + value = ValueWrapper("foo") + self.assertEqual("'foo'", value.get_sql()) + + value = ValueWrapper("foo") + self.assertEqual("?", value.get_sql(parameterizer=Parameterizer())) + + value = ValueWrapper("foo", allow_parametrize=False) + self.assertEqual("'foo'", value.get_sql(parameterizer=Parameterizer())) From 31eea5a7d1299d33ce1776e97aab44879c54de35 Mon Sep 17 00:00:00 2001 From: henadzit Date: Tue, 19 Nov 2024 12:29:59 +0100 Subject: [PATCH 7/8] Fix typing issues --- pypika/dialects/mssql.py | 7 ++++-- pypika/dialects/oracle.py | 4 +++ pypika/queries.py | 51 ++++++++++++++++++++------------------- pypika/terms.py | 4 +-- 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/pypika/dialects/mssql.py b/pypika/dialects/mssql.py index 76bcfa3..dc316be 100644 --- a/pypika/dialects/mssql.py +++ b/pypika/dialects/mssql.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from pypika.enums import Dialects from pypika.exceptions import QueryException from pypika.queries import Query, QueryBuilder +from pypika.terms import ValueWrapper from pypika.utils import builder @@ -42,7 +43,7 @@ def top(self, value: str | int) -> MSSQLQueryBuilder: # type:ignore[return] @builder def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return] # Overridden to provide a more domain-specific API for T-SQL users - self._limit = limit + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) def _offset_sql(self, **kwargs) -> str: order_by = "" @@ -53,6 +54,8 @@ def _offset_sql(self, **kwargs) -> str: ) def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) def _apply_pagination(self, querystring: str, **kwargs) -> str: diff --git a/pypika/dialects/oracle.py b/pypika/dialects/oracle.py index 6830812..174fa2c 100644 --- a/pypika/dialects/oracle.py +++ b/pypika/dialects/oracle.py @@ -29,7 +29,11 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return super().get_sql(*args, **kwargs) def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs)) def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) diff --git a/pypika/queries.py b/pypika/queries.py index 0845268..5419afb 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -535,8 +535,8 @@ def __init__( self._set_operation = [(set_operation, set_operation_query)] self._orderbys: list[tuple[Field, Order | None]] = [] - self._limit: int | None = None - self._offset: int | None = None + self._limit: ValueWrapper | None = None + self._offset: ValueWrapper | None = None self._wrapper_cls = wrapper_cls @@ -552,12 +552,12 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "Self": # type:ignore[retur self._orderbys.append((field, kwargs.get("order"))) @builder - def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = self.wrap_constant(limit) + def limit(self, limit: int) -> "Self": # type:ignore[return] + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) @builder - def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = self.wrap_constant(offset) + def offset(self, offset: int) -> "Self": # type:ignore[return] + self._offset = cast(ValueWrapper, self.wrap_constant(offset)) @builder def union(self, other: Selectable) -> "Self": # type:ignore[return] @@ -624,11 +624,8 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An if self._orderbys: querystring += self._orderby_sql(**kwargs) - if self._limit is not None: - querystring += self._limit_sql(**kwargs) - - if self._offset: - querystring += self._offset_sql(**kwargs) + querystring += self._limit_sql(**kwargs) + querystring += self._offset_sql(**kwargs) if subquery: querystring = "({query})".format(query=querystring, **kwargs) @@ -669,9 +666,13 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " ORDER BY {orderby}".format(orderby=",".join(clauses)) def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) @@ -725,8 +726,8 @@ def __init__( self._joins: list[Join] = [] self._unions: list = [] - self._limit: int | None = None - self._offset: int | None = None + self._limit: ValueWrapper | None = None + self._offset: ValueWrapper | None = None self._updates: list[tuple] = [] @@ -1222,12 +1223,12 @@ def hash_join(self, item: Table | "QueryBuilder" | AliasedQuery) -> "Joiner": return self.join(item, JoinType.hash) @builder - def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = self.wrap_constant(limit) + def limit(self, limit: int) -> "Self": # type:ignore[return] + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) @builder - def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = self.wrap_constant(offset) + def offset(self, offset: int) -> "Self": # type:ignore[return] + self._offset = cast(ValueWrapper, self.wrap_constant(offset)) @builder def union(self, other: Self) -> _SetOperation: @@ -1267,9 +1268,9 @@ def __sub__(self, other: Self) -> _SetOperation: # type:ignore[override] @builder def slice(self, slice: slice) -> "Self": # type:ignore[return] if slice.start is not None: - self._offset = self.wrap_constant(slice.start) + self._offset = cast(ValueWrapper, self.wrap_constant(slice.start)) if slice.stop is not None: - self._limit = self.wrap_constant(slice.stop) + self._limit = cast(ValueWrapper, self.wrap_constant(slice.stop)) def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override] if not isinstance(item, slice): @@ -1536,12 +1537,8 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring def _apply_pagination(self, querystring: str, **kwargs) -> str: - if self._limit is not None: - querystring += self._limit_sql(**kwargs) - - if self._offset is not None: - querystring += self._offset_sql(**kwargs) - + querystring += self._limit_sql(**kwargs) + querystring += self._offset_sql(**kwargs) return querystring def _with_sql(self, **kwargs: Any) -> str: @@ -1754,9 +1751,13 @@ def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return f" HAVING {having}" def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) def _set_sql(self, **kwargs: Any) -> str: diff --git a/pypika/terms.py b/pypika/terms.py index 43aca83..f6b12df 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -370,9 +370,9 @@ class Parameterizer: be accessed via the `values` attribute. """ - def __init__(self, placeholder_factory: Optional[Callable[[int], str]] = None) -> None: + def __init__(self, placeholder_factory: Callable[[int], str] | None = None) -> None: self.placeholder_factory = placeholder_factory - self.values = [] + self.values: list = [] def should_parameterize(self, value: Any) -> bool: if isinstance(value, Enum): From d196b33c2cf8fe2b32a89f67c8c2a7426f130287 Mon Sep 17 00:00:00 2001 From: henadzit Date: Thu, 21 Nov 2024 18:12:15 +0100 Subject: [PATCH 8/8] Release 0.3.0 --- CHANGELOG.md | 9 +++++++++ pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2953e..f4e1752 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # ChangeLog +## 0.3 + +## 0.3.0 +- Add `Parameterizer` +- Uppdate `Parameter` to be dialect-aware +- Remove `ListParameter`, `DictParameter`, `QmarkParameter`, etc. +- Wrap query's offset and limit with ValueWrapper so they can be parametrized +- Fix a missing whitespace for MSSQL when pagination without ordering is used + ## 0.2 ### 0.2.2 diff --git a/pyproject.toml b/pyproject.toml index 5f1bbfa..8887e9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pypika-tortoise" -version = "0.2.2" +version = "0.3.0" description = "Forked from pypika and streamline just for tortoise-orm" authors = ["long2ice "] license = "Apache-2.0"