Skip to content

fix(core): fix the remote column inferring and disable simplify expression rule #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_query(manifest_str, postgres: PostgresContainer):
assert len(result["data"]) == 1
assert result["data"][0] == [
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"1_370",
370,
"1996-01-02",
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tools/mdl_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
for model in mdl["models"]:
for column in model["columns"]:
# ignore hidden columns
if column.get("is_hidden"):
if column.get("isHidden"):
continue
sql = f"select \"{column['name']}\" from \"{model['name']}\""
try:
Expand Down
26 changes: 19 additions & 7 deletions wren-core/core/src/logical_plan/analyze/model_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ use crate::logical_plan::analyze::plan::{
use crate::logical_plan::utils::create_remote_table_source;
use crate::mdl::manifest::Model;
use crate::mdl::utils::quoted;
use crate::mdl::AnalyzedWrenMDL;
use crate::mdl::{AnalyzedWrenMDL, SessionStateRef};

/// [ModelGenerationRule] is responsible for generating the model plan node.
pub struct ModelGenerationRule {
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
session_state: SessionStateRef,
}

impl ModelGenerationRule {
pub fn new(mdl: Arc<AnalyzedWrenMDL>) -> Self {
pub fn new(mdl: Arc<AnalyzedWrenMDL>, session_state: SessionStateRef) -> Self {
Self {
analyzed_wren_mdl: mdl,
session_state,
}
}

Expand All @@ -39,7 +41,10 @@ impl ModelGenerationRule {
extension.node.as_any().downcast_ref::<ModelPlanNode>()
{
let source_plan = model_plan.relation_chain.clone().plan(
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
)?;
let result = match source_plan {
Some(plan) => {
Expand Down Expand Up @@ -73,9 +78,10 @@ impl ModelGenerationRule {
LogicalPlanBuilder::scan_with_filters(
TableReference::from(model.table_reference()),
create_remote_table_source(
&model,
Arc::clone(&model),
&self.analyzed_wren_mdl.wren_mdl(),
),
Arc::clone(&self.session_state),
)?,
None,
original_scan.filters.clone(),
).expect("Failed to create table scan")
Expand All @@ -89,7 +95,10 @@ impl ModelGenerationRule {
None => {
LogicalPlanBuilder::scan(
TableReference::from(model.table_reference()),
create_remote_table_source(&model, &self.analyzed_wren_mdl.wren_mdl()),
create_remote_table_source(
Arc::clone(&model),
&self.analyzed_wren_mdl.wren_mdl(),
Arc::clone(&self.session_state))?,
None,
).expect("Failed to create table scan")
.project(model_plan.required_exprs.clone())?
Expand All @@ -111,7 +120,10 @@ impl ModelGenerationRule {
.downcast_ref::<CalculationPlanNode>(
) {
let source_plan = calculation_plan.relation_chain.clone().plan(
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
)?;

if let Expr::Alias(alias) = calculation_plan.measures[0].clone() {
Expand Down
47 changes: 19 additions & 28 deletions wren-core/core/src/logical_plan/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use crate::mdl::lineage::DatasetLink;
use crate::mdl::utils::quoted;
use crate::mdl::{
manifest::{Column, Model},
WrenMDL,
};
use crate::mdl::{Dataset, SessionStateRef};
use datafusion::arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit,
};
Expand All @@ -12,14 +19,6 @@ use petgraph::Graph;
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};

use crate::mdl::lineage::DatasetLink;
use crate::mdl::utils::quoted;
use crate::mdl::{
manifest::{Column, Model},
WrenMDL,
};
use crate::mdl::{Dataset, SessionStateRef};

fn create_mock_list_type() -> DataType {
let string_filed = Arc::new(Field::new("string", DataType::Utf8, false));
DataType::List(string_filed)
Expand Down Expand Up @@ -112,28 +111,20 @@ pub fn create_schema(columns: Vec<Arc<Column>>) -> Result<SchemaRef> {
)))
}

pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc<dyn TableSource> {
pub fn create_remote_table_source(
model: Arc<Model>,
mdl: &WrenMDL,
session_state_ref: SessionStateRef,
) -> Result<Arc<dyn TableSource>> {
if let Some(table_provider) = mdl.get_table(model.table_reference()) {
Arc::new(DefaultTableSource::new(table_provider))
Ok(Arc::new(DefaultTableSource::new(table_provider)))
} else {
let fields: Vec<Field> = model
.get_physical_columns()
.iter()
.map(|column| {
let column = Arc::clone(column);
let name = if let Some(ref expression) = column.expression {
expression.clone()
} else {
column.name.clone()
};
// TODO: find a way for the remote table to provide the data type
// We don't know the data type of the remote table, so we just mock a Int32 type here
Field::new(name, DataType::Int8, column.not_null)
})
.collect();

let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new()));
Arc::new(LogicalTableSource::new(schema))
let dataset = Dataset::Model(model);
let schema = dataset
.to_remote_schema(Some(mdl.get_register_tables()), session_state_ref)?;
Ok(Arc::new(LogicalTableSource::new(Arc::new(
schema.as_arrow().clone(),
))))
}
}

Expand Down
20 changes: 10 additions & 10 deletions wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ use datafusion::optimizer::push_down_filter::PushDownFilter;
use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use datafusion::optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
use datafusion::optimizer::OptimizerRule;
Expand Down Expand Up @@ -77,10 +76,13 @@ pub async fn create_ctx_with_mdl(
Arc::clone(&reset_default_catalog_schema),
)),
Arc::new(ModelAnalyzeRule::new(
Arc::clone(&analyzed_mdl),
Arc::clone(&reset_default_catalog_schema),
)),
Arc::new(ModelGenerationRule::new(
Arc::clone(&analyzed_mdl),
reset_default_catalog_schema,
)),
Arc::new(ModelGenerationRule::new(Arc::clone(&analyzed_mdl))),
Arc::new(InlineTableScan::new()),
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
Arc::new(ExpandWildcardRule::new()),
Expand All @@ -106,17 +108,16 @@ pub async fn create_ctx_with_mdl(
fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
vec![
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
// Disable SimplifyExpressions to avoid apply some function locally
// Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(ExtractEquijoinPredicate::new()),
// simplify expressions does not simplify expressions in subqueries, so we
// run it again after running the optimizations that potentially converted
// subqueries to joins
Arc::new(SimplifyExpressions::new()),
// Disable SimplifyExpressions to avoid apply some function locally
// Arc::new(SimplifyExpressions::new()),
Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(EliminateDuplicatedExpr::new()),
Arc::new(EliminateFilter::new()),
Expand All @@ -133,9 +134,8 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
// Arc::new(PushDownLimit::new()),
Arc::new(PushDownFilter::new()),
Arc::new(SingleDistinctToGroupBy::new()),
// The previous optimizations added expressions and projections,
// that might benefit from the following rules
Arc::new(SimplifyExpressions::new()),
// Disable SimplifyExpressions to avoid apply some function locally
// Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Expand Down
3 changes: 1 addition & 2 deletions wren-core/core/src/mdl/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::logical_plan::utils::map_data_type;
use crate::mdl::manifest::{Column, Metric, Model};
use crate::mdl::utils::quoted;
use crate::mdl::{RegisterTables, SessionStateRef};
use datafusion::arrow::datatypes::DataType::Utf8;
use datafusion::arrow::datatypes::Field;
use datafusion::common::DFSchema;
use datafusion::common::Result;
Expand Down Expand Up @@ -75,7 +74,7 @@ impl Column {
let columns = Self::collect_columns(expr);
Ok(columns
.into_iter()
.map(|c| Field::new(c.value, Utf8, false))
.map(|c| Field::new(c.value, map_data_type(&self.r#type), false))
.collect())
} else {
Ok(vec![self.to_field()])
Expand Down
52 changes: 51 additions & 1 deletion wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ mod test {
use crate::mdl::function::RemoteFunction;
use crate::mdl::manifest::Manifest;
use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL};
use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use datafusion::arrow::array::{
ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray,
};
use datafusion::common::not_impl_err;
use datafusion::common::Result;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -810,6 +812,51 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_disable_simplify_expression() -> Result<()> {
let sql = "select current_date";
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::new(AnalyzedWrenMDL::default()),
&[],
sql,
)
.await?;
assert_eq!(actual, "SELECT current_date()");
Ok(())
}

/// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly.
#[tokio::test]
async fn test_infer_timestamp_column() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("artist", artist())?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("artist")
.table_reference("artist")
.column(ColumnBuilder::new("出道時間", "timestamp").build())
.build(),
)
.build();

let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = r#"select current_date > "出道時間" from wren.test.artist"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
assert_eq!(actual,
"SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \
(SELECT artist.\"出道時間\" FROM (SELECT artist.\"出道時間\" AS \"出道時間\" FROM artist) AS artist) AS artist");
Ok(())
}

async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
let ctx = SessionContext::new();
// To roundtrip testing, we should register the mock table for the planned sql.
Expand Down Expand Up @@ -873,10 +920,13 @@ mod test {
Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"]));
let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"]));
let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300]));
let debut_time: ArrayRef =
Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3]));
RecordBatch::try_from_iter(vec![
("名字", name),
("組別", group),
("訂閱數", subscribe),
("出道時間", debut_time),
])
.unwrap()
}
Expand Down
Loading