diff --git a/wren-modeling-py/Cargo.lock b/wren-modeling-py/Cargo.lock index 351d75f08..ba7f820ed 100644 --- a/wren-modeling-py/Cargo.lock +++ b/wren-modeling-py/Cargo.lock @@ -2633,6 +2633,7 @@ dependencies = [ "datafusion", "env_logger", "log", + "parking_lot", "petgraph", "petgraph-evcxr", "serde", diff --git a/wren-modeling-rs/core/Cargo.toml b/wren-modeling-rs/core/Cargo.toml index 35b044e20..f1bc208a7 100644 --- a/wren-modeling-rs/core/Cargo.toml +++ b/wren-modeling-rs/core/Cargo.toml @@ -17,6 +17,7 @@ async-trait = { workspace = true } datafusion = { workspace = true } env_logger = { workspace = true } log = { workspace = true } +parking_lot = "0.12.3" petgraph = "0.6.5" petgraph-evcxr = "*" serde = { workspace = true } diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs b/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs index 41916893d..2832af3bc 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/plan.rs @@ -27,7 +27,8 @@ use crate::mdl::utils::{ create_remote_expr_for_model, create_wren_calculated_field_expr, create_wren_expr_for_model, is_dag, }; -use crate::mdl::{AnalyzedWrenMDL, ColumnReference, Dataset}; +use crate::mdl::Dataset; +use crate::mdl::{AnalyzedWrenMDL, ColumnReference, SessionStateRef}; #[derive(Debug)] pub(crate) enum WrenPlan { @@ -66,8 +67,9 @@ impl ModelPlanNode { required_fields: Vec, original_table_scan: Option, analyzed_wren_mdl: Arc, + session_state: SessionStateRef, ) -> Result { - ModelPlanNodeBuilder::new(analyzed_wren_mdl).build( + ModelPlanNodeBuilder::new(analyzed_wren_mdl, session_state).build( model, required_fields, original_table_scan, @@ -90,10 +92,14 @@ struct ModelPlanNodeBuilder { required_calculation: Vec, fields: VecDeque<(Option, Arc)>, analyzed_wren_mdl: Arc, + session_state: SessionStateRef, } impl ModelPlanNodeBuilder { - fn new(analyzed_wren_mdl: Arc) -> Self { + fn new( + analyzed_wren_mdl: Arc, + session_state: SessionStateRef, + ) -> Self { Self { required_exprs_buffer: BTreeSet::new(), directed_graph: Graph::new(), @@ -101,6 +107,7 @@ impl ModelPlanNodeBuilder { required_calculation: vec![], fields: VecDeque::new(), analyzed_wren_mdl, + session_state, } } @@ -139,6 +146,7 @@ impl ModelPlanNodeBuilder { let expr = create_wren_calculated_field_expr( column_rf, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?; let expr_plan = expr.alias(column.name()); expr_plan @@ -177,6 +185,7 @@ impl ModelPlanNodeBuilder { if self.is_contain_calculation_source(&qualified_column) { collect_partial_model_plan( Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), &qualified_column, &mut self.model_required_fields, )?; @@ -186,6 +195,7 @@ impl ModelPlanNodeBuilder { let _ = collect_model_required_fields( &qualified_column, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), &mut self.model_required_fields, ); } @@ -194,6 +204,7 @@ impl ModelPlanNodeBuilder { &column, Arc::clone(&model), Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?; self.model_required_fields .entry(model_ref.clone()) @@ -254,6 +265,7 @@ impl ModelPlanNodeBuilder { source, source_required_fields, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )? } else { let Some(first_calculation) = calculate_iter.next() else { @@ -271,6 +283,7 @@ impl ModelPlanNodeBuilder { self.directed_graph.clone(), &self.model_required_fields.clone(), Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?; for calculation_plan in calculate_iter { @@ -360,6 +373,7 @@ impl ModelPlanNodeBuilder { if self.is_contain_calculation_source(qualified_column) { collect_partial_model_plan( Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), qualified_column, &mut partial_model_required_fields, )?; @@ -368,6 +382,7 @@ impl ModelPlanNodeBuilder { collect_model_required_fields( qualified_column, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), &mut partial_model_required_fields, )?; @@ -384,6 +399,7 @@ impl ModelPlanNodeBuilder { source, source_required_fields, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?; let partial_chain = RelationChain::with_chain( @@ -393,6 +409,7 @@ impl ModelPlanNodeBuilder { column_graph.clone(), &partial_model_required_fields, Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?; let Some(column_rf) = self .analyzed_wren_mdl @@ -405,7 +422,7 @@ impl ModelPlanNodeBuilder { column_rf, col_expr, partial_chain, - Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?))) } } @@ -421,6 +438,7 @@ fn is_required_column(expr: &Expr, name: &str) -> bool { fn collect_partial_model_plan( analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, qualified_column: &Column, required_fields: &mut HashMap>, ) -> Result<()> { @@ -446,7 +464,7 @@ fn collect_partial_model_plan( let expr = create_wren_expr_for_model( &c.name, dataset.try_as_model().unwrap(), - Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )?; required_fields .entry(relation_ref.clone()) @@ -463,6 +481,7 @@ fn collect_partial_model_plan( fn collect_model_required_fields( qualified_column: &Column, analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, model_required_fields: &mut HashMap>, ) -> Result<()> { let Some(set) = analyzed_wren_mdl @@ -488,7 +507,7 @@ fn collect_model_required_fields( let Ok(expr) = create_wren_expr_for_model( expression, Arc::clone(&m), - Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), ) else { // skip the semantic expression (e.g. calculated field or relationship column) debug!( @@ -512,6 +531,7 @@ fn collect_model_required_fields( &column, Arc::clone(&m), Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )?; debug!("Required field: {}", &expr_plan); model_required_fields @@ -530,11 +550,22 @@ fn get_remote_column_exp( column: &mdl::manifest::Column, model: Arc, analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, ) -> Result { let expr = if let Some(expression) = &column.expression { - create_remote_expr_for_model(expression, model, analyzed_wren_mdl)? + create_remote_expr_for_model( + expression, + model, + analyzed_wren_mdl, + session_state_ref, + )? } else { - create_remote_expr_for_model(&column.name, model, analyzed_wren_mdl)? + create_remote_expr_for_model( + &column.name, + model, + analyzed_wren_mdl, + session_state_ref, + )? }; Ok(expr.alias(column.name.clone())) } @@ -656,6 +687,7 @@ impl ModelSourceNode { model: Arc, required_exprs: Vec, analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, original_table_scan: Option, ) -> Result { let mut required_exprs_buffer = BTreeSet::new(); @@ -689,6 +721,7 @@ impl ModelSourceNode { &column, Arc::clone(&model), Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )?)); } } else { @@ -711,6 +744,7 @@ impl ModelSourceNode { &column, Arc::clone(&model), Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )?; required_exprs_buffer.insert(OrdExpr::new(expr_plan.clone())); } @@ -793,7 +827,7 @@ impl CalculationPlanNode { calculation: ColumnReference, calculation_expr: Expr, relation_chain: RelationChain, - analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, ) -> Result { let Some(model) = calculation.dataset.try_as_model() else { return plan_err!("Only support model as source dataset"); @@ -822,7 +856,7 @@ impl CalculationPlanNode { let dimensions = vec![create_wren_expr_for_model( &pk_column.name, Arc::clone(&model), - Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )? .alias(pk_column.name())]; let schema_ref = DFSchemaRef::new( diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/relation_chain.rs b/wren-modeling-rs/core/src/logical_plan/analyze/relation_chain.rs index bda6b0f9d..e4bca3663 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/relation_chain.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/relation_chain.rs @@ -7,7 +7,8 @@ use crate::logical_plan::utils::create_schema; use crate::mdl; use crate::mdl::lineage::DatasetLink; use crate::mdl::manifest::JoinType; -use crate::mdl::{AnalyzedWrenMDL, Dataset}; +use crate::mdl::Dataset; +use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; use datafusion::catalog::TableReference; use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::logical_expr::{ @@ -33,6 +34,7 @@ impl RelationChain { dataset: &Dataset, required_fields: Vec, analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, ) -> datafusion::common::Result { match dataset { Dataset::Model(source_model) => { @@ -41,6 +43,7 @@ impl RelationChain { Arc::clone(source_model), required_fields, analyzed_wren_mdl, + session_state_ref, None, )?), }))) @@ -58,6 +61,7 @@ impl RelationChain { directed_graph: Graph, model_required_fields: &HashMap>, analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, ) -> datafusion::common::Result { let mut relation_chain = source; @@ -88,6 +92,7 @@ impl RelationChain { fields.iter().cloned().map(|c| c.expr).collect(), None, Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), )?; let df_schema = @@ -101,6 +106,7 @@ impl RelationChain { Arc::clone(target_model), fields.iter().cloned().map(|c| c.expr).collect(), Arc::clone(&analyzed_wren_mdl), + Arc::clone(&session_state_ref), None, )?), }) diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs b/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs index 5cdabfd03..9e9713d6f 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/rule.rs @@ -18,18 +18,25 @@ use crate::logical_plan::analyze::plan::{ }; use crate::logical_plan::utils::create_remote_table_source; use crate::mdl::manifest::Model; -use crate::mdl::{AnalyzedWrenMDL, WrenMDL}; +use crate::mdl::{AnalyzedWrenMDL, SessionStateRef, WrenMDL}; /// [ModelAnalyzeRule] responsible for analyzing the model plan node. Turn TableScan from a model to a ModelPlanNode. /// We collect the required fields from the projection, filter, aggregation, and join, /// and pass them to the ModelPlanNode. pub struct ModelAnalyzeRule { analyzed_wren_mdl: Arc, + session_state: SessionStateRef, } impl ModelAnalyzeRule { - pub fn new(analyzed_wren_mdl: Arc) -> Self { - Self { analyzed_wren_mdl } + pub fn new( + analyzed_wren_mdl: Arc, + session_state: SessionStateRef, + ) -> Self { + Self { + analyzed_wren_mdl, + session_state, + } } fn analyze_model_internal( @@ -98,6 +105,7 @@ impl ModelAnalyzeRule { field, Some(LogicalPlan::TableScan(table_scan.clone())), Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), )?), }); used_columns.borrow_mut().clear(); @@ -126,6 +134,7 @@ impl ModelAnalyzeRule { let left = match unwrap_arc(join.left) { LogicalPlan::TableScan(table_scan) => analyze_table_scan( Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), table_scan, buffer.iter().cloned().collect(), )?, @@ -135,6 +144,7 @@ impl ModelAnalyzeRule { let right = match unwrap_arc(join.right) { LogicalPlan::TableScan(table_scan) => analyze_table_scan( Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), table_scan, buffer.iter().cloned().collect(), )?, @@ -171,6 +181,7 @@ fn belong_to_mdl(mdl: &WrenMDL, table_reference: TableReference) -> bool { fn analyze_table_scan( analyzed_wren_mdl: Arc, + session_state_ref: SessionStateRef, table_scan: TableScan, required_field: Vec, ) -> Result { @@ -182,7 +193,8 @@ fn analyze_table_scan( model, required_field, Some(LogicalPlan::TableScan(table_scan.clone())), - Arc::clone(&analyzed_wren_mdl), + analyzed_wren_mdl, + session_state_ref, )?), })) } else { @@ -436,10 +448,8 @@ impl AnalyzerRule for RemoveWrenPrefixRule { #[cfg(test)] mod test { - use crate::logical_plan::analyze::rule::RemoveWrenPrefixRule; - use crate::logical_plan::context_provider::WrenContextProvider; - use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; - use crate::mdl::AnalyzedWrenMDL; + use std::sync::Arc; + use datafusion::common::DataFusionError; use datafusion::config::ConfigOptions; use datafusion::error::Result; @@ -449,9 +459,15 @@ mod test { use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::unparser::plan_to_sql; use log::info; - use std::sync::Arc; + + use crate::logical_plan::analyze::rule::RemoveWrenPrefixRule; + #[allow(deprecated)] + use crate::logical_plan::context_provider::WrenContextProvider; + use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; + use crate::mdl::AnalyzedWrenMDL; #[test] + #[allow(deprecated)] fn test_remove_prefix() -> Result<(), DataFusionError> { let manifest = ManifestBuilder::new() .model( diff --git a/wren-modeling-rs/core/src/logical_plan/context_provider.rs b/wren-modeling-rs/core/src/logical_plan/context_provider.rs index 8005eaea1..f5f0e7b8f 100644 --- a/wren-modeling-rs/core/src/logical_plan/context_provider.rs +++ b/wren-modeling-rs/core/src/logical_plan/context_provider.rs @@ -1,8 +1,7 @@ use std::{collections::HashMap, sync::Arc}; -use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::datatypes::DataType; use datafusion::datasource::DefaultTableSource; -use datafusion::logical_expr::builder::LogicalTableSource; use datafusion::{ common::{plan_err, Result}, config::ConfigOptions, @@ -10,11 +9,11 @@ use datafusion::{ sql::{planner::ContextProvider, TableReference}, }; -use crate::mdl::manifest::{Column, Model}; -use crate::mdl::{utils, WrenMDL}; +use crate::mdl::WrenMDL; -use super::utils::{create_table_source, map_data_type}; +use super::utils::create_table_source; +#[deprecated(since = "0.8.0", note = "try to create plan by SessionContext instead")] /// WrenContextProvider is a ContextProvider implementation that uses the WrenMDL /// to provide table sources and other metadata. pub struct WrenContextProvider { @@ -22,6 +21,7 @@ pub struct WrenContextProvider { tables: HashMap>, } +#[allow(deprecated)] impl WrenContextProvider { pub fn new(mdl: &WrenMDL) -> Result { let mut tables = HashMap::new(); @@ -57,6 +57,7 @@ impl WrenContextProvider { } } +#[allow(deprecated)] impl ContextProvider for WrenContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.to_string(); @@ -98,172 +99,3 @@ impl ContextProvider for WrenContextProvider { Vec::new() } } - -/// RemoteContextProvider is a ContextProvider implementation that is used to provide -/// the schema for the remote column in the column expression -pub struct RemoteContextProvider { - options: ConfigOptions, - tables: HashMap>, -} - -impl RemoteContextProvider { - pub fn new(mdl: &WrenMDL) -> Result { - let tables = mdl - .manifest - .models - .iter() - .map(|model| { - let remove_provider = mdl.get_table(&model.table_reference); - let datasource = if let Some(table_provider) = remove_provider { - Arc::new(DefaultTableSource::new(table_provider)) - } else { - create_remote_table_source(model, mdl)? - }; - Ok((model.table_reference.clone(), datasource)) - }) - .collect::>>()?; - Ok(Self { - tables, - options: Default::default(), - }) - } -} - -impl ContextProvider for RemoteContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - match self.tables.get(name.to_string().as_str()) { - Some(table) => Ok(table.clone()), - _ => plan_err!("Table not found: {}", name.table()), - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - -fn create_remote_table_source( - model: &Model, - wren_mdl: &WrenMDL, -) -> Result> { - if let Some(table_provider) = wren_mdl.get_table(&model.table_reference) { - Ok(Arc::new(DefaultTableSource::new(table_provider))) - } else { - let schema = create_schema(model.get_physical_columns()); - Ok(Arc::new(LogicalTableSource::new(schema?))) - } -} - -fn create_schema(columns: Vec>) -> Result { - let fields: Vec = columns - .iter() - .filter(|c| !c.is_calculated) - .flat_map(|column| { - if column.expression.is_none() { - let data_type = if let Ok(data_type) = map_data_type(&column.r#type) { - data_type - } else { - // TODO optimize to use Datafusion's error type - unimplemented!("Unsupported data type: {}", column.r#type) - }; - vec![Field::new(&column.name, data_type, column.no_null)] - } else if let Ok(idents) = - utils::collect_identifiers(column.expression.as_ref().unwrap()) - { - idents - .iter() - .map(|c| { - // we don't know the data type or nullable of the remote table, - // so we just mock a Int32 type and false here - Field::new(&c.name, DataType::Int8, true) - }) - .collect() - } else { - panic!( - "Failed to collect identifiers from expression: {}", - column.expression.as_ref().unwrap() - ); - } - }) - .collect(); - Ok(SchemaRef::new(Schema::new_with_metadata( - fields, - HashMap::new(), - ))) -} - -pub(crate) struct DynamicContextProvider { - delegate: Box, -} - -impl DynamicContextProvider { - pub fn new(delegate: Box) -> Self { - Self { delegate } - } -} - -impl ContextProvider for DynamicContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - self.delegate.get_table_source(name) - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn options(&self) -> &ConfigOptions { - self.delegate.options() - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} diff --git a/wren-modeling-rs/core/src/logical_plan/utils.rs b/wren-modeling-rs/core/src/logical_plan/utils.rs index ae9eb4045..d30b1eccc 100644 --- a/wren-modeling-rs/core/src/logical_plan/utils.rs +++ b/wren-modeling-rs/core/src/logical_plan/utils.rs @@ -9,9 +9,10 @@ use petgraph::dot::{Config, Dot}; use petgraph::Graph; use crate::mdl::lineage::DatasetLink; +use crate::mdl::Dataset; use crate::mdl::{ manifest::{Column, Model}, - Dataset, WrenMDL, + WrenMDL, }; pub fn map_data_type(data_type: &str) -> Result { diff --git a/wren-modeling-rs/core/src/mdl/context.rs b/wren-modeling-rs/core/src/mdl/context.rs index b34aeb5af..ecc62add3 100644 --- a/wren-modeling-rs/core/src/mdl/context.rs +++ b/wren-modeling-rs/core/src/mdl/context.rs @@ -26,7 +26,10 @@ pub async fn create_ctx_with_mdl( let new_state = ctx .state() .with_analyzer_rules(vec![ - Arc::new(ModelAnalyzeRule::new(Arc::clone(&analyzed_mdl))), + Arc::new(ModelAnalyzeRule::new( + Arc::clone(&analyzed_mdl), + ctx.state_ref(), + )), Arc::new(ModelGenerationRule::new(Arc::clone(&analyzed_mdl))), Arc::new(RemoveWrenPrefixRule::new(Arc::clone(&analyzed_mdl))), ]) diff --git a/wren-modeling-rs/core/src/mdl/dataset.rs b/wren-modeling-rs/core/src/mdl/dataset.rs new file mode 100644 index 000000000..715fca9a2 --- /dev/null +++ b/wren-modeling-rs/core/src/mdl/dataset.rs @@ -0,0 +1,168 @@ +use crate::logical_plan::utils::map_data_type; +use crate::mdl::manifest::{Column, Metric, Model}; +use crate::mdl::{RegisterTables, SessionStateRef}; +use datafusion::arrow::datatypes::DataType::Utf8; +use datafusion::arrow::datatypes::Field; +use datafusion::common::DFSchema; +use datafusion::common::Result; +use datafusion::logical_expr::sqlparser::ast::Expr::CompoundIdentifier; +use datafusion::sql::sqlparser::ast::Expr::Identifier; +use datafusion::sql::sqlparser::ast::{visit_expressions, Expr, Ident}; +use std::fmt::Display; +use std::ops::ControlFlow; +use std::sync::Arc; + +impl Model { + /// Physical columns are columns that can be selected from the model. + /// e.g. columns that are not a relationship column + pub fn get_physical_columns(&self) -> Vec> { + self.columns + .iter() + .filter(|c| c.relationship.is_none()) + .map(Arc::clone) + .collect() + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn get_column(&self, column_name: &str) -> Option> { + self.columns + .iter() + .find(|c| c.name == column_name) + .map(Arc::clone) + } + + pub fn primary_key(&self) -> Option<&str> { + self.primary_key.as_deref() + } +} + +impl Column { + pub fn name(&self) -> &str { + &self.name + } + + pub fn expression(&self) -> Option<&str> { + self.expression.as_deref() + } + + pub fn to_field(&self) -> Result { + let data_type = map_data_type(&self.r#type)?; + Ok(Field::new(&self.name, data_type, self.no_null)) + } + + pub fn to_remote_field(&self, session_state: SessionStateRef) -> Result> { + if self.expression().is_some() { + let session_state = session_state.read(); + let expr = session_state.sql_to_expr( + self.expression().unwrap(), + session_state.config_options().sql_parser.dialect.as_str(), + )?; + let columns = Self::collect_columns(expr); + Ok(columns + .into_iter() + .map(|c| Field::new(c.value, Utf8, false)) + .collect()) + } else { + Ok(vec![self.to_field()?]) + } + } + + fn collect_columns(expr: Expr) -> Vec { + let mut visited = vec![]; + visit_expressions(&expr, |e| { + if let CompoundIdentifier(ids) = e { + ids.iter().cloned().for_each(|id| visited.push(id)); + } else if let Identifier(id) = e { + visited.push(id.clone()); + } + ControlFlow::<()>::Continue(()) + }); + visited + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Dataset { + Model(Arc), + Metric(Arc), +} + +impl Dataset { + pub fn name(&self) -> &str { + match self { + Dataset::Model(model) => model.name(), + Dataset::Metric(metric) => metric.name(), + } + } + + pub fn try_as_model(&self) -> Option> { + match self { + Dataset::Model(model) => Some(Arc::clone(model)), + _ => None, + } + } + + pub fn to_qualified_schema(&self) -> Result { + match self { + Dataset::Model(model) => { + let fields = model + .get_physical_columns() + .iter() + .map(|c| c.to_field()) + .collect::>>()?; + let arrow_schema = datafusion::arrow::datatypes::Schema::new(fields); + DFSchema::try_from_qualified_schema(&model.name, &arrow_schema) + } + Dataset::Metric(_) => todo!(), + } + } + + /// Create the schema with the remote table name + pub fn to_remote_schema( + &self, + register_tables: Option<&RegisterTables>, + session_state: SessionStateRef, + ) -> Result { + match self { + Dataset::Model(model) => { + let schema = register_tables + .map(|rt| rt.get(&model.table_reference)) + .filter(|rt| rt.is_some()) + .map(|rt| rt.unwrap().schema()); + + if let Some(schema) = schema { + DFSchema::try_from_qualified_schema(&model.table_reference, &schema) + } else { + let fields: Vec = model + .get_physical_columns() + .iter() + .filter(|c| !c.is_calculated) + .map(|c| c.to_remote_field(Arc::clone(&session_state))) + .collect::>>>()? + .iter() + .flat_map(|c| c.clone()) + .collect(); + let arrow_schema = datafusion::arrow::datatypes::Schema::new(fields); + + DFSchema::try_from_qualified_schema( + &model.table_reference, + &arrow_schema, + ) + } + } + Dataset::Metric(_) => todo!(), + } + } +} + +impl Display for Dataset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Dataset::Model(model) => write!(f, "{}", model.name()), + Dataset::Metric(metric) => write!(f, "{}", metric.name()), + } + } +} diff --git a/wren-modeling-rs/core/src/mdl/lineage.rs b/wren-modeling-rs/core/src/mdl/lineage.rs index 9a3abbaae..5522f1972 100644 --- a/wren-modeling-rs/core/src/mdl/lineage.rs +++ b/wren-modeling-rs/core/src/mdl/lineage.rs @@ -12,7 +12,7 @@ use crate::mdl::{utils, WrenMDL}; use super::manifest::{JoinType, Relationship}; use super::utils::{collect_identifiers, to_expr_queue}; -use super::Dataset; +use crate::mdl::Dataset; pub struct Lineage { pub source_columns_map: HashMap>, @@ -327,7 +327,8 @@ mod test { }; use crate::mdl::lineage::Lineage; use crate::mdl::manifest::JoinType; - use crate::mdl::{Dataset, WrenMDL}; + use crate::mdl::Dataset; + use crate::mdl::WrenMDL; #[test] fn test_collect_source_columns() -> Result<()> { diff --git a/wren-modeling-rs/core/src/mdl/manifest.rs b/wren-modeling-rs/core/src/mdl/manifest.rs index 15d2ee93a..b41264a75 100644 --- a/wren-modeling-rs/core/src/mdl/manifest.rs +++ b/wren-modeling-rs/core/src/mdl/manifest.rs @@ -39,33 +39,6 @@ pub struct Model { pub properties: Vec<(String, String)>, } -impl Model { - /// Physical columns are columns that can be selected from the model. - /// e.g. columns that are not a relationship column - pub fn get_physical_columns(&self) -> Vec> { - self.columns - .iter() - .filter(|c| c.relationship.is_none()) - .map(Arc::clone) - .collect() - } - - pub fn name(&self) -> &str { - &self.name - } - - pub fn get_column(&self, column_name: &str) -> Option> { - self.columns - .iter() - .find(|c| c.name == column_name) - .map(Arc::clone) - } - - pub fn primary_key(&self) -> Option<&str> { - self.primary_key.as_deref() - } -} - mod table_reference { use serde::{self, Deserialize, Deserializer, Serialize, Serializer}; @@ -150,12 +123,6 @@ pub struct Column { pub properties: Vec<(String, String)>, } -impl Column { - pub fn name(&self) -> &str { - &self.name - } -} - #[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq)] #[serde(rename_all = "camelCase")] pub struct Relationship { diff --git a/wren-modeling-rs/core/src/mdl/mod.rs b/wren-modeling-rs/core/src/mdl/mod.rs index 2afa8b901..d9e912ec8 100644 --- a/wren-modeling-rs/core/src/mdl/mod.rs +++ b/wren-modeling-rs/core/src/mdl/mod.rs @@ -1,22 +1,26 @@ -use std::fmt::Display; -use std::{collections::HashMap, sync::Arc}; - +use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; use datafusion::{error::Result, sql::unparser::plan_to_sql}; use log::{debug, info}; - use manifest::Relationship; +use parking_lot::RwLock; +use std::{collections::HashMap, sync::Arc}; use crate::logical_plan::utils::from_qualified_name_str; use crate::mdl::context::create_ctx_with_mdl; -use crate::mdl::manifest::{Column, Manifest, Metric, Model}; +use crate::mdl::manifest::{Column, Manifest, Model}; pub mod builder; pub mod context; +pub(crate) mod dataset; pub mod lineage; pub mod manifest; pub mod utils; +pub use dataset::Dataset; + +pub type SessionStateRef = Arc>; + pub struct AnalyzedWrenMDL { pub wren_mdl: Arc, pub lineage: Arc, @@ -53,11 +57,12 @@ impl AnalyzedWrenMDL { } } +pub type RegisterTables = HashMap>; // This is the main struct that holds the manifest and provides methods to access the models pub struct WrenMDL { pub manifest: Manifest, pub qualified_references: HashMap, - pub register_tables: HashMap>, + pub register_tables: RegisterTables, } impl WrenMDL { @@ -136,9 +141,7 @@ impl WrenMDL { self.register_tables.get(name).cloned() } - pub fn get_register_tables( - &self, - ) -> &HashMap> { + pub fn get_register_tables(&self) -> &RegisterTables { &self.register_tables } @@ -227,37 +230,6 @@ impl ColumnReference { } } -#[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub enum Dataset { - Model(Arc), - Metric(Arc), -} - -impl Dataset { - pub fn name(&self) -> &str { - match self { - Dataset::Model(model) => model.name(), - Dataset::Metric(metric) => metric.name(), - } - } - - pub fn try_as_model(&self) -> Option> { - match self { - Dataset::Model(model) => Some(Arc::clone(model)), - _ => None, - } - } -} - -impl Display for Dataset { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Dataset::Model(model) => write!(f, "{}", model.name()), - Dataset::Metric(metric) => write!(f, "{}", metric.name()), - } - } -} - #[cfg(test)] mod test { use std::fs; @@ -310,9 +282,9 @@ mod test { "select orderkey + orderkey from test.test.orders", "select orderkey from test.test.orders where orders.totalprice > 10", "select orders.orderkey from test.test.orders left join test.test.customer on (orders.custkey = customer.custkey) where orders.totalprice > 10", - "select orderkey, min(totalprice) from test.test.orders group by 1", + "select orderkey, sum(totalprice) from test.test.orders group by 1", "select orderkey, count(*) from test.test.orders where orders.totalprice > 10 group by 1", - "select min_totalcost from test.test.profile", + "select totalcost from test.test.profile", // TODO: support calculated without relationship // "select orderkey_plus_custkey from orders", ]; diff --git a/wren-modeling-rs/core/src/mdl/utils.rs b/wren-modeling-rs/core/src/mdl/utils.rs index e08e10366..acd7bc9c5 100644 --- a/wren-modeling-rs/core/src/mdl/utils.rs +++ b/wren-modeling-rs/core/src/mdl/utils.rs @@ -2,24 +2,20 @@ use std::collections::{BTreeSet, VecDeque}; use std::ops::ControlFlow; use std::sync::Arc; -use datafusion::common::{internal_err, plan_err, Column}; +use datafusion::common::{plan_err, Column, DFSchema}; use datafusion::error::Result; -use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc; -use datafusion::logical_expr::{Expr, LogicalPlan}; -use datafusion::sql::planner::SqlToRel; +use datafusion::execution::session_state::SessionState; +use datafusion::logical_expr::Expr; use datafusion::sql::sqlparser::ast::Expr::{CompoundIdentifier, Identifier}; -use datafusion::sql::sqlparser::ast::{visit_expressions, visit_expressions_mut}; +use datafusion::sql::sqlparser::ast::{visit_expressions, visit_expressions_mut, Ident}; use datafusion::sql::sqlparser::dialect::GenericDialect; use datafusion::sql::sqlparser::parser::Parser; -use log::debug; use petgraph::algo::is_cyclic_directed; use petgraph::{EdgeType, Graph}; -use crate::logical_plan::context_provider::DynamicContextProvider; -use crate::logical_plan::context_provider::{RemoteContextProvider, WrenContextProvider}; use crate::logical_plan::utils::from_qualified_name; use crate::mdl::manifest::Model; -use crate::mdl::{AnalyzedWrenMDL, ColumnReference}; +use crate::mdl::{AnalyzedWrenMDL, ColumnReference, Dataset, SessionStateRef}; pub fn to_expr_queue(column: Column) -> VecDeque { column.name.split('.').map(String::from).collect() @@ -63,6 +59,7 @@ pub fn collect_identifiers(expr: &str) -> Result> { pub fn create_wren_calculated_field_expr( column_rf: ColumnReference, analyzed_wren_mdl: Arc, + session_state: SessionStateRef, ) -> Result { if !column_rf.column.is_calculated { return plan_err!("Column is not calculated: {}", column_rf.column.name); @@ -89,105 +86,97 @@ pub fn create_wren_calculated_field_expr( .collect::>() // Collect into a BTreeSet to remove duplicates .into_iter() // Convert BTreeSet back into an iterator .map(|m| m.to_string()) - .collect::>() - .join(", "); - + .collect::>(); // Remove all relationship fields from the expression. Only keep the target expression and its source table. let expr = column_rf.column.expression.clone().unwrap(); - let wrapped = format!("select {} from {}", expr, models); - let parsed = Parser::parse_sql(&GenericDialect {}, &wrapped).unwrap(); - let mut statement = parsed[0].clone(); - visit_expressions_mut(&mut statement, |expr| { - if let CompoundIdentifier(ids) = expr { + let session_state = session_state.read(); + let mut expr = session_state.sql_to_expr( + &expr, + session_state.config_options().sql_parser.dialect.as_str(), + )?; + visit_expressions_mut(&mut expr, |e| { + if let CompoundIdentifier(ids) = e { let name_size = ids.len(); if name_size > 2 { let slice = &ids[name_size - 2..name_size]; - *expr = CompoundIdentifier(slice.to_vec()); + *e = CompoundIdentifier(slice.to_vec()); } } ControlFlow::<()>::Continue(()) }); - debug!("Statement: {:?}", statement.to_string()); - // Create the expression only has the table prefix. We don't need the catalog and schema prefix when planning. - let context_provider = WrenContextProvider::new_bare(&analyzed_wren_mdl.wren_mdl)?; - let sql_to_rel = SqlToRel::new(&context_provider); - let plan = match sql_to_rel.sql_statement_to_plan(statement.clone()) { - Ok(plan) => plan, - Err(e) => return plan_err!("Error creating plan: {}", e), - }; - let result = match plan { - LogicalPlan::Projection(projection) => { - if let LogicalPlan::Aggregate(aggregation) = unwrap_arc(projection.input) { - aggregation.aggr_expr[0].clone() - } else { - projection.expr[0].clone() - } - } - _ => return internal_err!("Unexpected plan type: {:?}", plan), + let Some(schema) = models + .into_iter() + .map(|m| analyzed_wren_mdl.wren_mdl().get_model(&m)) + .filter(|m| m.is_some()) + .map(|m| Dataset::Model(m.unwrap())) + .map(|m| m.to_qualified_schema()) + .reduce(|acc, schema| acc?.join(&schema?)) + .transpose()? + else { + return plan_err!("Error for creating schemas: {}", qualified_col); }; - Ok(result) + session_state.create_logical_expr(&expr.to_string(), &schema) } /// Create the Logical Expr for the remote column. -/// Use [RemoteContextProvider] to provide the context for the remote column. pub(crate) fn create_remote_expr_for_model( - expr: &String, + expr: &str, model: Arc, analyzed_wren_mdl: Arc, + session_state: SessionStateRef, ) -> Result { - let context_provider = RemoteContextProvider::new(&analyzed_wren_mdl.wren_mdl)?; - create_expr_for_model( - expr, - model, - DynamicContextProvider::new(Box::new(context_provider)), + let dataset = Dataset::Model(model); + let schema = dataset.to_remote_schema( + Some(analyzed_wren_mdl.wren_mdl().get_register_tables()), + Arc::clone(&session_state), + )?; + let session_state = session_state.read(); + session_state.create_logical_expr( + qualified_expr(expr, &schema, &session_state)?.as_str(), + &schema, ) } /// Create the Logical Expr for the remote column. -/// Use [RemoteContextProvider] to provide the context for the remote column. pub(crate) fn create_wren_expr_for_model( - expr: &String, + expr: &str, model: Arc, - analyzed_wren_mdl: Arc, + session_state: SessionStateRef, ) -> Result { - let context_provider = WrenContextProvider::new(&analyzed_wren_mdl.wren_mdl)?; - let wrapped = format!( - "select {} from {}.{}.{}", - expr, - analyzed_wren_mdl.wren_mdl().catalog(), - analyzed_wren_mdl.wren_mdl().schema(), - &model.name - ); - let parsed = Parser::parse_sql(&GenericDialect {}, &wrapped).unwrap(); - let statement = &parsed[0]; - debug!("Statement: {:?}", statement.to_string()); - - let sql_to_rel = SqlToRel::new(&context_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; - - match plan { - LogicalPlan::Projection(projection) => Ok(projection.expr[0].clone()), - _ => internal_err!("Unexpected plan type: {:?}", plan), - } + let dataset = Dataset::Model(model); + let schema = dataset.to_qualified_schema()?; + let session_state = session_state.read(); + session_state.create_logical_expr( + qualified_expr(expr, &schema, &session_state)?.as_str(), + &schema, + ) } -/// Create the Logical Expr for the column belong to the model according to the context provider -pub(crate) fn create_expr_for_model( - expr: &String, - model: Arc, - context_provider: DynamicContextProvider, -) -> Result { - let wrapped = format!("select {} from {}", expr, &model.table_reference); - let parsed = Parser::parse_sql(&GenericDialect {}, &wrapped)?; - let statement = &parsed[0]; - - let sql_to_rel = SqlToRel::new(&context_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; - match plan { - LogicalPlan::Projection(projection) => Ok(projection.expr[0].clone()), - _ => internal_err!("Unexpected plan type: {:?}", plan), - } +/// Add the table prefix for the column expression if it can be resolved by the schema. +fn qualified_expr( + expr: &str, + schema: &DFSchema, + session_state: &SessionState, +) -> Result { + let mut expr = session_state.sql_to_expr( + expr, + session_state.config_options().sql_parser.dialect.as_str(), + )?; + visit_expressions_mut(&mut expr, |e| { + if let Identifier(id) = e { + if let Ok((Some(qualifier), _)) = + schema.qualified_field_with_unqualified_name(&id.value) + { + let mut parts: Vec<_> = + qualifier.to_vec().into_iter().map(Ident::new).collect(); + parts.push(id.clone()); + *e = CompoundIdentifier(parts); + } + } + ControlFlow::<()>::Continue(()) + }); + Ok(expr.to_string()) } #[cfg(test)] @@ -196,10 +185,12 @@ mod tests { use std::path::PathBuf; use std::sync::Arc; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + use crate::logical_plan::utils::from_qualified_name; use crate::mdl::manifest::Manifest; use crate::mdl::AnalyzedWrenMDL; - use datafusion::error::Result; #[test] fn test_create_wren_expr() -> Result<()> { @@ -210,7 +201,7 @@ mod tests { let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); - + let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl .qualified_references @@ -223,6 +214,7 @@ mod tests { let expr = super::create_wren_calculated_field_expr( column_rf.clone(), analyzed_mdl.clone(), + ctx.state_ref(), )?; assert_eq!(expr.to_string(), "customer.name"); Ok(()) @@ -237,7 +229,7 @@ mod tests { let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); - + let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl .qualified_references @@ -249,9 +241,10 @@ mod tests { .unwrap(); let expr = super::create_wren_calculated_field_expr( column_rf.clone(), - analyzed_mdl.clone(), + analyzed_mdl, + ctx.state_ref(), )?; - assert_eq!(expr.to_string(), "orders.orderkey + orders.custkey"); + assert_eq!(expr.to_string(), "orderkey + custkey"); Ok(()) } @@ -270,4 +263,45 @@ mod tests { assert!(result.contains(&super::Column::new_unqualified("order_id"))); Ok(()) } + + #[test] + fn test_create_wren_expr_for_model() -> Result<()> { + let test_data: PathBuf = + [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] + .iter() + .collect(); + let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); + let mdl = serde_json::from_str::(&mdl_json).unwrap(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let ctx = SessionContext::new(); + let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); + let expr = super::create_wren_expr_for_model( + "name", + Arc::clone(&model), + ctx.state_ref(), + )?; + assert_eq!(expr.to_string(), "customer.name"); + Ok(()) + } + + #[test] + fn test_create_remote_expr_for_model() -> Result<()> { + let test_data: PathBuf = + [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] + .iter() + .collect(); + let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); + let mdl = serde_json::from_str::(&mdl_json).unwrap(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let ctx = SessionContext::new(); + let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); + let expr = super::create_remote_expr_for_model( + "c_name", + Arc::clone(&model), + analyzed_mdl, + ctx.state_ref(), + )?; + assert_eq!(expr.to_string(), "customer.c_name"); + Ok(()) + } } diff --git a/wren-modeling-rs/core/tests/data/mdl.json b/wren-modeling-rs/core/tests/data/mdl.json index 22add5091..40b3ac664 100644 --- a/wren-modeling-rs/core/tests/data/mdl.json +++ b/wren-modeling-rs/core/tests/data/mdl.json @@ -61,10 +61,10 @@ "relationship": "CustomerProfile" }, { - "name": "min_totalcost", + "name": "totalcost", "type": "integer", "isCalculated": true, - "expression": "min(customer.orders.totalprice)" + "expression": "sum(customer.orders.totalprice)" } ], "primaryKey": "custkey"