Skip to content

feat(udf): support implicit cast for UDF arguments #14338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions e2e_test/udf/udf.slt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might also update sql_udf.slt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this, no worries. 😄

Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90
233276425899864771438119478

query I
select float_to_decimal('-1e-10'::float8);
select float_to_decimal('-1e-10');
----
-0.0000000001000000000000000036

Expand All @@ -138,36 +138,36 @@ NULL
false

query T
select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
select jsonb_concat(ARRAY['null', '1', '"str"', '{}'::jsonb]);
----
[null, 1, "str", {}]

query T
select jsonb_array_identity(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
select jsonb_array_identity(ARRAY[null, '1', '"str"', '{}'::jsonb]);
----
{NULL,1,"\"str\"","{}"}

query T
select jsonb_array_struct_identity(ROW(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct<v jsonb[], len int>);
select jsonb_array_struct_identity(ROW(ARRAY[null, '1', '"str"', '{}'::jsonb], 4)::struct<v jsonb[], len int>);
----
("{NULL,1,""\\""str\\"""",""{}""}",4)

query T
select (return_all(
true,
1 ::smallint,
1 ::int,
1 ::bigint,
1 ::float4,
1 ::float8,
12345678901234567890.12345678 ::decimal,
1,
1,
1,
1,
12345678901234567890.12345678,
date '2023-06-01',
time '01:02:03.456789',
timestamp '2023-06-01 01:02:03.456789',
interval '1 month 2 days 3 seconds',
'string',
'bytes'::bytea,
'{"key":1}'::jsonb,
'bytes',
'{"key":1}',
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
Expand All @@ -177,11 +177,11 @@ query T
select (return_all_arrays(
array[null, true],
array[null, 1 ::smallint],
array[null, 1 ::int],
array[null, 1],
array[null, 1 ::bigint],
array[null, 1 ::float4],
array[null, 1 ::float8],
array[null, 12345678901234567890.12345678 ::decimal],
array[null, 12345678901234567890.12345678],
array[null, date '2023-06-01'],
array[null, time '01:02:03.456789'],
array[null, timestamp '2023-06-01 01:02:03.456789'],
Expand All @@ -197,21 +197,21 @@ select (return_all_arrays(
# test large string output
query I
select length((return_all(
null::boolean,
null::smallint,
null::int,
null::bigint,
null::float4,
null::float8,
null::decimal,
null::date,
null::time,
null::timestamp,
null::interval,
repeat('a', 100000)::varchar,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
repeat('a', 100000),
repeat('a', 100000)::bytea,
null::jsonb,
null::struct<f1 int, f2 int>
null,
null
)).varchar);
----
100000
Expand Down Expand Up @@ -253,16 +253,13 @@ select count(*) from series(1000000);
----
1000000

# TODO: support argument implicit cast for UDF
# e.g. extract_tcp_info(E'\\x45');

query T
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea);
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4');
----
(192.168.0.14,192.168.0.1,861,8374)

query TTII
select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: BYTEA)).*;
select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4')).*;
----
192.168.0.14 192.168.0.1 861 8374

Expand Down
32 changes: 26 additions & 6 deletions src/expr/core/src/sig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! Metadata of expressions.

use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::sync::LazyLock;
Expand Down Expand Up @@ -47,7 +48,7 @@ pub struct FunctionRegistry(HashMap<FuncName, Vec<FuncSign>>);
impl FunctionRegistry {
/// Inserts a function signature.
pub fn insert(&mut self, sig: FuncSign) {
let list = self.0.entry(sig.name).or_default();
let list = self.0.entry(sig.name.clone()).or_default();
if sig.is_aggregate() {
// merge retractable and append-only aggregate
if let Some(existing) = list
Expand Down Expand Up @@ -85,6 +86,22 @@ impl FunctionRegistry {
list.push(sig);
}

/// Remove a function signature from registry.
pub fn remove(&mut self, sig: FuncSign) -> Option<FuncSign> {
let pos = self
.0
.get_mut(&sig.name)?
.iter()
.positions(|s| s.inputs_type == sig.inputs_type && s.ret_type == sig.ret_type)
.rev()
.collect_vec();
let mut ret = None;
for p in pos {
ret = Some(self.0.get_mut(&sig.name)?.swap_remove(p));
}
ret
}

/// Returns a function signature with the same type, argument types and return type.
/// Deprecated functions are included.
pub fn get(
Expand Down Expand Up @@ -300,11 +317,12 @@ impl FuncSign {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FuncName {
Scalar(ScalarFunctionType),
Table(TableFunctionType),
Aggregate(AggregateFunctionType),
Udf(String),
}

impl From<ScalarFunctionType> for FuncName {
Expand Down Expand Up @@ -333,11 +351,12 @@ impl fmt::Display for FuncName {

impl FuncName {
/// Returns the name of the function in `UPPER_CASE` style.
pub fn as_str_name(&self) -> &'static str {
pub fn as_str_name(&self) -> Cow<'static, str> {
match self {
Self::Scalar(ty) => ty.as_str_name(),
Self::Table(ty) => ty.as_str_name(),
Self::Aggregate(ty) => ty.to_protobuf().as_str_name(),
Self::Scalar(ty) => ty.as_str_name().into(),
Self::Table(ty) => ty.as_str_name().into(),
Self::Aggregate(ty) => ty.to_protobuf().as_str_name().into(),
Self::Udf(name) => name.clone().into(),
}
}

Expand Down Expand Up @@ -437,6 +456,7 @@ pub enum FuncBuilder {
/// `None` means equal to the return type.
append_only_state_type: Option<DataType>,
},
Udf,
}

/// Register a function into global registry.
Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/tests/sig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn test_func_sig_map() {
}

new_map
.entry(sig.name)
.entry(sig.name.clone())
.or_default()
.entry(sig.inputs_type.to_vec())
.or_default()
Expand Down
8 changes: 3 additions & 5 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Binder {
// Used later in sql udf expression evaluation
let args = f.args.clone();

let inputs = f
let mut inputs = f
.args
.into_iter()
.map(|arg| self.bind_function_arg(arg))
Expand Down Expand Up @@ -224,12 +224,10 @@ impl Binder {
// user defined function
// TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422
if let Ok(schema) = self.first_valid_schema()
&& let Some(func) = schema.get_function_by_name_args(
&function_name,
&inputs.iter().map(|arg| arg.return_type()).collect_vec(),
)
&& let Some(func) = schema.get_function_by_name_inputs(&function_name, &mut inputs)
{
use crate::catalog::function_catalog::FunctionKind::*;

if func.language == "sql" {
if func.body.is_none() {
return Err(ErrorCode::InvalidInputSyntax(
Expand Down
29 changes: 29 additions & 0 deletions src/frontend/src/catalog/root_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::catalog::system_catalog::{
};
use crate::catalog::table_catalog::TableCatalog;
use crate::catalog::{DatabaseId, IndexCatalog, SchemaId};
use crate::expr::{Expr, ExprImpl};

#[derive(Copy, Clone)]
pub enum SchemaPath<'a> {
Expand Down Expand Up @@ -753,6 +754,34 @@ impl Catalog {
.ok_or_else(|| CatalogError::NotFound("connection", connection_name.to_string()))
}

pub fn get_function_by_name_inputs<'a>(
&self,
db_name: &str,
schema_path: SchemaPath<'a>,
function_name: &str,
inputs: &mut [ExprImpl],
) -> CatalogResult<(&Arc<FunctionCatalog>, &'a str)> {
schema_path
.try_find(|schema_name| {
Ok(self
.get_schema_by_name(db_name, schema_name)?
.get_function_by_name_inputs(function_name, inputs))
})?
.ok_or_else(|| {
CatalogError::NotFound(
"function",
format!(
"{}({})",
function_name,
inputs
.iter()
.map(|a| a.return_type().to_string())
.join(", ")
),
)
})
}

pub fn get_function_by_name_args<'a>(
&self,
db_name: &str,
Expand Down
Loading