-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace SQLAlchemyReader with
dbapi
readers
After some discussion in #257, I decided to replace SQLAlchemy with `dbapi` readers. The SQLAlchemy types weren't mapping as nicely as I'd hoped to Recap's. Now, DB types are mapped straight to Recap types. This PR adds a `PostgresqlReader` and `SnowflakeReader`. The PG reader has a very basic test, but the Snowflake reader is untested. I plan to use `fakesnow` for Snowflake in the future.
- Loading branch information
1 parent
159e009
commit bbabb48
Showing
10 changed files
with
421 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Protocol, Tuple | ||
|
||
from recap.types import NullType, RecapType, StructType, UnionType | ||
|
||
|
||
class DbapiReader(ABC): | ||
def __init__(self, connection: Connection) -> None: | ||
self.connection = connection | ||
|
||
def struct(self, table: str, schema: str, catalog: str) -> StructType: | ||
cursor = self.connection.cursor() | ||
cursor.execute( | ||
f""" | ||
SELECT | ||
* | ||
FROM information_schema.columns | ||
WHERE table_name = {self.param_style} | ||
AND table_schema = {self.param_style} | ||
AND table_catalog = {self.param_style} | ||
ORDER BY ordinal_position ASC | ||
""", | ||
(table, schema, catalog), | ||
) | ||
|
||
names = [name[0].upper() for name in cursor.description] | ||
fields = [] | ||
for row in cursor.fetchall(): | ||
column_props = dict(zip(names, row)) | ||
base_type = self.get_recap_type(column_props) | ||
is_nullable = column_props["IS_NULLABLE"].upper() == "YES" | ||
|
||
if is_nullable: | ||
base_type = UnionType([NullType(), base_type]) | ||
|
||
if column_props["COLUMN_DEFAULT"] is not None or is_nullable: | ||
base_type.extra_attrs["default"] = column_props["COLUMN_DEFAULT"] | ||
|
||
base_type.extra_attrs["name"] = column_props["COLUMN_NAME"] | ||
|
||
fields.append(base_type) | ||
|
||
return StructType(fields=fields) | ||
|
||
@property | ||
def param_style(cls) -> str: | ||
return "%s" | ||
|
||
@abstractmethod | ||
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: | ||
... | ||
|
||
def _get_time_unit(self, params: list[str] | None) -> str | None: | ||
match params: | ||
case [unit, _] if int(unit) == 0: | ||
return "second" | ||
case [unit, _] if int(unit) <= 3: | ||
return "millisecond" | ||
case [unit, _] if int(unit) <= 6: | ||
return "microsecond" | ||
|
||
def _parse_parameters(self, col_type: str) -> list[str] | None: | ||
""" | ||
Parse types that have parameters. | ||
""" | ||
match = re.search(r"\((.*?)\)", col_type) | ||
if match: | ||
return [p.strip() for p in match.group(1).split(",")] | ||
return None | ||
|
||
|
||
class Connection(Protocol): | ||
def close(self) -> None: | ||
... | ||
|
||
def commit(self) -> None: | ||
... | ||
|
||
def rollback(self) -> None: | ||
... | ||
|
||
def cursor(self) -> Cursor: | ||
... | ||
|
||
|
||
class Cursor(Protocol): | ||
def execute(self, query: str, parameters: Tuple = ()) -> None: | ||
... | ||
|
||
def executemany(self, query: str, parameter_list: List[Tuple]) -> None: | ||
... | ||
|
||
def fetchone(self) -> Tuple: | ||
... | ||
|
||
def fetchall(self) -> List[Tuple]: | ||
... | ||
|
||
def fetchmany(self, size: int) -> List[Tuple]: | ||
... | ||
|
||
@property | ||
def description(self) -> List[Tuple]: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from math import ceil | ||
from typing import Any | ||
|
||
from recap.readers.dbapi import DbapiReader | ||
from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType | ||
|
||
MAX_FIELD_SIZE = 1073741824 | ||
|
||
|
||
class PostgresqlReader(DbapiReader): | ||
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: | ||
data_type = column_props["DATA_TYPE"].lower() | ||
octet_length = column_props["CHARACTER_OCTET_LENGTH"] | ||
max_length = column_props["CHARACTER_MAXIMUM_LENGTH"] | ||
|
||
if data_type in ["bigint", "int8", "bigserial", "serial8"]: | ||
base_type = IntType(bits=64, signed=True) | ||
elif data_type in ["integer", "int", "int4", "serial", "serial4"]: | ||
base_type = IntType(bits=32, signed=True) | ||
elif data_type in ["smallint", "smallserial", "serial2"]: | ||
base_type = IntType(bits=16, signed=True) | ||
elif data_type in ["double precision", "float8"]: | ||
base_type = FloatType(bits=64) | ||
elif data_type in ["real", "float4"]: | ||
base_type = FloatType(bits=32) | ||
elif data_type == "boolean": | ||
base_type = BoolType() | ||
elif ( | ||
data_type in ["text", "json", "jsonb"] | ||
or data_type.startswith("character varying") | ||
or data_type.startswith("varchar") | ||
): | ||
base_type = StringType(bytes_=octet_length, variable=True) | ||
elif data_type.startswith("char"): | ||
print(column_props) | ||
base_type = StringType(bytes_=octet_length, variable=False) | ||
elif data_type == "bytea" or data_type.startswith("bit varying"): | ||
base_type = BytesType(bytes_=MAX_FIELD_SIZE, variable=True) | ||
elif data_type.startswith("bit"): | ||
byte_length = ceil(max_length / 8) | ||
base_type = BytesType(bytes_=byte_length, variable=False) | ||
elif data_type.startswith("timestamp"): | ||
dt_precision = column_props["DATETIME_PRECISION"] | ||
unit = self._get_time_unit([dt_precision]) or "microsecond" | ||
base_type = IntType( | ||
bits=64, | ||
logical="build.recap.Timestamp", | ||
unit=unit, | ||
) | ||
elif data_type.startswith("decimal") or data_type.startswith("numeric"): | ||
base_type = BytesType( | ||
logical="build.recap.Decimal", | ||
bytes_=32, | ||
variable=False, | ||
precision=column_props["NUMERIC_PRECISION"], | ||
scale=column_props["NUMERIC_SCALE"], | ||
) | ||
else: | ||
raise ValueError(f"Unknown data type: {data_type}") | ||
|
||
return base_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Any | ||
|
||
from recap.readers.dbapi import DbapiReader | ||
from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType | ||
|
||
|
||
class SnowflakeReader(DbapiReader): | ||
def get_recap_type(self, column_props: dict[str, Any]) -> RecapType: | ||
data_type = column_props["DATA_TYPE"].lower() | ||
octet_length = column_props["CHARACTER_OCTET_LENGTH"] | ||
|
||
if data_type in [ | ||
"float", | ||
"float4", | ||
"float8", | ||
"double", | ||
"double precision", | ||
"real", | ||
]: | ||
base_type = FloatType(bits=64) | ||
elif data_type == "boolean": | ||
base_type = BoolType() | ||
elif data_type in [ | ||
"number", | ||
"decimal", | ||
"numeric", | ||
"int", | ||
"integer", | ||
"bigint", | ||
"smallint", | ||
"tinyint", | ||
"byteint", | ||
] or ( | ||
data_type.startswith("number") | ||
or data_type.startswith("decimal") | ||
or data_type.startswith("numeric") | ||
): | ||
base_type = BytesType( | ||
logical="build.recap.Decimal", | ||
bytes_=32, | ||
variable=False, | ||
precision=column_props["NUMERIC_PRECISION"] or 38, | ||
scale=column_props["NUMERIC_SCALE"] or 0, | ||
) | ||
elif ( | ||
data_type.startswith("varchar") | ||
or data_type.startswith("string") | ||
or data_type.startswith("text") | ||
or data_type.startswith("nvarchar") | ||
or data_type.startswith("nvarchar2") | ||
or data_type.startswith("char varying") | ||
or data_type.startswith("nchar varying") | ||
): | ||
base_type = StringType(bytes_=octet_length, variable=True) | ||
elif ( | ||
data_type.startswith("char") | ||
or data_type.startswith("nchar") | ||
or data_type.startswith("character") | ||
): | ||
base_type = StringType(bytes_=octet_length, variable=False) | ||
elif data_type in ["binary", "varbinary"]: | ||
base_type = BytesType(bytes_=octet_length) | ||
elif data_type == "date": | ||
base_type = IntType(bits=32, logical="build.recap.Date", unit="day") | ||
elif data_type.startswith("timestamp") or data_type.startswith("datetime"): | ||
params = self._parse_parameters(data_type) | ||
unit = self._get_time_unit(params) or "nanosecond" | ||
base_type = IntType( | ||
bits=64, | ||
logical="build.recap.Timestamp", | ||
unit=unit, | ||
) | ||
elif data_type.startswith("time"): | ||
params = self._parse_parameters(data_type) | ||
unit = self._get_time_unit(params) or "nanosecond" | ||
base_type = IntType(bits=32, logical="build.recap.Time", unit=unit) | ||
else: | ||
raise ValueError(f"Unknown data type: {data_type}") | ||
|
||
return base_type |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.