Skip to content

Commit c572eab

Browse files
committed
refactor(datafusion): simplify execute and to_pyarrow implementations
1 parent 0b9c874 commit c572eab

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

ibis/backends/datafusion/__init__.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ def to_pyarrow_batches(
468468
self,
469469
expr: ir.Expr,
470470
*,
471-
params: Mapping[ir.Scalar, Any] | None = None,
472471
chunk_size: int = 1_000_000,
473472
**kwargs: Any,
474473
) -> pa.ipc.RecordBatchReader:
@@ -477,49 +476,42 @@ def to_pyarrow_batches(
477476
self._register_udfs(expr)
478477
self._register_in_memory_tables(expr)
479478

480-
sql = self.compile(expr.as_table(), params=params, **kwargs)
481-
frame = self.con.sql(sql)
482-
batches = frame.collect()
483-
schema = expr.as_table().schema()
479+
table_expr = expr.as_table()
480+
raw_sql = self.compile(table_expr, **kwargs)
481+
482+
frame = self.con.sql(raw_sql)
483+
484+
schema = table_expr.schema()
485+
names = schema.names
486+
484487
struct_schema = schema.as_struct().to_pyarrow()
488+
485489
return pa.ipc.RecordBatchReader.from_batches(
486490
schema.to_pyarrow(),
487491
(
488-
# convert the renamed and casted columns batch into a record batch
492+
# convert the renamed + casted columns into a record batch
489493
pa.RecordBatch.from_struct_array(
490494
# rename columns to match schema because datafusion lowercases things
491-
pa.RecordBatch.from_arrays(batch.columns, names=schema.names)
492-
# casting the struct array to appropriate types to work around
495+
pa.RecordBatch.from_arrays(batch.columns, names=names)
496+
# cast the struct array to the desired types to work around
493497
# https://github.com/apache/arrow-datafusion-python/issues/534
494498
.to_struct_array()
495499
.cast(struct_schema)
496500
)
497-
for batch in batches
501+
for batch in frame.collect()
498502
),
499503
)
500504

501-
def to_pyarrow(
502-
self,
503-
expr: ir.Expr,
504-
*,
505-
params: Mapping[ir.Scalar, Any] | None = None,
506-
**kwargs: Any,
507-
) -> pa.Table:
508-
self._register_in_memory_tables(expr)
509-
510-
batch_reader = self.to_pyarrow_batches(expr, params=params, **kwargs)
505+
def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
506+
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
511507
arrow_table = batch_reader.read_all()
512508
return expr.__pyarrow_result__(arrow_table)
513509

514-
def execute(
515-
self,
516-
expr: ir.Expr,
517-
params: Mapping[ir.Expr, object] | None = None,
518-
limit: int | str | None = "default",
519-
**kwargs: Any,
520-
):
521-
output = self.to_pyarrow(expr.as_table(), params=params, limit=limit, **kwargs)
522-
return expr.__pandas_result__(output.to_pandas(timestamp_as_object=True))
510+
def execute(self, expr: ir.Expr, **kwargs: Any):
511+
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
512+
return expr.__pandas_result__(
513+
batch_reader.read_pandas(timestamp_as_object=True)
514+
)
523515

524516
def _to_sqlglot(
525517
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any

0 commit comments

Comments
 (0)