From bbabb4872caa9e48a718d6375deb4d2e1d125d53 Mon Sep 17 00:00:00 2001 From: Chris Riccomini Date: Tue, 13 Jun 2023 13:32:38 -0700 Subject: [PATCH] Replace SQLAlchemyReader with `dbapi` readers After some discussion in #257, I decided to replace SQLAlchemy with `dbapi` readers. The SQLAlchemy types weren't mapping as nicely as I'd hoped to Recap's. Now, DB types are mapped straight to Recap types. This PR adds a `PostgresqlReader` and `SnowflakeReader`. The PG reader has a very basic test, but the Snowflake reader is untested. I plan to use `fakesnow` for Snowflake in the future. --- .github/workflows/ci.yaml | 15 +++ pdm.lock | 2 +- pyproject.toml | 1 - recap/readers/dbapi.py | 107 +++++++++++++++++++ recap/readers/postgresql.py | 61 +++++++++++ recap/readers/snowflake.py | 80 +++++++++++++++ recap/readers/sqlalchemy.py | 89 ---------------- recap/types.py | 11 +- tests/readers/test_postgresql.py | 152 +++++++++++++++++++++++++++ tests/readers/test_sqlalchemy.py | 170 ------------------------------- 10 files changed, 421 insertions(+), 267 deletions(-) create mode 100644 recap/readers/dbapi.py create mode 100644 recap/readers/postgresql.py create mode 100644 recap/readers/snowflake.py delete mode 100644 recap/readers/sqlalchemy.py create mode 100644 tests/readers/test_postgresql.py delete mode 100644 tests/readers/test_sqlalchemy.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e4eb20fd..9bbcb80d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -11,6 +11,21 @@ jobs: test: runs-on: ubuntu-latest + services: + postgres: + image: postgres:14 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: matrix: python-version: ['3.10'] diff --git a/pdm.lock b/pdm.lock index 157960e4..03882f41 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1146,7 +1146,7 @@ dependencies = [ lock_version = "4.2" cross_platform = true groups = ["default", "dbs", "docs", "fss", "gcp", "kafka", "style", "tests"] -content_hash = "sha256:77822620ea9fca5f566b074de05030fc4de7c75a3baf8dba6dda581ef4b1577a" +content_hash = "sha256:45c38d1aa35fca1edaca58d7f6628097c24ccd5cd28b19815bce4c5849b9d28f" [metadata.files] "aiobotocore 2.5.0" = [ diff --git a/pyproject.toml b/pyproject.toml index f4776063..c5ca51d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ dependencies = [ "uvicorn[standard]>=0.20.0", "httpx>=0.23.1", "typer>=0.7.0", - "sqlalchemy>=1.4.45", "rich>=12.6.0", "setuptools>=65.6.3", "starlette>=0.22.0", diff --git a/recap/readers/dbapi.py b/recap/readers/dbapi.py new file mode 100644 index 00000000..269953c4 --- /dev/null +++ b/recap/readers/dbapi.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from typing import Any, List, Protocol, Tuple + +from recap.types import NullType, RecapType, StructType, UnionType + + +class DbapiReader(ABC): + def __init__(self, connection: Connection) -> None: + self.connection = connection + + def struct(self, table: str, schema: str, catalog: str) -> StructType: + cursor = self.connection.cursor() + cursor.execute( + f""" + SELECT + * + FROM information_schema.columns + WHERE table_name = {self.param_style} + AND table_schema = {self.param_style} + AND table_catalog = {self.param_style} + ORDER BY ordinal_position ASC + """, + (table, schema, catalog), + ) + + names = [name[0].upper() for name in cursor.description] + fields = [] + for row in cursor.fetchall(): + column_props = dict(zip(names, row)) + base_type = self.get_recap_type(column_props) + is_nullable = column_props["IS_NULLABLE"].upper() == "YES" + + if is_nullable: + base_type = UnionType([NullType(), base_type]) + + if column_props["COLUMN_DEFAULT"] is not None or is_nullable: + base_type.extra_attrs["default"] = column_props["COLUMN_DEFAULT"] + + base_type.extra_attrs["name"] = column_props["COLUMN_NAME"] + + fields.append(base_type) + + return StructType(fields=fields) + + @property + def param_style(cls) -> str: + return "%s" + + @abstractmethod + def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: + ... + + def _get_time_unit(self, params: list[str] | None) -> str | None: + match params: + case [unit, _] if int(unit) == 0: + return "second" + case [unit, _] if int(unit) <= 3: + return "millisecond" + case [unit, _] if int(unit) <= 6: + return "microsecond" + + def _parse_parameters(self, col_type: str) -> list[str] | None: + """ + Parse types that have parameters. + """ + match = re.search(r"\((.*?)\)", col_type) + if match: + return [p.strip() for p in match.group(1).split(",")] + return None + + +class Connection(Protocol): + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self) -> None: + ... + + def cursor(self) -> Cursor: + ... + + +class Cursor(Protocol): + def execute(self, query: str, parameters: Tuple = ()) -> None: + ... + + def executemany(self, query: str, parameter_list: List[Tuple]) -> None: + ... + + def fetchone(self) -> Tuple: + ... + + def fetchall(self) -> List[Tuple]: + ... + + def fetchmany(self, size: int) -> List[Tuple]: + ... + + @property + def description(self) -> List[Tuple]: + ... diff --git a/recap/readers/postgresql.py b/recap/readers/postgresql.py new file mode 100644 index 00000000..d9c28859 --- /dev/null +++ b/recap/readers/postgresql.py @@ -0,0 +1,61 @@ +from math import ceil +from typing import Any + +from recap.readers.dbapi import DbapiReader +from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType + +MAX_FIELD_SIZE = 1073741824 + + +class PostgresqlReader(DbapiReader): + def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: + data_type = column_props["DATA_TYPE"].lower() + octet_length = column_props["CHARACTER_OCTET_LENGTH"] + max_length = column_props["CHARACTER_MAXIMUM_LENGTH"] + + if data_type in ["bigint", "int8", "bigserial", "serial8"]: + base_type = IntType(bits=64, signed=True) + elif data_type in ["integer", "int", "int4", "serial", "serial4"]: + base_type = IntType(bits=32, signed=True) + elif data_type in ["smallint", "smallserial", "serial2"]: + base_type = IntType(bits=16, signed=True) + elif data_type in ["double precision", "float8"]: + base_type = FloatType(bits=64) + elif data_type in ["real", "float4"]: + base_type = FloatType(bits=32) + elif data_type == "boolean": + base_type = BoolType() + elif ( + data_type in ["text", "json", "jsonb"] + or data_type.startswith("character varying") + or data_type.startswith("varchar") + ): + base_type = StringType(bytes_=octet_length, variable=True) + elif data_type.startswith("char"): + print(column_props) + base_type = StringType(bytes_=octet_length, variable=False) + elif data_type == "bytea" or data_type.startswith("bit varying"): + base_type = BytesType(bytes_=MAX_FIELD_SIZE, variable=True) + elif data_type.startswith("bit"): + byte_length = ceil(max_length / 8) + base_type = BytesType(bytes_=byte_length, variable=False) + elif data_type.startswith("timestamp"): + dt_precision = column_props["DATETIME_PRECISION"] + unit = self._get_time_unit([dt_precision]) or "microsecond" + base_type = IntType( + bits=64, + logical="build.recap.Timestamp", + unit=unit, + ) + elif data_type.startswith("decimal") or data_type.startswith("numeric"): + base_type = BytesType( + logical="build.recap.Decimal", + bytes_=32, + variable=False, + precision=column_props["NUMERIC_PRECISION"], + scale=column_props["NUMERIC_SCALE"], + ) + else: + raise ValueError(f"Unknown data type: {data_type}") + + return base_type diff --git a/recap/readers/snowflake.py b/recap/readers/snowflake.py new file mode 100644 index 00000000..932d3e1e --- /dev/null +++ b/recap/readers/snowflake.py @@ -0,0 +1,80 @@ +from typing import Any + +from recap.readers.dbapi import DbapiReader +from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType + + +class SnowflakeReader(DbapiReader): + def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: + data_type = column_props["DATA_TYPE"].lower() + octet_length = column_props["CHARACTER_OCTET_LENGTH"] + + if data_type in [ + "float", + "float4", + "float8", + "double", + "double precision", + "real", + ]: + base_type = FloatType(bits=64) + elif data_type == "boolean": + base_type = BoolType() + elif data_type in [ + "number", + "decimal", + "numeric", + "int", + "integer", + "bigint", + "smallint", + "tinyint", + "byteint", + ] or ( + data_type.startswith("number") + or data_type.startswith("decimal") + or data_type.startswith("numeric") + ): + base_type = BytesType( + logical="build.recap.Decimal", + bytes_=32, + variable=False, + precision=column_props["NUMERIC_PRECISION"] or 38, + scale=column_props["NUMERIC_SCALE"] or 0, + ) + elif ( + data_type.startswith("varchar") + or data_type.startswith("string") + or data_type.startswith("text") + or data_type.startswith("nvarchar") + or data_type.startswith("nvarchar2") + or data_type.startswith("char varying") + or data_type.startswith("nchar varying") + ): + base_type = StringType(bytes_=octet_length, variable=True) + elif ( + data_type.startswith("char") + or data_type.startswith("nchar") + or data_type.startswith("character") + ): + base_type = StringType(bytes_=octet_length, variable=False) + elif data_type in ["binary", "varbinary"]: + base_type = BytesType(bytes_=octet_length) + elif data_type == "date": + base_type = IntType(bits=32, logical="build.recap.Date", unit="day") + elif data_type.startswith("timestamp") or data_type.startswith("datetime"): + params = self._parse_parameters(data_type) + unit = self._get_time_unit(params) or "nanosecond" + base_type = IntType( + bits=64, + logical="build.recap.Timestamp", + unit=unit, + ) + elif data_type.startswith("time"): + params = self._parse_parameters(data_type) + unit = self._get_time_unit(params) or "nanosecond" + base_type = IntType(bits=32, logical="build.recap.Time", unit=unit) + else: + raise ValueError(f"Unknown data type: {data_type}") + + return base_type diff --git a/recap/readers/sqlalchemy.py b/recap/readers/sqlalchemy.py deleted file mode 100644 index a9ce2b13..00000000 --- a/recap/readers/sqlalchemy.py +++ /dev/null @@ -1,89 +0,0 @@ -from sqlalchemy import create_engine, engine, inspect, types - -from recap.types import BoolType, BytesType, FloatType, IntType, StringType, StructType - - -class SqlAlchemyReader: - def __init__(self, connection: str | engine.Engine): - self.engine = ( - connection - if isinstance(connection, engine.Engine) - else create_engine(connection) - ) - self.inspector = inspect(self.engine) - - def struct(self, table: str) -> StructType: - columns = self.inspector.get_columns(table) - fields = [] - for column in columns: - field = None - match column["type"]: - case types.INTEGER(): - field = IntType(bits=32, signed=True, name=column["name"]) - case types.BIGINT(): - field = IntType(bits=64, signed=True, name=column["name"]) - case types.SMALLINT(): - field = IntType(bits=16, signed=True, name=column["name"]) - case types.FLOAT(): - field = FloatType(bits=64, name=column["name"]) - case types.REAL(): - field = FloatType(bits=32, name=column["name"]) - case types.BOOLEAN(): - field = BoolType(name=column["name"]) - case types.VARCHAR() | types.TEXT() | types.NVARCHAR() | types.CLOB(): - field = StringType( - bytes_=column["type"].length, name=column["name"] - ) - case types.CHAR() | types.NCHAR(): - field = StringType( - bytes_=column["type"].length, - variable=False, - name=column["name"], - ) - case types.DATE(): - field = IntType( - logical="build.recap.Date", - bits=32, - signed=True, - unit="day", - name=column["name"], - ) - case types.TIME(): - field = IntType( - logical="build.recap.Time", - bits=32, - signed=True, - unit="microsecond", - name=column["name"], - ) - case types.DATETIME() | types.TIMESTAMP(): - field = IntType( - logical="build.recap.Timestamp", - bits=64, - signed=True, - unit="microsecond", - timezone="UTC", - name=column["name"], - ) - case types.BINARY() | types.VARBINARY() | types.BLOB(): - is_variable = not isinstance(column["type"], types.BINARY) - field = BytesType( - bytes_=column["type"].length, - variable=is_variable, - name=column["name"], - ) - case types.DECIMAL() | types.NUMERIC(): - field = BytesType( - logical="build.recap.Decimal", - bytes_=16, - variable=False, - precision=column["type"].precision, - scale=column["type"].scale, - name=column["name"], - ) - case _: - raise TypeError(f"Unsupported type {column['type']}") - - fields.append(field) - - return StructType(fields=fields, name=table) diff --git a/recap/types.py b/recap/types.py index 060dfdc2..2d5401a2 100644 --- a/recap/types.py +++ b/recap/types.py @@ -46,6 +46,11 @@ def __eq__(self, other): ) == (other.type_, other.logical, other.alias, other.doc, other.extra_attrs) return False + def __repr__(self): + attrs = vars(self) + attrs_str = ", ".join(f"{k}={v}" for k, v in attrs.items()) + return f"{self.__class__.__name__}({attrs_str})" + class NullType(RecapType): """Represents a null Recap type.""" @@ -247,7 +252,6 @@ def __eq__(self, other): "duration64": IntType( logical="build.recap.Duration", bits=64, - signed=True, unit="millisecond", ), "interval128": BytesType( @@ -259,32 +263,27 @@ def __eq__(self, other): "time32": IntType( logical="build.recap.Time", bits=32, - signed=True, unit="second", ), "time64": IntType( logical="build.recap.Time", bits=64, - signed=True, unit="second", ), "timestamp64": IntType( logical="build.recap.Timestamp", bits=64, - signed=True, unit="millisecond", timezone="UTC", ), "date32": IntType( logical="build.recap.Date", bits=32, - signed=True, unit="day", ), "date64": IntType( logical="build.recap.Date", bits=64, - signed=True, unit="day", ), } diff --git a/tests/readers/test_postgresql.py b/tests/readers/test_postgresql.py new file mode 100644 index 00000000..ac0156ab --- /dev/null +++ b/tests/readers/test_postgresql.py @@ -0,0 +1,152 @@ +import os + +import psycopg2 +import pytest + +from recap.readers.postgresql import MAX_FIELD_SIZE, PostgresqlReader +from recap.types import ( + BoolType, + BytesType, + FloatType, + IntType, + NullType, + StringType, + StructType, + UnionType, +) + + +@pytest.mark.skipif( + "CI" not in os.environ, reason="Skipping PostgreSQL tests outside CI" +) +class TestPostgresqlReader: + @classmethod + def setup_class(cls): + # Connect to the PostgreSQL database + cls.connection = psycopg2.connect( + host="localhost", + port="5432", + user="postgres", + password="password", + dbname="testdb", + ) + + # Create tables + cursor = cls.connection.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS test_types ( + test_bigint BIGINT, + test_integer INTEGER, + test_smallint SMALLINT, + test_float DOUBLE PRECISION, + test_real REAL, + test_boolean BOOLEAN, + test_text TEXT, + test_char CHAR(10), + test_bytea BYTEA, + test_bit BIT(10), + test_timestamp TIMESTAMP, + test_decimal DECIMAL(10,2) + ); + """ + ) + cls.connection.commit() + + @classmethod + def teardown_class(cls): + # Delete tables + cursor = cls.connection.cursor() + cursor.execute("DROP TABLE IF EXISTS test_types;") + cls.connection.commit() + + # Close the connection + cls.connection.close() + + def test_struct_method(self): + # Initiate the PostgresqlReader class + reader = PostgresqlReader(self.connection) # type: ignore + + # Test 'test_types' table + test_types_struct = reader.struct("test_types", "public", "testdb") + + # Define the expected output for 'test_types' table + expected_fields = [ + UnionType( + default=None, + name="test_bigint", + types=[NullType(), IntType(bits=64, signed=True)], + ), + UnionType( + default=None, + name="test_integer", + types=[NullType(), IntType(bits=32, signed=True)], + ), + UnionType( + default=None, + name="test_smallint", + types=[NullType(), IntType(bits=16, signed=True)], + ), + UnionType( + default=None, + name="test_float", + types=[NullType(), FloatType(bits=64)], + ), + UnionType( + default=None, + name="test_real", + types=[NullType(), FloatType(bits=32)], + ), + UnionType( + default=None, + name="test_boolean", + types=[NullType(), BoolType()], + ), + UnionType( + default=None, + name="test_text", + types=[NullType(), StringType(bytes_=MAX_FIELD_SIZE, variable=True)], + ), + UnionType( + default=None, + name="test_char", + # 40 = max of 4 bytes in a UTF-8 encoded unicode character * 10 chars + types=[NullType(), StringType(bytes_=40, variable=False)], + ), + UnionType( + default=None, + name="test_bytea", + types=[NullType(), BytesType(bytes_=MAX_FIELD_SIZE, variable=True)], + ), + UnionType( + default=None, + name="test_bit", + types=[NullType(), BytesType(bytes_=2, variable=False)], + ), + UnionType( + default=None, + name="test_timestamp", + types=[ + NullType(), + IntType(bits=64, logical="build.recap.Timestamp", unit="microsecond"), + ], + ), + UnionType( + default=None, + name="test_decimal", + types=[ + NullType(), + BytesType( + logical="build.recap.Decimal", + bytes_=32, + variable=False, + precision=10, + scale=2, + ), + ], + ), + ] + + print(test_types_struct) + print(StructType(fields=expected_fields)) # type: ignore + assert test_types_struct == StructType(fields=expected_fields) # type: ignore diff --git a/tests/readers/test_sqlalchemy.py b/tests/readers/test_sqlalchemy.py deleted file mode 100644 index 493e6c5b..00000000 --- a/tests/readers/test_sqlalchemy.py +++ /dev/null @@ -1,170 +0,0 @@ -# pylint: disable=missing-docstring - -import pytest -from sqlalchemy import ( - CHAR, - DECIMAL, - NCHAR, - NVARCHAR, - TIMESTAMP, - VARCHAR, - BigInteger, - Boolean, - Column, - Date, - DateTime, - Float, - Integer, - LargeBinary, - MetaData, - SmallInteger, - Table, - Text, - Time, - create_engine, -) - -from recap.readers.sqlalchemy import SqlAlchemyReader -from recap.types import BoolType, BytesType, FloatType, IntType, StringType - -metadata = MetaData() - -test_table = Table( - "test_table", - metadata, - Column("column_int", Integer), - Column("column_bigint", BigInteger), - Column("column_smallint", SmallInteger), - Column("column_float", Float), - Column("column_real", Float), - Column("column_boolean", Boolean), - Column("column_varchar", VARCHAR), - Column("column_text", Text), - Column("column_nvarchar", NVARCHAR), - Column("column_char", CHAR), - Column("column_nchar", NCHAR), - Column("column_date", Date), - Column("column_time", Time), - Column("column_datetime", DateTime), - Column("column_timestamp", TIMESTAMP), - Column("column_binary", LargeBinary), - Column("column_decimal", DECIMAL), -) - - -@pytest.fixture(scope="module") -def engine(): - engine = create_engine("sqlite:///:memory:") - metadata.create_all(engine) - return engine - - -def test_sqlalchemy_reader(engine): - reader = SqlAlchemyReader(engine) - struct = reader.struct("test_table") - - # Now we validate the returned StructType - assert struct.extra_attrs["name"] == "test_table" - assert len(struct.fields) == 17 - - int_type_field = struct.fields[0] - assert isinstance(int_type_field, IntType) - assert int_type_field.extra_attrs["name"] == "column_int" - assert int_type_field.bits == 32 - assert int_type_field.signed is True - - bigint_type_field = struct.fields[1] - assert isinstance(bigint_type_field, IntType) - assert bigint_type_field.extra_attrs["name"] == "column_bigint" - assert bigint_type_field.bits == 64 - assert bigint_type_field.signed is True - - smallint_type_field = struct.fields[2] - assert isinstance(smallint_type_field, IntType) - assert smallint_type_field.extra_attrs["name"] == "column_smallint" - assert smallint_type_field.bits == 16 - assert smallint_type_field.signed is True - - float_type_field = struct.fields[3] - assert isinstance(float_type_field, FloatType) - assert float_type_field.extra_attrs["name"] == "column_float" - assert float_type_field.bits == 64 - - # SQLite doesn't have a REAL type, so it's mapped to FLOAT - float_type_field = struct.fields[4] - assert isinstance(float_type_field, FloatType) - assert float_type_field.extra_attrs["name"] == "column_real" - assert float_type_field.bits == 64 - - float_type_field = struct.fields[5] - assert isinstance(float_type_field, BoolType) - assert float_type_field.extra_attrs["name"] == "column_boolean" - - varchar_type_field = struct.fields[6] - assert isinstance(varchar_type_field, StringType) - assert varchar_type_field.extra_attrs["name"] == "column_varchar" - assert varchar_type_field.bytes_ is None - - text_type_field = struct.fields[7] - assert isinstance(text_type_field, StringType) - assert text_type_field.extra_attrs["name"] == "column_text" - assert text_type_field.bytes_ is None - - nvarchar_type_field = struct.fields[8] - assert isinstance(nvarchar_type_field, StringType) - assert nvarchar_type_field.extra_attrs["name"] == "column_nvarchar" - assert nvarchar_type_field.bytes_ is None - - char_type_field = struct.fields[9] - assert isinstance(char_type_field, StringType) - assert char_type_field.extra_attrs["name"] == "column_char" - assert char_type_field.bytes_ is None - - nchar_type_field = struct.fields[10] - assert isinstance(nchar_type_field, StringType) - assert nchar_type_field.extra_attrs["name"] == "column_nchar" - assert nchar_type_field.bytes_ is None - - date_type_field = struct.fields[11] - assert isinstance(date_type_field, IntType) - assert date_type_field.extra_attrs["name"] == "column_date" - assert date_type_field.logical == "build.recap.Date" - assert date_type_field.bits == 32 - assert date_type_field.signed is True - assert date_type_field.extra_attrs["unit"] == "day" - - time_type_field = struct.fields[12] - assert isinstance(time_type_field, IntType) - assert time_type_field.extra_attrs["name"] == "column_time" - assert time_type_field.logical == "build.recap.Time" - assert time_type_field.bits == 32 - assert time_type_field.signed is True - assert time_type_field.extra_attrs["unit"] == "microsecond" - - datetime_type_field = struct.fields[13] - assert isinstance(datetime_type_field, IntType) - assert datetime_type_field.extra_attrs["name"] == "column_datetime" - assert datetime_type_field.logical == "build.recap.Timestamp" - assert datetime_type_field.bits == 64 - assert datetime_type_field.signed is True - assert datetime_type_field.extra_attrs["unit"] == "microsecond" - assert datetime_type_field.extra_attrs["timezone"] == "UTC" - - timestamp_type_field = struct.fields[14] - assert isinstance(timestamp_type_field, IntType) - assert timestamp_type_field.extra_attrs["name"] == "column_timestamp" - assert timestamp_type_field.logical == "build.recap.Timestamp" - assert timestamp_type_field.bits == 64 - assert timestamp_type_field.signed is True - assert timestamp_type_field.extra_attrs["unit"] == "microsecond" - assert timestamp_type_field.extra_attrs["timezone"] == "UTC" - - binary_type_field = struct.fields[15] - assert isinstance(binary_type_field, BytesType) - assert binary_type_field.extra_attrs["name"] == "column_binary" - assert binary_type_field.bytes_ is None - - decimal_type_field = struct.fields[16] - assert isinstance(decimal_type_field, BytesType) - assert decimal_type_field.extra_attrs["name"] == "column_decimal" - assert decimal_type_field.bytes_ == 16