Skip to content

Commit 52fa8a9

Browse files
authored
refactor(ir): Don't bypass argument coercion and validation for user defined functions (#3001)
* Inherit slots instead of redefining them * Don't bypass argument coercion and validation for user defined functions * Update return type's attribute name * Use VectorizedUDF base * Add tests * Remove/rewrite scalar udf tests * Remove test from dask backend
1 parent 27f1c91 commit 52fa8a9

File tree

8 files changed

+131
-91
lines changed

8 files changed

+131
-91
lines changed

ibis/backends/dask/tests/test_udf.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,6 @@ def test_udf(t, df):
175175
tm.assert_series_equal(result, expected, check_names=False)
176176

177177

178-
def test_elementwise_udf_with_non_vectors(con):
179-
expr = my_add(1.0, 2.0)
180-
result = con.execute(expr)
181-
assert result == 3.0
182-
183-
184178
def test_multiple_argument_udf(con, t, df):
185179
expr = my_add(t.b, t.c)
186180

ibis/backends/dask/udf.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def make_struct_op_meta(op: ir.Expr) -> List[Tuple[str, np.dtype]]:
2626
"""Unpacks a dt.Struct into a DataFrame meta"""
2727
return list(
2828
zip(
29-
op._output_type.names,
30-
[x.to_dask() for x in op._output_type.types],
29+
op.return_type.names,
30+
[x.to_dask() for x in op.return_type.types],
3131
)
3232
)
3333

@@ -72,16 +72,14 @@ def execute_udf_node(op, *args, **kwargs):
7272
# kwargs here. This is true for all udf execution in this
7373
# file.
7474
# See ibis.udf.vectorized.UserDefinedFunction
75-
if isinstance(op._output_type, dt.Struct):
75+
if isinstance(op.return_type, dt.Struct):
7676
meta = make_struct_op_meta(op)
7777

7878
df = dd.map_partitions(op.func, *args, meta=meta)
7979
return df
8080
else:
8181
name = args[0].name if len(args) == 1 else None
82-
meta = pandas.Series(
83-
[], name=name, dtype=op._output_type.to_dask()
84-
)
82+
meta = pandas.Series([], name=name, dtype=op.return_type.to_dask())
8583
df = dd.map_partitions(op.func, *args, meta=meta)
8684

8785
return df
@@ -124,11 +122,11 @@ def lazy_agg(*series: pandas.Series):
124122
# Depending on the type of operation, lazy_result is a Delayed that
125123
# could become a dd.Series or a dd.core.Scalar
126124
if isinstance(op, ops.AnalyticVectorizedUDF):
127-
if isinstance(op._output_type, dt.Struct):
125+
if isinstance(op.return_type, dt.Struct):
128126
meta = make_struct_op_meta(op)
129127
else:
130128
meta = make_meta_series(
131-
dtype=op._output_type.to_dask(),
129+
dtype=op.return_type.to_dask(),
132130
name=args[0].name,
133131
)
134132
result = dd.from_delayed(lazy_result, meta=meta)
@@ -151,13 +149,13 @@ def lazy_agg(*series: pandas.Series):
151149
result = result.repartition(divisions=original_divisions)
152150
else:
153151
# lazy_result is a dd.core.Scalar from an ungrouped reduction
154-
if isinstance(op._output_type, (dt.Array, dt.Struct)):
152+
if isinstance(op.return_type, (dt.Array, dt.Struct)):
155153
# we're outputing a dt.Struct that will need to be destructured
156154
# or an array of an unknown size.
157155
# we compute so we can work with items inside downstream.
158156
result = lazy_result.compute()
159157
else:
160-
output_meta = safe_scalar_type(op._output_type.to_dask())
158+
output_meta = safe_scalar_type(op.return_type.to_dask())
161159
result = dd.from_delayed(
162160
lazy_result, meta=output_meta, verify_meta=False
163161
)
@@ -181,7 +179,7 @@ def execute_reduction_node_groupby(op, *args, aggcontext, **kwargs):
181179
func = op.func
182180
groupings = args[0].index
183181
parent_df = args[0].obj
184-
out_type = op._output_type.to_dask()
182+
out_type = op.return_type.to_dask()
185183

186184
grouped_df = parent_df.groupby(groupings)
187185
col_names = [col._meta._selected_obj.name for col in args]
@@ -223,7 +221,7 @@ def execute_analytic_node_groupby(op, *args, aggcontext, **kwargs):
223221
func = op.func
224222
groupings = args[0].index
225223
parent_df = args[0].obj
226-
out_type = op._output_type.to_dask()
224+
out_type = op.return_type.to_dask()
227225

228226
grouped_df = parent_df.groupby(groupings)
229227
col_names = [col._meta._selected_obj.name for col in args]
@@ -232,7 +230,7 @@ def apply_wrapper(df, apply_func, col_names):
232230
cols = (df[col] for col in col_names)
233231
return apply_func(*cols)
234232

235-
if isinstance(op._output_type, dt.Struct):
233+
if isinstance(op.return_type, dt.Struct):
236234
# with struct output we destruct to a dataframe directly
237235
meta = dd.utils.make_meta(make_struct_op_meta(op))
238236
meta.index.name = parent_df.index.name

ibis/backends/pandas/tests/execution/test_functions.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,16 @@ def test_execute_with_same_hash_value_in_scope(
216216
def my_func(x, y):
217217
return x
218218

219-
expr = my_func(left, right)
219+
df = pd.DataFrame({"left": [left], "right": [right]})
220+
table = ibis.pandas.from_dataframe(df)
221+
222+
expr = my_func(table.left, table.right)
220223
result = execute(expr)
221-
assert type(result) is expected_type
222-
assert result == expected_value
224+
assert isinstance(result, pd.Series)
225+
226+
result = result.tolist()
227+
assert result == [expected_value]
228+
assert type(result[0]) is expected_type
223229

224230

225231
def test_ifelse_returning_bool():
@@ -248,7 +254,12 @@ def test_signature_does_not_match_input_type(dtype, value):
248254
def func(x):
249255
return x
250256

251-
expr = func(value)
252-
result = execute(expr)
253-
assert type(result) == type(value)
254-
assert result == value
257+
df = pd.DataFrame({"col": [value]})
258+
table = ibis.pandas.from_dataframe(df)
259+
260+
result = execute(table.col)
261+
assert isinstance(result, pd.Series)
262+
263+
result = result.tolist()
264+
assert result == [value]
265+
assert type(result[0]) is type(value)

ibis/backends/pandas/tests/test_udf.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,6 @@ def test_udf(t, df):
111111
tm.assert_series_equal(result, expected)
112112

113113

114-
def test_elementwise_udf_with_non_vectors(con):
115-
expr = my_add(1.0, 2.0)
116-
result = con.execute(expr)
117-
assert result == 3.0
118-
119-
120114
def test_multiple_argument_udf(con, t, df):
121115
expr = my_add(t.b, t.c)
122116

ibis/backends/pyspark/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,7 +1798,7 @@ def compile_fillna_table(t, expr, scope, timecontext, **kwargs):
17981798
@compiles(ops.ElementWiseVectorizedUDF)
17991799
def compile_elementwise_udf(t, expr, scope, timecontext, **kwargs):
18001800
op = expr.op()
1801-
spark_output_type = spark_dtype(op._output_type)
1801+
spark_output_type = spark_dtype(op.return_type)
18021802
func = op.func
18031803
spark_udf = pandas_udf(func, spark_output_type, PandasUDFType.SCALAR)
18041804
func_args = (t.translate(arg, scope, timecontext) for arg in op.func_args)
@@ -1809,7 +1809,7 @@ def compile_elementwise_udf(t, expr, scope, timecontext, **kwargs):
18091809
def compile_reduction_udf(t, expr, scope, timecontext, context=None, **kwargs):
18101810
op = expr.op()
18111811

1812-
spark_output_type = spark_dtype(op._output_type)
1812+
spark_output_type = spark_dtype(op.return_type)
18131813
spark_udf = pandas_udf(
18141814
op.func, spark_output_type, PandasUDFType.GROUPED_AGG
18151815
)

ibis/expr/operations/vectorized.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from types import FunctionType, LambdaType
2+
13
from public import public
24

35
from .. import rules as rlz
@@ -7,79 +9,39 @@
79
from .reductions import Reduction
810

911

10-
@public
11-
class ElementWiseVectorizedUDF(ValueOp):
12-
"""Node for element wise UDF."""
13-
14-
func = Arg(callable)
15-
func_args = Arg(tuple)
16-
input_type = Arg(rlz.shape_like('func_args'))
17-
_output_type = Arg(rlz.noop)
18-
19-
def __init__(self, func, args, input_type, output_type):
20-
self.func = func
21-
self.func_args = args
22-
self.input_type = input_type
23-
self._output_type = output_type
12+
class VectorizedUDF(ValueOp):
13+
func = Arg(rlz.instance_of((FunctionType, LambdaType)))
14+
func_args = Arg(rlz.list_of(rlz.column(rlz.any)))
15+
input_type = Arg(rlz.list_of(rlz.datatype))
16+
return_type = Arg(rlz.datatype)
2417

2518
@property
2619
def inputs(self):
2720
return self.func_args
2821

29-
def output_type(self):
30-
return self._output_type.column_type()
31-
3222
def root_tables(self):
3323
return distinct_roots(*self.func_args)
3424

3525

3626
@public
37-
class ReductionVectorizedUDF(Reduction):
38-
"""Node for reduction UDF."""
27+
class ElementWiseVectorizedUDF(VectorizedUDF):
28+
"""Node for element wise UDF."""
3929

40-
func = Arg(callable)
41-
func_args = Arg(tuple)
42-
input_type = Arg(rlz.shape_like('func_args'))
43-
_output_type = Arg(rlz.noop)
30+
def output_type(self):
31+
return self.return_type.column_type()
4432

45-
def __init__(self, func, args, input_type, output_type):
46-
self.func = func
47-
self.func_args = args
48-
self.input_type = input_type
49-
self._output_type = output_type
5033

51-
@property
52-
def inputs(self):
53-
return self.func_args
34+
@public
35+
class ReductionVectorizedUDF(VectorizedUDF, Reduction):
36+
"""Node for reduction UDF."""
5437

5538
def output_type(self):
56-
return self._output_type.scalar_type()
57-
58-
def root_tables(self):
59-
return distinct_roots(*self.func_args)
39+
return self.return_type.scalar_type()
6040

6141

6242
@public
63-
class AnalyticVectorizedUDF(AnalyticOp):
43+
class AnalyticVectorizedUDF(VectorizedUDF, AnalyticOp):
6444
"""Node for analytics UDF."""
6545

66-
func = Arg(callable)
67-
func_args = Arg(tuple)
68-
input_type = Arg(rlz.shape_like('func_args'))
69-
_output_type = Arg(rlz.noop)
70-
71-
def __init__(self, func, args, input_type, output_type):
72-
self.func = func
73-
self.func_args = args
74-
self.input_type = input_type
75-
self._output_type = output_type
76-
77-
@property
78-
def inputs(self):
79-
return self.func_args
80-
8146
def output_type(self):
82-
return self._output_type.column_type()
83-
84-
def root_tables(self):
85-
return distinct_roots(*self.func_args)
47+
return self.return_type.column_type()

ibis/tests/expr/test_udf.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
3+
import ibis
4+
import ibis.common.exceptions as com
5+
import ibis.expr.datatypes as dt
6+
import ibis.expr.operations as ops
7+
import ibis.expr.types as ir
8+
9+
10+
@pytest.fixture
11+
def table():
12+
return ibis.table(
13+
[
14+
("a", "int8"),
15+
("b", "string"),
16+
("c", "bool"),
17+
],
18+
name="test",
19+
)
20+
21+
22+
@pytest.mark.parametrize(
23+
("klass", "output_type"),
24+
[
25+
(ops.ElementWiseVectorizedUDF, ir.IntegerColumn),
26+
(ops.ReductionVectorizedUDF, ir.IntegerScalar),
27+
(ops.AnalyticVectorizedUDF, ir.IntegerColumn),
28+
],
29+
)
30+
def test_vectorized_udf_operations(table, klass, output_type):
31+
udf = klass(
32+
func=lambda a, b, c: a,
33+
func_args=[table.a, table.b, table.c],
34+
input_type=[dt.int8(), dt.string(), dt.boolean()],
35+
return_type=dt.int8(),
36+
)
37+
assert udf.func_args[0].equals(table.a)
38+
assert udf.func_args[1].equals(table.b)
39+
assert udf.func_args[2].equals(table.c)
40+
assert udf.input_type == [dt.int8(), dt.string(), dt.boolean()]
41+
assert udf.return_type == dt.int8()
42+
43+
factory = udf.output_type()
44+
expr = factory(udf)
45+
assert isinstance(expr, output_type)
46+
47+
with pytest.raises(com.IbisTypeError):
48+
# wrong function type
49+
klass(
50+
func=1,
51+
func_args=[ibis.literal(1), table.b, table.c],
52+
input_type=[dt.int8(), dt.string(), dt.boolean()],
53+
return_type=dt.int8(),
54+
)
55+
56+
with pytest.raises(com.IbisTypeError):
57+
# scalar type instead of column type
58+
klass(
59+
func=lambda a, b, c: a,
60+
func_args=[ibis.literal(1), table.b, table.c],
61+
input_type=[dt.int8(), dt.string(), dt.boolean()],
62+
return_type=dt.int8(),
63+
)
64+
65+
with pytest.raises(com.IbisTypeError):
66+
# wrong input type
67+
klass(
68+
func=lambda a, b, c: a,
69+
func_args=[ibis.literal(1), table.b, table.c],
70+
input_type="int8",
71+
return_type=dt.int8(),
72+
)
73+
74+
with pytest.raises(com.IbisTypeError):
75+
# wrong return type
76+
klass(
77+
func=lambda a, b, c: a,
78+
func_args=[ibis.literal(1), table.b, table.c],
79+
input_type=[dt.int8(), dt.string(), dt.boolean()],
80+
return_type=table,
81+
)

ibis/udf/vectorized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def func(*args):
8585

8686
op = self.func_type(
8787
func=func,
88-
args=args,
88+
func_args=args,
8989
input_type=self.input_type,
90-
output_type=self.output_type,
90+
return_type=self.output_type,
9191
)
9292

9393
return op.to_expr()

0 commit comments

Comments
 (0)