Skip to content

Commit 9d4fbbd

Browse files
cpcloudkszucs
authored andcommitted
feat(api): implement pyarrow memtables
1 parent 1a1892c commit 9d4fbbd

File tree

5 files changed

+101
-4
lines changed

5 files changed

+101
-4
lines changed

ibis/backends/base/sql/compiler/select_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ def _collect_PandasInMemoryTable(self, node, toplevel=False):
306306
self.select_set = [node]
307307
self.table_set = node
308308

309+
def _collect_PyArrowInMemoryTable(self, node, toplevel=False):
310+
if toplevel:
311+
self.select_set = [node]
312+
self.table_set = node
313+
309314
def _convert_group_by(self, nodes):
310315
return list(range(len(nodes)))
311316

ibis/backends/pyarrow/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import pyarrow as pa
6+
7+
import ibis.expr.operations as ops
8+
import ibis.expr.rules as rlz
9+
import ibis.expr.schema as sch
10+
from ibis import util
11+
from ibis.common.grounds import Immutable
12+
13+
if TYPE_CHECKING:
14+
import pandas as pd
15+
16+
17+
class PyArrowTableProxy(Immutable, util.ToFrame):
18+
__slots__ = ('_t', '_hash')
19+
20+
def __init__(self, t: pa.Table) -> None:
21+
object.__setattr__(self, "_t", t)
22+
object.__setattr__(self, "_hash", hash((type(t), id(t))))
23+
24+
def __hash__(self) -> int:
25+
return self._hash
26+
27+
def __repr__(self) -> str:
28+
df_repr = util.indent(repr(self._t), spaces=2)
29+
return f"{self.__class__.__name__}:\n{df_repr}"
30+
31+
def to_frame(self) -> pd.DataFrame:
32+
return self._t.to_pandas()
33+
34+
def to_pyarrow(self, _: sch.Schema) -> pa.Table:
35+
return self._t
36+
37+
38+
class PyArrowInMemoryTable(ops.InMemoryTable):
39+
data = rlz.instance_of(PyArrowTableProxy)
40+
41+
42+
@sch.infer.register(pa.Table)
43+
def infer_pyarrow_table_schema(t: pa.Table, schema=None):
44+
import ibis.backends.pyarrow.datatypes # noqa: F401
45+
46+
return sch.schema(schema if schema is not None else t.schema)

ibis/backends/pyspark/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ibis.backends.base.df.timecontext import adjust_context
2121
from ibis.backends.pandas.client import PandasInMemoryTable
2222
from ibis.backends.pandas.execution import execute
23+
from ibis.backends.pyarrow import PyArrowInMemoryTable
2324
from ibis.backends.pyspark.datatypes import spark_dtype
2425
from ibis.backends.pyspark.timecontext import (
2526
combine_time_context,
@@ -1862,8 +1863,8 @@ def compile_random(*args, **kwargs):
18621863
return F.rand()
18631864

18641865

1865-
@compiles(ops.InMemoryTable)
18661866
@compiles(PandasInMemoryTable)
1867+
@compiles(PyArrowInMemoryTable)
18671868
def compile_in_memory_table(t, op, session, **kwargs):
18681869
fields = [
18691870
pt.StructField(name, spark_dtype(dtype), dtype.nullable)

ibis/backends/tests/test_generic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,29 @@ def test_memtable_bool_column(backend, con, monkeypatch):
920920
backend.assert_series_equal(t.a.execute(), pd.Series([True, False, True], name="a"))
921921

922922

923+
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
924+
@pytest.mark.notimpl(["dask", "pandas"], raises=com.UnboundExpressionError)
925+
@pytest.mark.broken(
926+
["druid"],
927+
raises=AssertionError,
928+
reason="result contains empty strings instead of None",
929+
)
930+
def test_memtable_construct(backend, con, monkeypatch):
931+
pa = pytest.importorskip("pyarrow")
932+
monkeypatch.setattr(ibis.options, "default_backend", con)
933+
934+
pa_t = pa.Table.from_pydict(
935+
{
936+
"a": list("abc"),
937+
"b": [1, 2, 3],
938+
"c": [1.0, 2.0, 3.0],
939+
"d": [None, "b", None],
940+
}
941+
)
942+
t = ibis.memtable(pa_t)
943+
backend.assert_frame_equal(t.execute(), pa_t.to_pandas())
944+
945+
923946
@pytest.mark.notimpl(
924947
["dask", "datafusion", "pandas", "polars"],
925948
raises=NotImplementedError,

ibis/expr/api.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
if TYPE_CHECKING:
4444
import pandas as pd
45+
import pyarrow as pa
4546

4647
from ibis.common.typing import SupportsSchema
4748

@@ -324,10 +325,10 @@ def memtable(
324325
Parameters
325326
----------
326327
data
327-
Any data accepted by the `pandas.DataFrame` constructor.
328+
Any data accepted by the `pandas.DataFrame` constructor or a `pyarrow.Table`.
328329
329-
The use of `DataFrame` underneath should **not** be relied upon and is
330-
free to change across non-major releases.
330+
Do not depend on the underlying storage type (e.g., pyarrow.Table), it's subject
331+
to change across non-major releases.
331332
columns
332333
Optional [`Iterable`][typing.Iterable] of [`str`][str] column names.
333334
schema
@@ -393,6 +394,15 @@ def memtable(
393394
"passing `columns` and schema` is ambiguous; "
394395
"pass one or the other but not both"
395396
)
397+
398+
try:
399+
import pyarrow as pa
400+
except ImportError:
401+
pass
402+
else:
403+
if isinstance(data, pa.Table):
404+
return _memtable_from_pyarrow_table(data, name=name, schema=schema)
405+
396406
df = pd.DataFrame(data, columns=columns)
397407
if df.columns.inferred_type != "string":
398408
cols = df.columns
@@ -421,6 +431,18 @@ def _memtable_from_dataframe(
421431
return op.to_expr()
422432

423433

434+
def _memtable_from_pyarrow_table(
435+
data: pa.Table, *, name: str | None = None, schema: SupportsSchema | None = None
436+
):
437+
from ibis.backends.pyarrow import PyArrowInMemoryTable, PyArrowTableProxy
438+
439+
return PyArrowInMemoryTable(
440+
name=name if name is not None else util.generate_unique_table_name("memtable"),
441+
schema=sch.infer(data) if schema is None else schema,
442+
data=PyArrowTableProxy(data),
443+
).to_expr()
444+
445+
424446
def _deferred_method_call(expr, method_name):
425447
method = operator.methodcaller(method_name)
426448
if isinstance(expr, str):

0 commit comments

Comments
 (0)