Skip to content

Commit 9bfffc1

Browse files
authored
feat(core): support to Array and Struct type (#914)
1 parent 54be39d commit 9bfffc1

File tree

6 files changed

+453
-91
lines changed

6 files changed

+453
-91
lines changed

wren-core-py/src/context.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,11 @@ impl PySessionContext {
106106
.block_on(create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), false))
107107
.map_err(CoreError::from)?;
108108

109-
remote_functions.iter().for_each(|remote_function| {
109+
remote_functions.iter().try_for_each(|remote_function| {
110110
debug!("Registering remote function: {:?}", remote_function);
111-
Self::register_remote_function(&ctx, remote_function);
112-
});
111+
Self::register_remote_function(&ctx, remote_function)?;
112+
Ok::<(), CoreError>(())
113+
})?;
113114

114115
Ok(Self {
115116
ctx,
@@ -199,27 +200,31 @@ impl PySessionContext {
199200
fn register_remote_function(
200201
ctx: &wren_core::SessionContext,
201202
remote_function: &RemoteFunction,
202-
) {
203+
) -> PyResult<()> {
203204
match &remote_function.function_type {
204205
FunctionType::Scalar => {
205206
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
206207
&remote_function.name,
207-
map_data_type(&remote_function.return_type),
208+
map_data_type(&remote_function.return_type)
209+
.map_err(CoreError::from)?,
208210
)))
209211
}
210212
FunctionType::Aggregate => {
211213
ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
212214
&remote_function.name,
213-
map_data_type(&remote_function.return_type),
215+
map_data_type(&remote_function.return_type)
216+
.map_err(CoreError::from)?,
214217
)))
215218
}
216219
FunctionType::Window => {
217220
ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
218221
&remote_function.name,
219-
map_data_type(&remote_function.return_type),
222+
map_data_type(&remote_function.return_type)
223+
.map_err(CoreError::from)?,
220224
)))
221225
}
222226
}
227+
Ok(())
223228
}
224229

225230
fn read_remote_function_list(path: Option<&str>) -> PyResult<Vec<PyRemoteFunction>> {

wren-core/core/src/logical_plan/analyze/plan.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ impl ModelPlanNodeBuilder {
227227
Some(TableReference::bare(quoted(model.name()))),
228228
Arc::new(Field::new(
229229
column.name(),
230-
map_data_type(&column.r#type),
230+
map_data_type(&column.r#type)?,
231231
column.not_null,
232232
)),
233233
));
@@ -735,7 +735,7 @@ impl ModelSourceNode {
735735
Some(TableReference::bare(quoted(model.name()))),
736736
Arc::new(Field::new(
737737
column.name(),
738-
map_data_type(&column.r#type),
738+
map_data_type(&column.r#type)?,
739739
column.not_null,
740740
)),
741741
));
@@ -775,7 +775,7 @@ impl ModelSourceNode {
775775
Some(TableReference::bare(quoted(model.name()))),
776776
Arc::new(Field::new(
777777
column.name(),
778-
map_data_type(&column.r#type),
778+
map_data_type(&column.r#type)?,
779779
column.not_null,
780780
)),
781781
));
@@ -869,12 +869,12 @@ impl CalculationPlanNode {
869869
let output_field = vec![
870870
Arc::new(Field::new(
871871
calculation.column.name(),
872-
map_data_type(&calculation.column.r#type),
872+
map_data_type(&calculation.column.r#type)?,
873873
calculation.column.not_null,
874874
)),
875875
Arc::new(Field::new(
876876
pk_column.name(),
877-
map_data_type(&pk_column.r#type),
877+
map_data_type(&pk_column.r#type)?,
878878
pk_column.not_null,
879879
)),
880880
]

wren-core/core/src/logical_plan/utils.rs

+145-19
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,100 @@ use datafusion::arrow::datatypes::{
99
DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit,
1010
};
1111
use datafusion::catalog_common::TableReference;
12+
use datafusion::common::plan_err;
1213
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
1314
use datafusion::datasource::DefaultTableSource;
1415
use datafusion::error::Result;
16+
use datafusion::logical_expr::sqlparser::ast::ArrayElemTypeDef;
17+
use datafusion::logical_expr::sqlparser::dialect::GenericDialect;
1518
use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource};
19+
use datafusion::sql::sqlparser::ast;
20+
use datafusion::sql::sqlparser::parser::Parser;
1621
use log::debug;
1722
use petgraph::dot::{Config, Dot};
1823
use petgraph::Graph;
1924
use std::collections::HashSet;
2025
use std::{collections::HashMap, sync::Arc};
2126

22-
fn create_mock_list_type() -> DataType {
23-
let string_filed = Arc::new(Field::new("string", DataType::Utf8, false));
24-
DataType::List(string_filed)
27+
fn create_list_type(array_type: &str) -> Result<DataType> {
28+
// Workaround for the array type without an element type
29+
if array_type.len() == "array".len() {
30+
return create_list_type("array<varchar>");
31+
}
32+
if let ast::DataType::Array(value) = parse_type(array_type)? {
33+
let data_type = match value {
34+
ArrayElemTypeDef::None => {
35+
return plan_err!("Array type must have an element type")
36+
}
37+
ArrayElemTypeDef::AngleBracket(data_type) => {
38+
map_data_type(&data_type.to_string())?
39+
}
40+
ArrayElemTypeDef::SquareBracket(_, _) => {
41+
unreachable!()
42+
}
43+
ArrayElemTypeDef::Parenthesis(_) => {
44+
return plan_err!(
45+
"The format of the array type should be 'array<element_type>'"
46+
)
47+
}
48+
};
49+
return Ok(DataType::List(Arc::new(Field::new(
50+
"element", data_type, false,
51+
))));
52+
}
53+
unreachable!()
2554
}
2655

27-
fn create_mock_struct_type() -> DataType {
56+
fn create_struct_type(struct_type: &str) -> Result<DataType> {
57+
let sql_type = parse_type(struct_type)?;
2858
let mut builder = SchemaBuilder::new();
29-
builder.push(Field::new("a", DataType::Boolean, false));
59+
let mut counter = 0;
60+
match sql_type {
61+
ast::DataType::Struct(fields, ..) => {
62+
if fields.is_empty() {
63+
return plan_err!("struct must have at least one field");
64+
}
65+
for field in fields {
66+
let data_type = map_data_type(field.field_type.to_string().as_str())?;
67+
let field = Field::new(
68+
field
69+
.field_name
70+
.map(|f| f.to_string())
71+
.unwrap_or_else(|| format!("c{}", counter)),
72+
data_type,
73+
true,
74+
);
75+
counter += 1;
76+
builder.push(field);
77+
}
78+
}
79+
_ => {
80+
unreachable!()
81+
}
82+
}
3083
let fields = builder.finish().fields;
31-
DataType::Struct(fields)
84+
Ok(DataType::Struct(fields))
85+
}
86+
87+
fn parse_type(struct_type: &str) -> Result<ast::DataType> {
88+
let dialect = GenericDialect {};
89+
Ok(Parser::new(&dialect)
90+
.try_with_sql(struct_type)?
91+
.parse_data_type()?)
3292
}
3393

34-
pub fn map_data_type(data_type: &str) -> DataType {
94+
pub fn map_data_type(data_type: &str) -> Result<DataType> {
3595
let lower = data_type.to_lowercase();
3696
let data_type = lower.as_str();
3797
// Currently, we don't care about the element type of the array or struct.
3898
// We only care about the array or struct itself.
3999
if data_type.starts_with("array") {
40-
return create_mock_list_type();
100+
return create_list_type(data_type);
41101
}
42102
if data_type.starts_with("struct") {
43-
return create_mock_struct_type();
103+
return create_struct_type(data_type);
44104
}
45-
match data_type {
105+
let result = match data_type {
46106
// Wren Definition Types
47107
"bool" | "boolean" => DataType::Boolean,
48108
"tinyint" => DataType::Int8,
@@ -90,7 +150,8 @@ pub fn map_data_type(data_type: &str) -> DataType {
90150
debug!("map unknown type {} to Utf8", data_type);
91151
DataType::Utf8
92152
}
93-
}
153+
};
154+
Ok(result)
94155
}
95156

96157
pub fn create_table_source(model: &Model) -> Result<Arc<dyn TableSource>> {
@@ -102,7 +163,7 @@ pub fn create_schema(columns: Vec<Arc<Column>>) -> Result<SchemaRef> {
102163
let fields: Vec<Field> = columns
103164
.iter()
104165
.map(|column| {
105-
let data_type = map_data_type(&column.r#type);
166+
let data_type = map_data_type(&column.r#type)?;
106167
Ok(Field::new(&column.name, data_type, column.not_null))
107168
})
108169
.collect::<Result<Vec<_>>>()?;
@@ -244,11 +305,12 @@ pub fn expr_to_columns(
244305

245306
#[cfg(test)]
246307
mod test {
247-
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
308+
use crate::logical_plan::utils::{
309+
create_list_type, create_struct_type, map_data_type,
310+
};
311+
use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
248312
use datafusion::common::Result;
249313

250-
use crate::logical_plan::utils::{create_mock_list_type, create_mock_struct_type};
251-
252314
#[test]
253315
pub fn test_map_data_type() -> Result<()> {
254316
let test_cases = vec![
@@ -303,16 +365,80 @@ mod test {
303365
("null", DataType::Null),
304366
("geography", DataType::Utf8),
305367
("range", DataType::Utf8),
306-
("array<int64>", create_mock_list_type()),
307-
("struct<name string, age int>", create_mock_struct_type()),
368+
("array", create_list_type("array<varchar>")?),
369+
("array<int64>", create_list_type("array<int64>")?),
370+
(
371+
"struct<name string, age int>",
372+
create_struct_type("struct<name string, age int>")?,
373+
),
308374
];
309375
for (data_type, expected) in test_cases {
310-
let result = super::map_data_type(data_type);
376+
let result = map_data_type(data_type)?;
311377
assert_eq!(result, expected);
312378
// test case insensitivity
313-
let result = super::map_data_type(&data_type.to_uppercase());
379+
let result = map_data_type(&data_type.to_uppercase())?;
314380
assert_eq!(result, expected);
315381
}
382+
383+
let _ = map_data_type("array").map_err(|e| {
384+
assert_eq!(
385+
e.to_string(),
386+
"SQL error: ParserError(\"Expected: <, found: EOF\")"
387+
);
388+
});
389+
390+
let _ = map_data_type("array<>").map_err(|e| {
391+
assert_eq!(
392+
e.to_string(),
393+
"SQL error: ParserError(\"Expected: <, found: <> at Line: 1, Column: 6\")"
394+
);
395+
});
396+
397+
let _ = map_data_type("array(int64)").map_err(|e| {
398+
assert_eq!(
399+
e.to_string(),
400+
"SQL error: ParserError(\"Expected: <, found: ( at Line: 1, Column: 6\")"
401+
);
402+
});
403+
404+
let _ = map_data_type("struct").map_err(|e| {
405+
assert_eq!(
406+
e.to_string(),
407+
"Error during planning: struct must have at least one field"
408+
);
409+
});
410+
411+
Ok(())
412+
}
413+
414+
#[test]
415+
fn test_parse_struct() -> Result<()> {
416+
let struct_string = "STRUCT<name VARCHAR, age INT>";
417+
let result = create_struct_type(struct_string)?;
418+
let fields: Fields = vec![
419+
Field::new("name", DataType::Utf8, true),
420+
Field::new("age", DataType::Int32, true),
421+
]
422+
.into();
423+
let expected = DataType::Struct(fields);
424+
assert_eq!(result, expected);
425+
426+
let struct_string = "STRUCT<VARCHAR, INT>";
427+
let result = create_struct_type(struct_string)?;
428+
let fields: Fields = vec![
429+
Field::new("c0", DataType::Utf8, true),
430+
Field::new("c1", DataType::Int32, true),
431+
]
432+
.into();
433+
let expected = DataType::Struct(fields);
434+
assert_eq!(result, expected);
435+
let struct_string = "STRUCT<>";
436+
let _ = create_struct_type(struct_string).map_err(|e| {
437+
assert_eq!(
438+
e.to_string(),
439+
"Error during planning: struct must have at least one field"
440+
)
441+
});
316442
Ok(())
317443
}
318444
}

wren-core/core/src/mdl/dataset.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ impl Column {
5858
}
5959

6060
/// Transform the column to a datafusion field
61-
pub fn to_field(&self) -> Field {
62-
let data_type = map_data_type(&self.r#type);
63-
Field::new(&self.name, data_type, self.not_null)
61+
pub fn to_field(&self) -> Result<Field> {
62+
let data_type = map_data_type(&self.r#type)?;
63+
Ok(Field::new(&self.name, data_type, self.not_null))
6464
}
6565

6666
/// Transform the column to a datafusion field for a remote table
@@ -72,12 +72,12 @@ impl Column {
7272
session_state.config_options().sql_parser.dialect.as_str(),
7373
)?;
7474
let columns = Self::collect_columns(expr);
75-
Ok(columns
75+
columns
7676
.into_iter()
77-
.map(|c| Field::new(c.value, map_data_type(&self.r#type), false))
78-
.collect())
77+
.map(|c| Ok(Field::new(c.value, map_data_type(&self.r#type)?, false)))
78+
.collect::<Result<_>>()
7979
} else {
80-
Ok(vec![self.to_field()])
80+
Ok(vec![self.to_field()?])
8181
}
8282
}
8383

@@ -123,7 +123,7 @@ impl Dataset {
123123
.get_physical_columns()
124124
.iter()
125125
.map(|c| c.to_field())
126-
.collect();
126+
.collect::<Result<_>>()?;
127127
let arrow_schema = datafusion::arrow::datatypes::Schema::new(fields);
128128
DFSchema::try_from_qualified_schema(quoted(&model.name), &arrow_schema)
129129
}

0 commit comments

Comments
 (0)