diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2953e..f4e1752 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # ChangeLog +## 0.3 + +## 0.3.0 +- Add `Parameterizer` +- Uppdate `Parameter` to be dialect-aware +- Remove `ListParameter`, `DictParameter`, `QmarkParameter`, etc. +- Wrap query's offset and limit with ValueWrapper so they can be parametrized +- Fix a missing whitespace for MSSQL when pagination without ordering is used + ## 0.2 ### 0.2.2 diff --git a/pypika/__init__.py b/pypika/__init__.py index f4ef90c..b1ef3e5 100644 --- a/pypika/__init__.py +++ b/pypika/__init__.py @@ -21,19 +21,16 @@ CustomFunction, EmptyCriterion, Field, - FormatParameter, Index, Interval, - NamedParameter, Not, NullValue, - NumericParameter, Parameter, - PyformatParameter, - QmarkParameter, + Parameterizer, Rollup, SystemTimeValue, Tuple, + ValueWrapper, ) NULL = NullValue() diff --git a/pypika/dialects/mssql.py b/pypika/dialects/mssql.py index 610c918..dc316be 100644 --- a/pypika/dialects/mssql.py +++ b/pypika/dialects/mssql.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from pypika.enums import Dialects from pypika.exceptions import QueryException from pypika.queries import Query, QueryBuilder +from pypika.terms import ValueWrapper from pypika.utils import builder @@ -42,25 +43,29 @@ def top(self, value: str | int) -> MSSQLQueryBuilder: # type:ignore[return] @builder def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return] # Overridden to provide a more domain-specific API for T-SQL users - self._limit = limit + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) - def _offset_sql(self) -> str: + def _offset_sql(self, **kwargs) -> str: order_by = "" if not self._orderbys: - order_by = "ORDER BY (SELECT 0)" - return order_by + " OFFSET {offset} ROWS".format(offset=self._offset or 0) + order_by = " ORDER BY (SELECT 0)" + return order_by + " OFFSET {offset} ROWS".format( + offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0 + ) - def _limit_sql(self) -> str: - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) - def _apply_pagination(self, querystring: str) -> str: + def _apply_pagination(self, querystring: str, **kwargs) -> str: # Note: Overridden as MSSQL specifies offset before the fetch next limit if self._limit is not None or self._offset: # Offset has to be present if fetch next is specified in a MSSQL query - querystring += self._offset_sql() + querystring += self._offset_sql(**kwargs) if self._limit is not None: - querystring += self._limit_sql() + querystring += self._limit_sql(**kwargs) return querystring diff --git a/pypika/dialects/oracle.py b/pypika/dialects/oracle.py index 5ab3867..174fa2c 100644 --- a/pypika/dialects/oracle.py +++ b/pypika/dialects/oracle.py @@ -28,8 +28,12 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: kwargs["groupby_alias"] = False return super().get_sql(*args, **kwargs) - def _offset_sql(self) -> str: - return " OFFSET {offset} ROWS".format(offset=self._offset) - - def _limit_sql(self) -> str: - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) + def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" + return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs)) + + def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) diff --git a/pypika/queries.py b/pypika/queries.py index f64f7a1..5419afb 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -535,8 +535,8 @@ def __init__( self._set_operation = [(set_operation, set_operation_query)] self._orderbys: list[tuple[Field, Order | None]] = [] - self._limit: int | None = None - self._offset: int | None = None + self._limit: ValueWrapper | None = None + self._offset: ValueWrapper | None = None self._wrapper_cls = wrapper_cls @@ -553,11 +553,11 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "Self": # type:ignore[retur @builder def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = limit + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) @builder def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = offset + self._offset = cast(ValueWrapper, self.wrap_constant(offset)) @builder def union(self, other: Selectable) -> "Self": # type:ignore[return] @@ -624,11 +624,8 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An if self._orderbys: querystring += self._orderby_sql(**kwargs) - if self._limit is not None: - querystring += self._limit_sql() - - if self._offset: - querystring += self._offset_sql() + querystring += self._limit_sql(**kwargs) + querystring += self._offset_sql(**kwargs) if subquery: querystring = "({query})".format(query=querystring, **kwargs) @@ -668,11 +665,15 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " ORDER BY {orderby}".format(orderby=",".join(clauses)) - def _offset_sql(self) -> str: - return " OFFSET {offset}".format(offset=self._offset) + def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" + return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) - def _limit_sql(self) -> str: - return " LIMIT {limit}".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" + return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) class QueryBuilder(Selectable, Term): # type:ignore[misc] @@ -725,8 +726,8 @@ def __init__( self._joins: list[Join] = [] self._unions: list = [] - self._limit: int | None = None - self._offset: int | None = None + self._limit: ValueWrapper | None = None + self._offset: ValueWrapper | None = None self._updates: list[tuple] = [] @@ -1223,11 +1224,11 @@ def hash_join(self, item: Table | "QueryBuilder" | AliasedQuery) -> "Joiner": @builder def limit(self, limit: int) -> "Self": # type:ignore[return] - self._limit = limit + self._limit = cast(ValueWrapper, self.wrap_constant(limit)) @builder def offset(self, offset: int) -> "Self": # type:ignore[return] - self._offset = offset + self._offset = cast(ValueWrapper, self.wrap_constant(offset)) @builder def union(self, other: Self) -> _SetOperation: @@ -1252,7 +1253,8 @@ def minus(self, other: Self) -> _SetOperation: @builder def set(self, field: Field | str, value: Any) -> "Self": # type:ignore[return] field = Field(field) if not isinstance(field, Field) else field - self._updates.append((field, self._wrapper_cls(value))) + value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls) + self._updates.append((field, value)) def __add__(self, other: Self) -> _SetOperation: # type:ignore[override] return self.union(other) @@ -1265,8 +1267,10 @@ def __sub__(self, other: Self) -> _SetOperation: # type:ignore[override] @builder def slice(self, slice: slice) -> "Self": # type:ignore[return] - self._offset = slice.start - self._limit = slice.stop + if slice.start is not None: + self._offset = cast(ValueWrapper, self.wrap_constant(slice.start)) + if slice.stop is not None: + self._limit = cast(ValueWrapper, self.wrap_constant(slice.stop)) def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override] if not isinstance(item, slice): @@ -1512,7 +1516,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An if self._orderbys: querystring += self._orderby_sql(**kwargs) - querystring = self._apply_pagination(querystring) + querystring = self._apply_pagination(querystring, **kwargs) if self._for_update: querystring += self._for_update_sql(**kwargs) @@ -1532,13 +1536,9 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring - def _apply_pagination(self, querystring: str) -> str: - if self._limit is not None: - querystring += self._limit_sql() - - if self._offset: - querystring += self._offset_sql() - + def _apply_pagination(self, querystring: str, **kwargs) -> str: + querystring += self._limit_sql(**kwargs) + querystring += self._offset_sql(**kwargs) return querystring def _with_sql(self, **kwargs: Any) -> str: @@ -1750,11 +1750,15 @@ def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: having = self._havings.get_sql(quote_char=quote_char, **kwargs) # type:ignore[union-attr] return f" HAVING {having}" - def _offset_sql(self) -> str: - return " OFFSET {offset}".format(offset=self._offset) + def _offset_sql(self, **kwargs) -> str: + if self._offset is None: + return "" + return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) - def _limit_sql(self) -> str: - return " LIMIT {limit}".format(limit=self._limit) + def _limit_sql(self, **kwargs) -> str: + if self._limit is None: + return "" + return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) def _set_sql(self, **kwargs: Any) -> str: return " SET {set}".format( diff --git a/pypika/terms.py b/pypika/terms.py index c0eccb0..f6b12df 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -5,7 +5,7 @@ import re import sys import uuid -from datetime import date, datetime, time +from datetime import date, time from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, Type, TypeVar, cast @@ -316,114 +316,78 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() -def idx_placeholder_gen(idx: int) -> str: - return str(idx + 1) - - -def named_placeholder_gen(idx: int) -> str: - return f"param{idx + 1}" - - class Parameter(Term): - is_aggregate = None - - def __init__(self, placeholder: str | int) -> None: - super().__init__() - self._placeholder = placeholder - - @property - def placeholder(self) -> str | int: - return self._placeholder - - def get_sql(self, **kwargs: Any) -> str: - return str(self.placeholder) - - def update_parameters(self, param_key: Any, param_value: Any, **kwargs) -> None: - pass - - def get_param_key(self, placeholder: Any, **kwargs) -> str | int: - return placeholder - - -class ListParameter(Parameter): - def __init__(self, placeholder: str | int | Callable[[int], str] = idx_placeholder_gen) -> None: - super().__init__(placeholder=placeholder) # type:ignore[arg-type] - self._parameters: list = [] - - @property - def placeholder(self) -> str: - if callable(self._placeholder): - return self._placeholder(len(self._parameters)) - - return str(self._placeholder) - - def get_parameters(self, **kwargs) -> list: - return self._parameters - - def update_parameters(self, value: Any, **kwargs) -> None: # type:ignore[override] - self._parameters.append(value) - - -class DictParameter(Parameter): - def __init__( - self, placeholder: str | int | Callable[[int], str] = named_placeholder_gen - ) -> None: - super().__init__(placeholder=placeholder) # type:ignore[arg-type] - self._parameters: dict = {} - - @property - def placeholder(self) -> str: - if callable(self._placeholder): - return self._placeholder(len(self._parameters)) - - return str(self._placeholder) - - def get_parameters(self, **kwargs) -> dict: - return self._parameters - - def get_param_key(self, placeholder: Any, **kwargs) -> str: - return placeholder[1:] - - def update_parameters( # type:ignore[override] - self, param_key: Any, value: Any, **kwargs - ) -> None: - self._parameters[param_key] = value - - -class QmarkParameter(ListParameter): - def get_sql(self, **kwargs) -> str: - return "?" + """ + Represents a parameter in a query. The placeholder can be specified with the `placeholder` argument or + will be determined based on the dialect if not provided. + """ + IDX_PLACEHOLDERS = { + Dialects.ORACLE: lambda _: "?", + Dialects.MSSQL: lambda _: "?", + Dialects.MYSQL: lambda _: "%s", + Dialects.POSTGRESQL: lambda idx: f"${idx}", + Dialects.SQLITE: lambda _: "?", + } + DEFAULT_PLACEHOLDER = "?" + is_aggregate = None -class NumericParameter(ListParameter): - """Numeric, positional style, e.g. ...WHERE name=:1""" + def __init__(self, placeholder: str | None = None, idx: int | None = None) -> None: + if not placeholder and idx is None: + raise ValueError("Must provide either a placeholder or an idx") - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) + if idx is not None and idx < 1: + raise ValueError("idx must start at 1") + if placeholder and idx: + raise ValueError("Cannot provide both a placeholder and an idx") -class FormatParameter(ListParameter): - """ANSI C printf format codes, e.g. ...WHERE name=%s""" + self._placeholder = placeholder + self._idx = idx def get_sql(self, **kwargs: Any) -> str: - return "%s" + if self._placeholder: + return self._placeholder + dialect = kwargs.get("dialect", None) + return self.IDX_PLACEHOLDERS.get(dialect, lambda _: self.DEFAULT_PLACEHOLDER)(self._idx) -class NamedParameter(DictParameter): - """Named style, e.g. ...WHERE name=:name""" - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) +class Parameterizer: + """ + Parameterizer can be used to replace values with parameters in a query: +>>>>>>> 94fab2d (Replace ListParameter with Parameterizer) + + >>> parameterizer = Parameterizer() + >>> customers = Table("customers") + >>> sql = Query.from_(customers).select(customers.id)\ + ... .where(customers.lname == "Mustermann")\ + ... .get_sql(parameterizer=parameterizer, dialect=Dialects.SQLITE) + >>> sql, parameterizer.values + ('SELECT "id" FROM "customers" WHERE "lname"=?', ['Mustermann']) + + Parameterizer remembers the values it has seen and replaces them with parameters. The values can + be accessed via the `values` attribute. + """ + def __init__(self, placeholder_factory: Callable[[int], str] | None = None) -> None: + self.placeholder_factory = placeholder_factory + self.values: list = [] -class PyformatParameter(DictParameter): - """Python extended format codes, e.g. ...WHERE name=%(name)s""" + def should_parameterize(self, value: Any) -> bool: + if isinstance(value, Enum): + return False - def get_sql(self, **kwargs: Any) -> str: - return "%({placeholder})s".format(placeholder=self.placeholder) + if isinstance(value, str) and value == "*": + return False + return True - def get_param_key(self, placeholder: T, **kwargs) -> T: - return placeholder[2:-2] # type:ignore[index] + def create_param(self, value: Any) -> Parameter: + self.values.append(value) + if self.placeholder_factory: + return Parameter(self.placeholder_factory(len(self.values))) + else: + return Parameter(idx=len(self.values)) class Negative(Term): @@ -442,9 +406,23 @@ def get_sql(self, **kwargs: Any) -> str: class ValueWrapper(Term): is_aggregate = None - def __init__(self, value: Any, alias: str | None = None) -> None: + def __init__( + self, value: Any, alias: str | None = None, allow_parametrize: bool = True + ) -> None: + """ + A wrapper for a constant value such as a string or number. + + :param value: + The value to be wrapped. + :param alias: + An optional alias for the value. + :param allow_parametrize: + Whether the value should be replaced with a parameter in the query if parameterizer + is used. + """ super().__init__(alias) self.value = value + self.allow_parametrize = allow_parametrize def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @@ -475,48 +453,27 @@ def get_formatted_value(cls, value: Any, **kwargs) -> str: return "null" return str(value) - def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: - param_sql = parameter.get_sql(**kwargs) - param_key = parameter.get_param_key(placeholder=param_sql) - - return param_sql, param_key # type:ignore[return-value] - def get_sql( self, quote_char: str | None = None, secondary_quote_char: str = "'", - parameter: Parameter | None = None, + parameterizer: Parameterizer | None = None, **kwargs: Any, ) -> str: - if parameter is None: + if ( + parameterizer is None + or not parameterizer.should_parameterize(self.value) + or not self.allow_parametrize + ): sql = self.get_value_sql( quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs ) return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) - # Don't stringify number or date values when using a parameter - if isinstance(self.value, (int, float, date, time, datetime)): - value_sql = self.value - else: - value_sql = self.get_value_sql( - quote_char=quote_char, **kwargs - ) # type:ignore[assignment] - param_sql, param_key = self._get_param_data(parameter, **kwargs) - parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) - - return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) - - -class ParameterValueWrapper(ValueWrapper): - def __init__(self, parameter: Parameter, value: Any, alias: str | None = None) -> None: - super().__init__(value, alias) - self._parameter = parameter - - def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: - param_sql = self._parameter.get_sql(**kwargs) - param_key = self._parameter.get_param_key(placeholder=param_sql) - - return param_sql, param_key # type:ignore[return-value] + param = parameterizer.create_param(self.value) + return format_alias_sql( + param.get_sql(**kwargs), self.alias, quote_char=quote_char, **kwargs + ) class JSON(Term): @@ -1510,7 +1467,7 @@ def get_sql(self, **kwargs: Any) -> str: # FIXME escape function_sql = self.get_function_sql( - with_namespace=with_namespace, quote_char=quote_char, dialect=dialect + with_namespace=with_namespace, quote_char=quote_char, dialect=dialect, **kwargs ) if self.schema is not None: diff --git a/pyproject.toml b/pyproject.toml index 5f1bbfa..8887e9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pypika-tortoise" -version = "0.2.2" +version = "0.3.0" description = "Forked from pypika and streamline just for tortoise-orm" authors = ["long2ice "] license = "Apache-2.0" diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 920803a..e3bdc9f 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,17 +1,13 @@ import unittest from datetime import date -from pypika import ( - FormatParameter, - NamedParameter, - NumericParameter, - Parameter, - PyformatParameter, - QmarkParameter, - Query, - Tables, -) -from pypika.terms import ListParameter, ParameterValueWrapper +from pypika import Parameter, Query, Tables, ValueWrapper +from pypika.dialects.mssql import MSSQLQuery +from pypika.dialects.mysql import MySQLQuery +from pypika.dialects.postgresql import PostgreSQLQuery +from pypika.enums import Dialects +from pypika.functions import Upper +from pypika.terms import Case, Parameterizer class ParametrizedTests(unittest.TestCase): @@ -86,36 +82,43 @@ def test_join(self): ) def test_qmark_parameter(self): - self.assertEqual("?", QmarkParameter().get_sql()) + self.assertEqual("?", Parameter("?").get_sql()) - def test_numeric_parameter(self): - self.assertEqual(":14", NumericParameter("14").get_sql()) - self.assertEqual(":15", NumericParameter(15).get_sql()) + def test_oracle(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.ORACLE)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.ORACLE)) - def test_named_parameter(self): - self.assertEqual(":buz", NamedParameter("buz").get_sql()) + def test_mssql(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.MSSQL)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.MSSQL)) - def test_format_parameter(self): - self.assertEqual("%s", FormatParameter().get_sql()) + def test_mysql(self): + self.assertEqual("%s", Parameter(idx=1).get_sql(dialect=Dialects.MYSQL)) + self.assertEqual("%s", Parameter(idx=2).get_sql(dialect=Dialects.MYSQL)) - def test_pyformat_parameter(self): - self.assertEqual("%(buz)s", PyformatParameter("buz").get_sql()) + def test_postgres(self): + self.assertEqual("$1", Parameter(idx=1).get_sql(dialect=Dialects.POSTGRESQL)) + self.assertEqual("$2", Parameter(idx=2).get_sql(dialect=Dialects.POSTGRESQL)) + def test_sqlite(self): + self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.SQLITE)) + self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.SQLITE)) -class ParametrizedTestsWithValues(unittest.TestCase): + +class ParameterizerTests(unittest.TestCase): table_abc, table_efg = Tables("abc", "efg") def test_param_insert(self): q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") - parameter = QmarkParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) - self.assertEqual([1, 2.2, "foo"], parameter.get_parameters()) + self.assertEqual([1, 2.2, "foo"], parameterizer.values) - def test_param_select_join(self): + def test_select_join_in_mysql(self): q = ( - Query.from_(self.table_abc) + MySQLQuery.from_(self.table_abc) .select("*") .where(self.table_abc.category == "foobar") .join(self.table_efg) @@ -124,17 +127,17 @@ def test_param_select_join(self): .limit(10) ) - parameter = FormatParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.MYSQL) self.assertEqual( - 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + "SELECT * FROM `abc` JOIN `efg` ON `abc`.`id`=`efg`.`abc_id` WHERE `abc`.`category`=%s AND `efg`.`date`>=%s LIMIT %s", sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameter.get_parameters()) + self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) - def test_param_select_subquery(self): + def test_select_subquery_in_postgres(self): q = ( - Query.from_(self.table_abc) + PostgreSQLQuery.from_(self.table_abc) .select("*") .where(self.table_abc.category == "foobar") .where( @@ -147,15 +150,15 @@ def test_param_select_subquery(self): .limit(10) ) - parameter = ListParameter(placeholder=lambda idx: f"&{idx+1}") - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) self.assertEqual( - 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + 'SELECT * FROM "abc" WHERE "category"=$1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=$2) LIMIT $3', sql, ) - self.assertEqual(["foobar", date(2024, 2, 22)], parameter.get_parameters()) + self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) - def test_join(self): + def test_join_in_postgres(self): subquery = ( Query.from_(self.table_efg) .select(self.table_efg.fiz, self.table_efg.buz) @@ -163,52 +166,77 @@ def test_join(self): ) q = ( - Query.from_(self.table_abc) + PostgreSQLQuery.from_(self.table_abc) .join(subquery) .on(self.table_abc.bar == subquery.buz) .select(self.table_abc.foo, subquery.fiz) .where(self.table_abc.bar == "bar") ) - parameter = NamedParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) self.assertEqual( - 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' - ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=$1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=$2', sql, ) - self.assertEqual({"param1": "buz", "param2": "bar"}, parameter.get_parameters()) + self.assertEqual(["buz", "bar"], parameterizer.values) - def test_join_with_parameter_value_wrapper(self): - subquery = ( - Query.from_(self.table_efg) - .select(self.table_efg.fiz, self.table_efg.buz) - .where(self.table_efg.buz == ParameterValueWrapper(Parameter(":buz"), "buz")) + def test_function_parameter(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == Upper(ValueWrapper("foobar"))) ) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT * FROM "abc" WHERE "category"=UPPER(?)', sql) + + self.assertEqual(["foobar"], parameterizer.values) + def test_case_when_in_select(self): + q = Query.from_(self.table_abc).select( + Case().when(self.table_abc.category == "foobar", 1).else_(2) + ) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT CASE WHEN "category"=? THEN ? ELSE ? END FROM "abc"', sql) + self.assertEqual(["foobar", 1, 2], parameterizer.values) + + def test_case_when_in_where(self): q = ( Query.from_(self.table_abc) - .join(subquery) - .on(self.table_abc.bar == subquery.buz) - .select(self.table_abc.foo, subquery.fiz) - .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter("bar"), "bar")) + .select("*") + .where( + self.table_abc.category_int + > Case().when(self.table_abc.category == "foobar", 1).else_(2) + ) ) - - parameter = NamedParameter() - sql = q.get_sql(parameter=parameter) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) self.assertEqual( - 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' - ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', + 'SELECT * FROM "abc" WHERE "category_int">CASE WHEN "category"=? THEN ? ELSE ? END', sql, ) - self.assertEqual({":buz": "buz", "bar": "bar"}, parameter.get_parameters()) + self.assertEqual(["foobar", 1, 2], parameterizer.values) - def test_pyformat_parameter(self): - q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") + def test_limit_and_offest(self): + q = Query.from_(self.table_abc).select("*").limit(10).offset(5) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) + self.assertEqual('SELECT * FROM "abc" LIMIT ? OFFSET ?', sql) + self.assertEqual([10, 5], parameterizer.values) - parameter = PyformatParameter() - sql = q.get_sql(parameter=parameter) + def test_limit_and_offest_in_mssql(self): + q = MSSQLQuery.from_(self.table_abc).select("*").limit(10).offset(5) + parameterizer = Parameterizer() + sql = q.get_sql(parameterizer=parameterizer) self.assertEqual( - 'INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql + 'SELECT * FROM "abc" ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY', sql ) - self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) + self.assertEqual([5, 10], parameterizer.values) + + def test_placeholder_factory(self): + parameterizer = Parameterizer(placeholder_factory=lambda _: "%s") + param = parameterizer.create_param(1) + self.assertEqual("%s", param.get_sql()) diff --git a/tests/test_terms.py b/tests/test_terms.py index a5dea35..295e39d 100644 --- a/tests/test_terms.py +++ b/tests/test_terms.py @@ -1,7 +1,7 @@ from unittest import TestCase from pypika import Field, Query, Table -from pypika.terms import AtTimezone +from pypika.terms import AtTimezone, Parameterizer, ValueWrapper class FieldAliasTests(TestCase): @@ -49,3 +49,15 @@ def test_passes_kwargs_to_field_get_sql(self): 'FROM "customers" JOIN "accounts" ON "customers"."account_id"="accounts"."account_id"', query.get_sql(with_namespace=True), ) + + +class ValueWrapperTests(TestCase): + def test_allow_parametrize(self): + value = ValueWrapper("foo") + self.assertEqual("'foo'", value.get_sql()) + + value = ValueWrapper("foo") + self.assertEqual("?", value.get_sql(parameterizer=Parameterizer())) + + value = ValueWrapper("foo", allow_parametrize=False) + self.assertEqual("'foo'", value.get_sql(parameterizer=Parameterizer())) diff --git a/tests/test_updates.py b/tests/test_updates.py index 175a256..a722554 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -43,7 +43,7 @@ def test_update__table_schema(self): def test_update_with_none(self): q = Query.update("abc").set("foo", None) - self.assertEqual('UPDATE "abc" SET "foo"=null', str(q)) + self.assertEqual('UPDATE "abc" SET "foo"=NULL', str(q)) def test_update_from(self): from_table = Table("from_table")