Skip to content

Commit 657b8ab

Browse files
authored
feat(sql-udf): add semantic check when creating sql udf (#14549)
1 parent 6414832 commit 657b8ab

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

e2e_test/udf/sql_udf.slt

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ create function add_return(INT, INT) returns int language sql return $1 + $2;
2828
statement ok
2929
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);
3030

31-
# Recursive definition can be accepted, but will be eventually rejected during runtime
32-
statement ok
31+
# Recursive definition can NOT be accepted at present due to semantic check
32+
statement error failed to conduct semantic check, please see if you are calling non-existence functions
3333
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';
3434

3535
# Complex but error-prone definition, recursive & normal sql udfs interleaving
36-
statement ok
36+
statement error failed to conduct semantic check, please see if you are calling non-existence functions
3737
create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)';
3838

3939
# Recursive corner case
@@ -46,7 +46,7 @@ create function add_sub_wrapper(INT, INT) returns int language sql as 'select ad
4646

4747
# Create a valid recursive function
4848
# Please note we do NOT support actual running the recursive sql udf at present
49-
statement ok
49+
statement error failed to conduct semantic check, please see if you are calling non-existence functions
5050
create function fib(INT) returns int
5151
language sql as 'select case
5252
when $1 = 0 then 0
@@ -57,12 +57,12 @@ create function fib(INT) returns int
5757
end;';
5858

5959
# The execution will eventually exceed the pre-defined max stack depth
60-
statement error function fib calling stack depth limit exceeded
61-
select fib(100);
60+
# statement error function fib calling stack depth limit exceeded
61+
# select fib(100);
6262

6363
# Currently create a materialized view with a recursive sql udf will be rejected
64-
statement error function fib calling stack depth limit exceeded
65-
create materialized view foo_mv as select fib(100);
64+
# statement error function fib calling stack depth limit exceeded
65+
# create materialized view foo_mv as select fib(100);
6666

6767
statement ok
6868
create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$;
@@ -77,6 +77,10 @@ create function print_add_one(INT) returns int language sql as 'select print($1
7777
statement ok
7878
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';
7979

80+
# Calling a non-existence function
81+
statement error failed to conduct semantic check, please see if you are calling non-existence functions
82+
create function non_exist(INT) returns int language sql as 'select yo(114514)';
83+
8084
# Call the defined sql udf
8185
query I
8286
select add(1, -1);
@@ -124,12 +128,12 @@ select foo(114514);
124128
foo(INT)
125129

126130
# Rejected deep calling stack
127-
statement error function recursive calling stack depth limit exceeded
128-
select recursive(1, 1);
131+
# statement error function recursive calling stack depth limit exceeded
132+
# select recursive(1, 1);
129133

130134
# Same as above
131-
statement error function recursive calling stack depth limit exceeded
132-
select recursive_non_recursive(1, 1);
135+
# statement error function recursive calling stack depth limit exceeded
136+
# select recursive_non_recursive(1, 1);
133137

134138
query I
135139
select add_sub_wrapper(1, 1);
@@ -168,12 +172,12 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
168172
5 5 10
169173

170174
# Recursive sql udf with normal table
171-
statement error function fib calling stack depth limit exceeded
172-
select fib(c1) from t1;
175+
# statement error function fib calling stack depth limit exceeded
176+
# select fib(c1) from t1;
173177

174178
# Recursive sql udf with materialized view
175-
statement error function fib calling stack depth limit exceeded
176-
create materialized view bar_mv as select fib(c1) from t1;
179+
# statement error function fib calling stack depth limit exceeded
180+
# create materialized view bar_mv as select fib(c1) from t1;
177181

178182
# Invalid function body syntax
179183
statement error Expected an expression:, found: EOF at the end
@@ -259,20 +263,20 @@ drop function call_regexp_replace;
259263
statement ok
260264
drop function add_sub_wrapper;
261265

262-
statement ok
263-
drop function recursive;
266+
# statement ok
267+
# drop function recursive;
264268

265269
statement ok
266270
drop function foo;
267271

268-
statement ok
269-
drop function recursive_non_recursive;
272+
# statement ok
273+
# drop function recursive_non_recursive;
270274

271275
statement ok
272276
drop function add_sub_types;
273277

274-
statement ok
275-
drop function fib;
278+
# statement ok
279+
# drop function fib;
276280

277281
statement ok
278282
drop function print;

src/frontend/src/binder/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ impl Binder {
415415
pub fn set_clause(&mut self, clause: Option<Clause>) {
416416
self.context.clause = clause;
417417
}
418+
419+
pub fn udf_context_mut(&mut self) -> &mut UdfContext {
420+
&mut self.udf_context
421+
}
418422
}
419423

420424
/// The column name stored in [`BindContext`] for a column without an alias.

src/frontend/src/handler/create_sql_function.rs

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashMap;
16+
1517
use itertools::Itertools;
1618
use pgwire::pg_response::StatementType;
1719
use risingwave_common::catalog::FunctionId;
@@ -25,8 +27,60 @@ use risingwave_sqlparser::parser::{Parser, ParserError};
2527

2628
use super::*;
2729
use crate::catalog::CatalogError;
30+
use crate::expr::{ExprImpl, Literal};
2831
use crate::{bind_data_type, Binder};
2932

33+
/// Create a mock `udf_context`, which is used for semantic check
34+
fn create_mock_udf_context(arg_types: Vec<DataType>) -> HashMap<String, ExprImpl> {
35+
(1..=arg_types.len())
36+
.map(|i| {
37+
let mock_expr =
38+
ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone())));
39+
(format!("${i}"), mock_expr.clone())
40+
})
41+
.collect()
42+
}
43+
44+
fn extract_udf_expression(ast: Vec<Statement>) -> Result<Expr> {
45+
if ast.len() != 1 {
46+
return Err(ErrorCode::InvalidInputSyntax(
47+
"the query for sql udf should contain only one statement".to_string(),
48+
)
49+
.into());
50+
}
51+
52+
// Extract the expression out
53+
let Statement::Query(query) = ast[0].clone() else {
54+
return Err(ErrorCode::InvalidInputSyntax(
55+
"invalid function definition, please recheck the syntax".to_string(),
56+
)
57+
.into());
58+
};
59+
60+
let SetExpr::Select(select) = query.body else {
61+
return Err(ErrorCode::InvalidInputSyntax(
62+
"missing `select` body for sql udf expression, please recheck the syntax".to_string(),
63+
)
64+
.into());
65+
};
66+
67+
if select.projection.len() != 1 {
68+
return Err(ErrorCode::InvalidInputSyntax(
69+
"`projection` should contain only one `SelectItem`".to_string(),
70+
)
71+
.into());
72+
}
73+
74+
let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
75+
return Err(ErrorCode::InvalidInputSyntax(
76+
"expect `UnnamedExpr` for `projection`".to_string(),
77+
)
78+
.into());
79+
};
80+
81+
Ok(expr)
82+
}
83+
3084
pub async fn handle_create_sql_function(
3185
handler_args: HandlerArgs,
3286
or_replace: bool,
@@ -45,7 +99,8 @@ pub async fn handle_create_sql_function(
4599
}
46100

47101
let language = "sql".to_string();
48-
// Just a basic sanity check for language
102+
103+
// Just a basic sanity check for `language`
49104
if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") {
50105
return Err(ErrorCode::InvalidParameterValue(
51106
"`language` for sql udf must be `sql`".to_string(),
@@ -149,6 +204,30 @@ pub async fn handle_create_sql_function(
149204
return Err(ErrorCode::InvalidInputSyntax(err).into());
150205
} else {
151206
debug_assert!(parse_result.is_ok());
207+
208+
// Conduct semantic check (e.g., see if the inner calling functions exist, etc.)
209+
let ast = parse_result.unwrap();
210+
let mut binder = Binder::new_for_system(session);
211+
212+
binder
213+
.udf_context_mut()
214+
.update_context(create_mock_udf_context(arg_types.clone()));
215+
216+
if let Ok(expr) = extract_udf_expression(ast) {
217+
if let Err(e) = binder.bind_expr(expr) {
218+
return Err(ErrorCode::InvalidInputSyntax(
219+
format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}")
220+
)
221+
.into());
222+
}
223+
} else {
224+
return Err(ErrorCode::InvalidInputSyntax(
225+
"failed to parse the input query and extract the udf expression,
226+
please recheck the syntax"
227+
.to_string(),
228+
)
229+
.into());
230+
}
152231
}
153232

154233
// Create the actual function, will be stored in function catalog

0 commit comments

Comments
 (0)