From d9c74748817f756b9b2131db9752ed2e469f6451 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 18 Nov 2024 10:17:44 +0100 Subject: [PATCH] Add placeholder_factory to Parameterizer --- pypika/queries.py | 3 ++- pypika/terms.py | 9 +++++++-- tests/test_parameter.py | 5 +++++ tests/test_updates.py | 2 +- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 79258d1..29e2e04 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1228,7 +1228,8 @@ def minus(self, other: "QueryBuilder") -> _SetOperation: @builder def set(self, field: Union[Field, str], value: Any) -> "QueryBuilder": field = Field(field) if not isinstance(field, Field) else field - self._updates.append((field, self._wrapper_cls(value))) + value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls) + self._updates.append((field, value)) def __add__(self, other: "QueryBuilder") -> _SetOperation: return self.union(other) diff --git a/pypika/terms.py b/pypika/terms.py index f8b9d49..9dfb2ef 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -7,6 +7,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Iterable, Iterator, List, @@ -333,7 +334,8 @@ class Parameterizer: be accessed via the `values` attribute. """ - def __init__(self) -> None: + def __init__(self, placeholder_factory: Optional[Callable[[int], str]] = None) -> None: + self.placeholder_factory = placeholder_factory self.values = [] def should_parameterize(self, value: Any) -> bool: @@ -347,7 +349,10 @@ def should_parameterize(self, value: Any) -> bool: def create_param(self, value: Any) -> Parameter: self.values.append(value) - return Parameter(idx=len(self.values)) + if self.placeholder_factory: + return Parameter(self.placeholder_factory(len(self.values))) + else: + return Parameter(idx=len(self.values)) class Negative(Term): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index e832a90..e3bdc9f 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -235,3 +235,8 @@ def test_limit_and_offest_in_mssql(self): 'SELECT * FROM "abc" ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY', sql ) self.assertEqual([5, 10], parameterizer.values) + + def test_placeholder_factory(self): + parameterizer = Parameterizer(placeholder_factory=lambda _: "%s") + param = parameterizer.create_param(1) + self.assertEqual("%s", param.get_sql()) diff --git a/tests/test_updates.py b/tests/test_updates.py index 175a256..a722554 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -43,7 +43,7 @@ def test_update__table_schema(self): def test_update_with_none(self): q = Query.update("abc").set("foo", None) - self.assertEqual('UPDATE "abc" SET "foo"=null', str(q)) + self.assertEqual('UPDATE "abc" SET "foo"=NULL', str(q)) def test_update_from(self): from_table = Table("from_table")