@@ -9,40 +9,100 @@ use datafusion::arrow::datatypes::{
9
9
DataType , Field , IntervalUnit , Schema , SchemaBuilder , SchemaRef , TimeUnit ,
10
10
} ;
11
11
use datafusion:: catalog_common:: TableReference ;
12
+ use datafusion:: common:: plan_err;
12
13
use datafusion:: common:: tree_node:: { TreeNode , TreeNodeRecursion } ;
13
14
use datafusion:: datasource:: DefaultTableSource ;
14
15
use datafusion:: error:: Result ;
16
+ use datafusion:: logical_expr:: sqlparser:: ast:: ArrayElemTypeDef ;
17
+ use datafusion:: logical_expr:: sqlparser:: dialect:: GenericDialect ;
15
18
use datafusion:: logical_expr:: { builder:: LogicalTableSource , Expr , TableSource } ;
19
+ use datafusion:: sql:: sqlparser:: ast;
20
+ use datafusion:: sql:: sqlparser:: parser:: Parser ;
16
21
use log:: debug;
17
22
use petgraph:: dot:: { Config , Dot } ;
18
23
use petgraph:: Graph ;
19
24
use std:: collections:: HashSet ;
20
25
use std:: { collections:: HashMap , sync:: Arc } ;
21
26
22
- fn create_mock_list_type ( ) -> DataType {
23
- let string_filed = Arc :: new ( Field :: new ( "string" , DataType :: Utf8 , false ) ) ;
24
- DataType :: List ( string_filed)
27
+ fn create_list_type ( array_type : & str ) -> Result < DataType > {
28
+ // Workaround for the array type without an element type
29
+ if array_type. len ( ) == "array" . len ( ) {
30
+ return create_list_type ( "array<varchar>" ) ;
31
+ }
32
+ if let ast:: DataType :: Array ( value) = parse_type ( array_type) ? {
33
+ let data_type = match value {
34
+ ArrayElemTypeDef :: None => {
35
+ return plan_err ! ( "Array type must have an element type" )
36
+ }
37
+ ArrayElemTypeDef :: AngleBracket ( data_type) => {
38
+ map_data_type ( & data_type. to_string ( ) ) ?
39
+ }
40
+ ArrayElemTypeDef :: SquareBracket ( _, _) => {
41
+ unreachable ! ( )
42
+ }
43
+ ArrayElemTypeDef :: Parenthesis ( _) => {
44
+ return plan_err ! (
45
+ "The format of the array type should be 'array<element_type>'"
46
+ )
47
+ }
48
+ } ;
49
+ return Ok ( DataType :: List ( Arc :: new ( Field :: new (
50
+ "element" , data_type, false ,
51
+ ) ) ) ) ;
52
+ }
53
+ unreachable ! ( )
25
54
}
26
55
27
- fn create_mock_struct_type ( ) -> DataType {
56
+ fn create_struct_type ( struct_type : & str ) -> Result < DataType > {
57
+ let sql_type = parse_type ( struct_type) ?;
28
58
let mut builder = SchemaBuilder :: new ( ) ;
29
- builder. push ( Field :: new ( "a" , DataType :: Boolean , false ) ) ;
59
+ let mut counter = 0 ;
60
+ match sql_type {
61
+ ast:: DataType :: Struct ( fields, ..) => {
62
+ if fields. is_empty ( ) {
63
+ return plan_err ! ( "struct must have at least one field" ) ;
64
+ }
65
+ for field in fields {
66
+ let data_type = map_data_type ( field. field_type . to_string ( ) . as_str ( ) ) ?;
67
+ let field = Field :: new (
68
+ field
69
+ . field_name
70
+ . map ( |f| f. to_string ( ) )
71
+ . unwrap_or_else ( || format ! ( "c{}" , counter) ) ,
72
+ data_type,
73
+ true ,
74
+ ) ;
75
+ counter += 1 ;
76
+ builder. push ( field) ;
77
+ }
78
+ }
79
+ _ => {
80
+ unreachable ! ( )
81
+ }
82
+ }
30
83
let fields = builder. finish ( ) . fields ;
31
- DataType :: Struct ( fields)
84
+ Ok ( DataType :: Struct ( fields) )
85
+ }
86
+
87
+ fn parse_type ( struct_type : & str ) -> Result < ast:: DataType > {
88
+ let dialect = GenericDialect { } ;
89
+ Ok ( Parser :: new ( & dialect)
90
+ . try_with_sql ( struct_type) ?
91
+ . parse_data_type ( ) ?)
32
92
}
33
93
34
- pub fn map_data_type ( data_type : & str ) -> DataType {
94
+ pub fn map_data_type ( data_type : & str ) -> Result < DataType > {
35
95
let lower = data_type. to_lowercase ( ) ;
36
96
let data_type = lower. as_str ( ) ;
37
97
// Currently, we don't care about the element type of the array or struct.
38
98
// We only care about the array or struct itself.
39
99
if data_type. starts_with ( "array" ) {
40
- return create_mock_list_type ( ) ;
100
+ return create_list_type ( data_type ) ;
41
101
}
42
102
if data_type. starts_with ( "struct" ) {
43
- return create_mock_struct_type ( ) ;
103
+ return create_struct_type ( data_type ) ;
44
104
}
45
- match data_type {
105
+ let result = match data_type {
46
106
// Wren Definition Types
47
107
"bool" | "boolean" => DataType :: Boolean ,
48
108
"tinyint" => DataType :: Int8 ,
@@ -90,7 +150,8 @@ pub fn map_data_type(data_type: &str) -> DataType {
90
150
debug ! ( "map unknown type {} to Utf8" , data_type) ;
91
151
DataType :: Utf8
92
152
}
93
- }
153
+ } ;
154
+ Ok ( result)
94
155
}
95
156
96
157
pub fn create_table_source ( model : & Model ) -> Result < Arc < dyn TableSource > > {
@@ -102,7 +163,7 @@ pub fn create_schema(columns: Vec<Arc<Column>>) -> Result<SchemaRef> {
102
163
let fields: Vec < Field > = columns
103
164
. iter ( )
104
165
. map ( |column| {
105
- let data_type = map_data_type ( & column. r#type ) ;
166
+ let data_type = map_data_type ( & column. r#type ) ? ;
106
167
Ok ( Field :: new ( & column. name , data_type, column. not_null ) )
107
168
} )
108
169
. collect :: < Result < Vec < _ > > > ( ) ?;
@@ -244,11 +305,12 @@ pub fn expr_to_columns(
244
305
245
306
#[ cfg( test) ]
246
307
mod test {
247
- use datafusion:: arrow:: datatypes:: { DataType , IntervalUnit , TimeUnit } ;
308
+ use crate :: logical_plan:: utils:: {
309
+ create_list_type, create_struct_type, map_data_type,
310
+ } ;
311
+ use datafusion:: arrow:: datatypes:: { DataType , Field , Fields , IntervalUnit , TimeUnit } ;
248
312
use datafusion:: common:: Result ;
249
313
250
- use crate :: logical_plan:: utils:: { create_mock_list_type, create_mock_struct_type} ;
251
-
252
314
#[ test]
253
315
pub fn test_map_data_type ( ) -> Result < ( ) > {
254
316
let test_cases = vec ! [
@@ -303,16 +365,80 @@ mod test {
303
365
( "null" , DataType :: Null ) ,
304
366
( "geography" , DataType :: Utf8 ) ,
305
367
( "range" , DataType :: Utf8 ) ,
306
- ( "array<int64>" , create_mock_list_type( ) ) ,
307
- ( "struct<name string, age int>" , create_mock_struct_type( ) ) ,
368
+ ( "array" , create_list_type( "array<varchar>" ) ?) ,
369
+ ( "array<int64>" , create_list_type( "array<int64>" ) ?) ,
370
+ (
371
+ "struct<name string, age int>" ,
372
+ create_struct_type( "struct<name string, age int>" ) ?,
373
+ ) ,
308
374
] ;
309
375
for ( data_type, expected) in test_cases {
310
- let result = super :: map_data_type ( data_type) ;
376
+ let result = map_data_type ( data_type) ? ;
311
377
assert_eq ! ( result, expected) ;
312
378
// test case insensitivity
313
- let result = super :: map_data_type ( & data_type. to_uppercase ( ) ) ;
379
+ let result = map_data_type ( & data_type. to_uppercase ( ) ) ? ;
314
380
assert_eq ! ( result, expected) ;
315
381
}
382
+
383
+ let _ = map_data_type ( "array" ) . map_err ( |e| {
384
+ assert_eq ! (
385
+ e. to_string( ) ,
386
+ "SQL error: ParserError(\" Expected: <, found: EOF\" )"
387
+ ) ;
388
+ } ) ;
389
+
390
+ let _ = map_data_type ( "array<>" ) . map_err ( |e| {
391
+ assert_eq ! (
392
+ e. to_string( ) ,
393
+ "SQL error: ParserError(\" Expected: <, found: <> at Line: 1, Column: 6\" )"
394
+ ) ;
395
+ } ) ;
396
+
397
+ let _ = map_data_type ( "array(int64)" ) . map_err ( |e| {
398
+ assert_eq ! (
399
+ e. to_string( ) ,
400
+ "SQL error: ParserError(\" Expected: <, found: ( at Line: 1, Column: 6\" )"
401
+ ) ;
402
+ } ) ;
403
+
404
+ let _ = map_data_type ( "struct" ) . map_err ( |e| {
405
+ assert_eq ! (
406
+ e. to_string( ) ,
407
+ "Error during planning: struct must have at least one field"
408
+ ) ;
409
+ } ) ;
410
+
411
+ Ok ( ( ) )
412
+ }
413
+
414
+ #[ test]
415
+ fn test_parse_struct ( ) -> Result < ( ) > {
416
+ let struct_string = "STRUCT<name VARCHAR, age INT>" ;
417
+ let result = create_struct_type ( struct_string) ?;
418
+ let fields: Fields = vec ! [
419
+ Field :: new( "name" , DataType :: Utf8 , true ) ,
420
+ Field :: new( "age" , DataType :: Int32 , true ) ,
421
+ ]
422
+ . into ( ) ;
423
+ let expected = DataType :: Struct ( fields) ;
424
+ assert_eq ! ( result, expected) ;
425
+
426
+ let struct_string = "STRUCT<VARCHAR, INT>" ;
427
+ let result = create_struct_type ( struct_string) ?;
428
+ let fields: Fields = vec ! [
429
+ Field :: new( "c0" , DataType :: Utf8 , true ) ,
430
+ Field :: new( "c1" , DataType :: Int32 , true ) ,
431
+ ]
432
+ . into ( ) ;
433
+ let expected = DataType :: Struct ( fields) ;
434
+ assert_eq ! ( result, expected) ;
435
+ let struct_string = "STRUCT<>" ;
436
+ let _ = create_struct_type ( struct_string) . map_err ( |e| {
437
+ assert_eq ! (
438
+ e. to_string( ) ,
439
+ "Error during planning: struct must have at least one field"
440
+ )
441
+ } ) ;
316
442
Ok ( ( ) )
317
443
}
318
444
}
0 commit comments