Skip to content

Commit e7041d6

Browse files
authored
Mapping WrenType to ArrowType (#720)
1 parent 8449442 commit e7041d6

File tree

1 file changed

+131
-4
lines changed
  • wren-modeling-rs/core/src/logical_plan

1 file changed

+131
-4
lines changed

wren-modeling-rs/core/src/logical_plan/utils.rs

+131-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
2-
use datafusion::common::not_impl_err;
31
use std::{collections::HashMap, sync::Arc};
42

3+
use datafusion::arrow::datatypes::{
4+
DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit,
5+
};
56
use datafusion::datasource::DefaultTableSource;
67
use datafusion::error::Result;
78
use datafusion::logical_expr::{builder::LogicalTableSource, TableSource};
9+
use log::debug;
810
use petgraph::dot::{Config, Dot};
911
use petgraph::Graph;
1012

@@ -15,15 +17,74 @@ use crate::mdl::{
1517
WrenMDL,
1618
};
1719

20+
fn create_mock_list_type() -> DataType {
21+
let string_filed = Arc::new(Field::new("string", DataType::Utf8, false));
22+
DataType::List(string_filed)
23+
}
24+
25+
fn create_mock_struct_type() -> DataType {
26+
let mut builder = SchemaBuilder::new();
27+
builder.push(Field::new("a", DataType::Boolean, false));
28+
let fields = builder.finish().fields;
29+
DataType::Struct(fields)
30+
}
31+
1832
pub fn map_data_type(data_type: &str) -> Result<DataType> {
33+
let lower = data_type.to_lowercase();
34+
let data_type = lower.as_str();
35+
// Currently, we don't care about the element type of the array or struct.
36+
// We only care about the array or struct itself.
37+
if data_type.starts_with("array") {
38+
return Ok(create_mock_list_type());
39+
}
40+
if data_type.starts_with("struct") {
41+
return Ok(create_mock_struct_type());
42+
}
1943
let result = match data_type {
44+
// Wren Definition Types
45+
"bool" => DataType::Boolean,
46+
"tinyint" => DataType::Int8,
47+
"int2" => DataType::Int16,
48+
"smallint" => DataType::Int16,
49+
"int4" => DataType::Int32,
2050
"integer" => DataType::Int32,
51+
"int8" => DataType::Int64,
2152
"bigint" => DataType::Int64,
53+
"numeric" => DataType::Decimal128(38, 10), // set the default precision and scale
54+
"decimal" => DataType::Decimal128(38, 10),
2255
"varchar" => DataType::Utf8,
56+
"char" => DataType::Utf8,
57+
"bpchar" => DataType::Utf8, // we don't have a BPCHAR type, so we map it to Utf8
58+
"text" => DataType::Utf8,
59+
"string" => DataType::Utf8,
60+
"name" => DataType::Utf8,
61+
"float4" => DataType::Float32,
62+
"real" => DataType::Float32,
63+
"float8" => DataType::Float64,
2364
"double" => DataType::Float64,
24-
"timestamp" => DataType::Timestamp(TimeUnit::Nanosecond, None),
65+
"timestamp" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit
66+
"timestamptz" => DataType::Timestamp(TimeUnit::Nanosecond, None), // don't care about the time zone
2567
"date" => DataType::Date32,
26-
_ => return not_impl_err!("Unsupported data type: {}", &data_type),
68+
"interval" => DataType::Interval(IntervalUnit::DayTime),
69+
"json" => DataType::Utf8, // we don't have a JSON type, so we map it to Utf8
70+
"oid" => DataType::Int32,
71+
"bytea" => DataType::Binary,
72+
"uuid" => DataType::Utf8, // we don't have a UUID type, so we map it to Utf8
73+
"inet" => DataType::Utf8, // we don't have a INET type, so we map it to Utf8
74+
"unknown" => DataType::Utf8, // we don't have a UNKNOWN type, so we map it to Utf8
75+
// BigQuery Compatible Types
76+
"bignumeric" => DataType::Decimal128(38, 10), // set the default precision and scale
77+
"bytes" => DataType::Binary,
78+
"datetime" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit
79+
"float64" => DataType::Float64,
80+
"int64" => DataType::Int64,
81+
"time" => DataType::Time32(TimeUnit::Nanosecond), // chose the smallest time unit
82+
"null" => DataType::Null,
83+
_ => {
84+
// default to string
85+
debug!("map unknown type {} to Utf8", data_type);
86+
DataType::Utf8
87+
}
2788
};
2889
Ok(result)
2990
}
@@ -104,3 +165,69 @@ pub fn print_graph(graph: &Graph<Dataset, DatasetLink>) {
104165
let dot = Dot::with_config(graph, &[Config::EdgeNoLabel]);
105166
println!("graph: {:?}", dot);
106167
}
168+
169+
#[cfg(test)]
170+
mod test {
171+
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
172+
use datafusion::common::Result;
173+
174+
use crate::logical_plan::utils::{create_mock_list_type, create_mock_struct_type};
175+
176+
#[test]
177+
pub fn test_map_data_type() -> Result<()> {
178+
let test_cases = vec![
179+
("bool", DataType::Boolean),
180+
("tinyint", DataType::Int8),
181+
("int2", DataType::Int16),
182+
("smallint", DataType::Int16),
183+
("int4", DataType::Int32),
184+
("integer", DataType::Int32),
185+
("int8", DataType::Int64),
186+
("bigint", DataType::Int64),
187+
("numeric", DataType::Decimal128(38, 10)),
188+
("decimal", DataType::Decimal128(38, 10)),
189+
("varchar", DataType::Utf8),
190+
("char", DataType::Utf8),
191+
("bpchar", DataType::Utf8),
192+
("text", DataType::Utf8),
193+
("string", DataType::Utf8),
194+
("name", DataType::Utf8),
195+
("float4", DataType::Float32),
196+
("real", DataType::Float32),
197+
("float8", DataType::Float64),
198+
("double", DataType::Float64),
199+
("timestamp", DataType::Timestamp(TimeUnit::Nanosecond, None)),
200+
(
201+
"timestamptz",
202+
DataType::Timestamp(TimeUnit::Nanosecond, None),
203+
),
204+
("date", DataType::Date32),
205+
("interval", DataType::Interval(IntervalUnit::DayTime)),
206+
("json", DataType::Utf8),
207+
("oid", DataType::Int32),
208+
("bytea", DataType::Binary),
209+
("uuid", DataType::Utf8),
210+
("inet", DataType::Utf8),
211+
("unknown", DataType::Utf8),
212+
("bignumeric", DataType::Decimal128(38, 10)),
213+
("bytes", DataType::Binary),
214+
("datetime", DataType::Timestamp(TimeUnit::Nanosecond, None)),
215+
("float64", DataType::Float64),
216+
("int64", DataType::Int64),
217+
("time", DataType::Time32(TimeUnit::Nanosecond)),
218+
("null", DataType::Null),
219+
("geography", DataType::Utf8),
220+
("range", DataType::Utf8),
221+
("array<int64>", create_mock_list_type()),
222+
("struct<name string, age int>", create_mock_struct_type()),
223+
];
224+
for (data_type, expected) in test_cases {
225+
let result = super::map_data_type(data_type)?;
226+
assert_eq!(result, expected);
227+
// test case insensitivity
228+
let result = super::map_data_type(&data_type.to_uppercase())?;
229+
assert_eq!(result, expected);
230+
}
231+
Ok(())
232+
}
233+
}

0 commit comments

Comments
 (0)