Skip to content

Commit 02a1d48

Browse files
jstammerscpcloud
andauthored
feat(pyspark): add support for pyarrow and python UDFs (#9753)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent f8bea7e commit 02a1d48

File tree

3 files changed

+47
-16
lines changed

3 files changed

+47
-16
lines changed

ibis/backends/pyspark/__init__.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ibis.expr.api import Watermark
4747

4848
PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")
49-
49+
PYSPARK_LT_35 = vparse(pyspark.__version__) < vparse("3.5")
5050
ConnectionMode = Literal["streaming", "batch"]
5151

5252

@@ -359,18 +359,26 @@ def wrapper(*args):
359359
def _register_udfs(self, expr: ir.Expr) -> None:
360360
node = expr.op()
361361
for udf in node.find(ops.ScalarUDF):
362-
if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN):
363-
raise NotImplementedError(
364-
"Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend"
365-
)
366-
# register pandas UDFs
362+
udf_name = self.compiler.__sql_name__(udf)
363+
udf_return = PySparkType.from_ibis(udf.dtype)
367364
if udf.__input_type__ == InputType.PANDAS:
368-
udf_name = self.compiler.__sql_name__(udf)
369365
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
370-
udf_return = PySparkType.from_ibis(udf.dtype)
371366
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR)
372-
self._session.udf.register(udf_name, spark_udf)
373-
367+
elif udf.__input_type__ == InputType.PYTHON:
368+
udf_func = udf.__func__
369+
spark_udf = F.udf(udf_func, udf_return)
370+
elif udf.__input_type__ == InputType.PYARROW:
371+
# raise not implemented error if running on pyspark < 3.5
372+
if PYSPARK_LT_35:
373+
raise NotImplementedError(
374+
"pyarrow UDFs are only supported in pyspark >= 3.5"
375+
)
376+
udf_func = udf.__func__
377+
spark_udf = F.udf(udf_func, udf_return, useArrow=True)
378+
else:
379+
# Builtin functions don't need to be registered
380+
continue
381+
self._session.udf.register(udf_name, spark_udf)
374382
for udf in node.find(ops.ElementWiseVectorizedUDF):
375383
udf_name = self.compiler.__sql_name__(udf)
376384
udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type)

ibis/backends/pyspark/tests/test_udf.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import ibis
7+
from ibis.backends.pyspark import PYSPARK_LT_35
78

89
pytest.importorskip("pyspark")
910

@@ -22,12 +23,36 @@ def df(con):
2223
def repeat(x, n) -> str: ...
2324

2425

26+
@ibis.udf.scalar.python
27+
def py_repeat(x: str, n: int) -> str:
28+
return x * n
29+
30+
31+
@ibis.udf.scalar.pyarrow
32+
def pyarrow_repeat(x: str, n: int) -> str:
33+
return x * n
34+
35+
2536
def test_builtin_udf(t, df):
2637
result = t.mutate(repeated=repeat(t.str_col, 2)).execute()
2738
expected = df.assign(repeated=df.str_col * 2)
2839
tm.assert_frame_equal(result, expected)
2940

3041

42+
def test_python_udf(t, df):
43+
result = t.mutate(repeated=py_repeat(t.str_col, 2)).execute()
44+
expected = df.assign(repeated=df.str_col * 2)
45+
tm.assert_frame_equal(result, expected)
46+
47+
48+
@pytest.mark.xfail(PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
49+
def test_pyarrow_udf(t, df):
50+
result = t.mutate(repeated=pyarrow_repeat(t.str_col, 2)).execute()
51+
expected = df.assign(repeated=df.str_col * 2)
52+
tm.assert_frame_equal(result, expected)
53+
54+
55+
@pytest.mark.xfail(not PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
3156
def test_illegal_udf_type(t):
3257
@ibis.udf.scalar.pyarrow
3358
def my_add_one(x) -> str:
@@ -39,6 +64,6 @@ def my_add_one(x) -> str:
3964

4065
with pytest.raises(
4166
NotImplementedError,
42-
match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend",
67+
match="pyarrow UDFs are only supported in pyspark >= 3.5",
4368
):
4469
expr.execute()

ibis/backends/tests/test_udf.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
@no_python_udfs
3636
@cloudpickle_version_mismatch
37-
@mark.notimpl(["pyspark"])
3837
@mark.notyet(["datafusion"], raises=NotImplementedError)
3938
def test_udf(batting):
4039
@udf.scalar.python
@@ -59,7 +58,6 @@ def num_vowels(s: str, include_y: bool = False) -> int:
5958

6059
@no_python_udfs
6160
@cloudpickle_version_mismatch
62-
@mark.notimpl(["pyspark"])
6361
@mark.notyet(
6462
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
6563
)
@@ -89,7 +87,6 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
8987

9088
@no_python_udfs
9189
@cloudpickle_version_mismatch
92-
@mark.notimpl(["pyspark"])
9390
@mark.notyet(
9491
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
9592
)
@@ -174,10 +171,11 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
174171
add_one_pyarrow,
175172
marks=[
176173
mark.notyet(
177-
["snowflake", "sqlite", "pyspark", "flink"],
174+
["snowflake", "sqlite", "flink"],
178175
raises=NotImplementedError,
179176
reason="backend doesn't support pyarrow UDFs",
180-
)
177+
),
178+
mark.xfail_version(pyspark=["pyspark<3.5"]),
181179
],
182180
),
183181
],

0 commit comments

Comments
 (0)