Skip to content

Commit

Permalink
refactor: Improved SQL identifier (de)normalization (#2601)
Browse files Browse the repository at this point in the history
* refactor: WIP Improved SQL identifier (de)normalization

* Add tests

* Do not use db dialect to generate stream name

* Fix types
  • Loading branch information
edgarrmondragon authored Aug 13, 2024
1 parent e1812b6 commit c3c2351
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 54 deletions.
170 changes: 121 additions & 49 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import typing as t
import warnings
from collections import UserString
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
Expand All @@ -22,6 +23,89 @@
from sqlalchemy.engine.reflection import Inspector


class FullyQualifiedName(UserString):
"""A fully qualified table name.
This class provides a simple way to represent a fully qualified table name
as a single object. The string representation of this object is the fully
qualified table name, with the parts separated by periods.
The parts of the fully qualified table name are:
- database
- schema
- table
The database and schema are optional. If only the table name is provided,
the string representation of the object will be the table name alone.
Example:
```
table_name = FullyQualifiedName("my_table", "my_schema", "my_db")
print(table_name) # my_db.my_schema.my_table
```
"""

def __init__(
self,
*,
table: str = "",
schema: str | None = None,
database: str | None = None,
delimiter: str = ".",
dialect: sa.engine.Dialect,
) -> None:
"""Initialize the fully qualified table name.
Args:
table: The name of the table.
schema: The name of the schema. Defaults to None.
database: The name of the database. Defaults to None.
delimiter: The delimiter to use between parts. Defaults to '.'.
dialect: The SQLAlchemy dialect to use for quoting.
Raises:
ValueError: If the fully qualified name could not be generated.
"""
self.table = table
self.schema = schema
self.database = database
self.delimiter = delimiter
self.dialect = dialect

parts = []
if self.database:
parts.append(self.prepare_part(self.database))
if self.schema:
parts.append(self.prepare_part(self.schema))
if self.table:
parts.append(self.prepare_part(self.table))

if not parts:
raise ValueError(
"Could not generate fully qualified name: "
+ ":".join(
[
self.database or "(unknown-db)",
self.schema or "(unknown-schema)",
self.table or "(unknown-table-name)",
],
),
)

super().__init__(self.delimiter.join(parts))

def prepare_part(self, part: str) -> str:
"""Prepare a part of the fully qualified name.
Args:
part: The part to prepare.
Returns:
The prepared part.
"""
return self.dialect.identifier_preparer.quote(part)


class SQLConnector: # noqa: PLR0904
"""Base class for SQLAlchemy-based connectors.
Expand Down Expand Up @@ -238,13 +322,13 @@ def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""
return th.to_sql_type(jsonschema_type)

@staticmethod
def get_fully_qualified_name(
self,
table_name: str | None = None,
schema_name: str | None = None,
db_name: str | None = None,
delimiter: str = ".",
) -> str:
) -> FullyQualifiedName:
"""Concatenates a fully qualified name from the parts.
Args:
Expand All @@ -253,34 +337,16 @@ def get_fully_qualified_name(
db_name: The name of the database. Defaults to None.
delimiter: Generally: '.' for SQL names and '-' for Singer names.
Raises:
ValueError: If all 3 name parts not supplied.
Returns:
The fully qualified name as a string.
"""
parts = []

if db_name:
parts.append(db_name)
if schema_name:
parts.append(schema_name)
if table_name:
parts.append(table_name)

if not parts:
raise ValueError(
"Could not generate fully qualified name: "
+ ":".join(
[
db_name or "(unknown-db)",
schema_name or "(unknown-schema)",
table_name or "(unknown-table-name)",
],
),
)

return delimiter.join(parts)
return FullyQualifiedName(
table=table_name, # type: ignore[arg-type]
schema=schema_name,
database=db_name,
delimiter=delimiter,
dialect=self._dialect,
)

@property
def _dialect(self) -> sa.engine.Dialect:
Expand Down Expand Up @@ -429,12 +495,7 @@ def discover_catalog_entry(
`CatalogEntry` object for the given table or a view
"""
# Initialize unique stream name
unique_stream_id = self.get_fully_qualified_name(
db_name=None,
schema_name=schema_name,
table_name=table_name,
delimiter="-",
)
unique_stream_id = f"{schema_name}-{table_name}"

# Detect key properties
possible_primary_keys: list[list[str]] = []
Expand Down Expand Up @@ -528,7 +589,7 @@ def discover_catalog_entries(self) -> list[dict]:

def parse_full_table_name( # noqa: PLR6301
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
) -> tuple[str | None, str | None, str]:
"""Parse a fully qualified table name into its parts.
Expand All @@ -547,6 +608,13 @@ def parse_full_table_name( # noqa: PLR6301
A three part tuple (db_name, schema_name, table_name) with any unspecified
or unused parts returned as None.
"""
if isinstance(full_table_name, FullyQualifiedName):
return (
full_table_name.database,
full_table_name.schema,
full_table_name.table,
)

db_name: str | None = None
schema_name: str | None = None

Expand All @@ -560,7 +628,7 @@ def parse_full_table_name( # noqa: PLR6301

return db_name, schema_name, table_name

def table_exists(self, full_table_name: str) -> bool:
def table_exists(self, full_table_name: str | FullyQualifiedName) -> bool:
"""Determine if the target table already exists.
Args:
Expand All @@ -587,7 +655,7 @@ def schema_exists(self, schema_name: str) -> bool:

def get_table_columns(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_names: list[str] | None = None,
) -> dict[str, sa.Column]:
"""Return a list of table columns.
Expand Down Expand Up @@ -618,7 +686,7 @@ def get_table_columns(

def get_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_names: list[str] | None = None,
) -> sa.Table:
"""Return a table object.
Expand All @@ -643,7 +711,9 @@ def get_table(
schema=schema_name,
)

def column_exists(self, full_table_name: str, column_name: str) -> bool:
def column_exists(
self, full_table_name: str | FullyQualifiedName, column_name: str
) -> bool:
"""Determine if the target table already exists.
Args:
Expand All @@ -666,7 +736,7 @@ def create_schema(self, schema_name: str) -> None:

def create_empty_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
primary_keys: t.Sequence[str] | None = None,
partition_keys: list[str] | None = None,
Expand Down Expand Up @@ -715,7 +785,7 @@ def create_empty_table(

def _create_empty_column(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand Down Expand Up @@ -753,7 +823,7 @@ def prepare_schema(self, schema_name: str) -> None:

def prepare_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
primary_keys: t.Sequence[str],
partition_keys: list[str] | None = None,
Expand Down Expand Up @@ -797,7 +867,7 @@ def prepare_table(

def prepare_column(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand All @@ -822,7 +892,9 @@ def prepare_column(
sql_type=sql_type,
)

def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> None:
def rename_column(
self, full_table_name: str | FullyQualifiedName, old_name: str, new_name: str
) -> None:
"""Rename the provided columns.
Args:
Expand Down Expand Up @@ -951,7 +1023,7 @@ def _get_type_sort_key(

def _get_column_type(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
) -> sa.types.TypeEngine:
"""Get the SQL type of the declared column.
Expand All @@ -976,7 +1048,7 @@ def _get_column_type(

def get_column_add_ddl(
self,
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand Down Expand Up @@ -1009,7 +1081,7 @@ def get_column_add_ddl(

@staticmethod
def get_column_rename_ddl(
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
new_column_name: str,
) -> sa.DDL:
Expand Down Expand Up @@ -1037,7 +1109,7 @@ def get_column_rename_ddl(

@staticmethod
def get_column_alter_ddl(
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand Down Expand Up @@ -1096,7 +1168,7 @@ def update_collation(

def _adapt_column_type(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand Down Expand Up @@ -1187,7 +1259,7 @@ def deserialize_json(self, json_str: str) -> object: # noqa: PLR6301
def delete_old_versions(
self,
*,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
version_column_name: str,
current_version: int,
) -> None:
Expand Down
9 changes: 5 additions & 4 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if t.TYPE_CHECKING:
from sqlalchemy.sql import Executable

from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.target_base import Target

_C = t.TypeVar("_C", bound=SQLConnector)
Expand Down Expand Up @@ -109,7 +110,7 @@ def database_name(self) -> str | None:
# Assumes single-DB target context.

@property
def full_table_name(self) -> str:
def full_table_name(self) -> FullyQualifiedName:
"""Return the fully qualified table name.
Returns:
Expand All @@ -122,7 +123,7 @@ def full_table_name(self) -> str:
)

@property
def full_schema_name(self) -> str:
def full_schema_name(self) -> FullyQualifiedName:
"""Return the fully qualified schema name.
Returns:
Expand Down Expand Up @@ -269,7 +270,7 @@ def process_batch(self, context: dict) -> None:

def generate_insert_statement(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
) -> str | Executable:
"""Generate an insert statement for the given records.
Expand Down Expand Up @@ -297,7 +298,7 @@ def generate_insert_statement(

def bulk_insert_records(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
records: t.Iterable[dict[str, t.Any]],
) -> int | None:
Expand Down
3 changes: 2 additions & 1 deletion singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from singer_sdk.streams.core import REPLICATION_INCREMENTAL, Stream

if t.TYPE_CHECKING:
from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.helpers.types import Context
from singer_sdk.tap_base import Tap

Expand Down Expand Up @@ -124,7 +125,7 @@ def primary_keys(self, new_value: t.Sequence[str]) -> None:
self._singer_catalog_entry.metadata.root.table_key_properties = new_value

@property
def fully_qualified_name(self) -> str:
def fully_qualified_name(self) -> FullyQualifiedName:
"""Generate the fully qualified version of the table name.
Raises:
Expand Down
Loading

0 comments on commit c3c2351

Please sign in to comment.