Skip to content

Commit c162750

Browse files
committed
refactor(duckdb): use sqlalchemy-views to reduce string hacking
1 parent 3874a8e commit c162750

File tree

3 files changed

+68
-46
lines changed

3 files changed

+68
-46
lines changed

ibis/backends/duckdb/__init__.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
2727
from ibis.backends.duckdb.datatypes import parse
2828

29-
_dialect = sa.dialects.postgresql.dialect()
30-
3129
# counters for in-memory, parquet, and csv reads
3230
# used if no table name is specified
3331
pd_n = itertools.count(0)
@@ -42,10 +40,18 @@ def normalize_filenames(source_list):
4240
return list(map(util.normalize_filename, source_list))
4341

4442

45-
def _format_kwargs(kwargs):
46-
return (
47-
f"{k}='{v}'" if isinstance(v, str) else f"{k}={v!r}" for k, v in kwargs.items()
48-
)
43+
def _create_view(*args, **kwargs):
44+
import sqlalchemy_views as sav
45+
46+
return sav.CreateView(*args, **kwargs)
47+
48+
49+
def _format_kwargs(kwargs: Mapping[str, Any]):
50+
bindparams, pieces = [], []
51+
for name, value in kwargs.items():
52+
bindparams.append(sa.bindparam(name, value))
53+
pieces.append(f"{name} = :{name}")
54+
return sa.text(", ".join(pieces)).bindparams(*bindparams)
4955

5056

5157
class Backend(BaseAlchemyBackend):
@@ -189,14 +195,13 @@ def register(
189195
elif first.startswith(("postgres://", "postgresql://")):
190196
return self.read_postgres(source, table_name=table_name, **kwargs)
191197
else:
192-
self._register_failure()
193-
return None
198+
self._register_failure() # noqa: RET503
194199

195200
def _register_failure(self):
196201
import inspect
197202

198203
msg = ", ".join(
199-
m[0] for m in inspect.getmembers(self) if m[0].startswith("read_")
204+
name for name, _ in inspect.getmembers(self) if name.startswith("read_")
200205
)
201206
raise ValueError(
202207
f"Cannot infer appropriate read function for input, "
@@ -219,7 +224,7 @@ def read_csv(
219224
table_name
220225
An optional name to use for the created table. This defaults to
221226
a sequentially generated name.
222-
**kwargs
227+
kwargs
223228
Additional keyword arguments passed to DuckDB loading function.
224229
See https://duckdb.org/docs/data/csv for more information.
225230
@@ -233,24 +238,19 @@ def read_csv(
233238
if not table_name:
234239
table_name = f"ibis_read_csv_{next(csv_n)}"
235240

236-
quoted_table_name = self._quote(table_name)
237-
238241
# auto_detect and columns collide, so we set auto_detect=True
239242
# unless COLUMNS has been specified
240-
args = [
241-
str(source_list),
242-
f"auto_detect={kwargs.pop('auto_detect', 'columns' not in kwargs)}",
243-
*_format_kwargs(kwargs),
244-
]
245-
sql = f"""CREATE OR REPLACE VIEW {quoted_table_name} AS
246-
SELECT * FROM read_csv({', '.join(args)})"""
247-
248243
if any(source.startswith(("http://", "https://")) for source in source_list):
249244
self._load_extensions(["httpfs"])
250245

246+
kwargs["auto_detect"] = kwargs.pop("auto_detect", "columns" not in kwargs)
247+
source = sa.select(sa.literal_column("*")).select_from(
248+
sa.func.read_csv(sa.func.list_value(*source_list), _format_kwargs(kwargs))
249+
)
250+
view = _create_view(sa.table(table_name), source, or_replace=True)
251251
with self.begin() as con:
252-
con.execute(sa.text(sql))
253-
return self._read(table_name)
252+
con.execute(view)
253+
return self.table(table_name)
254254

255255
def read_parquet(
256256
self,
@@ -268,7 +268,7 @@ def read_parquet(
268268
table_name
269269
An optional name to use for the created table. This defaults to
270270
a sequentially generated name.
271-
**kwargs
271+
kwargs
272272
Additional keyword arguments passed to DuckDB loading function.
273273
See https://duckdb.org/docs/data/parquet for more information.
274274
@@ -302,17 +302,20 @@ def read_parquet(
302302
source.startswith(("http://", "https://")) for source in source_list
303303
):
304304
self._load_extensions(["httpfs"])
305-
dataset = str(source_list)
306-
table_name = table_name or f"ibis_read_parquet_{next(pa_n)}"
307305

308-
quoted_table_name = self._quote(table_name)
309-
sql = f"""CREATE OR REPLACE VIEW {quoted_table_name} AS
310-
SELECT * FROM read_parquet({dataset})"""
306+
if table_name is None:
307+
table_name = f"ibis_read_parquet_{next(pa_n)}"
311308

309+
source = sa.select(sa.literal_column("*")).select_from(
310+
sa.func.read_parquet(
311+
sa.func.list_value(*source_list), _format_kwargs(kwargs)
312+
)
313+
)
314+
view = _create_view(sa.table(table_name), source, or_replace=True)
312315
with self.begin() as con:
313-
con.execute(sa.text(sql))
316+
con.execute(view)
314317

315-
return self._read(table_name)
318+
return self.table(table_name)
316319

317320
def read_in_memory(
318321
self, dataframe: pd.DataFrame | pa.Table, table_name: str | None = None
@@ -336,9 +339,9 @@ def read_in_memory(
336339
with self.begin() as con:
337340
con.connection.register(table_name, dataframe)
338341

339-
return self._read(table_name)
342+
return self.table(table_name)
340343

341-
def read_postgres(self, uri, table_name=None):
344+
def read_postgres(self, uri, table_name: str | None = None, schema: str = "public"):
342345
"""Register a table from a postgres instance into a DuckDB table.
343346
344347
Parameters
@@ -347,6 +350,8 @@ def read_postgres(self, uri, table_name=None):
347350
The postgres URI in form 'postgres://user:password@host:port'
348351
table_name
349352
The table to read
353+
schema
354+
PostgreSQL schema where `table_name` resides
350355
351356
Returns
352357
-------
@@ -358,18 +363,13 @@ def read_postgres(self, uri, table_name=None):
358363
"`table_name` is required when registering a postgres table"
359364
)
360365
self._load_extensions(["postgres_scanner"])
361-
quoted_table_name = self._quote(table_name)
362-
sql = (
363-
f"CREATE OR REPLACE VIEW {quoted_table_name} AS "
364-
f"SELECT * FROM postgres_scan_pushdown('{uri}', 'public', '{table_name}')"
366+
source = sa.select(sa.literal_column("*")).select_from(
367+
sa.func.postgres_scan_pushdown(uri, schema, table_name)
365368
)
369+
view = _create_view(sa.table(table_name), source, or_replace=True)
366370
with self.begin() as con:
367-
con.execute(sa.text(sql))
368-
369-
return self._read(table_name)
371+
con.execute(view)
370372

371-
def _read(self, table_name):
372-
_table = self.table(table_name)
373373
return self.table(table_name)
374374

375375
def to_pyarrow_batches(
@@ -379,7 +379,7 @@ def to_pyarrow_batches(
379379
params: Mapping[ir.Scalar, Any] | None = None,
380380
limit: int | str | None = None,
381381
chunk_size: int = 1_000_000,
382-
**kwargs: Any,
382+
**_: Any,
383383
) -> pa.ipc.RecordBatchReader:
384384
# TODO: duckdb seems to not care about the `chunk_size` argument
385385
# and returns batches in 1024 row chunks
@@ -401,7 +401,7 @@ def to_pyarrow(
401401
*,
402402
params: Mapping[ir.Scalar, Any] | None = None,
403403
limit: int | str | None = None,
404-
**kwargs: Any,
404+
**_: Any,
405405
) -> pa.Table:
406406
pa = self._import_pyarrow()
407407
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)

poetry.lock

Lines changed: 17 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ snowflake-sqlalchemy = { version = ">=1.4.1,<2", optional = true, extras = [
8282
"pandas"
8383
] }
8484
sqlalchemy = { version = ">=1.4,<2", optional = true }
85+
sqlalchemy-views = { version = ">=0.3.1,<1", optional = true }
8586
trino = { version = ">=0.319,<1", optional = true, extras = ["sqlalchemy"] }
8687

8788
[tool.poetry.group.dev.dependencies]
@@ -169,7 +170,13 @@ bigquery = [
169170
clickhouse = ["clickhouse-driver", "clickhouse-cityhash", "lz4"]
170171
dask = ["dask", "pyarrow", "regex"]
171172
datafusion = ["datafusion"]
172-
duckdb = ["duckdb", "duckdb-engine", "pyarrow", "sqlalchemy"]
173+
duckdb = [
174+
"duckdb",
175+
"duckdb-engine",
176+
"pyarrow",
177+
"sqlalchemy",
178+
"sqlalchemy-views"
179+
]
173180
geospatial = ["geoalchemy2", "geopandas", "shapely"]
174181
impala = ["fsspec", "impyla", "requests", "sqlalchemy"]
175182
mssql = ["sqlalchemy", "pymssql"]

0 commit comments

Comments
 (0)