diff --git a/src/stairlight/map.py b/src/stairlight/map.py index 115e062..6c9130b 100644 --- a/src/stairlight/map.py +++ b/src/stairlight/map.py @@ -292,7 +292,7 @@ def add_unmapped_params( """ if not params: template_str = template.get_template_str() - params = template.detect_jinja_params(template_str=template_str) + params = template.get_jinja_params(template_str=template_str) self.unmapped.append( { MapKey.TEMPLATE: template, @@ -314,7 +314,7 @@ def detect_unmapped_params( list[str]: Unmapped parameters """ template_str: str = template.get_template_str() - template_params: list[str] = template.detect_jinja_params(template_str) + template_params: list[str] = template.get_jinja_params(template_str) if not template_params: return [] diff --git a/src/stairlight/source/template.py b/src/stairlight/source/template.py index cf09980..bcf6da1 100644 --- a/src/stairlight/source/template.py +++ b/src/stairlight/source/template.py @@ -4,6 +4,7 @@ import re from abc import ABC, abstractmethod from logging import getLogger +from string import Template as StringTemplate from typing import Any, Iterator from jinja2 import BaseLoader, Environment @@ -125,8 +126,8 @@ def mapped(self) -> bool: return result @staticmethod - def detect_jinja_params(template_str: str) -> list: - """Detect jinja parameters from template string + def get_jinja_params(template_str: str) -> list: + """get jinja parameters from template string Args: template_str (str): Template string @@ -142,18 +143,14 @@ def detect_jinja_params(template_str: str) -> list: for param in re.findall("[^{}]+", jinja_expressions, re.IGNORECASE) ] - @staticmethod - def render_by_base_loader( - source_type: TemplateSourceType, - key: str, + def render_by_jinja( + self, template_str: str, params: dict[str, Any], ) -> str: - """Render query string from template string + """Render query string by jinja2 Args: - source_type (str): source type - key (str): key path template_str (str): template string params (dict[str, Any]): Jinja parameters @@ -163,42 +160,95 @@ def render_by_base_loader( Returns: str: rendered query string """ - result: str = "" - jinja_template = Environment(loader=BaseLoader()).from_string(template_str) + if not self.get_jinja_params(template_str=template_str): + return template_str + + rendered_str: str = template_str try: - result = jinja_template.render(params) + env = Environment(loader=BaseLoader()) + template = env.from_string(template_str) + rendered_str = template.render(params) except UndefinedError as undefined_error: logger.warning( ( f"{undefined_error.message}, " - f"source_type: {source_type}, " - f"key: {key}" + f"source_type: {self.source_type}, " + f"key: {self.key}" ) ) - result = template_str - return result + return rendered_str - @staticmethod - def ignore_params_from_template_str( - template_str: str, ignore_params: list[str] + def render_by_string_template( + self, + template_str: str, + params: dict[str, Any], ) -> str: - """ignore parameters from template string + """_summary_ + + Args: + template_str (str): template string + params (dict[str, Any]): mapping dict + + Returns: + str: rendered query string + """ + s = StringTemplate(template=template_str) + if not s.get_identifiers(): + return template_str + + try: + rendered_str = s.substitute(params) + except KeyError as e: + logger.warning( + ( + f"{e.message}, " + f"source_type: {self.source_type}, " + f"key: {self.key}" + ) + ) + rendered_str = s.safe_substitute(params) + return rendered_str + + @staticmethod + def ignore_jinja_params(template_str: str, ignore_params: list[str]) -> str: + """ignore jinja parameters Args: template_str (str): template string ignore_params (list[str]): ignore parameters Returns: - str: replaced template string + str: ignored template string """ if not ignore_params: - ignore_params = [] - replaced_str = template_str + return template_str + + ignored_str = template_str for ignore_param in ignore_params: - replaced_str = replaced_str.replace( + ignored_str = ignored_str.replace( "{{{{ {} }}}}".format(ignore_param), "ignored" ) - return replaced_str + return ignored_str + + @staticmethod + def ignore_string_template_params( + template_str: str, ignore_params: list[str] + ) -> str: + """ignore string.Template parameters + + Args: + template_str (str): template string + ignore_params (list[str]): ignore parameters + + Returns: + str: ignored template string + """ + if not ignore_params: + return template_str + + s = StringTemplate(template=template_str) + ignore_mapping = {k: "ignored" for k in ignore_params} + return s.safe_substitute(ignore_mapping) @abstractmethod def get_uri(self) -> str: @@ -218,23 +268,30 @@ def render(self, params: dict[str, Any], ignore_params: list[str] = None) -> str Returns: str: Query statement """ - if not ignore_params: - ignore_params = [] - template_str = self.get_template_str() - replaced_template_str = self.ignore_params_from_template_str( - template_str=template_str, + rendered_str = self.get_template_str() + rendered_str = self.ignore_jinja_params( + template_str=rendered_str, ignore_params=ignore_params, ) - if params: - results = self.render_by_base_loader( - source_type=self.source_type, - key=self.key, - template_str=replaced_template_str, - params=params, - ) - else: - results = replaced_template_str - return results + rendered_str = self.ignore_string_template_params( + template_str=rendered_str, + ignore_params=ignore_params, + ) + + if not params: + return rendered_str + + rendered_str = self.render_by_jinja( + template_str=rendered_str, + params=params, + ) + + rendered_str = self.render_by_string_template( + template_str=rendered_str, + params=params, + ) + + return rendered_str class RenderingTemplateException(Exception): diff --git a/tests/sql/cte_multi_line_identifiers.sql b/tests/sql/cte_multi_line_identifiers.sql new file mode 100644 index 0000000..1d71e1e --- /dev/null +++ b/tests/sql/cte_multi_line_identifiers.sql @@ -0,0 +1,38 @@ +WITH c AS ( + SELECT + test_id, + col_c + FROM + $sub_table_01 + WHERE + 0 = 0 +), +d AS ( + SELECT + test_id, + col_d + FROM + ${sub_table_02} + WHERE + 0 = 0 +), +e AS ( + SELECT + test_id, + col_d + FROM + $sub_table_02 + WHERE + 0 = 0 +) + +SELECT + * +FROM + ${main_table} AS b + INNER JOIN c + ON b.test_id = c.test_id + INNER JOIN d + ON b.test_id = d.test_id +WHERE + 1 = 1 diff --git a/tests/stairlight/source/file/test_template.py b/tests/stairlight/source/file/test_template.py index a7f75a9..cf17472 100644 --- a/tests/stairlight/source/file/test_template.py +++ b/tests/stairlight/source/file/test_template.py @@ -50,139 +50,235 @@ def file_template( def test_mapped(self, file_template: FileTemplate, expected_mapped: bool): assert file_template.mapped == expected_mapped - def test_detect_jinja_params(self, file_template: FileTemplate): + def test_get_jinja_params(self, file_template: FileTemplate): template_str = file_template.get_template_str() - assert len(file_template.detect_jinja_params(template_str)) > 0 + assert len(file_template.get_jinja_params(template_str)) > 0 -@pytest.mark.parametrize( - ("key", "params", "ignore_params", "expected_table", "detected_params"), - [ +class TestFileTemplateRender: + @pytest.mark.parametrize( ( - "tests/sql/cte_multi_line_params.sql", - { - "params": { + "key", + "params", + "ignore_params", + "expected_table", + ), + [ + ( + "tests/sql/cte_multi_line_params.sql", + { + "params": { + "main_table": "PROJECT_P.DATASET_Q.TABLE_R", + "sub_table_01": "PROJECT_S.DATASET_T.TABLE_U", + "sub_table_02": "PROJECT_V.DATASET_W.TABLE_X", + } + }, + [], + "PROJECT_P.DATASET_Q.TABLE_R", + ), + ( + "tests/sql/cte_multi_line.sql", + { + "params": { + "PROJECT": "PROJECT_g", + "DATASET": "DATASET_h", + "TABLE": "TABLE_i", + } + }, + [ + "execution_date.add(days=1).isoformat()", + "execution_date.add(days=2).isoformat()", + ], + "PROJECT_g.DATASET_h.TABLE_i", + ), + ( + "tests/sql/params_with_default_value.sql", + { + "params": { + "main_table": "PROJECT_P.DATASET_Q.TABLE_R", + "sub_table_01": "PROJECT_S.DATASET_T.TABLE_U", + "sub_table_02": "PROJECT_V.DATASET_W.TABLE_X", + } + }, + [ + "params.target_column | default('\"top\"')", + 'params.target_column or "top"', + ], + "PROJECT_P.DATASET_Q.TABLE_R", + ), + ( + "tests/sql/nested_join.sql", + None, + [], + "PROJECT_B.DATASET_B.TABLE_B", + ), + ( + "tests/sql/cte_multi_line_identifiers.sql", + { "main_table": "PROJECT_P.DATASET_Q.TABLE_R", "sub_table_01": "PROJECT_S.DATASET_T.TABLE_U", "sub_table_02": "PROJECT_V.DATASET_W.TABLE_X", - } - }, - [], - "PROJECT_P.DATASET_Q.TABLE_R", - [ - "params.sub_table_01", - "params.sub_table_02", - "params.main_table", - ], - ), - ( + }, + [], + "PROJECT_P.DATASET_Q.TABLE_R", + ), + ], + ids=[ + "tests/sql/cte_multi_line_params.sql", "tests/sql/cte_multi_line.sql", - { - "params": { - "PROJECT": "PROJECT_g", - "DATASET": "DATASET_h", - "TABLE": "TABLE_i", - } - }, - [ - "execution_date.add(days=1).isoformat()", - "execution_date.add(days=2).isoformat()", - ], - "PROJECT_g.DATASET_h.TABLE_i", - [ - "execution_date.add(days=1).isoformat()", - "execution_date.add(days=2).isoformat()", - "params.PROJECT", - "params.DATASET", - "params.TABLE", - ], - ), - ( "tests/sql/params_with_default_value.sql", - { - "params": { - "main_table": "PROJECT_P.DATASET_Q.TABLE_R", - "sub_table_01": "PROJECT_S.DATASET_T.TABLE_U", - "sub_table_02": "PROJECT_V.DATASET_W.TABLE_X", - } - }, - [ - "params.target_column | default('\"top\"')", - 'params.target_column or "top"', - ], - "PROJECT_P.DATASET_Q.TABLE_R", - [ - "params.sub_table_01", - "params.sub_table_02", - "params.main_table", - "params.target_column | default('\"top\"')", - 'params.target_column_2 or "top"', - 'params.target_column_2 or "latest"', - ], - ), - ( "tests/sql/nested_join.sql", - None, - [], - "PROJECT_B.DATASET_B.TABLE_B", - [], - ), - ], - ids=[ - "tests/sql/cte_multi_line_params.sql", - "tests/sql/cte_multi_line.sql", - "tests/sql/params_with_default_value.sql", - "tests/sql/nested_join.sql", - ], -) -class TestFileTemplateRender: - @pytest.fixture(scope="function") - def file_template( - self, - mapping_config: MappingConfig, - key: str, - ) -> FileTemplate: - return FileTemplate( - mapping_config=mapping_config, - key=key, - ) - + "tests/sql/cte_multi_line_identifiers.sql", + ], + ) def test_render( self, - file_template: FileTemplate, + mapping_config: MappingConfig, + key, params, ignore_params, expected_table, - detected_params, ): + file_template = FileTemplate( + mapping_config=mapping_config, + key=key, + ) actual = file_template.render( params=params, ignore_params=ignore_params, ) assert expected_table in actual - def test_detect_jinja_params( + @pytest.mark.parametrize( + ("key", "detected_params"), + [ + ( + "tests/sql/cte_multi_line_params.sql", + [ + "params.sub_table_01", + "params.sub_table_02", + "params.main_table", + ], + ), + ( + "tests/sql/cte_multi_line.sql", + [ + "execution_date.add(days=1).isoformat()", + "execution_date.add(days=2).isoformat()", + "params.PROJECT", + "params.DATASET", + "params.TABLE", + ], + ), + ( + "tests/sql/params_with_default_value.sql", + [ + "params.sub_table_01", + "params.sub_table_02", + "params.main_table", + "params.target_column | default('\"top\"')", + 'params.target_column_2 or "top"', + 'params.target_column_2 or "latest"', + ], + ), + ( + "tests/sql/nested_join.sql", + [], + ), + ( + "tests/sql/cte_multi_line_identifiers.sql", + [], + ), + ], + ids=[ + "tests/sql/cte_multi_line_params.sql", + "tests/sql/cte_multi_line.sql", + "tests/sql/params_with_default_value.sql", + "tests/sql/nested_join.sql", + "tests/sql/cte_multi_line_identifiers.sql", + ], + ) + def test_get_jinja_params( self, - file_template: FileTemplate, - params, - ignore_params, - expected_table, + mapping_config, + key, detected_params, ): + file_template = FileTemplate( + mapping_config=mapping_config, + key=key, + ) template_str = file_template.get_template_str() - actual = file_template.detect_jinja_params(template_str=template_str) + actual = file_template.get_jinja_params(template_str=template_str) assert actual == detected_params - def test_ignore_params_from_template_str( + @pytest.mark.parametrize( + ("key", "ignore_params"), + [ + ( + "tests/sql/cte_multi_line.sql", + [ + "execution_date.add(days=1).isoformat()", + "execution_date.add(days=2).isoformat()", + ], + ), + ( + "tests/sql/params_with_default_value.sql", + [ + "params.target_column | default('\"top\"')", + 'params.target_column or "top"', + ], + ), + ( + "tests/sql/nested_join.sql", + [], + ), + ], + ids=[ + "tests/sql/cte_multi_line.sql", + "tests/sql/params_with_default_value.sql", + "tests/sql/nested_join.sql", + ], + ) + def test_ignore_jinja_params( self, - file_template: FileTemplate, - params, + mapping_config, + key, + ignore_params, + ): + file_template = FileTemplate( + mapping_config=mapping_config, + key=key, + ) + template_str = file_template.get_template_str() + actual = file_template.ignore_jinja_params( + template_str=template_str, + ignore_params=ignore_params, + ) + assert all(ignore_param not in actual for ignore_param in ignore_params) + + @pytest.mark.parametrize( + ("key", "ignore_params"), + [ + ( + "tests/sql/cte_multi_line_identifiers.sql", + ["main_table", "sub_table_02"], + ), + ], + ids=["tests/sql/cte_multi_line_identifiers.sql"], + ) + def test_ignore_string_template_params( + self, + mapping_config, + key, ignore_params, - expected_table, - detected_params, ): + file_template = FileTemplate( + mapping_config=mapping_config, + key=key, + ) template_str = file_template.get_template_str() - actual = file_template.ignore_params_from_template_str( + actual = file_template.ignore_string_template_params( template_str=template_str, ignore_params=ignore_params, ) diff --git a/tests/stairlight/source/gcs/test_template.py b/tests/stairlight/source/gcs/test_template.py index 7dfc11b..c1dbd62 100644 --- a/tests/stairlight/source/gcs/test_template.py +++ b/tests/stairlight/source/gcs/test_template.py @@ -84,7 +84,7 @@ def test_get_uri( ): assert gcs_template.uri == f"{GCS_URI_SCHEME}{bucket}/{key}" - def test_detect_jinja_params( + def test_get_jinja_params( self, mocker, gcs_template: GcsTemplate, local_file_path: str ): with open(local_file_path, "r") as test_file: @@ -94,12 +94,12 @@ def test_detect_jinja_params( return_value=test_file_str.encode(), ) template_str = gcs_template.get_template_str() - assert len(gcs_template.detect_jinja_params(template_str)) > 0 + assert len(gcs_template.get_jinja_params(template_str)) > 0 @pytest.mark.integration - def test_detect_jinja_params_integration(self, gcs_template: GcsTemplate): + def test_get_jinja_params_integration(self, gcs_template: GcsTemplate): template_str = gcs_template.get_template_str() - assert len(gcs_template.detect_jinja_params(template_str)) > 0 + assert len(gcs_template.get_jinja_params(template_str)) > 0 def test_render( self, diff --git a/tests/stairlight/source/s3/test_template.py b/tests/stairlight/source/s3/test_template.py index 8627d36..21918ae 100644 --- a/tests/stairlight/source/s3/test_template.py +++ b/tests/stairlight/source/s3/test_template.py @@ -75,7 +75,7 @@ def test_get_uri( assert s3_template.uri == f"{S3_URI_SCHEME}{bucket}/{key}" @mock_s3 - def test_detect_jinja_params( + def test_get_jinja_params( self, s3_template: S3Template, key: str, @@ -85,15 +85,15 @@ def test_detect_jinja_params( s3_bucket = s3_client.Bucket(BUCKET_NAME) s3_bucket.upload_file("tests/sql/gcs/cte/cte_multi_line.sql", key) template_str = s3_template.get_template_str() - assert len(s3_template.detect_jinja_params(template_str)) > 0 + assert len(s3_template.get_jinja_params(template_str)) > 0 @pytest.mark.integration - def test_detect_jinja_params_integration( + def test_get_jinja_params_integration( self, s3_template: S3Template, ): template_str = s3_template.get_template_str() - assert len(s3_template.detect_jinja_params(template_str)) > 0 + assert len(s3_template.get_jinja_params(template_str)) > 0 @mock_s3 def test_render(