Skip to content

Commit 7d8fe5a

Browse files
committed
fix(snowflake): fix array printing by using a pyarrow extension type
1 parent 8bac145 commit 7d8fe5a

File tree

6 files changed

+100
-13
lines changed

6 files changed

+100
-13
lines changed

ibis/backends/snowflake/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def to_pyarrow(
364364
limit: int | str | None = None,
365365
**_: Any,
366366
) -> pa.Table:
367+
from ibis.backends.snowflake.converter import SnowflakePyArrowData
368+
367369
self._run_pre_execute_hooks(expr)
368370

369371
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
@@ -375,9 +377,7 @@ def to_pyarrow(
375377
if res is None:
376378
res = target_schema.empty_table()
377379

378-
res = res.rename_columns(target_schema.names).cast(target_schema)
379-
380-
return expr.__pyarrow_result__(res)
380+
return expr.__pyarrow_result__(res, data_mapper=SnowflakePyArrowData)
381381

382382
def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
383383
if (table := cursor.cursor.fetch_arrow_all()) is None:

ibis/backends/snowflake/converter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from ibis.formats.pandas import PandasData
6+
from ibis.formats.pyarrow import PYARROW_JSON_TYPE, PyArrowData
7+
8+
if TYPE_CHECKING:
9+
import pyarrow as pa
10+
11+
import ibis.expr.datatypes as dt
12+
from ibis.expr.schema import Schema
413

514

615
class SnowflakePandasData(PandasData):
@@ -10,3 +19,23 @@ def convert_JSON(s, dtype, pandas_type):
1019
return s.map(converter, na_action="ignore").astype("object")
1120

1221
convert_Struct = convert_Array = convert_Map = convert_JSON
22+
23+
24+
class SnowflakePyArrowData(PyArrowData):
25+
@classmethod
26+
def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table:
27+
import pyarrow as pa
28+
29+
columns = [cls.convert_column(table[name], typ) for name, typ in schema.items()]
30+
return pa.Table.from_arrays(columns, names=schema.names)
31+
32+
@classmethod
33+
def convert_column(cls, column: pa.Array, dtype: dt.DataType) -> pa.Array:
34+
if dtype.is_json() or dtype.is_array() or dtype.is_map() or dtype.is_struct():
35+
import pyarrow as pa
36+
37+
if isinstance(column, pa.ChunkedArray):
38+
column = column.combine_chunks()
39+
40+
return pa.ExtensionArray.from_storage(PYARROW_JSON_TYPE, column)
41+
return super().convert_column(column, dtype)

ibis/backends/snowflake/tests/test_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,10 @@ def test_read_parquet(con, data_dir):
219219
t = con.read_parquet(path)
220220

221221
assert t.timestamp_col.type().is_timestamp()
222+
223+
224+
def test_array_repr(con, monkeypatch):
225+
monkeypatch.setattr(ibis.options, "interactive", True)
226+
t = con.tables.ARRAY_TYPES
227+
expr = t.x
228+
assert repr(expr)

ibis/expr/types/generic.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import ibis.expr.builders as bl
2020
import ibis.expr.types as ir
21+
from ibis.formats.pyarrow import PyArrowData
2122

2223

2324
@public
@@ -1204,10 +1205,13 @@ class Scalar(Value):
12041205
def __interactive_rich_console__(self, console, options):
12051206
return console.render(repr(self.execute()), options=options)
12061207

1207-
def __pyarrow_result__(self, table: pa.Table) -> pa.Scalar:
1208-
from ibis.formats.pyarrow import PyArrowData
1208+
def __pyarrow_result__(
1209+
self, table: pa.Table, data_mapper: type[PyArrowData] | None = None
1210+
) -> pa.Scalar:
1211+
if data_mapper is None:
1212+
from ibis.formats.pyarrow import PyArrowData as data_mapper
12091213

1210-
return PyArrowData.convert_scalar(table[0][0], self.type())
1214+
return data_mapper.convert_scalar(table[0][0], self.type())
12111215

12121216
def __pandas_result__(self, df: pd.DataFrame) -> Any:
12131217
return df.iat[0, 0]
@@ -1275,10 +1279,13 @@ def __interactive_rich_console__(self, console, options):
12751279
projection = named.as_table()
12761280
return console.render(projection, options=options)
12771281

1278-
def __pyarrow_result__(self, table: pa.Table) -> pa.Array | pa.ChunkedArray:
1279-
from ibis.formats.pyarrow import PyArrowData
1282+
def __pyarrow_result__(
1283+
self, table: pa.Table, data_mapper: type[PyArrowData] | None = None
1284+
) -> pa.Array | pa.ChunkedArray:
1285+
if data_mapper is None:
1286+
from ibis.formats.pyarrow import PyArrowData as data_mapper
12801287

1281-
return PyArrowData.convert_column(table[0], self.type())
1288+
return data_mapper.convert_column(table[0], self.type())
12821289

12831290
def __pandas_result__(self, df: pd.DataFrame) -> pd.Series:
12841291
from ibis.formats.pandas import PandasData

ibis/expr/types/relations.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ibis.expr.types.groupby import GroupedTable
3333
from ibis.expr.types.tvf import WindowedTable
3434
from ibis.selectors import IfAnyAll, Selector
35+
from ibis.formats.pyarrow import PyArrowData
3536

3637
_ALIASES = (f"_ibis_view_{n:d}" for n in itertools.count())
3738

@@ -158,10 +159,13 @@ def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
158159

159160
return IbisDataFrame(self, nan_as_null=nan_as_null, allow_copy=allow_copy)
160161

161-
def __pyarrow_result__(self, table: pa.Table) -> pa.Table:
162-
from ibis.formats.pyarrow import PyArrowData
162+
def __pyarrow_result__(
163+
self, table: pa.Table, data_mapper: type[PyArrowData] | None = None
164+
) -> pa.Table:
165+
if data_mapper is None:
166+
from ibis.formats.pyarrow import PyArrowData as data_mapper
163167

164-
return PyArrowData.convert_table(table, self.schema())
168+
return data_mapper.convert_table(table, self.schema())
165169

166170
def __pandas_result__(self, df: pd.DataFrame) -> pd.DataFrame:
167171
from ibis.formats.pandas import PandasData

ibis/formats/pyarrow.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import TYPE_CHECKING, Any
45

56
import pyarrow as pa
@@ -12,6 +13,38 @@
1213
if TYPE_CHECKING:
1314
from collections.abc import Sequence
1415

16+
17+
class JSONScalar(pa.ExtensionScalar):
18+
def as_py(self):
19+
value = self.value
20+
if value is None:
21+
return value
22+
else:
23+
return json.loads(value.as_py())
24+
25+
26+
class JSONArray(pa.ExtensionArray):
27+
pass
28+
29+
30+
class JSONType(pa.ExtensionType):
31+
def __init__(self):
32+
super().__init__(pa.string(), "ibis.json")
33+
34+
def __arrow_ext_serialize__(self):
35+
return b""
36+
37+
@classmethod
38+
def __arrow_ext_deserialize__(cls, storage_type, serialized):
39+
return cls()
40+
41+
def __arrow_ext_class__(self):
42+
return JSONArray
43+
44+
def __arrow_ext_scalar_class__(self):
45+
return JSONScalar
46+
47+
1548
_from_pyarrow_types = {
1649
pa.int8(): dt.Int8,
1750
pa.int16(): dt.Int16,
@@ -57,7 +90,6 @@
5790
dt.Unknown: pa.string(),
5891
dt.MACADDR: pa.string(),
5992
dt.INET: pa.string(),
60-
dt.JSON: pa.string(),
6193
}
6294

6395

@@ -95,6 +127,8 @@ def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType:
95127
key_dtype = cls.to_ibis(typ.key_type, typ.key_field.nullable)
96128
value_dtype = cls.to_ibis(typ.item_type, typ.item_field.nullable)
97129
return dt.Map(key_dtype, value_dtype, nullable=nullable)
130+
elif isinstance(typ, JSONType):
131+
return dt.JSON()
98132
else:
99133
return _from_pyarrow_types[typ](nullable=nullable)
100134

@@ -154,6 +188,8 @@ def from_ibis(cls, dtype: dt.DataType) -> pa.DataType:
154188
nullable=dtype.value_type.nullable,
155189
)
156190
return pa.map_(key_field, value_field, keys_sorted=False)
191+
elif dtype.is_json():
192+
return PYARROW_JSON_TYPE
157193
else:
158194
try:
159195
return _to_pyarrow_types[type(dtype)]
@@ -254,3 +290,7 @@ def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table:
254290
return table.cast(desired_schema)
255291
else:
256292
return table
293+
294+
295+
PYARROW_JSON_TYPE = JSONType()
296+
pa.register_extension_type(PYARROW_JSON_TYPE)

0 commit comments

Comments
 (0)