Skip to content

Commit acd5ee6

Browse files
wangrunji0408KveinAxelxzhseh
authored
feat(udf): support implicit cast for UDF arguments (#14338) (#14437)
Signed-off-by: Kevin Axel <[email protected]> Signed-off-by: Runji Wang <[email protected]> Co-authored-by: Kevin Axel <[email protected]> Co-authored-by: Zihao Xu <[email protected]>
1 parent e03349d commit acd5ee6

File tree

9 files changed

+167
-52
lines changed

9 files changed

+167
-52
lines changed

e2e_test/udf/udf.slt

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
115115
233276425899864771438119478
116116

117117
query I
118-
select float_to_decimal('-1e-10'::float8);
118+
select float_to_decimal('-1e-10');
119119
----
120120
-0.0000000001000000000000000036
121121

@@ -138,36 +138,36 @@ NULL
138138
false
139139

140140
query T
141-
select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
141+
select jsonb_concat(ARRAY['null', '1', '"str"', '{}'::jsonb]);
142142
----
143143
[null, 1, "str", {}]
144144

145145
query T
146-
select jsonb_array_identity(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
146+
select jsonb_array_identity(ARRAY[null, '1', '"str"', '{}'::jsonb]);
147147
----
148148
{NULL,1,"\"str\"","{}"}
149149

150150
query T
151-
select jsonb_array_struct_identity(ROW(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct<v jsonb[], len int>);
151+
select jsonb_array_struct_identity(ROW(ARRAY[null, '1', '"str"', '{}'::jsonb], 4)::struct<v jsonb[], len int>);
152152
----
153153
("{NULL,1,""\\""str\\"""",""{}""}",4)
154154

155155
query T
156156
select (return_all(
157157
true,
158158
1 ::smallint,
159-
1 ::int,
160-
1 ::bigint,
161-
1 ::float4,
162-
1 ::float8,
163-
12345678901234567890.12345678 ::decimal,
159+
1,
160+
1,
161+
1,
162+
1,
163+
12345678901234567890.12345678,
164164
date '2023-06-01',
165165
time '01:02:03.456789',
166166
timestamp '2023-06-01 01:02:03.456789',
167167
interval '1 month 2 days 3 seconds',
168168
'string',
169-
'bytes'::bytea,
170-
'{"key":1}'::jsonb,
169+
'bytes',
170+
'{"key":1}',
171171
row(1, 2)::struct<f1 int, f2 int>
172172
)).*;
173173
----
@@ -177,11 +177,11 @@ query T
177177
select (return_all_arrays(
178178
array[null, true],
179179
array[null, 1 ::smallint],
180-
array[null, 1 ::int],
180+
array[null, 1],
181181
array[null, 1 ::bigint],
182182
array[null, 1 ::float4],
183183
array[null, 1 ::float8],
184-
array[null, 12345678901234567890.12345678 ::decimal],
184+
array[null, 12345678901234567890.12345678],
185185
array[null, date '2023-06-01'],
186186
array[null, time '01:02:03.456789'],
187187
array[null, timestamp '2023-06-01 01:02:03.456789'],
@@ -197,21 +197,21 @@ select (return_all_arrays(
197197
# test large string output
198198
query I
199199
select length((return_all(
200-
null::boolean,
201-
null::smallint,
202-
null::int,
203-
null::bigint,
204-
null::float4,
205-
null::float8,
206-
null::decimal,
207-
null::date,
208-
null::time,
209-
null::timestamp,
210-
null::interval,
211-
repeat('a', 100000)::varchar,
200+
null,
201+
null,
202+
null,
203+
null,
204+
null,
205+
null,
206+
null,
207+
null,
208+
null,
209+
null,
210+
null,
211+
repeat('a', 100000),
212212
repeat('a', 100000)::bytea,
213-
null::jsonb,
214-
null::struct<f1 int, f2 int>
213+
null,
214+
null
215215
)).varchar);
216216
----
217217
100000
@@ -253,16 +253,13 @@ select count(*) from series(1000000);
253253
----
254254
1000000
255255

256-
# TODO: support argument implicit cast for UDF
257-
# e.g. extract_tcp_info(E'\\x45');
258-
259256
query T
260-
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea);
257+
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4');
261258
----
262259
(192.168.0.14,192.168.0.1,861,8374)
263260

264261
query TTII
265-
select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: BYTEA)).*;
262+
select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4')).*;
266263
----
267264
192.168.0.14 192.168.0.1 861 8374
268265

src/expr/core/src/sig/mod.rs

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
//! Metadata of expressions.
1616
17+
use std::borrow::Cow;
1718
use std::collections::HashMap;
1819
use std::fmt;
1920
use std::sync::LazyLock;
@@ -47,7 +48,7 @@ pub struct FunctionRegistry(HashMap<FuncName, Vec<FuncSign>>);
4748
impl FunctionRegistry {
4849
/// Inserts a function signature.
4950
pub fn insert(&mut self, sig: FuncSign) {
50-
let list = self.0.entry(sig.name).or_default();
51+
let list = self.0.entry(sig.name.clone()).or_default();
5152
if sig.is_aggregate() {
5253
// merge retractable and append-only aggregate
5354
if let Some(existing) = list
@@ -85,6 +86,22 @@ impl FunctionRegistry {
8586
list.push(sig);
8687
}
8788

89+
/// Remove a function signature from registry.
90+
pub fn remove(&mut self, sig: FuncSign) -> Option<FuncSign> {
91+
let pos = self
92+
.0
93+
.get_mut(&sig.name)?
94+
.iter()
95+
.positions(|s| s.inputs_type == sig.inputs_type && s.ret_type == sig.ret_type)
96+
.rev()
97+
.collect_vec();
98+
let mut ret = None;
99+
for p in pos {
100+
ret = Some(self.0.get_mut(&sig.name)?.swap_remove(p));
101+
}
102+
ret
103+
}
104+
88105
/// Returns a function signature with the same type, argument types and return type.
89106
/// Deprecated functions are included.
90107
pub fn get(
@@ -300,11 +317,12 @@ impl FuncSign {
300317
}
301318
}
302319

303-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
320+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
304321
pub enum FuncName {
305322
Scalar(ScalarFunctionType),
306323
Table(TableFunctionType),
307324
Aggregate(AggregateFunctionType),
325+
Udf(String),
308326
}
309327

310328
impl From<ScalarFunctionType> for FuncName {
@@ -333,11 +351,12 @@ impl fmt::Display for FuncName {
333351

334352
impl FuncName {
335353
/// Returns the name of the function in `UPPER_CASE` style.
336-
pub fn as_str_name(&self) -> &'static str {
354+
pub fn as_str_name(&self) -> Cow<'static, str> {
337355
match self {
338-
Self::Scalar(ty) => ty.as_str_name(),
339-
Self::Table(ty) => ty.as_str_name(),
340-
Self::Aggregate(ty) => ty.to_protobuf().as_str_name(),
356+
Self::Scalar(ty) => ty.as_str_name().into(),
357+
Self::Table(ty) => ty.as_str_name().into(),
358+
Self::Aggregate(ty) => ty.to_protobuf().as_str_name().into(),
359+
Self::Udf(name) => name.clone().into(),
341360
}
342361
}
343362

@@ -437,6 +456,7 @@ pub enum FuncBuilder {
437456
/// `None` means equal to the return type.
438457
append_only_state_type: Option<DataType>,
439458
},
459+
Udf,
440460
}
441461

442462
/// Register a function into global registry.

src/expr/impl/tests/sig.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fn test_func_sig_map() {
2929
}
3030

3131
new_map
32-
.entry(sig.name)
32+
.entry(sig.name.clone())
3333
.or_default()
3434
.entry(sig.inputs_type.to_vec())
3535
.or_default()

src/frontend/src/binder/expr/function.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ impl Binder {
117117
return self.bind_array_transform(f);
118118
}
119119

120-
let inputs = f
120+
let mut inputs = f
121121
.args
122122
.into_iter()
123123
.map(|arg| self.bind_function_arg(arg))
@@ -152,10 +152,7 @@ impl Binder {
152152
// user defined function
153153
// TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422
154154
if let Ok(schema) = self.first_valid_schema()
155-
&& let Some(func) = schema.get_function_by_name_args(
156-
&function_name,
157-
&inputs.iter().map(|arg| arg.return_type()).collect_vec(),
158-
)
155+
&& let Some(func) = schema.get_function_by_name_inputs(&function_name, &mut inputs)
159156
{
160157
use crate::catalog::function_catalog::FunctionKind::*;
161158
match &func.kind {

src/frontend/src/catalog/root_catalog.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::catalog::system_catalog::{
3737
};
3838
use crate::catalog::table_catalog::TableCatalog;
3939
use crate::catalog::{DatabaseId, IndexCatalog, SchemaId};
40+
use crate::expr::{Expr, ExprImpl};
4041

4142
#[derive(Copy, Clone)]
4243
pub enum SchemaPath<'a> {
@@ -729,6 +730,34 @@ impl Catalog {
729730
.ok_or_else(|| CatalogError::NotFound("connection", connection_name.to_string()))
730731
}
731732

733+
pub fn get_function_by_name_inputs<'a>(
734+
&self,
735+
db_name: &str,
736+
schema_path: SchemaPath<'a>,
737+
function_name: &str,
738+
inputs: &mut [ExprImpl],
739+
) -> CatalogResult<(&Arc<FunctionCatalog>, &'a str)> {
740+
schema_path
741+
.try_find(|schema_name| {
742+
Ok(self
743+
.get_schema_by_name(db_name, schema_name)?
744+
.get_function_by_name_inputs(function_name, inputs))
745+
})?
746+
.ok_or_else(|| {
747+
CatalogError::NotFound(
748+
"function",
749+
format!(
750+
"{}({})",
751+
function_name,
752+
inputs
753+
.iter()
754+
.map(|a| a.return_type().to_string())
755+
.join(", ")
756+
),
757+
)
758+
})
759+
}
760+
732761
pub fn get_function_by_name_args<'a>(
733762
&self,
734763
db_name: &str,

0 commit comments

Comments
 (0)