diff --git a/ibis-server/tests/routers/v3/connector/test_postgres.py b/ibis-server/tests/routers/v3/connector/test_postgres.py index 098eac7d4..13960e115 100644 --- a/ibis-server/tests/routers/v3/connector/test_postgres.py +++ b/ibis-server/tests/routers/v3/connector/test_postgres.py @@ -27,19 +27,17 @@ "table": "orders", }, "columns": [ - {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, - {"name": "custkey", "expression": "o_custkey", "type": "integer"}, + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, { - "name": "orderstatus", - "expression": "o_orderstatus", + "name": "o_orderstatus", "type": "varchar", }, { - "name": "totalprice", - "expression": "o_totalprice", + "name": "o_totalprice", "type": "double", }, - {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, + {"name": "o_orderdate", "type": "date"}, { "name": "order_cust_key", "expression": "concat(o_orderkey, '_', o_custkey)", @@ -56,7 +54,7 @@ "type": "timestamp", }, ], - "primaryKey": "orderkey", + "primaryKey": "o_orderkey", }, { "name": "customer", @@ -65,10 +63,10 @@ "table": "customer", }, "columns": [ - {"name": "custkey", "expression": "c_custkey", "type": "integer"}, - {"name": "name", "expression": "c_name", "type": "varchar"}, + {"name": "c_custkey", "type": "integer"}, + {"name": "c_name", "type": "varchar"}, ], - "primaryKey": "custkey", + "primaryKey": "c_custkey", }, ], } @@ -109,17 +107,17 @@ def test_query(postgres: PostgresContainer): "2024-01-01 23:59:59.000000 UTC", "1_370", 370, - 1, "1996-01-02", + 1, "O", "172799.49", ] assert result["dtypes"] == { - "orderkey": "int32", - "custkey": "int32", - "orderstatus": "object", - "totalprice": "object", - "orderdate": "object", + "o_orderkey": "int32", + "o_custkey": "int32", + "o_orderstatus": "object", + "o_totalprice": "object", + "o_orderdate": "object", "order_cust_key": "object", "timestamp": "object", "timestamptz": "object", @@ -270,7 +268,7 @@ def test_validate_with_unknown_rule(postgres: PostgresContainer): json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "parameters": {"modelName": "orders", "columnName": "orderkey"}, + "parameters": {"modelName": "orders", "columnName": "o_orderkey"}, }, ) assert response.status_code == 422 @@ -287,7 +285,7 @@ def test_validate_rule_column_is_valid(postgres: PostgresContainer): json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "parameters": {"modelName": "orders", "columnName": "orderkey"}, + "parameters": {"modelName": "orders", "columnName": "o_orderkey"}, }, ) assert response.status_code == 204 @@ -302,7 +300,7 @@ def test_validate_rule_column_is_valid_with_invalid_parameters( json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, + "parameters": {"modelName": "X", "columnName": "o_orderkey"}, }, ) assert response.status_code == 422 @@ -352,7 +350,7 @@ def test_validate_rule_column_is_valid_without_one_parameter( json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, + "parameters": {"columnName": "o_orderkey"}, }, ) assert response.status_code == 422 @@ -364,7 +362,7 @@ def test_dry_plan(): url=f"{base_url}/dry-plan", json={ "manifestStr": manifest_str, - "sql": "SELECT orderkey, order_cust_key FROM wren.public.orders LIMIT 1", + "sql": "SELECT o_orderkey, order_cust_key FROM wren.public.orders LIMIT 1", }, ) assert response.status_code == 200 diff --git a/wren-modeling-py/src/lib.rs b/wren-modeling-py/src/lib.rs index 51855a2b4..ddac8934b 100644 --- a/wren-modeling-py/src/lib.rs +++ b/wren-modeling-py/src/lib.rs @@ -57,10 +57,10 @@ mod tests { "table": "customer" }, "columns": [ - {"name": "custkey", "expression": "c_custkey", "type": "integer"}, - {"name": "name", "expression": "c_name", "type": "varchar"} + {"name": "c_custkey", "type": "integer"}, + {"name": "c_name", "type": "varchar"} ], - "primaryKey": "custkey" + "primaryKey": "c_custkey" } ] }"#; @@ -71,7 +71,7 @@ mod tests { .unwrap(); assert_eq!( transformed_sql, - r#"SELECT * FROM (SELECT main.customer.c_custkey AS custkey, main.customer.c_name AS "name" FROM main.customer) AS customer"# + r#"SELECT * FROM (SELECT main.customer.c_custkey AS c_custkey, main.customer.c_name AS c_name FROM main.customer) AS customer"# ); } } diff --git a/wren-modeling-py/tests/test_modeling_core.py b/wren-modeling-py/tests/test_modeling_core.py index 9933204fa..45e943977 100644 --- a/wren-modeling-py/tests/test_modeling_core.py +++ b/wren-modeling-py/tests/test_modeling_core.py @@ -14,10 +14,10 @@ "table": "customer", }, "columns": [ - {"name": "custkey", "expression": "c_custkey", "type": "integer"}, - {"name": "name", "expression": "c_name", "type": "varchar"}, + {"name": "c_custkey", "type": "integer"}, + {"name": "c_name", "type": "varchar"}, ], - "primaryKey": "custkey", + "primaryKey": "c_custkey", }, ], } @@ -30,5 +30,5 @@ def test_transform_sql(): rewritten_sql = wren_core.transform_sql(manifest_str, sql) assert ( rewritten_sql - == 'SELECT * FROM (SELECT main.customer.c_custkey AS custkey, main.customer.c_name AS "name" FROM main.customer) AS customer' + == 'SELECT * FROM (SELECT main.customer.c_custkey AS c_custkey, main.customer.c_name AS c_name FROM main.customer) AS customer' ) diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs b/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs index b1467dc6b..d97c59dfc 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs @@ -28,7 +28,7 @@ use std::sync::Arc; /// 3. Remove the catalog and schema prefix of Wren for the column and refresh the schema. (top-down) /// /// The traverse path of step 1 and step 2 should be same. -/// The corresponding scope will be pushed to or popped from the scope_queue sequentially. +/// The corresponding scope will be pushed to or popped from the childs of [Scope] sequentially. pub struct ModelAnalyzeRule { analyzed_wren_mdl: Arc, session_state: SessionStateRef, @@ -36,10 +36,9 @@ pub struct ModelAnalyzeRule { impl AnalyzerRule for ModelAnalyzeRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - let scope_queue = RefCell::new(VecDeque::new()); let root = RefCell::new(Scope::new()); - self.analyze_scope(plan, &root, &scope_queue)? - .map_data(|plan| self.analyze_model(plan, &root, &scope_queue).data())? + self.analyze_scope(plan, &root)? + .map_data(|plan| self.analyze_model(plan, &root).data())? .map_data(|plan| { plan.transform_up_with_subqueries(&|plan| -> Result< Transformed, @@ -79,9 +78,8 @@ impl ModelAnalyzeRule { &self, plan: LogicalPlan, root: &RefCell, - scope_queue: &RefCell>>, ) -> Result> { - plan.transform_up(&|plan| -> Result> { + plan.transform_up(&mut |plan| -> Result> { let plan = self.analyze_scope_internal(plan, root)?.data; plan.map_subqueries(|plan| { if let LogicalPlan::Subquery(Subquery { @@ -90,18 +88,17 @@ impl ModelAnalyzeRule { }) = &plan { outer_ref_columns.iter().try_for_each(|expr| { - let mut scope_mut = root.borrow_mut(); - self.collect_required_column(expr.clone(), &mut scope_mut) + let mut root_mut = root.borrow_mut(); + self.collect_required_column(expr.clone(), &mut root_mut) })?; let child_scope = RefCell::new(Scope::new_child(RefCell::clone(root))); self.analyze_scope( Arc::unwrap_or_clone(Arc::clone(subquery)), &child_scope, - scope_queue, )?; - let mut scope_queue = scope_queue.borrow_mut(); - scope_queue.push_back(child_scope); + let mut root_mut = root.borrow_mut(); + root_mut.push_child(child_scope); } Ok(Transformed::no(plan)) }) @@ -245,10 +242,13 @@ impl ModelAnalyzeRule { .get_view(relation.table()) .is_none() { - scope.add_required_column( + let added = scope.add_required_column( relation.clone(), - Expr::Column(Column::new(Some(relation), name)), + Expr::Column(Column::new(Some(relation.clone()), name)), )?; + if !added { + return plan_err!("Relation {} isn't visited", relation); + } } } // It is possible that the column is a rebase column from the aggregation or join @@ -274,22 +274,20 @@ impl ModelAnalyzeRule { &self, plan: LogicalPlan, root: &RefCell, - scope_queue: &RefCell>>, ) -> Result> { - plan.transform_up(&|plan| -> Result> { - let plan = self.analyze_model_internal(plan, root, scope_queue)?.data; + plan.transform_up(&mut |plan| -> Result> { + let plan = self.analyze_model_internal(plan, root)?.data; // If the plan contains subquery, we should analyze the subquery recursively + let mut root = root.borrow_mut(); plan.map_subqueries(|plan| { if let LogicalPlan::Subquery(subquery) = &plan { - let mut scope_queue_mut = scope_queue.borrow_mut(); - let Some(child_scope) = scope_queue_mut.pop_front() else { + let Some(child_scope) = root.pop_child() else { return internal_err!("No child scope found for subquery"); }; let transformed = self .analyze_model( Arc::unwrap_or_clone(Arc::clone(&subquery.subquery)), &child_scope, - scope_queue, )? .data; return Ok(Transformed::yes(LogicalPlan::Subquery( @@ -306,7 +304,6 @@ impl ModelAnalyzeRule { &self, plan: LogicalPlan, scope: &RefCell, - scope_queue: &RefCell>>, ) -> Result> { match plan { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { @@ -315,63 +312,8 @@ impl ModelAnalyzeRule { // SubqueryAlias -> SubqueryAlias -> Extension -> ModelPlanNode // to get the correct required columns match Arc::unwrap_or_clone(Arc::clone(&input)) { - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { - if let LogicalPlan::Extension(Extension { node }) = - Arc::unwrap_or_clone(Arc::clone(&input)) - { - if let Some(model_node) = - node.as_any().downcast_ref::() - { - if let Some(model) = self - .analyzed_wren_mdl - .wren_mdl() - .get_model(model_node.plan_name()) - { - let scope = scope.borrow(); - let field: Vec = if let Some(used_columns) = - scope.try_get_required_columns(&alias) - { - used_columns.iter().cloned().collect() - } else { - // If the required columns are not found in the current scope but the table is visited, - // it could be a count(*) query - if scope.try_get_visited_dataset(&alias).is_none() - { - return internal_err!( - "Table {} not found in the visited dataset and required columns map", - alias); - }; - vec![] - }; - let model_plan = LogicalPlan::Extension(Extension { - node: Arc::new(ModelPlanNode::new( - Arc::clone(&model), - field, - None, - Arc::clone(&self.analyzed_wren_mdl), - Arc::clone(&self.session_state), - )?), - }); - let subquery = LogicalPlanBuilder::from(model_plan) - .alias(alias)? - .build()?; - Ok(Transformed::yes(subquery)) - } else { - internal_err!( - "Model {} not found in the WrenMDL", - model_node.plan_name() - ) - } - } else { - internal_err!( - "ModelPlanNode not found in the Extension node" - ) - } - } else { - Ok(Transformed::no(LogicalPlan::SubqueryAlias( - SubqueryAlias::try_new(input, alias)?, - ))) - } + LogicalPlan::SubqueryAlias(subquery_alias) => { + self.analyze_subquery_alias_model(subquery_alias, scope, alias) } LogicalPlan::TableScan(table_scan) => { let model_plan = self @@ -438,17 +380,6 @@ impl ModelAnalyzeRule { null_equals_null: join.null_equals_null, }))) } - LogicalPlan::Subquery(Subquery { subquery, .. }) => { - let mut scope_queue_mut = scope_queue.borrow_mut(); - let Some(child_scope) = scope_queue_mut.pop_front() else { - return internal_err!("No child scope found for subquery"); - }; - self.analyze_model( - Arc::unwrap_or_clone(subquery), - &child_scope, - scope_queue, - ) - } _ => Ok(Transformed::no(plan)), } } @@ -506,6 +437,69 @@ impl ModelAnalyzeRule { } } + /// Because the bottom-up transformation is used, the table_scan is already transformed + /// to the ModelPlanNode before the SubqueryAlias. We should check the patten of Wren-generated model plan like: + /// SubqueryAlias -> SubqueryAlias -> Extension -> ModelPlanNode + /// to get the correct required columns + fn analyze_subquery_alias_model( + &self, + subquery_alias: SubqueryAlias, + scope: &RefCell, + alias: TableReference, + ) -> Result> { + let SubqueryAlias { input, .. } = subquery_alias; + if let LogicalPlan::Extension(Extension { node }) = + Arc::unwrap_or_clone(Arc::clone(&input)) + { + if let Some(model_node) = node.as_any().downcast_ref::() { + if let Some(model) = self + .analyzed_wren_mdl + .wren_mdl() + .get_model(model_node.plan_name()) + { + let scope = scope.borrow(); + let field: Vec = if let Some(used_columns) = + scope.try_get_required_columns(&alias) + { + used_columns.iter().cloned().collect() + } else { + // If the required columns are not found in the current scope but the table is visited, + // it could be a count(*) query + if scope.try_get_visited_dataset(&alias).is_none() { + return internal_err!( + "Table {} not found in the visited dataset and required columns map", + alias); + }; + vec![] + }; + let model_plan = LogicalPlan::Extension(Extension { + node: Arc::new(ModelPlanNode::new( + Arc::clone(&model), + field, + None, + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + )?), + }); + let subquery = + LogicalPlanBuilder::from(model_plan).alias(alias)?.build()?; + Ok(Transformed::yes(subquery)) + } else { + internal_err!( + "Model {} not found in the WrenMDL", + model_node.plan_name() + ) + } + } else { + internal_err!("ModelPlanNode not found in the Extension node") + } + } else { + Ok(Transformed::no(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(input, alias)?, + ))) + } + } + /// Remove the catalog and schema prefix of Wren for the column and refresh the schema. /// The plan created by DataFusion is always with the Wren prefix for the column name. /// Something like "wrenai.public.order_items_model.price". However, the model plan will be rewritten to a subquery alias @@ -760,6 +754,7 @@ pub struct Scope { visited_tables: HashSet, /// The parent scope parent: Option>>, + childs: VecDeque>, } impl Scope { @@ -769,6 +764,7 @@ impl Scope { visited_dataset: HashMap::new(), visited_tables: HashSet::new(), parent: None, + childs: VecDeque::new(), } } @@ -778,31 +774,49 @@ impl Scope { visited_dataset: HashMap::new(), visited_tables: HashSet::new(), parent: Some(Box::new(parent)), + childs: VecDeque::new(), } } + pub fn pop_child(&mut self) -> Option> { + self.childs.pop_front() + } + + pub fn push_child(&mut self, child: RefCell) { + self.childs.push_back(child); + } + + /// Add the required column to the scope, return true if the column is added successfully. + /// If the table isn't exist in the current scope, try to add the column to the parent scope. + /// If the table is not visited by the parent and the current scope, return false pub fn add_required_column( &mut self, table_ref: TableReference, expr: Expr, - ) -> Result<()> { - if self.visited_dataset.contains_key(&table_ref) { + ) -> Result { + let added = if self.visited_dataset.contains_key(&table_ref) { self.required_columns - .entry(table_ref) + .entry(table_ref.clone()) .or_default() .insert(expr); - Ok(()) + true } else if let Some(ref parent) = &self.parent { parent .clone() .borrow_mut() - .add_required_column(table_ref, expr)?; - Ok(()) - } else if self.visited_tables.contains(&table_ref) { + .add_required_column(table_ref.clone(), expr)? + } else { + false + }; + + if added { + Ok(true) + } else if self.try_get_visited_table(&table_ref).is_some() { // If the table is visited but the dataset is not found, it could be a subquery alias - Ok(()) + Ok(true) } else { - plan_err!("Table {} not found in the visited dataset", table_ref) + // the table is not visited by both the parent and the current scope + Ok(false) } } @@ -847,16 +861,19 @@ impl Scope { } } - pub fn try_get_visited_table(&self, table_ref: &TableReference) -> bool { + pub fn try_get_visited_table( + &self, + table_ref: &TableReference, + ) -> Option { if self.visited_tables.contains(table_ref) { - return true; + return Some(table_ref.clone()); } if let Some(ref parent) = &self.parent { let scope = parent.borrow(); scope.try_get_visited_table(table_ref) } else { - false + None } } } diff --git a/wren-modeling-rs/core/src/mdl/context.rs b/wren-modeling-rs/core/src/mdl/context.rs index 4c9b48047..55ddf6021 100644 --- a/wren-modeling-rs/core/src/mdl/context.rs +++ b/wren-modeling-rs/core/src/mdl/context.rs @@ -94,7 +94,7 @@ pub async fn register_table_with_mdl( Ok(()) } -struct WrenDataSource { +pub struct WrenDataSource { schema: SchemaRef, } @@ -103,6 +103,10 @@ impl WrenDataSource { let schema = create_schema(model.get_physical_columns().clone())?; Ok(Self { schema }) } + + pub fn new_with_schema(schema: SchemaRef) -> Self { + Self { schema } + } } #[async_trait] diff --git a/wren-modeling-rs/core/src/mdl/function.rs b/wren-modeling-rs/core/src/mdl/function.rs new file mode 100644 index 000000000..1b96c85fd --- /dev/null +++ b/wren-modeling-rs/core/src/mdl/function.rs @@ -0,0 +1,198 @@ +use datafusion::arrow::datatypes::DataType; +use datafusion::common::internal_err; +use datafusion::common::Result; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, + Signature, TypeSignature, Volatility, WindowUDFImpl, +}; +use std::any::Any; + +/// A scalar UDF that will be bypassed when planning logical plan. +/// This is used to register the remote function to the context. The function should not be +/// invoked by DataFusion. It's only used to generate the logical plan and unparsed them to SQL. +#[derive(Debug)] +pub struct ByPassScalarUDF { + name: String, + return_type: DataType, + signature: Signature, +} + +impl ByPassScalarUDF { + pub fn new(name: &str, return_type: DataType) -> Self { + Self { + name: name.to_string(), + return_type, + signature: Signature::one_of( + vec![ + TypeSignature::VariadicAny, + TypeSignature::Uniform(0, vec![]), + ], + Volatility::Volatile, + ), + } + } +} + +impl ScalarUDFImpl for ByPassScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + internal_err!("This function should not be called") + } +} + +/// An aggregate UDF that will be bypassed when planning logical plan. +/// See [ByPassScalarUDF] for more details. +#[derive(Debug)] +pub struct ByPassAggregateUDF { + name: String, + return_type: DataType, + signature: Signature, +} + +impl ByPassAggregateUDF { + pub fn new(name: &str, return_type: DataType) -> Self { + Self { + name: name.to_string(), + return_type, + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ByPassAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + internal_err!("This function should not be called") + } +} + +/// A window UDF that will be bypassed when planning logical plan. +/// See [ByPassScalarUDF] for more details. +#[derive(Debug)] +pub struct ByPassWindowFunction { + name: String, + return_type: DataType, + signature: Signature, +} + +impl ByPassWindowFunction { + pub fn new(name: &str, return_type: DataType) -> Self { + Self { + name: name.to_string(), + return_type, + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl WindowUDFImpl for ByPassWindowFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + internal_err!("This function should not be called") + } +} + +#[cfg(test)] +mod test { + use crate::mdl::function::{ + ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, + }; + use datafusion::arrow::datatypes::DataType; + use datafusion::common::Result; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn test_by_pass_scalar_udf() -> Result<()> { + let udf = ByPassScalarUDF::new("date_diff", DataType::Int64); + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::new_from_impl(udf)); + + let plan = ctx + .sql("SELECT date_diff(1, 2)") + .await? + .into_unoptimized_plan(); + let expected = "Projection: date_diff(Int64(1), Int64(2))\n EmptyRelation"; + assert_eq!(format!("{plan}"), expected); + Ok(()) + } + + #[tokio::test] + async fn test_by_pass_agg_udf() -> Result<()> { + let udf = ByPassAggregateUDF::new("count_self", DataType::Int64); + let ctx = SessionContext::new(); + ctx.register_udaf(AggregateUDF::new_from_impl(udf)); + + let plan = ctx.sql("SELECT c2, count_self(*) FROM (VALUES (1,2), (2,3), (3,4)) t(c1, c2) GROUP BY 1").await?.into_unoptimized_plan(); + let expected = "Projection: t.c2, count_self(*)\ + \n Aggregate: groupBy=[[t.c2]], aggr=[[count_self(*)]]\ + \n SubqueryAlias: t\ + \n Projection: column1 AS c1, column2 AS c2\ + \n Values: (Int64(1), Int64(2)), (Int64(2), Int64(3)), (Int64(3), Int64(4))"; + assert_eq!(format!("{plan}"), expected); + Ok(()) + } + + #[tokio::test] + async fn test_by_pass_window_udf() -> Result<()> { + let udf = ByPassWindowFunction::new("custom_window", DataType::Int64); + let ctx = SessionContext::new(); + ctx.register_udwf(WindowUDF::new_from_impl(udf)); + + let plan = ctx + .sql("SELECT custom_window(1, 2) OVER ()") + .await? + .into_unoptimized_plan(); + let expected = "Projection: custom_window(Int64(1),Int64(2)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[custom_window(Int64(1), Int64(2)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n EmptyRelation"; + assert_eq!(format!("{plan}"), expected); + Ok(()) + } +} diff --git a/wren-modeling-rs/core/src/mdl/mod.rs b/wren-modeling-rs/core/src/mdl/mod.rs index 07f5a62bd..8c7c4f8b5 100644 --- a/wren-modeling-rs/core/src/mdl/mod.rs +++ b/wren-modeling-rs/core/src/mdl/mod.rs @@ -1,27 +1,28 @@ +use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule; +use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule; +use crate::logical_plan::analyze::model_generation::ModelGenerationRule; +use crate::logical_plan::utils::from_qualified_name_str; +use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl, WrenDataSource}; +use crate::mdl::manifest::{Column, Manifest, Model, View}; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::prelude::SessionContext; -use datafusion::sql::unparser::dialect::Dialect; +use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle}; use datafusion::sql::unparser::Unparser; -use log::{debug, info}; -use parking_lot::RwLock; -use std::{collections::HashMap, sync::Arc}; - -use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule; -use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule; -use crate::logical_plan::analyze::model_generation::ModelGenerationRule; -use crate::logical_plan::utils::from_qualified_name_str; -use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl}; -use crate::mdl::manifest::{Column, Manifest, Model, View}; +use datafusion::sql::TableReference; pub use dataset::Dataset; +use log::{debug, info}; use manifest::Relationship; +use parking_lot::RwLock; use regex::Regex; +use std::{collections::HashMap, sync::Arc}; pub mod builder; pub mod context; pub(crate) mod dataset; +pub mod function; pub mod lineage; pub mod manifest; pub mod utils; @@ -35,7 +36,7 @@ pub struct AnalyzedWrenMDL { impl AnalyzedWrenMDL { pub fn analyze(manifest: Manifest) -> Result { - let wren_mdl = Arc::new(WrenMDL::new(manifest)); + let wren_mdl = Arc::new(WrenMDL::new_and_register_table_ref(manifest)); let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?); Ok(AnalyzedWrenMDL { wren_mdl, lineage }) } @@ -135,6 +136,36 @@ impl WrenMDL { Arc::new(WrenMDL::new(manifest)) } + /// Create a WrenMDL from a manifest and register the table reference of the model as a remote table. + /// All the column without expression will be considered a column + pub fn new_and_register_table_ref(manifest: Manifest) -> Self { + let mut mdl = WrenMDL::new(manifest); + let sources: Vec<_> = mdl + .models() + .iter() + .map(|model| { + let name = TableReference::from(&model.table_reference); + let fields: Vec<_> = model + .columns + .iter() + .filter(|column| { + !column.is_calculated + && column.expression.is_none() + && column.relationship.is_none() + }) + .map(|column| column.to_field()) + .collect(); + let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields)); + let datasource = WrenDataSource::new_with_schema(schema); + (name.to_quoted_string(), Arc::new(datasource)) + }) + .collect(); + sources + .into_iter() + .for_each(|(name, ds_ref)| mdl.register_table(name, ds_ref)); + mdl + } + pub fn register_table(&mut self, name: String, table: Arc) { self.register_tables.insert(name, table); } @@ -155,6 +186,10 @@ impl WrenMDL { &self.manifest.schema } + pub fn models(&self) -> &[Arc] { + &self.manifest.models + } + pub fn get_model(&self, name: &str) -> Option> { self.manifest .models @@ -248,6 +283,10 @@ impl Dialect for WrenDialect { None } } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::SQLStandard + } } fn non_lowercase(sql: &str) -> bool { @@ -327,7 +366,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); let _ = mdl::transform_sql( Arc::clone(&analyzed_mdl), - "select orderkey + orderkey from test.test.orders", + "select o_orderkey + o_orderkey from test.test.orders", )?; Ok(()) } @@ -346,11 +385,11 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); let tests: Vec<&str> = vec![ - "select orderkey + orderkey from test.test.orders", - "select orderkey from test.test.orders where orders.totalprice > 10", - "select orders.orderkey from test.test.orders left join test.test.customer on (orders.custkey = customer.custkey) where orders.totalprice > 10", - "select orderkey, sum(totalprice) from test.test.orders group by 1", - "select orderkey, count(*) from test.test.orders where orders.totalprice > 10 group by 1", + "select o_orderkey + o_orderkey from test.test.orders", + "select o_orderkey from test.test.orders where orders.o_totalprice > 10", + "select orders.o_orderkey from test.test.orders left join test.test.customer on (orders.o_custkey = customer.c_custkey) where orders.o_totalprice > 10", + "select o_orderkey, sum(o_totalprice) from test.test.orders group by 1", + "select o_orderkey, count(*) from test.test.orders where orders.o_totalprice > 10 group by 1", "select totalcost from test.test.profile", "select totalcost from profile", // TODO: support calculated without relationship diff --git a/wren-modeling-rs/core/src/mdl/utils.rs b/wren-modeling-rs/core/src/mdl/utils.rs index 0d8cc1cc3..66013f336 100644 --- a/wren-modeling-rs/core/src/mdl/utils.rs +++ b/wren-modeling-rs/core/src/mdl/utils.rs @@ -247,7 +247,7 @@ mod tests { analyzed_mdl.clone(), ctx.state_ref(), )?; - assert_eq!(expr.to_string(), "customer.name"); + assert_eq!(expr.to_string(), "customer.c_name"); Ok(()) } @@ -275,7 +275,7 @@ mod tests { analyzed_mdl, ctx.state_ref(), )?; - assert_eq!(expr.to_string(), "orders.orderkey + orders.custkey"); + assert_eq!(expr.to_string(), "orders.o_orderkey + orders.o_custkey"); Ok(()) } @@ -312,11 +312,11 @@ mod tests { let ctx = SessionContext::new(); let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); let expr = super::create_wren_expr_for_model( - "name", + "c_name", Arc::clone(&model), ctx.state_ref(), )?; - assert_eq!(expr.to_string(), "customer.name"); + assert_eq!(expr.to_string(), "customer.c_name"); Ok(()) } diff --git a/wren-modeling-rs/core/tests/data/mdl.json b/wren-modeling-rs/core/tests/data/mdl.json index 994b10ea2..10e775eb8 100644 --- a/wren-modeling-rs/core/tests/data/mdl.json +++ b/wren-modeling-rs/core/tests/data/mdl.json @@ -11,19 +11,17 @@ }, "columns": [ { - "name": "custkey", - "type": "integer", - "expression": "c_custkey" + "name": "c_custkey", + "type": "integer" }, { - "name": "name", - "type": "varchar", - "expression": "c_name" + "name": "c_name", + "type": "varchar" }, { "name": "custkey_plus", "type": "integer", - "expression": "custkey + 1", + "expression": "c_custkey + 1", "isCalculated": true }, { @@ -36,7 +34,7 @@ } } ], - "primaryKey": "custkey", + "primaryKey": "c_custkey", "properties": { "description": "This is a customer table", "maintainer": "test" @@ -49,19 +47,16 @@ }, "columns": [ { - "name": "custkey", - "type": "integer", - "expression": "p_custkey" + "name": "p_custkey", + "type": "integer" }, { - "name": "phone", - "type": "varchar", - "expression": "p_phone" + "name": "p_phone", + "type": "varchar" }, { - "name": "sex", - "type": "varchar", - "expression": "p_sex" + "name": "p_sex", + "type": "varchar" }, { "name": "customer", @@ -72,10 +67,10 @@ "name": "totalcost", "type": "integer", "isCalculated": true, - "expression": "sum(customer.orders.totalprice)" + "expression": "sum(customer.orders.o_totalprice)" } ], - "primaryKey": "custkey" + "primaryKey": "p_custkey" }, { "name": "orders", @@ -86,19 +81,16 @@ }, "columns": [ { - "name": "orderkey", - "type": "integer", - "expression": "o_orderkey" + "name": "o_orderkey", + "type": "integer" }, { - "name": "custkey", - "type": "integer", - "expression": "o_custkey" + "name": "o_custkey", + "type": "integer" }, { - "name": "totalprice", - "type": "integer", - "expression": "o_totalprice" + "name": "o_totalprice", + "type": "integer" }, { "name": "customer", @@ -108,23 +100,23 @@ { "name": "customer_name", "type": "varchar", - "expression": "customer.name", + "expression": "customer.c_name", "isCalculated": true }, { "name": "orderkey_plus_custkey", "type": "integer", - "expression": "orderkey + custkey", + "expression": "o_orderkey + o_custkey", "isCalculated": true }, { "name": "hash_orderkey", "type": "varchar", - "expression": "md5(orderkey)", + "expression": "md5(o_orderkey)", "isCalculated": true } ], - "primaryKey": "orderkey" + "primaryKey": "o_orderkey" } ], "relationships": [ @@ -132,13 +124,13 @@ "name": "CustomerOrders", "models": ["customer", "orders"], "joinType": "ONE_TO_MANY", - "condition": "customer.custkey = orders.custkey" + "condition": "customer.c_custkey = orders.o_custkey" }, { "name" : "CustomerProfile", "models": ["customer", "profile"], "joinType": "ONE_TO_ONE", - "condition": "customer.custkey = profile.custkey" + "condition": "customer.c_custkey = profile.p_custkey" } ], "views": [ diff --git a/wren-modeling-rs/sqllogictest/test_files/model.slt b/wren-modeling-rs/sqllogictest/test_files/model.slt index bf76b5f93..70867ef86 100644 --- a/wren-modeling-rs/sqllogictest/test_files/model.slt +++ b/wren-modeling-rs/sqllogictest/test_files/model.slt @@ -37,6 +37,18 @@ select cnt1 = cnt2 from (select count(*) as cnt1 from (select "Customer_state" f ---- true +query I rowsort +WITH w1 as (select "Id" from "Order_items" where "Price" in + (select distinct "Price" from "Order_items" order by "Price" DESC LIMIT 5)) +select * from w1; +---- +123 +175 +178 +201 +56 +9 + # TODO: occurred fatal runtime error: stack overflow #query B #select actual = expected from (select "Totalprice" as actual from wrenai.public."Orders" where "Order_id" = '76754c0e642c8f99a8c3fcb8a14ac700'), (select sum(price) as expected from datafusion.public.order_items where order_id = '76754c0e642c8f99a8c3fcb8a14ac700') limit 1; @@ -48,9 +60,11 @@ select "Id", "Price" from "Order_items" where "Order_id" in (SELECT "Order_id" F ---- 105 287.4 -# TODO: DataFusion has some case sensitivity issue with the outer reference column name -# Test the query with outer reference column -# query I -# select "Customer_id" from wrenai.public."Orders" where not exists (select 1 from wrenai.public."Order_items" where "Orders"."Order_id" = "Order_items"."Order_id") -# ---- -# 1 +query T +select "Customer_id" from wrenai.public."Orders" where exists (select 1 from wrenai.public."Order_items" where "Orders"."Order_id" = "Order_items"."Order_id") order by 1 limit 5; +---- +0049e8442c2a3e4a8d1ff5a9549abd53 +024dad8e71332c433bc9a494565b9c49 +02d1b5b8831241174c6ef13efd35abbd +04eafb40a16989307464f27f1fed8907 +0732c0881c70ebcda536a4b14e9db106 diff --git a/wren-modeling-rs/sqllogictest/test_files/tpch/q15.slt.part b/wren-modeling-rs/sqllogictest/test_files/tpch/q15.slt.part index 39fc4a1d2..888aaaebb 100644 --- a/wren-modeling-rs/sqllogictest/test_files/tpch/q15.slt.part +++ b/wren-modeling-rs/sqllogictest/test_files/tpch/q15.slt.part @@ -16,9 +16,7 @@ # specific language governing permissions and limitations # under the License. -# TODO: External error: query failed: DataFusion error: Arrow error: Parser error: Invalid input syntax for type interval: "0 YEARS 3 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS seconds" -#query ITTTR -query error +query ITTTR with revenue0 (supplier_no, total_revenue) as ( select l_suppkey, @@ -50,5 +48,5 @@ where ) order by s_suppkey; -#---- -#21 Supplier#000000021 TZoQwNFFO i,baXpbpin02,hvuhE,GRVIKm 12-253-590-5816 1161099.4636 +---- +21 Supplier#000000021 TZoQwNFFO i,baXpbpin02,hvuhE,GRVIKm 12-253-590-5816 1161099.4636 diff --git a/wren-modeling-rs/sqllogictest/test_files/tpch/q20.slt.part b/wren-modeling-rs/sqllogictest/test_files/tpch/q20.slt.part index 525dc4c0b..e8aea1495 100644 --- a/wren-modeling-rs/sqllogictest/test_files/tpch/q20.slt.part +++ b/wren-modeling-rs/sqllogictest/test_files/tpch/q20.slt.part @@ -16,9 +16,7 @@ # specific language governing permissions and limitations # under the License. -# TODO: error: query failed: DataFusion error: Arrow error: Parser error: Invalid input syntax for type interval: "0 YEARS 12 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS seconds" -# query TT -query error +query TT select s_name, s_address @@ -56,4 +54,4 @@ where and n_name = 'CANADA' order by s_name; -#---- +---- diff --git a/wren-modeling-rs/sqllogictest/test_files/tpch/q4.slt.part b/wren-modeling-rs/sqllogictest/test_files/tpch/q4.slt.part index c823eb353..9becc0e9d 100644 --- a/wren-modeling-rs/sqllogictest/test_files/tpch/q4.slt.part +++ b/wren-modeling-rs/sqllogictest/test_files/tpch/q4.slt.part @@ -16,10 +16,7 @@ # specific language governing permissions and limitations # under the License. -# TODO -# DataFusion error: Arrow error: Parser error: Invalid input syntax for type interval: "0 YEARS 3 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS seconds" -#query TI -query error +query TI select o_orderpriority, count(*) as order_count @@ -41,9 +38,9 @@ group by o_orderpriority order by o_orderpriority; -#---- -#1-URGENT 93 -#2-HIGH 103 -#3-MEDIUM 109 -#4-NOT SPECIFIED 102 -#5-LOW 128 +---- +1-URGENT 93 +2-HIGH 103 +3-MEDIUM 109 +4-NOT SPECIFIED 102 +5-LOW 128