diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 608baac..3ef4cca 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -1,17 +1,13 @@ -import os from operator import contains, eq -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, cast -from urllib.parse import urlparse -from uuid import uuid4 +from typing import Sequence, Tuple, cast import snowflake.sqlalchemy.custom_types as sct import sqlalchemy from singer_sdk import typing as th -from singer_sdk.batch import lazy_chunked_generator from singer_sdk.connectors import SQLConnector -from singer_sdk.helpers._batch import BaseBatchFileEncoding, BatchConfig -from singer_sdk.helpers._typing import conform_record_data_types from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer +from snowflake.sqlalchemy.snowdialect import SnowflakeDialect from sqlalchemy.engine import Engine from sqlalchemy.sql import text @@ -93,6 +89,40 @@ def create_engine(self) -> Engine: echo=False, ) + def prepare_column( + self, + full_table_name: str, + column_name: str, + sql_type: sqlalchemy.types.TypeEngine, + ) -> None: + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + # Make quoted column names upper case because we create them that way + # and the metadata that SQLAlchemy returns is case insensitive only for non-quoted + # column names so these will look like they dont exist yet. + if '"' in formatter.format_collation(column_name): + column_name = column_name.upper() + + super().prepare_column( + full_table_name, + column_name, + sql_type, + ) + + @staticmethod + def get_column_rename_ddl( + table_name: str, + column_name: str, + new_column_name: str, + ) -> sqlalchemy.DDL: + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + # Since we build the ddl manually we can't rely on SQLAlchemy to + # quote column names automatically. + return SQLConnector.get_column_rename_ddl( + table_name, + formatter.format_collation(column_name), + formatter.format_collation(new_column_name), + ) + @staticmethod def get_column_alter_ddl( table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine @@ -109,11 +139,14 @@ def get_column_alter_ddl( Returns: A sqlalchemy DDL instance. """ + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + # Since we build the ddl manually we can't rely on SQLAlchemy to + # quote column names automatically. return sqlalchemy.DDL( "ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s SET DATA TYPE %(column_type)s", { "table_name": table_name, - "column_name": column_name, + "column_name": formatter.format_collation(column_name), "column_type": column_type, }, ) @@ -171,29 +204,35 @@ def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]: """Get Snowflake PUT statement.""" return (text(f"put '{file_uri}' '@~/target-snowflake/{sync_id}'"), {}) + def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPreparer) -> list: + column_selections = [] + for property_name, property_def in schema["properties"].items(): + clean_property_name = formatter.format_collation(property_name) + column_selections.append( + f"$1:{property_name}::{self.to_sql_type(property_def)} as {clean_property_name}" + ) + return column_selections + def _get_merge_from_stage_statement( self, full_table_name, schema, sync_id, file_format, key_properties ): """Get Snowflake MERGE statement.""" - # convert from case in JSON to UPPER column name - column_selections = [ - f"$1:{property_name}::{self.to_sql_type(property_def)} as {property_name.upper()}" - for property_name, property_def in schema["properties"].items() - ] + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + column_selections = self._get_column_selections(schema, formatter) # use UPPER from here onwards - upper_properties = [col.upper() for col in schema["properties"].keys()] - upper_key_properties = [col.upper() for col in key_properties] + formatted_properties = [formatter.format_collation(col) for col in schema["properties"].keys()] + formatted_key_properties = [formatter.format_collation(col) for col in key_properties] join_expr = " and ".join( - [f'd."{key}" = s."{key}"' for key in upper_key_properties] + [f'd.{key} = s.{key}' for key in formatted_key_properties] ) matched_clause = ", ".join( - [f'd."{col}" = s."{col}"' for col in upper_properties] + [f'd.{col} = s.{col}' for col in formatted_properties] ) - not_matched_insert_cols = ", ".join(upper_properties) + not_matched_insert_cols = ", ".join(formatted_properties) not_matched_insert_values = ", ".join( - [f's."{col}"' for col in upper_properties] + [f's.{col}' for col in formatted_properties] ) return ( text( @@ -210,11 +249,8 @@ def _get_merge_from_stage_statement( def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): """Get Snowflake COPY statement.""" - # convert from case in JSON to UPPER column name - column_selections = [ - f"$1:{property_name}::{self.to_sql_type(property_def)} as {property_name.upper()}" - for property_name, property_def in schema["properties"].items() - ] + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + column_selections = self._get_column_selections(schema, formatter) return ( text( f"copy into {full_table_name} from " diff --git a/target_snowflake/tests/__init__.py b/target_snowflake/tests/__init__.py deleted file mode 100644 index 7b52a26..0000000 --- a/target_snowflake/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for target-snowflake.""" diff --git a/tests/target_test_streams/reserved_words.singer b/tests/target_test_streams/reserved_words.singer new file mode 100644 index 0000000..99b436a --- /dev/null +++ b/tests/target_test_streams/reserved_words.singer @@ -0,0 +1,6 @@ +{"type": "SCHEMA", "stream": "reserved_words", "schema": {"properties": {"NON_RESERVED": {"type": ["string", "null"]}, "SELECT": {"type": ["string", "null"]}, "INSERT": {"type": ["string", "null"]}, "UPDATE": {"type": ["string", "null"]}}, "type": "object"}, "key_properties": ["NON_RESERVED"]} +{"type": "RECORD", "stream": "reserved_words", "record": {"NON_RESERVED": "sample_1", "SELECT": "test1", "INSERT": "test2", "UPDATE": "test3"}, "time_extracted": "2023-06-09T17:15:57.857785+00:00"} +{"type": "STATE", "value": {"bookmarks": {"reserved_words": {}}}} +{"type": "SCHEMA", "stream": "reserved_words", "schema": {"properties": {"NON_RESERVED": {"type": ["string", "null"]}, "SELECT": {"type": ["string", "null"]}, "INSERT": {"type": ["string", "null"]}, "UPDATE": {"type": ["string", "null"]}, "ACCOUNT": {"type": ["string", "null"]}}, "type": "object"}, "key_properties": ["NON_RESERVED"]} +{"type": "RECORD", "stream": "reserved_words", "record": {"NON_RESERVED": "sample_2", "SELECT": "test10", "INSERT": "test20", "UPDATE": "test30", "ACCOUNT": "test40"}, "time_extracted": "2023-06-09T17:16:44.036424+00:00"} +{"type": "STATE", "value": {"bookmarks": {"reserved_words": {}}}} diff --git a/tests/target_test_streams/reserved_words_no_key_props.singer b/tests/target_test_streams/reserved_words_no_key_props.singer new file mode 100644 index 0000000..2c80fdb --- /dev/null +++ b/tests/target_test_streams/reserved_words_no_key_props.singer @@ -0,0 +1,3 @@ +{"type": "SCHEMA", "stream": "reserved_words_no_key_props", "schema": {"properties": {"NON_RESERVED": {"type": ["string", "null"]}, "SELECT": {"type": ["string", "null"]}, "INSERT": {"type": ["string", "null"]}, "UPDATE": {"type": ["string", "null"]}}, "type": "object"}, "key_properties": []} +{"type": "RECORD", "stream": "reserved_words_no_key_props", "record": {"NON_RESERVED": "sample_1", "SELECT": "test10", "INSERT": "test20", "UPDATE": "test30"}, "time_extracted": "2023-06-09T18:18:04.605072+00:00"} +{"type": "STATE", "value": {"bookmarks": {"reserved_words_no_key_props": {}}}} diff --git a/tests/test_impl.py b/tests/test_impl.py index 70a6809..28825ef 100644 --- a/tests/test_impl.py +++ b/tests/test_impl.py @@ -1,23 +1,24 @@ +from pathlib import Path + import pytest import snowflake.sqlalchemy.custom_types as sct import sqlalchemy from singer_sdk.testing.suites import TestSuite -from singer_sdk.testing.target_tests import ( - TargetArrayData, - TargetCamelcaseComplexSchema, - TargetCamelcaseTest, - TargetCliPrintsTest, - TargetDuplicateRecords, - TargetEncodedStringData, - TargetInvalidSchemaTest, - TargetNoPrimaryKeys, - TargetOptionalAttributes, - TargetRecordBeforeSchemaTest, - TargetRecordMissingKeyProperty, - TargetSchemaNoProperties, - TargetSchemaUpdates, - TargetSpecialCharsInAttributes, -) +from singer_sdk.testing.target_tests import (TargetArrayData, + TargetCamelcaseComplexSchema, + TargetCamelcaseTest, + TargetCliPrintsTest, + TargetDuplicateRecords, + TargetEncodedStringData, + TargetInvalidSchemaTest, + TargetNoPrimaryKeys, + TargetOptionalAttributes, + TargetRecordBeforeSchemaTest, + TargetRecordMissingKeyProperty, + TargetSchemaNoProperties, + TargetSchemaUpdates, + TargetSpecialCharsInAttributes) +from singer_sdk.testing.templates import TargetFileTestTemplate class SnowflakeTargetArrayData(TargetArrayData): @@ -232,6 +233,48 @@ def validate(self) -> None: isinstance(column.type, expected_types[column.name]) +class SnowflakeTargetReservedWords(TargetFileTestTemplate): + + # Contains reserved words from https://docs.snowflake.com/en/sql-reference/reserved-keywords + # Syncs records then alters schema by adding a non-reserved word column. + name = "reserved_words" + + @property + def singer_filepath(self) -> Path: + current_dir = Path(__file__).resolve().parent + return current_dir / "target_test_streams" / f"{self.name}.singer" + + def validate(self) -> None: + connector = self.target.default_sink_class.connector_class(self.target.config) + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper() + result = connector.connection.execute( + f"select * from {table}", + ) + assert result.rowcount == 2 + row = result.first() + assert len(row) == 11 + +class SnowflakeTargetReservedWordsNoKeyProps(TargetFileTestTemplate): + + # Contains reserved words from https://docs.snowflake.com/en/sql-reference/reserved-keywords + # TODO: Syncs records then alters schema by adding a non-reserved word column. + name = "reserved_words_no_key_props" + + @property + def singer_filepath(self) -> Path: + current_dir = Path(__file__).resolve().parent + return current_dir / "target_test_streams" / f"{self.name}.singer" + + def validate(self) -> None: + connector = self.target.default_sink_class.connector_class(self.target.config) + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper() + result = connector.connection.execute( + f"select * from {table}", + ) + assert result.rowcount == 1 + row = result.first() + assert len(row) == 10 + target_tests = TestSuite( kind="target", tests=[ @@ -252,5 +295,7 @@ def validate(self) -> None: SnowflakeTargetSchemaNoProperties, SnowflakeTargetSchemaUpdates, TargetSpecialCharsInAttributes, # Implicitly asserts that special chars are handled + SnowflakeTargetReservedWords, + SnowflakeTargetReservedWordsNoKeyProps, ], )