Skip to content

Commit c9238bd

Browse files
authored
feat(pyarrow): support arrow PyCapsule interface in more places (#9663)
1 parent ddeecce commit c9238bd

File tree

3 files changed

+61
-36
lines changed

3 files changed

+61
-36
lines changed

ibis/backends/tests/test_client.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -927,27 +927,6 @@ def test_self_join_memory_table(backend, con, monkeypatch):
927927
param(
928928
lambda: pa.table({"a": ["a"], "b": [1]}).to_batches()[0],
929929
"df_arrow_single_batch",
930-
marks=[
931-
pytest.mark.notimpl(
932-
[
933-
"bigquery",
934-
"clickhouse",
935-
"duckdb",
936-
"exasol",
937-
"impala",
938-
"mssql",
939-
"mysql",
940-
"oracle",
941-
"postgres",
942-
"pyspark",
943-
"risingwave",
944-
"snowflake",
945-
"sqlite",
946-
"trino",
947-
"databricks",
948-
]
949-
)
950-
],
951930
id="pyarrow_single_batch",
952931
),
953932
param(

ibis/backends/tests/test_generic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,24 @@ def test_memtable_construct_from_pyarrow(backend, con, monkeypatch):
12221222
)
12231223

12241224

1225+
def test_memtable_construct_from_pyarrow_c_stream(backend, con):
1226+
pa = pytest.importorskip("pyarrow")
1227+
1228+
class Opaque:
1229+
def __init__(self, table):
1230+
self._table = table
1231+
1232+
def __arrow_c_stream__(self, *args, **kwargs):
1233+
return self._table.__arrow_c_stream__(*args, **kwargs)
1234+
1235+
table = pa.table({"a": list("abc"), "b": [1, 2, 3]})
1236+
1237+
t = ibis.memtable(Opaque(table))
1238+
1239+
res = con.to_pyarrow(t.order_by("a"))
1240+
assert res.equals(table)
1241+
1242+
12251243
@pytest.mark.parametrize("lazy", [False, True])
12261244
def test_memtable_construct_from_polars(backend, con, lazy):
12271245
pl = pytest.importorskip("polars")

ibis/expr/api.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -412,42 +412,55 @@ def memtable(
412412

413413
@lazy_singledispatch
414414
def _memtable(
415-
data: pd.DataFrame | Any,
415+
data: Any,
416416
*,
417417
columns: Iterable[str] | None = None,
418418
schema: SchemaLike | None = None,
419419
name: str | None = None,
420420
) -> Table:
421-
import pandas as pd
422-
423-
from ibis.formats.pandas import PandasDataFrameProxy
421+
if hasattr(data, "__arrow_c_stream__"):
422+
# Support objects exposing arrow's PyCapsule interface
423+
import pyarrow as pa
424424

425-
if not isinstance(data, pd.DataFrame):
426-
df = pd.DataFrame(data, columns=columns)
425+
data = pa.table(data)
427426
else:
428-
df = data
427+
import pandas as pd
428+
429+
data = pd.DataFrame(data, columns=columns)
430+
return _memtable(data, columns=columns, schema=schema, name=name)
431+
432+
433+
@_memtable.register("pandas.DataFrame")
434+
def _memtable_from_pandas_dataframe(
435+
data: pd.DataFrame,
436+
*,
437+
columns: Iterable[str] | None = None,
438+
schema: SchemaLike | None = None,
439+
name: str | None = None,
440+
) -> Table:
441+
from ibis.formats.pandas import PandasDataFrameProxy
429442

430-
if df.columns.inferred_type != "string":
431-
cols = df.columns
443+
if data.columns.inferred_type != "string":
444+
cols = data.columns
432445
newcols = getattr(
433446
schema,
434447
"names",
435448
(f"col{i:d}" for i in builtins.range(len(cols))),
436449
)
437-
df = df.rename(columns=dict(zip(cols, newcols)))
450+
data = data.rename(columns=dict(zip(cols, newcols)))
438451

439452
if columns is not None:
440-
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
453+
if (provided_col := len(columns)) != (exist_col := len(data.columns)):
441454
raise ValueError(
442455
"Provided `columns` must have an entry for each column in `data`.\n"
443456
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
444457
)
445458

446-
df = df.rename(columns=dict(zip(df.columns, columns)))
459+
data = data.rename(columns=dict(zip(data.columns, columns)))
447460

448461
# verify that the DataFrame has no duplicate column names because ibis
449462
# doesn't allow that
450-
cols = df.columns
463+
cols = data.columns
451464
dupes = [name for name, count in Counter(cols).items() if count > 1]
452465
if dupes:
453466
raise IbisInputError(
@@ -456,8 +469,8 @@ def _memtable(
456469

457470
op = ops.InMemoryTable(
458471
name=name if name is not None else util.gen_name("pandas_memtable"),
459-
schema=sch.infer(df) if schema is None else schema,
460-
data=PandasDataFrameProxy(df),
472+
schema=sch.infer(data) if schema is None else schema,
473+
data=PandasDataFrameProxy(data),
461474
)
462475
return op.to_expr()
463476

@@ -499,6 +512,21 @@ def _memtable_from_pyarrow_dataset(
499512
).to_expr()
500513

501514

515+
@_memtable.register("pyarrow.RecordBatchReader")
516+
def _memtable_from_pyarrow_RecordBatchReader(
517+
data: pa.Table,
518+
*,
519+
name: str | None = None,
520+
schema: SchemaLike | None = None,
521+
columns: Iterable[str] | None = None,
522+
):
523+
raise TypeError(
524+
"Creating an `ibis.memtable` from a `pyarrow.RecordBatchReader` would "
525+
"load _all_ data into memory. If you want to do this, please do so "
526+
"explicitly like `ibis.memtable(reader.read_all())`"
527+
)
528+
529+
502530
@_memtable.register("polars.LazyFrame")
503531
def _memtable_from_polars_lazyframe(data: pl.LazyFrame, **kwargs):
504532
return _memtable_from_polars_dataframe(data.collect(), **kwargs)

0 commit comments

Comments
 (0)