1
1
from __future__ import annotations
2
2
3
3
import contextlib
4
+ import itertools
4
5
import json
5
6
import warnings
6
- from typing import Any , Iterable , Mapping
7
+ import weakref
8
+ from typing import TYPE_CHECKING , Any , Iterable , Mapping
7
9
8
10
import sqlalchemy as sa
9
11
10
12
import ibis .expr .datatypes as dt
11
13
import ibis .expr .operations as ops
14
+ import ibis .expr .types as ir
12
15
from ibis .backends .base .sql .alchemy import (
13
16
AlchemyCompiler ,
14
17
AlchemyExprTranslator ,
15
18
BaseAlchemyBackend ,
16
19
)
17
20
from ibis .backends .base .sql .alchemy .query_builder import _AlchemyTableSetFormatter
18
21
22
+ if TYPE_CHECKING :
23
+ import pyarrow as pa
24
+
19
25
20
26
@contextlib .contextmanager
21
27
def _handle_pyarrow_warning (* , action : str ):
@@ -30,7 +36,7 @@ def _handle_pyarrow_warning(*, action: str):
30
36
31
37
with _handle_pyarrow_warning (action = "error" ):
32
38
try :
33
- import pyarrow # noqa: F401, ICN001
39
+ import pyarrow # noqa: ICN001
34
40
except ImportError :
35
41
_NATIVE_ARROW = False
36
42
else :
@@ -167,11 +173,17 @@ def do_connect(
167
173
self .database_name = dbparams ["database" ]
168
174
if connect_args is None :
169
175
connect_args = {}
170
- connect_args ["converter_class" ] = _SnowFlakeConverter
171
- connect_args ["session_parameters" ] = {
172
- PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT : "JSON" ,
173
- "STRICT_JSON_OUTPUT" : "TRUE" ,
174
- }
176
+ connect_args .setdefault ("converter_class" , _SnowFlakeConverter )
177
+ connect_args .setdefault (
178
+ "session_parameters" ,
179
+ {
180
+ PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT : "JSON" ,
181
+ "STRICT_JSON_OUTPUT" : "TRUE" ,
182
+ },
183
+ )
184
+ self ._default_connector_format = connect_args ["session_parameters" ].get (
185
+ PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT , "JSON"
186
+ )
175
187
engine = sa .create_engine (
176
188
url ,
177
189
connect_args = connect_args ,
@@ -213,6 +225,93 @@ def normalize_name(name):
213
225
self .con .dialect .normalize_name = normalize_name
214
226
return res
215
227
228
+ def to_pyarrow (
229
+ self ,
230
+ expr : ir .Expr ,
231
+ * ,
232
+ params : Mapping [ir .Scalar , Any ] | None = None ,
233
+ limit : int | str | None = None ,
234
+ ** kwargs : Any ,
235
+ ) -> pa .Table :
236
+ if not _NATIVE_ARROW :
237
+ return super ().to_pyarrow (expr , params = params , limit = limit , ** kwargs )
238
+
239
+ import pyarrow as pa
240
+
241
+ self ._register_in_memory_tables (expr )
242
+ query_ast = self .compiler .to_ast_ensure_limit (expr , limit , params = params )
243
+ sql = query_ast .compile ()
244
+ with self .begin () as con :
245
+ con .exec_driver_sql (
246
+ f"ALTER SESSION SET { PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT } = 'ARROW'"
247
+ )
248
+ res = con .execute (sql ).cursor .fetch_arrow_all ()
249
+ con .exec_driver_sql (
250
+ f"ALTER SESSION SET { PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT } = { self ._default_connector_format !r} "
251
+ )
252
+
253
+ target_schema = expr .as_table ().schema ().to_pyarrow ()
254
+ if res is None :
255
+ res = pa .Table .from_pylist ([], schema = target_schema )
256
+
257
+ if not res .schema .equals (target_schema , check_metadata = False ):
258
+ res = res .rename_columns (target_schema .names ).cast (target_schema )
259
+
260
+ if isinstance (expr , ir .Column ):
261
+ return res [expr .get_name ()]
262
+ elif isinstance (expr , ir .Scalar ):
263
+ return res [expr .get_name ()][0 ]
264
+ return res
265
+
266
+ def to_pyarrow_batches (
267
+ self ,
268
+ expr : ir .Expr ,
269
+ * ,
270
+ params : Mapping [ir .Scalar , Any ] | None = None ,
271
+ limit : int | str | None = None ,
272
+ chunk_size : int = 1000000 ,
273
+ ** kwargs : Any ,
274
+ ) -> pa .ipc .RecordBatchReader :
275
+ if not _NATIVE_ARROW :
276
+ return super ().to_pyarrow_batches (
277
+ expr , params = params , limit = limit , chunk_size = chunk_size , ** kwargs
278
+ )
279
+
280
+ import pyarrow as pa
281
+
282
+ self ._register_in_memory_tables (expr )
283
+ query_ast = self .compiler .to_ast_ensure_limit (expr , limit , params = params )
284
+ sql = query_ast .compile ()
285
+ con = self .con .connect ()
286
+ con .exec_driver_sql (
287
+ f"ALTER SESSION SET { PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT } = 'ARROW'"
288
+ )
289
+ cursor = con .execute (sql )
290
+ con .exec_driver_sql (
291
+ f"ALTER SESSION SET { PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT } = { self ._default_connector_format !r} "
292
+ )
293
+ raw_cursor = cursor .cursor
294
+ target_schema = expr .as_table ().schema ().to_pyarrow ()
295
+ target_columns = target_schema .names
296
+ reader = pa .RecordBatchReader .from_batches (
297
+ target_schema ,
298
+ itertools .chain .from_iterable (
299
+ (
300
+ t .rename_columns (target_columns )
301
+ .cast (target_schema )
302
+ .to_batches (max_chunksize = chunk_size )
303
+ )
304
+ for t in raw_cursor .fetch_arrow_batches ()
305
+ ),
306
+ )
307
+
308
+ def close (cursor = cursor , con = con ):
309
+ cursor .close ()
310
+ con .close ()
311
+
312
+ weakref .finalize (reader , close )
313
+ return reader
314
+
216
315
def _get_sqla_table (
217
316
self ,
218
317
name : str ,
0 commit comments