Skip to content

Fix RemoveWrenPrefixRule for being executed by DataFusion #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions wren-modeling-rs/core/src/logical_plan/analyze/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc;
use datafusion::logical_expr::LogicalPlan::Projection;
use datafusion::logical_expr::{
col, ident, utils, Extension, UserDefinedLogicalNodeCore,
};
Expand Down Expand Up @@ -396,7 +397,7 @@ impl RemoveWrenPrefixRule {
impl AnalyzerRule for RemoveWrenPrefixRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down(&|plan: LogicalPlan| -> Result<Transformed<LogicalPlan>> {
plan.map_expressions(&|expr: Expr| {
let transformed = plan.clone().map_expressions(&|expr: Expr| {
expr.transform_down(&|expr: Expr| -> Result<Transformed<Expr>> {
if let Expr::Column(ref column) = expr {
if let Some(relation) = &column.relation {
Expand Down Expand Up @@ -436,7 +437,23 @@ impl AnalyzerRule for RemoveWrenPrefixRule {
}
Ok(Transformed::no(expr.clone()))
})
})
})?;

if transformed.transformed {
// The schema of logical plan is static. Because we changed the expression, we should
// also recreate the plan.
if let Projection(_) = transformed.data {
let new_projection = datafusion::logical_expr::Projection::try_new(
transformed.data.expressions(),
Arc::new(plan.inputs()[0].clone()),
)?;
Ok(Transformed::yes(Projection(new_projection)))
} else {
Ok(Transformed::yes(transformed.data))
}
} else {
Ok(transformed)
}
})
.data()
}
Expand Down
3 changes: 3 additions & 0 deletions wren-modeling-rs/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use crate::logical_plan::utils::create_schema;
use crate::mdl::manifest::Model;
use crate::mdl::{AnalyzedWrenMDL, WrenMDL};

/// Apply Wren Rules to the context for sql generation.
/// TODO: There're some issue for unparsing the datafusion optimized plans.
/// Disable all the optimize rule for sql generation temporarily.
Comment on lines +22 to +24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an issue about unparsing the optimized plans to the SQL string. I'll file the issue in Datafusion.

pub async fn create_ctx_with_mdl(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
Expand Down
24 changes: 20 additions & 4 deletions wren-modeling-rs/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::create_ctx_with_mdl;
use crate::mdl::manifest::{Column, Manifest, Model};
use datafusion::execution::context::SessionState;
use datafusion::prelude::SessionContext;
use datafusion::{error::Result, sql::unparser::plan_to_sql};
Expand All @@ -6,17 +9,16 @@ use manifest::Relationship;
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc};

use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::create_ctx_with_mdl;
use crate::mdl::manifest::{Column, Manifest, Model};

pub mod builder;
pub mod context;
pub(crate) mod dataset;
pub mod lineage;
pub mod manifest;
pub mod utils;

use crate::logical_plan::analyze::rule::{
ModelAnalyzeRule, ModelGenerationRule, RemoveWrenPrefixRule,
};
pub use dataset::Dataset;

pub type SessionStateRef = Arc<RwLock<SessionState>>;
Expand Down Expand Up @@ -210,6 +212,20 @@ pub async fn transform_sql_with_ctx(
}
}

/// Apply Wren Rules to a given session context with a WrenMDL
pub fn apply_wren_rules(ctx: &SessionContext, analyzed_wren_mdl: Arc<AnalyzedWrenMDL>) {
ctx.add_analyzer_rule(Arc::new(ModelAnalyzeRule::new(
Arc::clone(&analyzed_wren_mdl),
ctx.state_ref(),
)));
ctx.add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
ctx.add_analyzer_rule(Arc::new(RemoveWrenPrefixRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
}

/// Analyze the decision point. It's same as the /v1/analysis/sql API in wren engine
pub fn decision_point_analyze(_wren_mdl: Arc<WrenMDL>, _sql: &str) {}

Expand Down
7 changes: 0 additions & 7 deletions wren-modeling-rs/sqllogictest/src/engine/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::prelude::SessionContext;
use log::info;
use sqllogictest::DBOutput;
use wren_core::mdl::transform_sql_with_ctx;

use super::{
error::Result,
Expand Down Expand Up @@ -55,12 +54,6 @@ impl sqllogictest::AsyncDB for DataFusion {
self.relative_path.display(),
sql
);
let sql = transform_sql_with_ctx(
self.ctx.session_ctx(),
Arc::clone(self.ctx.analyzed_wren_mdl()),
sql,
)
.await?;
run_query(self.ctx.session_ctx(), sql).await
}

Expand Down
16 changes: 2 additions & 14 deletions wren-modeling-rs/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ use datafusion::prelude::SessionConfig;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use log::info;
use tempfile::TempDir;

use wren_core::mdl::builder::{
ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder,
};
use wren_core::mdl::manifest::JoinType;
use wren_core::mdl::AnalyzedWrenMDL;
use wren_core::mdl::{apply_wren_rules, AnalyzedWrenMDL};

use crate::engine::utils::read_dir_recursive;

Expand Down Expand Up @@ -297,17 +296,6 @@ async fn register_ecommerce_mdl(
manifest,
register_tables,
)?);
// let new_state = ctx
// .state()
// .add_analyzer_rule(Arc::new(ModelAnalyzeRule::
//
// new(Arc::clone(&analyzed_mdl))))
// .add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
// &analyzed_mdl,
// ))))
// // TODO: disable optimize_projections rule
// // There are some conflict with the optimize rule, [datafusion::optimizer::optimize_projections::OptimizeProjections]
// .with_optimizer_rules(vec![]);
// let ctx = SessionContext::new_with_state(new_state);
apply_wren_rules(ctx, Arc::clone(&analyzed_mdl));
Ok((ctx.to_owned(), analyzed_mdl))
}
22 changes: 6 additions & 16 deletions wren-modeling-rs/wren-example/examples/to-many-calculation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use datafusion::prelude::{CsvReadOptions, SessionContext};
use wren_core::mdl::builder::{
ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder,
};
use wren_core::mdl::context::create_ctx_with_mdl;
use wren_core::mdl::manifest::{JoinType, Manifest};
use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL};
use wren_core::mdl::AnalyzedWrenMDL;

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -75,22 +76,11 @@ async fn main() -> Result<()> {
]);
let analyzed_mdl =
Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?);

let transformed = match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
"select totalprice from wrenai.public.customers",
)
.await
let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl).await?;
let df = match ctx
.sql("select totalprice from wrenai.public.customers")
.await
{
Ok(sql) => sql,
Err(e) => {
eprintln!("Error transforming SQL: {}", e);
return Ok(());
}
};
println!("Transformed SQL: {}", transformed);
let df = match ctx.sql(&transformed).await {
Ok(df) => df,
Err(e) => {
eprintln!("Error executing SQL: {}", e);
Expand Down
Loading