diff --git a/ibis-server/tests/routers/v3/connector/test_postgres.py b/ibis-server/tests/routers/v3/connector/test_postgres.py index 8961b3483..e05bf3d44 100644 --- a/ibis-server/tests/routers/v3/connector/test_postgres.py +++ b/ibis-server/tests/routers/v3/connector/test_postgres.py @@ -109,7 +109,7 @@ def test_query(manifest_str, postgres: PostgresContainer): assert len(result["data"]) == 1 assert result["data"][0] == [ "2024-01-01 23:59:59.000000", - "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", "1_370", 370, "1996-01-02", diff --git a/ibis-server/tools/mdl_validation.py b/ibis-server/tools/mdl_validation.py index d8d4dc5a0..5b27e0409 100644 --- a/ibis-server/tools/mdl_validation.py +++ b/ibis-server/tools/mdl_validation.py @@ -34,7 +34,7 @@ for model in mdl["models"]: for column in model["columns"]: # ignore hidden columns - if column.get("is_hidden"): + if column.get("isHidden"): continue sql = f"select \"{column['name']}\" from \"{model['name']}\"" try: diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 6a6291d45..603e0137a 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -15,17 +15,19 @@ use crate::logical_plan::analyze::plan::{ use crate::logical_plan::utils::create_remote_table_source; use crate::mdl::manifest::Model; use crate::mdl::utils::quoted; -use crate::mdl::AnalyzedWrenMDL; +use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; /// [ModelGenerationRule] is responsible for generating the model plan node. pub struct ModelGenerationRule { analyzed_wren_mdl: Arc, + session_state: SessionStateRef, } impl ModelGenerationRule { - pub fn new(mdl: Arc) -> Self { + pub fn new(mdl: Arc, session_state: SessionStateRef) -> Self { Self { analyzed_wren_mdl: mdl, + session_state, } } @@ -39,7 +41,10 @@ impl ModelGenerationRule { extension.node.as_any().downcast_ref::() { let source_plan = model_plan.relation_chain.clone().plan( - ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + ModelGenerationRule::new( + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + ), )?; let result = match source_plan { Some(plan) => { @@ -73,9 +78,10 @@ impl ModelGenerationRule { LogicalPlanBuilder::scan_with_filters( TableReference::from(model.table_reference()), create_remote_table_source( - &model, + Arc::clone(&model), &self.analyzed_wren_mdl.wren_mdl(), - ), + Arc::clone(&self.session_state), + )?, None, original_scan.filters.clone(), ).expect("Failed to create table scan") @@ -89,7 +95,10 @@ impl ModelGenerationRule { None => { LogicalPlanBuilder::scan( TableReference::from(model.table_reference()), - create_remote_table_source(&model, &self.analyzed_wren_mdl.wren_mdl()), + create_remote_table_source( + Arc::clone(&model), + &self.analyzed_wren_mdl.wren_mdl(), + Arc::clone(&self.session_state))?, None, ).expect("Failed to create table scan") .project(model_plan.required_exprs.clone())? @@ -111,7 +120,10 @@ impl ModelGenerationRule { .downcast_ref::( ) { let source_plan = calculation_plan.relation_chain.clone().plan( - ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + ModelGenerationRule::new( + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + ), )?; if let Expr::Alias(alias) = calculation_plan.measures[0].clone() { diff --git a/wren-core/core/src/logical_plan/utils.rs b/wren-core/core/src/logical_plan/utils.rs index 667e737e6..36c241d7e 100644 --- a/wren-core/core/src/logical_plan/utils.rs +++ b/wren-core/core/src/logical_plan/utils.rs @@ -1,3 +1,10 @@ +use crate::mdl::lineage::DatasetLink; +use crate::mdl::utils::quoted; +use crate::mdl::{ + manifest::{Column, Model}, + WrenMDL, +}; +use crate::mdl::{Dataset, SessionStateRef}; use datafusion::arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; @@ -12,14 +19,6 @@ use petgraph::Graph; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; -use crate::mdl::lineage::DatasetLink; -use crate::mdl::utils::quoted; -use crate::mdl::{ - manifest::{Column, Model}, - WrenMDL, -}; -use crate::mdl::{Dataset, SessionStateRef}; - fn create_mock_list_type() -> DataType { let string_filed = Arc::new(Field::new("string", DataType::Utf8, false)); DataType::List(string_filed) @@ -112,28 +111,20 @@ pub fn create_schema(columns: Vec>) -> Result { ))) } -pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc { +pub fn create_remote_table_source( + model: Arc, + mdl: &WrenMDL, + session_state_ref: SessionStateRef, +) -> Result> { if let Some(table_provider) = mdl.get_table(model.table_reference()) { - Arc::new(DefaultTableSource::new(table_provider)) + Ok(Arc::new(DefaultTableSource::new(table_provider))) } else { - let fields: Vec = model - .get_physical_columns() - .iter() - .map(|column| { - let column = Arc::clone(column); - let name = if let Some(ref expression) = column.expression { - expression.clone() - } else { - column.name.clone() - }; - // TODO: find a way for the remote table to provide the data type - // We don't know the data type of the remote table, so we just mock a Int32 type here - Field::new(name, DataType::Int8, column.not_null) - }) - .collect(); - - let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); - Arc::new(LogicalTableSource::new(schema)) + let dataset = Dataset::Model(model); + let schema = dataset + .to_remote_schema(Some(mdl.get_register_tables()), session_state_ref)?; + Ok(Arc::new(LogicalTableSource::new(Arc::new( + schema.as_arrow().clone(), + )))) } } diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 92800f191..b87d0690f 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -39,7 +39,6 @@ use datafusion::optimizer::push_down_filter::PushDownFilter; use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use datafusion::optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion::optimizer::simplify_expressions::SimplifyExpressions; use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion::optimizer::OptimizerRule; @@ -77,10 +76,13 @@ pub async fn create_ctx_with_mdl( Arc::clone(&reset_default_catalog_schema), )), Arc::new(ModelAnalyzeRule::new( + Arc::clone(&analyzed_mdl), + Arc::clone(&reset_default_catalog_schema), + )), + Arc::new(ModelGenerationRule::new( Arc::clone(&analyzed_mdl), reset_default_catalog_schema, )), - Arc::new(ModelGenerationRule::new(Arc::clone(&analyzed_mdl))), Arc::new(InlineTableScan::new()), // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), @@ -106,17 +108,16 @@ pub async fn create_ctx_with_mdl( fn optimize_rule_for_unparsing() -> Vec> { vec![ Arc::new(EliminateNestedUnion::new()), - Arc::new(SimplifyExpressions::new()), + // Disable SimplifyExpressions to avoid apply some function locally + // Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - // simplify expressions does not simplify expressions in subqueries, so we - // run it again after running the optimizations that potentially converted - // subqueries to joins - Arc::new(SimplifyExpressions::new()), + // Disable SimplifyExpressions to avoid apply some function locally + // Arc::new(SimplifyExpressions::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -133,9 +134,8 @@ fn optimize_rule_for_unparsing() -> Vec> { // Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), Arc::new(SingleDistinctToGroupBy::new()), - // The previous optimizations added expressions and projections, - // that might benefit from the following rules - Arc::new(SimplifyExpressions::new()), + // Disable SimplifyExpressions to avoid apply some function locally + // Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateGroupByConstant::new()), diff --git a/wren-core/core/src/mdl/dataset.rs b/wren-core/core/src/mdl/dataset.rs index a46b9d64c..6a0eee477 100644 --- a/wren-core/core/src/mdl/dataset.rs +++ b/wren-core/core/src/mdl/dataset.rs @@ -2,7 +2,6 @@ use crate::logical_plan::utils::map_data_type; use crate::mdl::manifest::{Column, Metric, Model}; use crate::mdl::utils::quoted; use crate::mdl::{RegisterTables, SessionStateRef}; -use datafusion::arrow::datatypes::DataType::Utf8; use datafusion::arrow::datatypes::Field; use datafusion::common::DFSchema; use datafusion::common::Result; @@ -75,7 +74,7 @@ impl Column { let columns = Self::collect_columns(expr); Ok(columns .into_iter() - .map(|c| Field::new(c.value, Utf8, false)) + .map(|c| Field::new(c.value, map_data_type(&self.r#type), false)) .collect()) } else { Ok(vec![self.to_field()]) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index cb682e360..d768caa81 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -434,7 +434,9 @@ mod test { use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::Manifest; use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL}; - use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray}; + use datafusion::arrow::array::{ + ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray, + }; use datafusion::common::not_impl_err; use datafusion::common::Result; use datafusion::prelude::SessionContext; @@ -810,6 +812,51 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_disable_simplify_expression() -> Result<()> { + let sql = "select current_date"; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::new(AnalyzedWrenMDL::default()), + &[], + sql, + ) + .await?; + assert_eq!(actual, "SELECT current_date()"); + Ok(()) + } + + /// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly. + #[tokio::test] + async fn test_infer_timestamp_column() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("artist", artist())?; + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("artist") + .table_reference("artist") + .column(ColumnBuilder::new("出道時間", "timestamp").build()) + .build(), + ) + .build(); + + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = r#"select current_date > "出道時間" from wren.test.artist"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \ + (SELECT artist.\"出道時間\" FROM (SELECT artist.\"出道時間\" AS \"出道時間\" FROM artist) AS artist) AS artist"); + Ok(()) + } + async fn assert_sql_valid_executable(sql: &str) -> Result<()> { let ctx = SessionContext::new(); // To roundtrip testing, we should register the mock table for the planned sql. @@ -873,10 +920,13 @@ mod test { Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"])); let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"])); let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300])); + let debut_time: ArrayRef = + Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3])); RecordBatch::try_from_iter(vec![ ("名字", name), ("組別", group), ("訂閱數", subscribe), + ("出道時間", debut_time), ]) .unwrap() }