Skip to content

Commit

Permalink
feat(ux): include basename of path in generated table names in read_*()
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Nov 22, 2024
1 parent b32a6f3 commit 3919f2f
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 59 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def read_parquet(
paths = list(glob.glob(str(path)))
schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="parquet").schema)

name = table_name or util.gen_name("read_parquet")
name = table_name or util.gen_name_from_path(paths[0], "parquet")
table = self.create_table(name, engine=engine, schema=schema, temp=True)

for file_path in paths:
Expand Down Expand Up @@ -609,7 +609,7 @@ def read_csv(
paths = list(glob.glob(str(path)))
schema = PyArrowSchema.to_ibis(ds.dataset(paths, format="csv").schema)

name = table_name or util.gen_name("read_csv")
name = table_name or util.gen_name_from_path(paths[0], "csv")
table = self.create_table(name, engine=engine, schema=schema, temp=True)

for file_path in paths:
Expand Down
30 changes: 15 additions & 15 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames
from ibis.util import deprecated, normalize_filename, normalize_filenames

try:
from datafusion import ExecutionContext as SessionContext
Expand Down Expand Up @@ -160,7 +160,7 @@ def _safe_raw_sql(self, sql: sge.Statement) -> Any:
yield self.raw_sql(sql).collect()

def _get_schema_using_query(self, query: str) -> sch.Schema:
name = gen_name("datafusion_metadata_view")
name = util.gen_name("datafusion_metadata_view")
table = sg.table(name, quoted=self.compiler.quoted)
src = sge.Create(
this=table,
Expand Down Expand Up @@ -437,11 +437,11 @@ def read_csv(
The just-registered table
"""
path = normalize_filenames(source_list)
table_name = table_name or gen_name("read_csv")
paths = normalize_filenames(source_list)
table_name = table_name or util.gen_name_from_path(paths[0], "csv")
# Our other backends support overwriting views / tables when re-registering
self.con.deregister_table(table_name)
self.con.register_csv(table_name, path, **kwargs)
self.con.register_csv(table_name, paths, **kwargs)
return self.table(table_name)

def read_parquet(
Expand All @@ -466,7 +466,7 @@ def read_parquet(
"""
path = normalize_filename(path)
table_name = table_name or gen_name("read_parquet")
table_name = table_name or util.gen_name_from_path(path, "parquet")
# Our other backends support overwriting views / tables when reregistering
self.con.deregister_table(table_name)
self.con.register_parquet(table_name, path, **kwargs)
Expand Down Expand Up @@ -496,7 +496,7 @@ def read_delta(
"""
source_table = normalize_filename(source_table)

table_name = table_name or gen_name("read_delta")
table_name = table_name or util.gen_name_from_path(source_table, "delta")

# Our other backends support overwriting views / tables when reregistering
self.con.deregister_table(table_name)
Expand Down Expand Up @@ -730,55 +730,55 @@ def _read_in_memory(

@_read_in_memory.register(dict)
def _pydict(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pydict")
tmp_name = util.gen_name("pydict")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pydict(source, name=tmp_name)


@_read_in_memory.register("polars.DataFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
tmp_name = util.gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source, name=tmp_name)


@_read_in_memory.register("polars.LazyFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
tmp_name = util.gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source.collect(), name=tmp_name)


@_read_in_memory.register("pyarrow.Table")
def _pyarrow_table(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
tmp_name = util.gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow(source, name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatchReader")
def _pyarrow_rbr(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
tmp_name = util.gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow(source.read_all(), name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatch")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
tmp_name = util.gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_record_batches(tmp_name, [[source]])


@_read_in_memory.register("pyarrow.dataset.Dataset")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
tmp_name = util.gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_dataset(tmp_name, source)


@_read_in_memory.register("pandas.DataFrame")
def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pandas")
tmp_name = util.gen_name("pandas")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pandas(source, name=tmp_name)
22 changes: 9 additions & 13 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,9 @@ def read_json(
An ibis table expression
"""
filenames = util.normalize_filenames(source_list)
if not table_name:
table_name = util.gen_name("read_json")
table_name = util.gen_name_from_path(filenames[0], "json")

options = [
sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items()
Expand All @@ -612,11 +613,7 @@ def read_json(

self._create_temp_view(
table_name,
sg.select(STAR).from_(
self.compiler.f.read_json_auto(
util.normalize_filenames(source_list), *options
)
),
sg.select(STAR).from_(self.compiler.f.read_json_auto(filenames, *options)),
)

return self.table(table_name)
Expand Down Expand Up @@ -703,7 +700,7 @@ def read_csv(
source_list = util.normalize_filenames(source_list)

if not table_name:
table_name = util.gen_name("read_csv")
table_name = util.gen_name_from_path(source_list[0], "csv")

# auto_detect and columns collide, so we set auto_detect=True
# unless COLUMNS has been specified
Expand Down Expand Up @@ -779,17 +776,16 @@ def read_geo(
The just-registered table
"""

if not table_name:
table_name = util.gen_name("read_geo")

# load geospatial extension
self.load_extension("spatial")

source = util.normalize_filename(source)
if source.startswith(("http://", "https://", "s3://")):
self._load_extensions(["httpfs"])

if not table_name:
table_name = util.gen_name_from_path(source, "geo")

source_expr = sg.select(STAR).from_(
self.compiler.f.st_read(
source,
Expand Down Expand Up @@ -835,7 +831,7 @@ def read_parquet(
"""
source_list = util.normalize_filenames(source_list)

table_name = table_name or util.gen_name("read_parquet")
table_name = table_name or util.gen_name_from_path(source_list[0], "parquet")

# Default to using the native duckdb parquet reader
# If that fails because of auth issues, fall back to ingesting via
Expand Down Expand Up @@ -944,7 +940,7 @@ def read_delta(
"""
source_table = util.normalize_filenames(source_table)[0]

table_name = table_name or util.gen_name("read_delta")
table_name = table_name or util.gen_name_from_path(source_table, "delta")

try:
from deltalake import DeltaTable
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from ibis.backends.sql import SQLBackend
from ibis.backends.tests.errors import Py4JJavaError
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -767,7 +766,7 @@ def _read_file(
f"`schema` must be explicitly provided when calling `read_{file_type}`"
)

table_name = table_name or gen_name(f"read_{file_type}")
table_name = table_name or util.gen_name_from_path(path, file_type)
tbl_properties = {
"connector": "filesystem",
"path": path,
Expand Down
31 changes: 14 additions & 17 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends import BaseBackend, NoUrl
from ibis.backends.polars.compiler import translate
from ibis.backends.polars.rewrites import bind_unbound_table, rewrite_join
from ibis.backends.sql.dialects import Polars
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.rewrites import lower_stringslice, replace_parameter
from ibis.formats.polars import PolarsSchema
from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -100,7 +100,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
def _finalize_memtable(self, name: str) -> None:
self.drop_table(name, force=True)

@deprecated(
@util.deprecated(
as_of="9.1",
instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.",
)
Expand Down Expand Up @@ -209,12 +209,12 @@ def read_csv(
The just-registered table
"""
source_list = normalize_filenames(path)
source_list = util.normalize_filenames(path)
table_name = table_name or util.gen_name_from_path(source_list[0], "csv")
# Flatten the list if there's only one element because Polars
# can't handle glob strings, or compressed CSVs in a single-element list
if len(source_list) == 1:
source_list = source_list[0]
table_name = table_name or gen_name("read_csv")
try:
table = pl.scan_csv(source_list, **kwargs)
# triggers a schema computation to handle compressed csv inference
Expand Down Expand Up @@ -250,8 +250,8 @@ def read_json(
The just-registered table
"""
path = normalize_filename(path)
table_name = table_name or gen_name("read_json")
path = util.normalize_filename(path)
table_name = table_name or util.gen_name_from_path(path, "json")
try:
self._add_table(table_name, pl.scan_ndjson(path, **kwargs))
except pl.exceptions.ComputeError:
Expand Down Expand Up @@ -290,8 +290,8 @@ def read_delta(
"read_delta method. You can install it using pip:\n\n"
"pip install 'ibis-framework[polars,deltalake]'\n"
)
path = normalize_filename(path)
table_name = table_name or gen_name("read_delta")
path = util.normalize_filename(path)
table_name = table_name or util.gen_name_from_path(path, "delta")
self._add_table(table_name, pl.scan_delta(path, **kwargs))
return self.table(table_name)

Expand All @@ -318,7 +318,7 @@ def read_pandas(
The just-registered table
"""
table_name = table_name or gen_name("read_in_memory")
table_name = table_name or util.gen_name("read_in_memory")

self._add_table(table_name, pl.from_pandas(source, **kwargs).lazy())
return self.table(table_name)
Expand Down Expand Up @@ -351,24 +351,21 @@ def read_parquet(
The just-registered table
"""
table_name = table_name or gen_name("read_parquet")
if not isinstance(path, (str, Path)) and len(path) == 1:
path = path[0]
paths = util.normalize_filenames(path)
table_name = table_name or util.gen_name_from_path(paths[0], "parquet")

if not isinstance(path, (str, Path)) and len(path) > 1:
if len(paths) > 1:
self._import_pyarrow()
import pyarrow.dataset as ds

paths = [normalize_filename(p) for p in path]
obj = pl.scan_pyarrow_dataset(
source=ds.dataset(paths, format="parquet"),
**kwargs,
)
self._add_table(table_name, obj)
else:
path = normalize_filename(path)
self._add_table(table_name, pl.scan_parquet(path, **kwargs))
obj = pl.scan_parquet(paths[0], **kwargs)

self._add_table(table_name, obj)
return self.table(table_name)

def create_table(
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def read_delta(
)
path = util.normalize_filename(path)
spark_df = self._session.read.format("delta").load(path, **kwargs)
table_name = table_name or util.gen_name("read_delta")
table_name = table_name or util.gen_name_from_path(path, "delta")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -827,7 +827,7 @@ def read_parquet(
)
path = util.normalize_filename(path)
spark_df = self._session.read.parquet(path, **kwargs)
table_name = table_name or util.gen_name("read_parquet")
table_name = table_name or util.gen_name_from_path(path, "parquet")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -869,7 +869,7 @@ def read_csv(
spark_df = self._session.read.csv(
source_list, inferSchema=inferSchema, header=header, **kwargs
)
table_name = table_name or util.gen_name("read_csv")
table_name = table_name or util.gen_name_from_path(source_list[0], "csv")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -907,7 +907,7 @@ def read_json(
)
source_list = util.normalize_filenames(source_list)
spark_df = self._session.read.json(source_list, **kwargs)
table_name = table_name or util.gen_name("read_json")
table_name = table_name or util.gen_name_from_path(source_list[0], "json")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def read_csv_dir(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_csv_dir")
table_name = table_name or util.gen_name_from_path(path, "csv_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def read_parquet_dir(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_parquet_dir")
table_name = table_name or util.gen_name_from_path(path, "parquet_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down Expand Up @@ -1318,7 +1318,7 @@ def read_json_dir(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_json_dir")
table_name = table_name or util.gen_name_from_path(path, "json_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)
Expand Down
Loading

0 comments on commit 3919f2f

Please sign in to comment.