Skip to content

Commit 705be19

Browse files
feat(udf): add initial support for JavaScript UDF (#14513)
Signed-off-by: Runji Wang <[email protected]> Co-authored-by: wangrunji0408 <[email protected]>
1 parent 3b8c942 commit 705be19

File tree

21 files changed

+538
-158
lines changed

21 files changed

+538
-158
lines changed

Cargo.lock

Lines changed: 139 additions & 97 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,17 @@ prost = { version = "0.12" }
124124
icelake = { git = "https://github.com/icelake-io/icelake", rev = "32c0bbf242f5c47b1e743f10577012fe7436c770", features = [
125125
"prometheus",
126126
] }
127-
arrow-array = "49"
128-
arrow-arith = "49"
129-
arrow-cast = "49"
130-
arrow-schema = "49"
131-
arrow-buffer = "49"
132-
arrow-flight = "49"
133-
arrow-select = "49"
134-
arrow-ord = "49"
135-
arrow-row = "49"
136-
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
127+
arrow-array = "50"
128+
arrow-arith = "50"
129+
arrow-cast = "50"
130+
arrow-schema = "50"
131+
arrow-buffer = "50"
132+
arrow-flight = "50"
133+
arrow-select = "50"
134+
arrow-ord = "50"
135+
arrow-row = "50"
136+
arrow-udf-js = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "7ba1c22" }
137+
arrow-udf-wasm = "0.1"
137138
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
138139
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
139140
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
@@ -143,7 +144,7 @@ arrow-schema-deltalake = { package = "arrow-schema", version = "48.0.1" }
143144
deltalake = { git = "https://github.com/risingwavelabs/delta-rs", rev = "5c2dccd4640490202ffe98adbd13b09cef8e007b", features = [
144145
"s3-no-concurrent-write",
145146
] }
146-
parquet = "49"
147+
parquet = "50"
147148
thiserror-ext = "0.0.11"
148149
tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" }
149150
tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [

e2e_test/udf/js_udf.slt

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
statement ok
2+
create function int_42() returns int language javascript as $$
3+
return 42;
4+
$$;
5+
6+
query I
7+
select int_42();
8+
----
9+
42
10+
11+
statement ok
12+
drop function int_42;
13+
14+
15+
statement ok
16+
create function gcd(a int, b int) returns int language javascript as $$
17+
// required before we support `RETURNS NULL ON NULL INPUT`
18+
if(a == null || b == null) {
19+
return null;
20+
}
21+
while (b != 0) {
22+
let t = b;
23+
b = a % b;
24+
a = t;
25+
}
26+
return a;
27+
$$;
28+
29+
query I
30+
select gcd(25, 15);
31+
----
32+
5
33+
34+
statement ok
35+
drop function gcd;
36+
37+
38+
statement ok
39+
create function decimal_add(a decimal, b decimal) returns decimal language javascript as $$
40+
return a + b;
41+
$$;
42+
43+
query R
44+
select decimal_add(1.11, 2.22);
45+
----
46+
3.33
47+
48+
statement ok
49+
drop function decimal_add;
50+
51+
52+
statement ok
53+
create function to_string(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns varchar language javascript as $$
54+
return a.toString() + b.toString() + c.toString() + d.toString() + e.toString() + f.toString() + g.toString() + h.toString() + i.toString() + JSON.stringify(j);
55+
$$;
56+
57+
query T
58+
select to_string(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
59+
----
60+
false1234.56.78.9abc1,2,3{"key":1}
61+
62+
statement ok
63+
drop function to_string;
64+
65+
66+
# show data types in javascript
67+
statement ok
68+
create function js_typeof(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns jsonb language javascript as $$
69+
return {
70+
boolean: typeof a,
71+
smallint: typeof b,
72+
int: typeof c,
73+
bigint: typeof d,
74+
real: typeof e,
75+
float: typeof f,
76+
decimal: typeof g,
77+
varchar: typeof h,
78+
bytea: typeof i,
79+
jsonb: typeof j,
80+
};
81+
$$;
82+
83+
query T
84+
select js_typeof(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
85+
----
86+
{"bigint": "number", "boolean": "boolean", "bytea": "object", "decimal": "bigdecimal", "float": "number", "int": "number", "jsonb": "object", "real": "number", "smallint": "number", "varchar": "string"}
87+
88+
statement ok
89+
drop function js_typeof;
90+
91+
92+
statement ok
93+
create function return_all(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>)
94+
returns struct<a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>>
95+
language javascript as $$
96+
return {a,b,c,d,e,f,g,h,i,j,s};
97+
$$;
98+
99+
query T
100+
select (return_all(
101+
true,
102+
1 ::smallint,
103+
1,
104+
1,
105+
1,
106+
1,
107+
12345678901234567890.12345678,
108+
'string',
109+
'bytes',
110+
'{"key":1}',
111+
row(1, 2)::struct<f1 int, f2 int>
112+
)).*;
113+
----
114+
t 1 1 1 1 1 12345678901234567890.12345678 string \x6279746573 {"key": 1} (1,2)
115+
116+
statement ok
117+
drop function return_all;
118+
119+
120+
statement ok
121+
create function series(n int) returns table (x int) language javascript as $$
122+
for(let i = 0; i < n; i++) {
123+
yield i;
124+
}
125+
$$;
126+
127+
query I
128+
select series(5);
129+
----
130+
0
131+
1
132+
2
133+
3
134+
4
135+
136+
statement ok
137+
drop function series;
138+
139+
140+
statement ok
141+
create function split(s varchar) returns table (word varchar, length int) language javascript as $$
142+
for(let word of s.split(' ')) {
143+
yield { word: word, length: word.length };
144+
}
145+
$$;
146+
147+
query IT
148+
select * from split('rising wave');
149+
----
150+
rising 6
151+
wave 4
152+
153+
statement ok
154+
drop function split;

e2e_test/udf/wasm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ edition = "2021"
88
crate-type = ["cdylib"]
99

1010
[dependencies]
11-
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
11+
arrow-udf = "0.1"
1212
genawaiter = "0.99"
1313
rust_decimal = "1"
1414
serde_json = "1"

proto/catalog.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,12 @@ message Function {
213213
uint32 database_id = 3;
214214
string name = 4;
215215
uint32 owner = 9;
216+
repeated string arg_names = 15;
216217
repeated data.DataType arg_types = 5;
217218
data.DataType return_type = 6;
218219
string language = 7;
219-
string link = 8;
220-
string identifier = 10;
220+
optional string link = 8;
221+
optional string identifier = 10;
221222
optional string body = 14;
222223

223224
oneof kind {

proto/expr.proto

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,21 +471,27 @@ message WindowFunction {
471471
message UserDefinedFunction {
472472
repeated ExprNode children = 1;
473473
string name = 2;
474+
repeated string arg_names = 8;
474475
repeated data.DataType arg_types = 3;
475476
string language = 4;
476477
// For external UDF: the link to the external function service.
477478
// For WASM UDF: the link to the wasm binary file.
478-
string link = 5;
479+
optional string link = 5;
479480
// An unique identifier for the function.
480481
// For external UDF, it's the name of the function in the external function service.
481482
// For WASM UDF, it's the name of the function in the wasm binary file.
482-
string identifier = 6;
483+
// For JavaScript UDF, it's the name of the function.
484+
optional string identifier = 6;
485+
// For JavaScript UDF, it's the body of the function.
486+
optional string body = 7;
483487
}
484488

485489
// Additional information for user defined table functions.
486490
message UserDefinedTableFunction {
491+
repeated string arg_names = 8;
487492
repeated data.DataType arg_types = 3;
488493
string language = 4;
489-
string link = 5;
490-
string identifier = 6;
494+
optional string link = 5;
495+
optional string identifier = 6;
496+
optional string body = 7;
491497
}

src/expr/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
1919
anyhow = "1"
2020
arrow-array = { workspace = true }
2121
arrow-schema = { workspace = true }
22+
arrow-udf-js = { workspace = true }
2223
arrow-udf-wasm = { workspace = true }
2324
async-trait = "0.1"
2425
auto_impl = "1"

src/expr/core/src/expr/expr_udf.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::time::Duration;
2020

2121
use anyhow::Context;
2222
use arrow_schema::{Field, Fields, Schema};
23+
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
2324
use arrow_udf_wasm::Runtime as WasmRuntime;
2425
use await_tree::InstrumentAwait;
2526
use cfg_or_panic::cfg_or_panic;
@@ -61,6 +62,7 @@ const INITIAL_RETRY_COUNT: u8 = 16;
6162
enum UdfImpl {
6263
External(Arc<ArrowFlightUdfClient>),
6364
Wasm(Arc<WasmRuntime>),
65+
JavaScript(JsRuntime),
6466
}
6567

6668
#[async_trait::async_trait]
@@ -123,6 +125,7 @@ impl UserDefinedFunction {
123125

124126
let output: arrow_array::RecordBatch = match &self.imp {
125127
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
128+
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
126129
UdfImpl::External(client) => {
127130
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
128131
let result = if disable_retry_count != 0 {
@@ -189,16 +192,36 @@ impl Build for UserDefinedFunction {
189192
let return_type = DataType::from(prost.get_return_type().unwrap());
190193
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();
191194

195+
let identifier = udf.get_identifier()?;
192196
let imp = match udf.language.as_str() {
193197
"wasm" => {
198+
let link = udf.get_link()?;
194199
// Use `block_in_place` as an escape hatch to run async code here in sync context.
195200
// Calling `block_on` directly will panic.
196201
UdfImpl::Wasm(tokio::task::block_in_place(|| {
197-
tokio::runtime::Handle::current()
198-
.block_on(get_or_create_wasm_runtime(&udf.link))
202+
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
199203
})?)
200204
}
201-
_ => UdfImpl::External(get_or_create_flight_client(&udf.link)?),
205+
"javascript" => {
206+
let mut rt = JsRuntime::new()?;
207+
let body = format!(
208+
"export function {}({}) {{ {} }}",
209+
identifier,
210+
udf.arg_names.join(","),
211+
udf.get_body()?
212+
);
213+
rt.add_function(
214+
identifier,
215+
arrow_schema::DataType::try_from(&return_type)?,
216+
CallMode::CalledOnNullInput,
217+
&body,
218+
)?;
219+
UdfImpl::JavaScript(rt)
220+
}
221+
_ => {
222+
let link = udf.get_link()?;
223+
UdfImpl::External(get_or_create_flight_client(link)?)
224+
}
202225
};
203226

204227
let arg_schema = Arc::new(Schema::new(
@@ -222,8 +245,8 @@ impl Build for UserDefinedFunction {
222245
return_type,
223246
arg_schema,
224247
imp,
225-
identifier: udf.identifier.clone(),
226-
span: format!("udf_call({})", udf.identifier).into(),
248+
identifier: identifier.clone(),
249+
span: format!("udf_call({})", identifier).into(),
227250
disable_retry_count: AtomicU8::new(0),
228251
})
229252
}

0 commit comments

Comments
 (0)