Skip to content

Commit

Permalink
Replace SQLAlchemyReader with dbapi readers
Browse files Browse the repository at this point in the history
After some discussion in #257, I decided to replace SQLAlchemy with `dbapi`
readers. The SQLAlchemy types weren't mapping as nicely as I'd hoped to Recap's.
Now, DB types are mapped straight to Recap types.

This PR adds a `PostgresqlReader` and `SnowflakeReader`. The PG reader has a
very basic test, but the Snowflake reader is untested. I plan to use `fakesnow`
for Snowflake in the future.
  • Loading branch information
criccomini committed Jun 14, 2023
1 parent 159e009 commit bbabb48
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 267 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ jobs:
test:
runs-on: ubuntu-latest

services:
postgres:
image: postgres:14
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: testdb
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
matrix:
python-version: ['3.10']
Expand Down
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies = [
"uvicorn[standard]>=0.20.0",
"httpx>=0.23.1",
"typer>=0.7.0",
"sqlalchemy>=1.4.45",
"rich>=12.6.0",
"setuptools>=65.6.3",
"starlette>=0.22.0",
Expand Down
107 changes: 107 additions & 0 deletions recap/readers/dbapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import re
from abc import ABC, abstractmethod
from typing import Any, List, Protocol, Tuple

from recap.types import NullType, RecapType, StructType, UnionType


class DbapiReader(ABC):
def __init__(self, connection: Connection) -> None:
self.connection = connection

def struct(self, table: str, schema: str, catalog: str) -> StructType:
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT
*
FROM information_schema.columns
WHERE table_name = {self.param_style}
AND table_schema = {self.param_style}
AND table_catalog = {self.param_style}
ORDER BY ordinal_position ASC
""",
(table, schema, catalog),
)

names = [name[0].upper() for name in cursor.description]
fields = []
for row in cursor.fetchall():
column_props = dict(zip(names, row))
base_type = self.get_recap_type(column_props)
is_nullable = column_props["IS_NULLABLE"].upper() == "YES"

if is_nullable:
base_type = UnionType([NullType(), base_type])

if column_props["COLUMN_DEFAULT"] is not None or is_nullable:
base_type.extra_attrs["default"] = column_props["COLUMN_DEFAULT"]

base_type.extra_attrs["name"] = column_props["COLUMN_NAME"]

fields.append(base_type)

return StructType(fields=fields)

@property
def param_style(cls) -> str:
return "%s"

@abstractmethod
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType:
...

def _get_time_unit(self, params: list[str] | None) -> str | None:
match params:
case [unit, _] if int(unit) == 0:
return "second"
case [unit, _] if int(unit) <= 3:
return "millisecond"
case [unit, _] if int(unit) <= 6:
return "microsecond"

def _parse_parameters(self, col_type: str) -> list[str] | None:
"""
Parse types that have parameters.
"""
match = re.search(r"\((.*?)\)", col_type)
if match:
return [p.strip() for p in match.group(1).split(",")]
return None


class Connection(Protocol):
def close(self) -> None:
...

def commit(self) -> None:
...

def rollback(self) -> None:
...

def cursor(self) -> Cursor:
...


class Cursor(Protocol):
def execute(self, query: str, parameters: Tuple = ()) -> None:
...

def executemany(self, query: str, parameter_list: List[Tuple]) -> None:
...

def fetchone(self) -> Tuple:
...

def fetchall(self) -> List[Tuple]:
...

def fetchmany(self, size: int) -> List[Tuple]:
...

@property
def description(self) -> List[Tuple]:
...
61 changes: 61 additions & 0 deletions recap/readers/postgresql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from math import ceil
from typing import Any

from recap.readers.dbapi import DbapiReader
from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType

MAX_FIELD_SIZE = 1073741824


class PostgresqlReader(DbapiReader):
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType:
data_type = column_props["DATA_TYPE"].lower()
octet_length = column_props["CHARACTER_OCTET_LENGTH"]
max_length = column_props["CHARACTER_MAXIMUM_LENGTH"]

if data_type in ["bigint", "int8", "bigserial", "serial8"]:
base_type = IntType(bits=64, signed=True)
elif data_type in ["integer", "int", "int4", "serial", "serial4"]:
base_type = IntType(bits=32, signed=True)
elif data_type in ["smallint", "smallserial", "serial2"]:
base_type = IntType(bits=16, signed=True)
elif data_type in ["double precision", "float8"]:
base_type = FloatType(bits=64)
elif data_type in ["real", "float4"]:
base_type = FloatType(bits=32)
elif data_type == "boolean":
base_type = BoolType()
elif (
data_type in ["text", "json", "jsonb"]
or data_type.startswith("character varying")
or data_type.startswith("varchar")
):
base_type = StringType(bytes_=octet_length, variable=True)
elif data_type.startswith("char"):
print(column_props)
base_type = StringType(bytes_=octet_length, variable=False)
elif data_type == "bytea" or data_type.startswith("bit varying"):
base_type = BytesType(bytes_=MAX_FIELD_SIZE, variable=True)
elif data_type.startswith("bit"):
byte_length = ceil(max_length / 8)
base_type = BytesType(bytes_=byte_length, variable=False)
elif data_type.startswith("timestamp"):
dt_precision = column_props["DATETIME_PRECISION"]
unit = self._get_time_unit([dt_precision]) or "microsecond"
base_type = IntType(
bits=64,
logical="build.recap.Timestamp",
unit=unit,
)
elif data_type.startswith("decimal") or data_type.startswith("numeric"):
base_type = BytesType(
logical="build.recap.Decimal",
bytes_=32,
variable=False,
precision=column_props["NUMERIC_PRECISION"],
scale=column_props["NUMERIC_SCALE"],
)
else:
raise ValueError(f"Unknown data type: {data_type}")

return base_type
80 changes: 80 additions & 0 deletions recap/readers/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any

from recap.readers.dbapi import DbapiReader
from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType


class SnowflakeReader(DbapiReader):
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType:
data_type = column_props["DATA_TYPE"].lower()
octet_length = column_props["CHARACTER_OCTET_LENGTH"]

if data_type in [
"float",
"float4",
"float8",
"double",
"double precision",
"real",
]:
base_type = FloatType(bits=64)
elif data_type == "boolean":
base_type = BoolType()
elif data_type in [
"number",
"decimal",
"numeric",
"int",
"integer",
"bigint",
"smallint",
"tinyint",
"byteint",
] or (
data_type.startswith("number")
or data_type.startswith("decimal")
or data_type.startswith("numeric")
):
base_type = BytesType(
logical="build.recap.Decimal",
bytes_=32,
variable=False,
precision=column_props["NUMERIC_PRECISION"] or 38,
scale=column_props["NUMERIC_SCALE"] or 0,
)
elif (
data_type.startswith("varchar")
or data_type.startswith("string")
or data_type.startswith("text")
or data_type.startswith("nvarchar")
or data_type.startswith("nvarchar2")
or data_type.startswith("char varying")
or data_type.startswith("nchar varying")
):
base_type = StringType(bytes_=octet_length, variable=True)
elif (
data_type.startswith("char")
or data_type.startswith("nchar")
or data_type.startswith("character")
):
base_type = StringType(bytes_=octet_length, variable=False)
elif data_type in ["binary", "varbinary"]:
base_type = BytesType(bytes_=octet_length)
elif data_type == "date":
base_type = IntType(bits=32, logical="build.recap.Date", unit="day")
elif data_type.startswith("timestamp") or data_type.startswith("datetime"):
params = self._parse_parameters(data_type)
unit = self._get_time_unit(params) or "nanosecond"
base_type = IntType(
bits=64,
logical="build.recap.Timestamp",
unit=unit,
)
elif data_type.startswith("time"):
params = self._parse_parameters(data_type)
unit = self._get_time_unit(params) or "nanosecond"
base_type = IntType(bits=32, logical="build.recap.Time", unit=unit)
else:
raise ValueError(f"Unknown data type: {data_type}")

return base_type
89 changes: 0 additions & 89 deletions recap/readers/sqlalchemy.py

This file was deleted.

Loading

0 comments on commit bbabb48

Please sign in to comment.