Skip to content

Commit 8051dc0

Browse files
goldmedalgrieve54706
authored andcommitted
fix(core): fix the remote column inferring and disable simplify expression rule (#874)
1 parent 9076c4d commit 8051dc0

File tree

7 files changed

+102
-50
lines changed

7 files changed

+102
-50
lines changed

ibis-server/tests/routers/v3/connector/test_postgres.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_query(manifest_str, postgres: PostgresContainer):
109109
assert len(result["data"]) == 1
110110
assert result["data"][0] == [
111111
"2024-01-01 23:59:59.000000",
112-
"2024-01-01 23:59:59.000000",
112+
"2024-01-01 23:59:59.000000 UTC",
113113
"1_370",
114114
370,
115115
"1996-01-02",

ibis-server/tools/mdl_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
for model in mdl["models"]:
3535
for column in model["columns"]:
3636
# ignore hidden columns
37-
if column.get("is_hidden"):
37+
if column.get("isHidden"):
3838
continue
3939
sql = f"select \"{column['name']}\" from \"{model['name']}\""
4040
try:

wren-core/core/src/logical_plan/analyze/model_generation.rs

+19-7
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,19 @@ use crate::logical_plan::analyze::plan::{
1515
use crate::logical_plan::utils::create_remote_table_source;
1616
use crate::mdl::manifest::Model;
1717
use crate::mdl::utils::quoted;
18-
use crate::mdl::AnalyzedWrenMDL;
18+
use crate::mdl::{AnalyzedWrenMDL, SessionStateRef};
1919

2020
/// [ModelGenerationRule] is responsible for generating the model plan node.
2121
pub struct ModelGenerationRule {
2222
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
23+
session_state: SessionStateRef,
2324
}
2425

2526
impl ModelGenerationRule {
26-
pub fn new(mdl: Arc<AnalyzedWrenMDL>) -> Self {
27+
pub fn new(mdl: Arc<AnalyzedWrenMDL>, session_state: SessionStateRef) -> Self {
2728
Self {
2829
analyzed_wren_mdl: mdl,
30+
session_state,
2931
}
3032
}
3133

@@ -39,7 +41,10 @@ impl ModelGenerationRule {
3941
extension.node.as_any().downcast_ref::<ModelPlanNode>()
4042
{
4143
let source_plan = model_plan.relation_chain.clone().plan(
42-
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
44+
ModelGenerationRule::new(
45+
Arc::clone(&self.analyzed_wren_mdl),
46+
Arc::clone(&self.session_state),
47+
),
4348
)?;
4449
let result = match source_plan {
4550
Some(plan) => {
@@ -73,9 +78,10 @@ impl ModelGenerationRule {
7378
LogicalPlanBuilder::scan_with_filters(
7479
TableReference::from(model.table_reference()),
7580
create_remote_table_source(
76-
&model,
81+
Arc::clone(&model),
7782
&self.analyzed_wren_mdl.wren_mdl(),
78-
),
83+
Arc::clone(&self.session_state),
84+
)?,
7985
None,
8086
original_scan.filters.clone(),
8187
).expect("Failed to create table scan")
@@ -89,7 +95,10 @@ impl ModelGenerationRule {
8995
None => {
9096
LogicalPlanBuilder::scan(
9197
TableReference::from(model.table_reference()),
92-
create_remote_table_source(&model, &self.analyzed_wren_mdl.wren_mdl()),
98+
create_remote_table_source(
99+
Arc::clone(&model),
100+
&self.analyzed_wren_mdl.wren_mdl(),
101+
Arc::clone(&self.session_state))?,
93102
None,
94103
).expect("Failed to create table scan")
95104
.project(model_plan.required_exprs.clone())?
@@ -111,7 +120,10 @@ impl ModelGenerationRule {
111120
.downcast_ref::<CalculationPlanNode>(
112121
) {
113122
let source_plan = calculation_plan.relation_chain.clone().plan(
114-
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
123+
ModelGenerationRule::new(
124+
Arc::clone(&self.analyzed_wren_mdl),
125+
Arc::clone(&self.session_state),
126+
),
115127
)?;
116128

117129
if let Expr::Alias(alias) = calculation_plan.measures[0].clone() {

wren-core/core/src/logical_plan/utils.rs

+19-28
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
use crate::mdl::lineage::DatasetLink;
2+
use crate::mdl::utils::quoted;
3+
use crate::mdl::{
4+
manifest::{Column, Model},
5+
WrenMDL,
6+
};
7+
use crate::mdl::{Dataset, SessionStateRef};
18
use datafusion::arrow::datatypes::{
29
DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit,
310
};
@@ -12,14 +19,6 @@ use petgraph::Graph;
1219
use std::collections::HashSet;
1320
use std::{collections::HashMap, sync::Arc};
1421

15-
use crate::mdl::lineage::DatasetLink;
16-
use crate::mdl::utils::quoted;
17-
use crate::mdl::{
18-
manifest::{Column, Model},
19-
WrenMDL,
20-
};
21-
use crate::mdl::{Dataset, SessionStateRef};
22-
2322
fn create_mock_list_type() -> DataType {
2423
let string_filed = Arc::new(Field::new("string", DataType::Utf8, false));
2524
DataType::List(string_filed)
@@ -112,28 +111,20 @@ pub fn create_schema(columns: Vec<Arc<Column>>) -> Result<SchemaRef> {
112111
)))
113112
}
114113

115-
pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc<dyn TableSource> {
114+
pub fn create_remote_table_source(
115+
model: Arc<Model>,
116+
mdl: &WrenMDL,
117+
session_state_ref: SessionStateRef,
118+
) -> Result<Arc<dyn TableSource>> {
116119
if let Some(table_provider) = mdl.get_table(model.table_reference()) {
117-
Arc::new(DefaultTableSource::new(table_provider))
120+
Ok(Arc::new(DefaultTableSource::new(table_provider)))
118121
} else {
119-
let fields: Vec<Field> = model
120-
.get_physical_columns()
121-
.iter()
122-
.map(|column| {
123-
let column = Arc::clone(column);
124-
let name = if let Some(ref expression) = column.expression {
125-
expression.clone()
126-
} else {
127-
column.name.clone()
128-
};
129-
// TODO: find a way for the remote table to provide the data type
130-
// We don't know the data type of the remote table, so we just mock a Int32 type here
131-
Field::new(name, DataType::Int8, column.not_null)
132-
})
133-
.collect();
134-
135-
let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new()));
136-
Arc::new(LogicalTableSource::new(schema))
122+
let dataset = Dataset::Model(model);
123+
let schema = dataset
124+
.to_remote_schema(Some(mdl.get_register_tables()), session_state_ref)?;
125+
Ok(Arc::new(LogicalTableSource::new(Arc::new(
126+
schema.as_arrow().clone(),
127+
))))
137128
}
138129
}
139130

wren-core/core/src/mdl/context.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ use datafusion::optimizer::push_down_filter::PushDownFilter;
3939
use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
4040
use datafusion::optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
4141
use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
42-
use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
4342
use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
4443
use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
4544
use datafusion::optimizer::OptimizerRule;
@@ -77,10 +76,13 @@ pub async fn create_ctx_with_mdl(
7776
Arc::clone(&reset_default_catalog_schema),
7877
)),
7978
Arc::new(ModelAnalyzeRule::new(
79+
Arc::clone(&analyzed_mdl),
80+
Arc::clone(&reset_default_catalog_schema),
81+
)),
82+
Arc::new(ModelGenerationRule::new(
8083
Arc::clone(&analyzed_mdl),
8184
reset_default_catalog_schema,
8285
)),
83-
Arc::new(ModelGenerationRule::new(Arc::clone(&analyzed_mdl))),
8486
Arc::new(InlineTableScan::new()),
8587
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
8688
Arc::new(ExpandWildcardRule::new()),
@@ -106,17 +108,16 @@ pub async fn create_ctx_with_mdl(
106108
fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
107109
vec![
108110
Arc::new(EliminateNestedUnion::new()),
109-
Arc::new(SimplifyExpressions::new()),
111+
// Disable SimplifyExpressions to avoid apply some function locally
112+
// Arc::new(SimplifyExpressions::new()),
110113
Arc::new(UnwrapCastInComparison::new()),
111114
Arc::new(ReplaceDistinctWithAggregate::new()),
112115
Arc::new(EliminateJoin::new()),
113116
Arc::new(DecorrelatePredicateSubquery::new()),
114117
Arc::new(ScalarSubqueryToJoin::new()),
115118
Arc::new(ExtractEquijoinPredicate::new()),
116-
// simplify expressions does not simplify expressions in subqueries, so we
117-
// run it again after running the optimizations that potentially converted
118-
// subqueries to joins
119-
Arc::new(SimplifyExpressions::new()),
119+
// Disable SimplifyExpressions to avoid apply some function locally
120+
// Arc::new(SimplifyExpressions::new()),
120121
Arc::new(RewriteDisjunctivePredicate::new()),
121122
Arc::new(EliminateDuplicatedExpr::new()),
122123
Arc::new(EliminateFilter::new()),
@@ -133,9 +134,8 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
133134
// Arc::new(PushDownLimit::new()),
134135
Arc::new(PushDownFilter::new()),
135136
Arc::new(SingleDistinctToGroupBy::new()),
136-
// The previous optimizations added expressions and projections,
137-
// that might benefit from the following rules
138-
Arc::new(SimplifyExpressions::new()),
137+
// Disable SimplifyExpressions to avoid apply some function locally
138+
// Arc::new(SimplifyExpressions::new()),
139139
Arc::new(UnwrapCastInComparison::new()),
140140
Arc::new(CommonSubexprEliminate::new()),
141141
Arc::new(EliminateGroupByConstant::new()),

wren-core/core/src/mdl/dataset.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use crate::logical_plan::utils::map_data_type;
22
use crate::mdl::manifest::{Column, Metric, Model};
33
use crate::mdl::utils::quoted;
44
use crate::mdl::{RegisterTables, SessionStateRef};
5-
use datafusion::arrow::datatypes::DataType::Utf8;
65
use datafusion::arrow::datatypes::Field;
76
use datafusion::common::DFSchema;
87
use datafusion::common::Result;
@@ -75,7 +74,7 @@ impl Column {
7574
let columns = Self::collect_columns(expr);
7675
Ok(columns
7776
.into_iter()
78-
.map(|c| Field::new(c.value, Utf8, false))
77+
.map(|c| Field::new(c.value, map_data_type(&self.r#type), false))
7978
.collect())
8079
} else {
8180
Ok(vec![self.to_field()])

wren-core/core/src/mdl/mod.rs

+51-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ mod test {
434434
use crate::mdl::function::RemoteFunction;
435435
use crate::mdl::manifest::Manifest;
436436
use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL};
437-
use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray};
437+
use datafusion::arrow::array::{
438+
ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray,
439+
};
438440
use datafusion::common::not_impl_err;
439441
use datafusion::common::Result;
440442
use datafusion::prelude::SessionContext;
@@ -810,6 +812,51 @@ mod test {
810812
Ok(())
811813
}
812814

815+
#[tokio::test]
816+
async fn test_disable_simplify_expression() -> Result<()> {
817+
let sql = "select current_date";
818+
let actual = transform_sql_with_ctx(
819+
&SessionContext::new(),
820+
Arc::new(AnalyzedWrenMDL::default()),
821+
&[],
822+
sql,
823+
)
824+
.await?;
825+
assert_eq!(actual, "SELECT current_date()");
826+
Ok(())
827+
}
828+
829+
/// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly.
830+
#[tokio::test]
831+
async fn test_infer_timestamp_column() -> Result<()> {
832+
let ctx = SessionContext::new();
833+
ctx.register_batch("artist", artist())?;
834+
let manifest = ManifestBuilder::new()
835+
.catalog("wren")
836+
.schema("test")
837+
.model(
838+
ModelBuilder::new("artist")
839+
.table_reference("artist")
840+
.column(ColumnBuilder::new("出道時間", "timestamp").build())
841+
.build(),
842+
)
843+
.build();
844+
845+
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
846+
let sql = r#"select current_date > "出道時間" from wren.test.artist"#;
847+
let actual = transform_sql_with_ctx(
848+
&SessionContext::new(),
849+
Arc::clone(&analyzed_mdl),
850+
&[],
851+
sql,
852+
)
853+
.await?;
854+
assert_eq!(actual,
855+
"SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \
856+
(SELECT artist.\"出道時間\" FROM (SELECT artist.\"出道時間\" AS \"出道時間\" FROM artist) AS artist) AS artist");
857+
Ok(())
858+
}
859+
813860
async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
814861
let ctx = SessionContext::new();
815862
// To roundtrip testing, we should register the mock table for the planned sql.
@@ -873,10 +920,13 @@ mod test {
873920
Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"]));
874921
let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"]));
875922
let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300]));
923+
let debut_time: ArrayRef =
924+
Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3]));
876925
RecordBatch::try_from_iter(vec![
877926
("名字", name),
878927
("組別", group),
879928
("訂閱數", subscribe),
929+
("出道時間", debut_time),
880930
])
881931
.unwrap()
882932
}

0 commit comments

Comments
 (0)