diff --git a/src/stairlight/map.py b/src/stairlight/map.py index c267026..115e062 100644 --- a/src/stairlight/map.py +++ b/src/stairlight/map.py @@ -103,21 +103,17 @@ def write_by_template_source(self, template_source: TemplateSource) -> None: template_source (TemplateSource): Template source """ for template in template_source.search_templates(): - # Check if the template is in the mapping_config - if not self._mapping_config: - self.add_unmapped_params(template=template) - elif template.is_mapped(): - for table_attributes in template.find_mapped_table_attributes(): - unmapped_params = self.detect_unmapped_params( - template=template, table_attributes=table_attributes - ) - if unmapped_params: - self.add_unmapped_params( - template=template, params=unmapped_params - ) - self.remap(template=template, table_attributes=table_attributes) - else: + if not self._mapping_config or not template.mapped: self.add_unmapped_params(template=template) + continue + + for table_attributes in template.find_mapped_table_attributes(): + unmapped_params = self.detect_unmapped_params( + template=template, table_attributes=table_attributes + ) + if unmapped_params: + self.add_unmapped_params(template=template, params=unmapped_params) + self.remap(template=template, table_attributes=table_attributes) def remap( self, template: Template, table_attributes: MappingConfigMappingTable diff --git a/src/stairlight/source/template.py b/src/stairlight/source/template.py index 50c9095..cf09980 100644 --- a/src/stairlight/source/template.py +++ b/src/stairlight/source/template.py @@ -15,6 +15,8 @@ StairlightConfig, ) +logger = getLogger(__name__) + class TemplateSourceType(enum.Enum): """Query template source type""" @@ -109,7 +111,8 @@ def find_mapped_table_attributes(self) -> Iterator[MappingConfigMappingTable]: yield table_attributes break - def is_mapped(self) -> bool: + @property + def mapped(self) -> bool: """Check if the template is set to mapping configuration Returns: @@ -160,17 +163,20 @@ def render_by_base_loader( Returns: str: rendered query string """ + result: str = "" + jinja_template = Environment(loader=BaseLoader()).from_string(template_str) try: - jinja_template = Environment(loader=BaseLoader()).from_string(template_str) - return jinja_template.render(params) + result = jinja_template.render(params) except UndefinedError as undefined_error: - raise RenderingTemplateException( + logger.warning( ( f"{undefined_error.message}, " f"source_type: {source_type}, " f"key: {key}" ) - ) from None + ) + result = template_str + return result @staticmethod def ignore_params_from_template_str( diff --git a/tests/stairlight/source/file/test_template.py b/tests/stairlight/source/file/test_template.py index 238f011..b88c6c0 100644 --- a/tests/stairlight/source/file/test_template.py +++ b/tests/stairlight/source/file/test_template.py @@ -21,7 +21,7 @@ @pytest.mark.parametrize( - ("key", "expected_is_mapped"), + ("key", "expected_mapped"), [ ("tests/sql/cte_multi_line_params.sql", True), ("tests/sql/cte_multi_line_params_copy.sql", True), @@ -41,15 +41,15 @@ def file_template( self, mapping_config: MappingConfig, key: str, - expected_is_mapped: bool, + expected_mapped: bool, ): return FileTemplate( mapping_config=mapping_config, key=key, ) - def test_is_mapped(self, file_template: FileTemplate, expected_is_mapped: bool): - assert file_template.is_mapped() == expected_is_mapped + def test_mapped(self, file_template: FileTemplate, expected_mapped: bool): + assert file_template.mapped == expected_mapped def test_detect_jinja_params(self, file_template: FileTemplate): template_str = file_template.get_template_str() @@ -206,7 +206,7 @@ def test_ignore_params_from_template_str( ], ids=["tests/sql/cte_multi_line.sql"], ) -class TestFileTemplateRenderException: +class TestFileTemplateUndefinedError: @pytest.fixture(scope="function") def file_template( self, @@ -223,14 +223,10 @@ def test_render( file_template: FileTemplate, key: str, params: dict[str, Any], + caplog, ): - with pytest.raises(RenderingTemplateException) as exception: - _ = file_template.render(params=params) - assert exception.value.args[0] == ( - f"'execution_date' is undefined, " - f"source_type: {file_template.source_type}, " - f"key: {key}" - ) + _ = file_template.render(params=params) + assert "undefined" in caplog.text @pytest.mark.parametrize( diff --git a/tests/stairlight/source/gcs/test_template.py b/tests/stairlight/source/gcs/test_template.py index d09c88e..7dfc11b 100644 --- a/tests/stairlight/source/gcs/test_template.py +++ b/tests/stairlight/source/gcs/test_template.py @@ -73,8 +73,8 @@ def gcs_template( ) -> GcsTemplate: return GcsTemplate(mapping_config=mapping_config, bucket=bucket, key=key) - def test_is_mapped(self, gcs_template: GcsTemplate): - assert gcs_template.is_mapped() + def test_mapped(self, gcs_template: GcsTemplate): + assert gcs_template.mapped def test_get_uri( self, diff --git a/tests/stairlight/source/s3/test_template.py b/tests/stairlight/source/s3/test_template.py index 375fdbb..8627d36 100644 --- a/tests/stairlight/source/s3/test_template.py +++ b/tests/stairlight/source/s3/test_template.py @@ -60,11 +60,11 @@ def s3_template( key=key, ) - def test_is_mapped( + def test_mapped( self, s3_template: S3Template, ): - assert s3_template.is_mapped() + assert s3_template.mapped def test_get_uri( self,