Skip to content

Commit 40558fd

Browse files
jayceslesarcpcloud
authored andcommitted
feat(export): allow passing keyword arguments to PyArrow ParquetWriter and CSVWriter
1 parent e3b9611 commit 40558fd

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

ibis/backends/base/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def to_parquet(
547547
import pyarrow.parquet as pq
548548

549549
with expr.to_pyarrow_batches(params=params) as batch_reader:
550-
with pq.ParquetWriter(path, batch_reader.schema) as writer:
550+
with pq.ParquetWriter(path, batch_reader.schema, **kwargs) as writer:
551551
for batch in batch_reader:
552552
writer.write_batch(batch)
553553

@@ -582,7 +582,7 @@ def to_csv(
582582
import pyarrow.csv as pcsv
583583

584584
with expr.to_pyarrow_batches(params=params) as batch_reader:
585-
with pcsv.CSVWriter(path, batch_reader.schema) as writer:
585+
with pcsv.CSVWriter(path, batch_reader.schema, **kwargs) as writer:
586586
for batch in batch_reader:
587587
writer.write_batch(batch)
588588

ibis/backends/duckdb/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,24 @@ class Backend(AlchemyCrossSchemaBackend, CanCreateSchema):
9898
name = "duckdb"
9999
compiler = DuckDBSQLCompiler
100100
supports_create_or_replace = True
101+
reserved_csv_copy_args = [
102+
"COMPRESSION",
103+
"FORCE_QUOTE",
104+
"DATEFORMAT",
105+
"DELIM",
106+
"SEP",
107+
"ESCAPE",
108+
"HEADER",
109+
"NULLSTR",
110+
"QUOTE",
111+
"TIMESTAMP_FORMAT"
112+
]
113+
reserved_parquet_copy_args = [
114+
"COMPRESSION",
115+
"ROW_GROUP_SIZE",
116+
"ROW_GROUP_SIZE_BYTES",
117+
"FIELD_IDS",
118+
]
101119

102120
@property
103121
def settings(self) -> _Settings:
@@ -1089,7 +1107,7 @@ def to_parquet(
10891107
"""
10901108
self._run_pre_execute_hooks(expr)
10911109
query = self._to_sql(expr, params=params)
1092-
args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items())]
1110+
args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_parquet_copy_args)]
10931111
copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})"
10941112
with self.begin() as con:
10951113
con.exec_driver_sql(copy_cmd)
@@ -1127,7 +1145,7 @@ def to_csv(
11271145
args = [
11281146
"FORMAT 'csv'",
11291147
f"HEADER {int(header)}",
1130-
*(f"{k.upper()} {v!r}" for k, v in kwargs.items()),
1148+
*(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_csv_copy_args),
11311149
]
11321150
copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})"
11331151
with self.begin() as con:

ibis/backends/tests/test_export.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
import pandas.testing as tm
55
import pyarrow as pa
6+
import pyarrow.csv as pcsv
67
import pytest
78
import sqlalchemy as sa
89
from pytest import param
@@ -220,6 +221,21 @@ def test_table_to_parquet(tmp_path, backend, awards_players):
220221
backend.assert_frame_equal(awards_players.to_pandas(), df)
221222

222223

224+
@pytest.mark.notimpl(["flink"])
225+
@pytest.mark.parametrize(("kwargs"), [({"version": "1.0"}), ({"version": "2.6"})])
226+
def test_table_to_parquet_writer_kwargs(kwargs, tmp_path, backend, awards_players):
227+
outparquet = tmp_path / "out.parquet"
228+
awards_players.to_parquet(outparquet, **kwargs)
229+
230+
df = pd.read_parquet(outparquet)
231+
232+
backend.assert_frame_equal(awards_players.to_pandas(), df)
233+
234+
file = pa.parquet.ParquetFile(outparquet)
235+
236+
assert file.metadata.format_version == kwargs["version"]
237+
238+
223239
@pytest.mark.notimpl(
224240
[
225241
"bigquery",
@@ -299,6 +315,17 @@ def test_table_to_csv(tmp_path, backend, awards_players):
299315
backend.assert_frame_equal(awards_players.to_pandas(), df)
300316

301317

318+
@pytest.mark.notimpl(["flink"])
319+
@pytest.mark.parametrize(("kwargs", "delimiter"), [({"write_options": pcsv.WriteOptions(delimiter=";")}, ";"), ({"write_options": pcsv.WriteOptions(delimiter="\t")}, "\t")])
320+
def test_table_to_csv_writer_kwargs(kwargs, delimiter, tmp_path, backend, awards_players):
321+
outcsv = tmp_path / "out.csv"
322+
# avoid pandas NaNonense
323+
awards_players = awards_players.select("playerID", "awardID", "yearID", "lgID")
324+
325+
awards_players.to_csv(outcsv, **kwargs)
326+
pd.read_csv(outcsv, delimiter=delimiter)
327+
328+
302329
@pytest.mark.parametrize(
303330
("dtype", "pyarrow_dtype"),
304331
[

0 commit comments

Comments
 (0)