Skip to content

feat(udf): add initial support for JavaScript UDF #14513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 139 additions & 97 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ prost = { version = "0.12" }
icelake = { git = "https://github.com/icelake-io/icelake", rev = "32c0bbf242f5c47b1e743f10577012fe7436c770", features = [
"prometheus",
] }
arrow-array = "49"
arrow-arith = "49"
arrow-cast = "49"
arrow-schema = "49"
arrow-buffer = "49"
arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-array = "50"
arrow-arith = "50"
arrow-cast = "50"
arrow-schema = "50"
arrow-buffer = "50"
arrow-flight = "50"
arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "7ba1c22" }
arrow-udf-wasm = "0.1"
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand All @@ -143,7 +144,7 @@ arrow-schema-deltalake = { package = "arrow-schema", version = "48.0.1" }
deltalake = { git = "https://github.com/risingwavelabs/delta-rs", rev = "5c2dccd4640490202ffe98adbd13b09cef8e007b", features = [
"s3-no-concurrent-write",
] }
parquet = "49"
parquet = "50"
thiserror-ext = "0.0.11"
tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" }
tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [
Expand Down
154 changes: 154 additions & 0 deletions e2e_test/udf/js_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
statement ok
create function int_42() returns int language javascript as $$
return 42;
$$;

query I
select int_42();
----
42

statement ok
drop function int_42;


statement ok
create function gcd(a int, b int) returns int language javascript as $$
// required before we support `RETURNS NULL ON NULL INPUT`
if(a == null || b == null) {
return null;
}
while (b != 0) {
let t = b;
b = a % b;
a = t;
}
return a;
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;


statement ok
create function decimal_add(a decimal, b decimal) returns decimal language javascript as $$
return a + b;
$$;

query R
select decimal_add(1.11, 2.22);
----
3.33

statement ok
drop function decimal_add;


statement ok
create function to_string(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns varchar language javascript as $$
return a.toString() + b.toString() + c.toString() + d.toString() + e.toString() + f.toString() + g.toString() + h.toString() + i.toString() + JSON.stringify(j);
$$;

query T
select to_string(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
----
false1234.56.78.9abc1,2,3{"key":1}

statement ok
drop function to_string;


# show data types in javascript
statement ok
create function js_typeof(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns jsonb language javascript as $$
return {
boolean: typeof a,
smallint: typeof b,
int: typeof c,
bigint: typeof d,
real: typeof e,
float: typeof f,
decimal: typeof g,
varchar: typeof h,
bytea: typeof i,
jsonb: typeof j,
};
$$;

query T
select js_typeof(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
----
{"bigint": "number", "boolean": "boolean", "bytea": "object", "decimal": "bigdecimal", "float": "number", "int": "number", "jsonb": "object", "real": "number", "smallint": "number", "varchar": "string"}

statement ok
drop function js_typeof;


statement ok
create function return_all(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>)
returns struct<a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>>
language javascript as $$
return {a,b,c,d,e,f,g,h,i,j,s};
$$;

query T
select (return_all(
true,
1 ::smallint,
1,
1,
1,
1,
12345678901234567890.12345678,
'string',
'bytes',
'{"key":1}',
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 12345678901234567890.12345678 string \x6279746573 {"key": 1} (1,2)

statement ok
drop function return_all;


statement ok
create function series(n int) returns table (x int) language javascript as $$
for(let i = 0; i < n; i++) {
yield i;
}
$$;

query I
select series(5);
----
0
1
2
3
4

statement ok
drop function series;


statement ok
create function split(s varchar) returns table (word varchar, length int) language javascript as $$
for(let word of s.split(' ')) {
yield { word: word, length: word.length };
}
$$;

query IT
select * from split('rising wave');
----
rising 6
wave 4

statement ok
drop function split;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-udf = "0.1"
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
5 changes: 3 additions & 2 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ message Function {
uint32 database_id = 3;
string name = 4;
uint32 owner = 9;
repeated string arg_names = 15;
repeated data.DataType arg_types = 5;
data.DataType return_type = 6;
string language = 7;
string link = 8;
string identifier = 10;
optional string link = 8;
optional string identifier = 10;
optional string body = 14;

oneof kind {
Expand Down
14 changes: 10 additions & 4 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,27 @@ message WindowFunction {
message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
optional string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
// For JavaScript UDF, it's the name of the function.
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
string link = 5;
string identifier = 6;
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
}
1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-js = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
Expand Down
33 changes: 28 additions & 5 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::Duration;

use anyhow::Context;
use arrow_schema::{Field, Fields, Schema};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
Expand Down Expand Up @@ -61,6 +62,7 @@ const INITIAL_RETRY_COUNT: u8 = 16;
enum UdfImpl {
External(Arc<ArrowFlightUdfClient>),
Wasm(Arc<WasmRuntime>),
JavaScript(JsRuntime),
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -123,6 +125,7 @@ impl UserDefinedFunction {

let output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if disable_retry_count != 0 {
Expand Down Expand Up @@ -189,16 +192,36 @@ impl Build for UserDefinedFunction {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

let identifier = udf.get_identifier()?;
let imp = match udf.language.as_str() {
"wasm" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(get_or_create_wasm_runtime(&udf.link))
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
}
_ => UdfImpl::External(get_or_create_flight_client(&udf.link)?),
"javascript" => {
let mut rt = JsRuntime::new()?;
let body = format!(
"export function {}({}) {{ {} }}",
identifier,
udf.arg_names.join(","),
udf.get_body()?
);
rt.add_function(
identifier,
arrow_schema::DataType::try_from(&return_type)?,
CallMode::CalledOnNullInput,
&body,
)?;
UdfImpl::JavaScript(rt)
}
_ => {
let link = udf.get_link()?;
UdfImpl::External(get_or_create_flight_client(link)?)
}
};

let arg_schema = Arc::new(Schema::new(
Expand All @@ -222,8 +245,8 @@ impl Build for UserDefinedFunction {
return_type,
arg_schema,
imp,
identifier: udf.identifier.clone(),
span: format!("udf_call({})", udf.identifier).into(),
identifier: identifier.clone(),
span: format!("udf_call({})", identifier).into(),
disable_retry_count: AtomicU8::new(0),
})
}
Expand Down
Loading