Skip to content

Commit 1a52a99

Browse files
committed
add description to by pass function
1 parent 37ed412 commit 1a52a99

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

wren-core/core/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub mod mdl;
33

44
pub use datafusion::error::DataFusionError;
55
pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
6-
pub use datafusion::prelude::SessionContext;
6+
pub use datafusion::prelude::*;
77
pub use datafusion::sql::sqlparser::*;
8+
pub use datafusion::arrow::*;
89
pub use mdl::AnalyzedWrenMDL;

wren-core/core/src/mdl/function.rs

+32-9
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ use datafusion::logical_expr::function::{
55
AccumulatorArgs, PartitionEvaluatorArgs, WindowUDFFieldArgs,
66
};
77
use datafusion::logical_expr::{
8-
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
9-
Signature, TypeSignature, Volatility, WindowUDFImpl,
8+
Accumulator, AggregateUDFImpl, ColumnarValue, DocSection, Documentation, DocumentationBuilder, PartitionEvaluator, ScalarUDFImpl, Signature, TypeSignature, Volatility, WindowUDFImpl
109
};
1110
use serde::{Deserialize, Serialize};
1211
use std::any::Any;
@@ -46,7 +45,7 @@ impl FromStr for FunctionType {
4645
type Err = String;
4746

4847
fn from_str(s: &str) -> Result<Self, Self::Err> {
49-
match s {
48+
match s.to_lowercase().as_str() {
5049
"scalar" => Ok(FunctionType::Scalar),
5150
"aggregate" => Ok(FunctionType::Aggregate),
5251
"window" => Ok(FunctionType::Window),
@@ -63,17 +62,20 @@ pub struct ByPassScalarUDF {
6362
name: String,
6463
return_type: DataType,
6564
signature: Signature,
65+
doc: Documentation,
6666
}
6767

6868
impl ByPassScalarUDF {
69-
pub fn new(name: &str, return_type: DataType) -> Self {
69+
pub fn new(name: &str, return_type: DataType, description: Option<String>) -> Self {
70+
let doc= DocumentationBuilder::new_with_details(DocSection::default(), description.unwrap_or("".to_string()), "").build();
7071
Self {
7172
name: name.to_string(),
7273
return_type,
7374
signature: Signature::one_of(
7475
vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
7576
Volatility::Volatile,
7677
),
78+
doc,
7779
}
7880
}
7981
}
@@ -98,6 +100,10 @@ impl ScalarUDFImpl for ByPassScalarUDF {
98100
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
99101
internal_err!("This function should not be called")
100102
}
103+
104+
fn documentation(&self) -> Option<&Documentation> {
105+
Some(&self.doc)
106+
}
101107
}
102108

103109
/// An aggregate UDF that will be bypassed when planning logical plan.
@@ -107,17 +113,20 @@ pub struct ByPassAggregateUDF {
107113
name: String,
108114
return_type: DataType,
109115
signature: Signature,
116+
doc: Documentation,
110117
}
111118

112119
impl ByPassAggregateUDF {
113-
pub fn new(name: &str, return_type: DataType) -> Self {
120+
pub fn new(name: &str, return_type: DataType, description: Option<String>) -> Self {
121+
let doc= DocumentationBuilder::new_with_details(DocSection::default(), description.unwrap_or("".to_string()), "").build();
114122
Self {
115123
name: name.to_string(),
116124
return_type,
117125
signature: Signature::one_of(
118126
vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
119127
Volatility::Volatile,
120128
),
129+
doc,
121130
}
122131
}
123132
}
@@ -142,6 +151,10 @@ impl AggregateUDFImpl for ByPassAggregateUDF {
142151
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
143152
internal_err!("This function should not be called")
144153
}
154+
155+
fn documentation(&self) -> Option<&Documentation> {
156+
Some(&self.doc)
157+
}
145158
}
146159

147160
/// A window UDF that will be bypassed when planning logical plan.
@@ -151,17 +164,20 @@ pub struct ByPassWindowFunction {
151164
name: String,
152165
return_type: DataType,
153166
signature: Signature,
167+
doc: Documentation,
154168
}
155169

156170
impl ByPassWindowFunction {
157-
pub fn new(name: &str, return_type: DataType) -> Self {
171+
pub fn new(name: &str, return_type: DataType, description: Option<String>) -> Self {
172+
let doc= DocumentationBuilder::new_with_details(DocSection::default(), description.unwrap_or("".to_string()), "").build();
158173
Self {
159174
name: name.to_string(),
160175
return_type,
161176
signature: Signature::one_of(
162177
vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
163178
Volatility::Volatile,
164179
),
180+
doc,
165181
}
166182
}
167183
}
@@ -193,6 +209,10 @@ impl WindowUDFImpl for ByPassWindowFunction {
193209
false,
194210
))
195211
}
212+
213+
fn documentation(&self) -> Option<&Documentation> {
214+
Some(&self.doc)
215+
}
196216
}
197217

198218
#[cfg(test)]
@@ -207,7 +227,7 @@ mod test {
207227

208228
#[tokio::test]
209229
async fn test_by_pass_scalar_udf() -> Result<()> {
210-
let udf = ByPassScalarUDF::new("date_diff", DataType::Int64);
230+
let udf = ByPassScalarUDF::new("date_diff", DataType::Int64, None);
211231
let ctx = SessionContext::new();
212232
ctx.register_udf(ScalarUDF::new_from_impl(udf));
213233

@@ -221,6 +241,7 @@ mod test {
221241
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
222242
"today",
223243
DataType::Utf8,
244+
None,
224245
)));
225246
let plan_2 = ctx.sql("SELECT today()").await?.into_unoptimized_plan();
226247
assert_eq!(format!("{plan_2}"), "Projection: today()\n EmptyRelation");
@@ -230,7 +251,7 @@ mod test {
230251

231252
#[tokio::test]
232253
async fn test_by_pass_agg_udf() -> Result<()> {
233-
let udf = ByPassAggregateUDF::new("count_self", DataType::Int64);
254+
let udf = ByPassAggregateUDF::new("count_self", DataType::Int64, None);
234255
let ctx = SessionContext::new();
235256
ctx.register_udaf(AggregateUDF::new_from_impl(udf));
236257

@@ -245,6 +266,7 @@ mod test {
245266
ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
246267
"total_count",
247268
DataType::Int64,
269+
None,
248270
)));
249271
let plan_2 = ctx
250272
.sql("SELECT total_count() AS total_count FROM (VALUES (1), (2), (3)) AS val(x)")
@@ -263,7 +285,7 @@ mod test {
263285

264286
#[tokio::test]
265287
async fn test_by_pass_window_udf() -> Result<()> {
266-
let udf = ByPassWindowFunction::new("custom_window", DataType::Int64);
288+
let udf = ByPassWindowFunction::new("custom_window", DataType::Int64, None);
267289
let ctx = SessionContext::new();
268290
ctx.register_udwf(WindowUDF::new_from_impl(udf));
269291

@@ -279,6 +301,7 @@ mod test {
279301
ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
280302
"cume_dist",
281303
DataType::Int64,
304+
None,
282305
)));
283306
let plan_2 = ctx
284307
.sql("SELECT cume_dist() OVER ()")

wren-core/core/src/mdl/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -389,18 +389,21 @@ fn register_remote_function(
389389
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
390390
&remote_function.name,
391391
map_data_type(&remote_function.return_type)?,
392+
remote_function.description.clone(),
392393
)))
393394
}
394395
FunctionType::Aggregate => {
395396
ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
396397
&remote_function.name,
397398
map_data_type(&remote_function.return_type)?,
399+
remote_function.description.clone(),
398400
)))
399401
}
400402
FunctionType::Window => {
401403
ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
402404
&remote_function.name,
403405
map_data_type(&remote_function.return_type)?,
406+
remote_function.description.clone(),
404407
)))
405408
}
406409
};

0 commit comments

Comments
 (0)