Skip to content

Commit 833d895

Browse files
jakepenzakcpcloud
authored andcommitted
feat(pyspark): add partitionBy argument to create_table
Adds the partitionBy argument to create_table method in pyspark backend to enable partitioned table creation fixes #8900
1 parent 32e82c7 commit 833d895

File tree

4 files changed

+243
-8
lines changed

4 files changed

+243
-8
lines changed

ibis/backends/pyspark/__init__.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def create_table(
600600
temp: bool | None = None,
601601
overwrite: bool = False,
602602
format: str = "parquet",
603+
partition_by: str | list[str] | None = None,
603604
) -> ir.Table:
604605
"""Create a new table in Spark.
605606
@@ -623,6 +624,8 @@ def create_table(
623624
If `True`, overwrite existing data
624625
format
625626
Format of the table on disk
627+
partition_by
628+
Name(s) of partitioning column(s)
626629
627630
Returns
628631
-------
@@ -651,7 +654,9 @@ def create_table(
651654
with self._active_catalog_database(catalog, db):
652655
self._run_pre_execute_hooks(table)
653656
df = self._session.sql(query)
654-
df.write.saveAsTable(name, format=format, mode=mode)
657+
df.write.saveAsTable(
658+
name, format=format, mode=mode, partitionBy=partition_by
659+
)
655660
elif schema is not None:
656661
schema = ibis.schema(schema)
657662
schema = PySparkSchema.from_ibis(schema)
@@ -953,6 +958,45 @@ def to_delta(
953958
df = self._session.sql(self.compile(expr, params=params, limit=limit))
954959
df.write.format("delta").save(os.fspath(path), **kwargs)
955960

961+
@util.experimental
962+
def to_parquet(
963+
self,
964+
expr: ir.Table,
965+
/,
966+
path: str | Path,
967+
*,
968+
params: Mapping[ir.Scalar, Any] | None = None,
969+
limit: int | str | None = None,
970+
**kwargs: Any,
971+
) -> None:
972+
"""Write the results of executing the given expression to a Parquet file.
973+
974+
This method is eager and will execute the associated expression
975+
immediately.
976+
977+
Parameters
978+
----------
979+
expr
980+
The ibis expression to execute and persist to a Parquet file.
981+
path
982+
The data source. A string or Path to the Parquet file.
983+
params
984+
Mapping of scalar parameter expressions to value.
985+
limit
986+
An integer to effect a specific row limit. A value of `None` means
987+
"no limit". The default is in `ibis/config.py`.
988+
**kwargs
989+
Additional keyword arguments passed to
990+
[pyspark.sql.DataFrameWriter](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.html).
991+
"""
992+
if self.mode == "streaming":
993+
raise NotImplementedError(
994+
"Writing to a Parquet file in streaming mode is not supported."
995+
)
996+
self._run_pre_execute_hooks(expr)
997+
df = self._session.sql(self.compile(expr, params=params, limit=limit))
998+
df.write.format("parquet").save(os.fspath(path), **kwargs)
999+
9561000
def to_pyarrow(
9571001
self,
9581002
expr: ir.Expr,

ibis/backends/pyspark/tests/test_client.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,131 @@ def test_create_table_no_catalog(con):
5454

5555
assert "t2" not in con.list_tables(database="default")
5656
assert con.current_database != "default"
57+
58+
59+
@pytest.mark.xfail_version(pyspark=["pyspark<3.4"], reason="no catalog support")
60+
def test_create_table_with_partition_and_catalog(con):
61+
# Create a sample table with a partition column
62+
data = {
63+
"epoch": [1712848119, 1712848121, 1712848155, 1712848169],
64+
"category1": ["A", "B", "A", "C"],
65+
"category2": ["G", "J", "G", "H"],
66+
}
67+
68+
t = ibis.memtable(data)
69+
70+
# 1D partition
71+
table_name = "pt1"
72+
73+
con.create_table(
74+
table_name,
75+
database=("spark_catalog", "default"),
76+
obj=t,
77+
overwrite=True,
78+
partition_by="category1",
79+
)
80+
assert table_name in con.list_tables(database="spark_catalog.default")
81+
82+
partitions = (
83+
con.raw_sql(f"SHOW PARTITIONS spark_catalog.default.{table_name}")
84+
.toPandas()
85+
.to_dict()
86+
)
87+
expected_partitions = {
88+
"partition": {0: "category1=A", 1: "category1=B", 2: "category1=C"}
89+
}
90+
assert partitions == expected_partitions
91+
92+
# Cleanup
93+
con.drop_table(table_name, database="spark_catalog.default")
94+
assert table_name not in con.list_tables(database="spark_catalog.default")
95+
96+
# 2D partition
97+
table_name = "pt2"
98+
99+
con.create_table(
100+
table_name,
101+
database=("spark_catalog", "default"),
102+
obj=t,
103+
overwrite=True,
104+
partition_by=["category1", "category2"],
105+
)
106+
assert table_name in con.list_tables(database="spark_catalog.default")
107+
108+
partitions = (
109+
con.raw_sql(f"SHOW PARTITIONS spark_catalog.default.{table_name}")
110+
.toPandas()
111+
.to_dict()
112+
)
113+
expected_partitions = {
114+
"partition": {
115+
0: "category1=A/category2=G",
116+
1: "category1=B/category2=J",
117+
2: "category1=C/category2=H",
118+
}
119+
}
120+
assert partitions == expected_partitions
121+
122+
# Cleanup
123+
con.drop_table(table_name, database="spark_catalog.default")
124+
assert table_name not in con.list_tables(database="spark_catalog.default")
125+
126+
127+
def test_create_table_with_partition_no_catalog(con):
128+
data = {
129+
"epoch": [1712848119, 1712848121, 1712848155, 1712848169],
130+
"category1": ["A", "B", "A", "C"],
131+
"category2": ["G", "J", "G", "H"],
132+
}
133+
134+
t = ibis.memtable(data)
135+
136+
# 1D partition
137+
table_name = "pt1"
138+
139+
con.create_table(
140+
table_name,
141+
obj=t,
142+
overwrite=True,
143+
partition_by="category1",
144+
)
145+
assert table_name in con.list_tables()
146+
147+
partitions = (
148+
con.raw_sql(f"SHOW PARTITIONS ibis_testing.{table_name}").toPandas().to_dict()
149+
)
150+
expected_partitions = {
151+
"partition": {0: "category1=A", 1: "category1=B", 2: "category1=C"}
152+
}
153+
assert partitions == expected_partitions
154+
155+
# Cleanup
156+
con.drop_table(table_name)
157+
assert table_name not in con.list_tables()
158+
159+
# 2D partition
160+
table_name = "pt2"
161+
162+
con.create_table(
163+
table_name,
164+
obj=t,
165+
overwrite=True,
166+
partition_by=["category1", "category2"],
167+
)
168+
assert table_name in con.list_tables()
169+
170+
partitions = (
171+
con.raw_sql(f"SHOW PARTITIONS ibis_testing.{table_name}").toPandas().to_dict()
172+
)
173+
expected_partitions = {
174+
"partition": {
175+
0: "category1=A/category2=G",
176+
1: "category1=B/category2=J",
177+
2: "category1=C/category2=H",
178+
}
179+
}
180+
assert partitions == expected_partitions
181+
182+
# Cleanup
183+
con.drop_table(table_name)
184+
assert table_name not in con.list_tables()

ibis/backends/pyspark/tests/test_import_export.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
import pandas as pd
77
import pytest
8+
from pandas.testing import assert_frame_equal
89

910
from ibis.backends.pyspark.datatypes import PySparkSchema
11+
from ibis.conftest import IS_SPARK_REMOTE
1012

1113

1214
@pytest.mark.parametrize(
@@ -73,3 +75,40 @@ def test_to_parquet_dir(con_streaming, tmp_path):
7375
sleep(2)
7476
df = pd.concat([pd.read_parquet(f) for f in path.glob("*.parquet")])
7577
assert len(df) == 5
78+
79+
80+
@pytest.mark.skipif(
81+
IS_SPARK_REMOTE, reason="Spark remote does not support assertions about local paths"
82+
)
83+
def test_to_parquet_read_parquet(con, tmp_path):
84+
# No Partitions
85+
t_out = con.table("awards_players")
86+
87+
t_out.to_parquet(tmp_path / "out_np")
88+
89+
t_in = con.read_parquet(tmp_path / "out_np")
90+
91+
cols = list(t_out.columns)
92+
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
93+
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
94+
95+
assert_frame_equal(expected, result)
96+
97+
# Partitions
98+
t_out = con.table("awards_players")
99+
100+
t_out.to_parquet(tmp_path / "out_p", partitionBy=["playerID"])
101+
102+
# Check partition paths
103+
distinct_playerids = t_out.select("playerID").distinct().to_pandas()
104+
105+
for pid in distinct_playerids["playerID"]:
106+
assert (tmp_path / "out_p" / f"playerID={pid}").exists()
107+
108+
t_in = con.read_parquet(tmp_path / "out_p")
109+
110+
cols = list(t_out.columns)
111+
expected = t_out.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
112+
result = t_in.to_pandas()[cols].sort_values(cols).reset_index(drop=True)
113+
114+
assert_frame_equal(expected, result)

ibis/backends/tests/test_export.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def test_to_pyarrow_batches_memtable(con):
207207

208208

209209
def test_table_to_parquet(tmp_path, backend, awards_players):
210+
if backend.name() == "pyspark" and IS_SPARK_REMOTE:
211+
pytest.skip("writes to remote output directory")
210212
outparquet = tmp_path / "out.parquet"
211213
awards_players.to_parquet(outparquet)
212214

@@ -257,15 +259,32 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe
257259
outparquet = tmp_path / "out.parquet"
258260
awards_players.to_parquet(outparquet, version=version)
259261

260-
df = pd.read_parquet(outparquet)
262+
if backend.name() == "pyspark":
263+
if IS_SPARK_REMOTE:
264+
pytest.skip("writes to remote output directory")
265+
# Pyspark will write more than one parquet file under outparquet as directory
266+
parquet_files = sorted(outparquet.glob("*.parquet"))
267+
df = (
268+
pd.concat(map(pd.read_parquet, parquet_files))
269+
.sort_values(list(awards_players.columns))
270+
.reset_index(drop=True)
271+
)
272+
result = (
273+
awards_players.to_pandas()
274+
.sort_values(list(awards_players.columns))
275+
.reset_index(drop=True)
276+
)
277+
backend.assert_frame_equal(result, df)
278+
else:
279+
df = pd.read_parquet(outparquet)
261280

262-
backend.assert_frame_equal(
263-
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
264-
)
281+
backend.assert_frame_equal(
282+
awards_players.to_pandas().fillna(pd.NA), df.fillna(pd.NA)
283+
)
265284

266-
md = pa.parquet.read_metadata(outparquet)
285+
md = pa.parquet.read_metadata(outparquet)
267286

268-
assert md.format_version == version
287+
assert md.format_version == version
269288

270289

271290
@pytest.mark.notimpl(
@@ -333,7 +352,12 @@ def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):
333352

334353
getattr(con, f"to_{ftype}")(memtable, outfile)
335354

336-
assert outfile.is_file()
355+
if con.name == "pyspark" and ftype == "parquet":
356+
if IS_SPARK_REMOTE:
357+
pytest.skip("writes to remote output directory")
358+
assert outfile.is_dir()
359+
else:
360+
assert outfile.is_file()
337361

338362

339363
def test_table_to_csv(tmp_path, backend, awards_players):

0 commit comments

Comments
 (0)