Skip to content

Commit 02efbb2

Browse files
NickCrewscpcloud
andauthored
feat(duckdb): forward UDF configuration dict as kwargs during registration (#10358)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent 12e6057 commit 02efbb2

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

ibis/backends/duckdb/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,22 +1735,20 @@ def _register_udfs(self, expr: ir.Expr) -> None:
17351735
registration_func(con)
17361736

17371737
def _register_udf(self, udf_node: ops.ScalarUDF):
1738-
func = udf_node.__func__
1739-
name = type(udf_node).__name__
17401738
type_mapper = self.compiler.type_mapper
17411739
input_types = [
17421740
type_mapper.to_string(param.annotation.pattern.dtype)
17431741
for param in udf_node.__signature__.parameters.values()
17441742
]
1745-
output_type = type_mapper.to_string(udf_node.dtype)
17461743

17471744
def register_udf(con):
17481745
return con.create_function(
1749-
name,
1750-
func,
1751-
input_types,
1752-
output_type,
1746+
name=type(udf_node).__name__,
1747+
function=udf_node.__func__,
1748+
parameters=input_types,
1749+
return_type=type_mapper.to_string(udf_node.dtype),
17531750
type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__],
1751+
**udf_node.__config__,
17541752
)
17551753

17561754
return register_udf

ibis/backends/duckdb/tests/test_udf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
import duckdb
34
import pytest
45
from pytest import param
56

7+
import ibis
68
from ibis import udf
79

810

@@ -103,3 +105,29 @@ def dont_intercept_null(x: int) -> int:
103105
)
104106
def test_dont_intercept_null(con, expr, expected):
105107
assert con.execute(expr) == expected
108+
109+
110+
def test_kwargs_are_forwarded(con):
111+
def nullify_two(x: int) -> int:
112+
return None if x == 2 else x
113+
114+
@udf.scalar.python
115+
def no_kwargs(x: int) -> int:
116+
return nullify_two(x)
117+
118+
@udf.scalar.python(null_handling="special")
119+
def with_kwargs(x: int) -> int:
120+
return nullify_two(x)
121+
122+
# If we return go Non-NULL -> Non-NULL, then passing null_handling="special"
123+
# will not change the result
124+
assert con.execute(no_kwargs(ibis.literal(1))) == 1
125+
assert con.execute(with_kwargs(ibis.literal(1))) == 1
126+
127+
# But, if our UDF ever goes Non-NULL -> NULL, then we NEED to pass
128+
# null_handling="special", otherwise duckdb throws an error
129+
assert con.execute(with_kwargs(ibis.literal(2))) is None
130+
131+
expr = no_kwargs(ibis.literal(2))
132+
with pytest.raises(duckdb.InvalidInputException):
133+
con.execute(expr)

0 commit comments

Comments
 (0)