Skip to content

Commit ce3d6a4

Browse files
committed
feat(snowflake): native pyarrow support
1 parent 7bd22af commit ce3d6a4

File tree

2 files changed

+109
-8
lines changed

2 files changed

+109
-8
lines changed

ibis/backends/snowflake/__init__.py

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import itertools
45
import json
56
import warnings
6-
from typing import Any, Iterable, Mapping
7+
import weakref
8+
from typing import TYPE_CHECKING, Any, Iterable, Mapping
79

810
import sqlalchemy as sa
911

1012
import ibis.expr.datatypes as dt
1113
import ibis.expr.operations as ops
14+
import ibis.expr.types as ir
1215
from ibis.backends.base.sql.alchemy import (
1316
AlchemyCompiler,
1417
AlchemyExprTranslator,
1518
BaseAlchemyBackend,
1619
)
1720
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
1821

22+
if TYPE_CHECKING:
23+
import pyarrow as pa
24+
1925

2026
@contextlib.contextmanager
2127
def _handle_pyarrow_warning(*, action: str):
@@ -30,7 +36,7 @@ def _handle_pyarrow_warning(*, action: str):
3036

3137
with _handle_pyarrow_warning(action="error"):
3238
try:
33-
import pyarrow # noqa: F401, ICN001
39+
import pyarrow # noqa: ICN001
3440
except ImportError:
3541
_NATIVE_ARROW = False
3642
else:
@@ -167,11 +173,17 @@ def do_connect(
167173
self.database_name = dbparams["database"]
168174
if connect_args is None:
169175
connect_args = {}
170-
connect_args["converter_class"] = _SnowFlakeConverter
171-
connect_args["session_parameters"] = {
172-
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON",
173-
"STRICT_JSON_OUTPUT": "TRUE",
174-
}
176+
connect_args.setdefault("converter_class", _SnowFlakeConverter)
177+
connect_args.setdefault(
178+
"session_parameters",
179+
{
180+
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON",
181+
"STRICT_JSON_OUTPUT": "TRUE",
182+
},
183+
)
184+
self._default_connector_format = connect_args["session_parameters"].get(
185+
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, "JSON"
186+
)
175187
engine = sa.create_engine(
176188
url,
177189
connect_args=connect_args,
@@ -213,6 +225,93 @@ def normalize_name(name):
213225
self.con.dialect.normalize_name = normalize_name
214226
return res
215227

228+
def to_pyarrow(
229+
self,
230+
expr: ir.Expr,
231+
*,
232+
params: Mapping[ir.Scalar, Any] | None = None,
233+
limit: int | str | None = None,
234+
**kwargs: Any,
235+
) -> pa.Table:
236+
if not _NATIVE_ARROW:
237+
return super().to_pyarrow(expr, params=params, limit=limit, **kwargs)
238+
239+
import pyarrow as pa
240+
241+
self._register_in_memory_tables(expr)
242+
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
243+
sql = query_ast.compile()
244+
with self.begin() as con:
245+
con.exec_driver_sql(
246+
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = 'ARROW'"
247+
)
248+
res = con.execute(sql).cursor.fetch_arrow_all()
249+
con.exec_driver_sql(
250+
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = {self._default_connector_format!r}"
251+
)
252+
253+
target_schema = expr.as_table().schema().to_pyarrow()
254+
if res is None:
255+
res = pa.Table.from_pylist([], schema=target_schema)
256+
257+
if not res.schema.equals(target_schema, check_metadata=False):
258+
res = res.rename_columns(target_schema.names).cast(target_schema)
259+
260+
if isinstance(expr, ir.Column):
261+
return res[expr.get_name()]
262+
elif isinstance(expr, ir.Scalar):
263+
return res[expr.get_name()][0]
264+
return res
265+
266+
def to_pyarrow_batches(
267+
self,
268+
expr: ir.Expr,
269+
*,
270+
params: Mapping[ir.Scalar, Any] | None = None,
271+
limit: int | str | None = None,
272+
chunk_size: int = 1000000,
273+
**kwargs: Any,
274+
) -> pa.ipc.RecordBatchReader:
275+
if not _NATIVE_ARROW:
276+
return super().to_pyarrow_batches(
277+
expr, params=params, limit=limit, chunk_size=chunk_size, **kwargs
278+
)
279+
280+
import pyarrow as pa
281+
282+
self._register_in_memory_tables(expr)
283+
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
284+
sql = query_ast.compile()
285+
con = self.con.connect()
286+
con.exec_driver_sql(
287+
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = 'ARROW'"
288+
)
289+
cursor = con.execute(sql)
290+
con.exec_driver_sql(
291+
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = {self._default_connector_format!r}"
292+
)
293+
raw_cursor = cursor.cursor
294+
target_schema = expr.as_table().schema().to_pyarrow()
295+
target_columns = target_schema.names
296+
reader = pa.RecordBatchReader.from_batches(
297+
target_schema,
298+
itertools.chain.from_iterable(
299+
(
300+
t.rename_columns(target_columns)
301+
.cast(target_schema)
302+
.to_batches(max_chunksize=chunk_size)
303+
)
304+
for t in raw_cursor.fetch_arrow_batches()
305+
),
306+
)
307+
308+
def close(cursor=cursor, con=con):
309+
cursor.close()
310+
con.close()
311+
312+
weakref.finalize(reader, close)
313+
return reader
314+
216315
def _get_sqla_table(
217316
self,
218317
name: str,

ibis/backends/tests/test_export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def test_to_pyarrow_batches_memtable(con):
171171

172172
def test_no_pyarrow_message(awards_players, monkeypatch):
173173
monkeypatch.setitem(sys.modules, "pyarrow", None)
174-
with pytest.raises(ModuleNotFoundError, match="requires `pyarrow` but"):
174+
with pytest.raises(
175+
ModuleNotFoundError, match="requires `pyarrow` but|import of pyarrow halted"
176+
):
175177
awards_players.to_pyarrow()
176178

177179

0 commit comments

Comments
 (0)