Skip to content

Fix RemoveWrenPrefixRule for being executed by DataFusion #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions wren-modeling-rs/core/src/logical_plan/analyze/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc;
use datafusion::logical_expr::LogicalPlan::Projection;
use datafusion::logical_expr::{
col, ident, utils, Extension, UserDefinedLogicalNodeCore,
};
Expand Down Expand Up @@ -396,7 +397,7 @@ impl RemoveWrenPrefixRule {
impl AnalyzerRule for RemoveWrenPrefixRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down(&|plan: LogicalPlan| -> Result<Transformed<LogicalPlan>> {
plan.map_expressions(&|expr: Expr| {
let transformed = plan.clone().map_expressions(&|expr: Expr| {
expr.transform_down(&|expr: Expr| -> Result<Transformed<Expr>> {
if let Expr::Column(ref column) = expr {
if let Some(relation) = &column.relation {
Expand Down Expand Up @@ -436,7 +437,23 @@ impl AnalyzerRule for RemoveWrenPrefixRule {
}
Ok(Transformed::no(expr.clone()))
})
})
})?;

if transformed.transformed {
// The schema of logical plan is static. Because we changed the expression, we should
// also recreate the plan.
if let Projection(_) = transformed.data {
let new_projection = datafusion::logical_expr::Projection::try_new(
transformed.data.expressions(),
Arc::new(plan.inputs()[0].clone()),
)?;
Ok(Transformed::yes(Projection(new_projection)))
} else {
Ok(Transformed::yes(transformed.data))
}
} else {
Ok(transformed)
}
})
.data()
}
Expand Down
3 changes: 3 additions & 0 deletions wren-modeling-rs/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use crate::logical_plan::utils::create_schema;
use crate::mdl::manifest::Model;
use crate::mdl::{AnalyzedWrenMDL, WrenMDL};

/// Apply Wren Rules to the context for sql generation.
/// TODO: There're some issue for unparsing the datafusion optimized plans.
/// Disable all the optimize rule for sql generation temporarily.
Comment on lines +22 to +24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an issue about unparsing the optimized plans to the SQL string. I'll file the issue in Datafusion.

pub async fn create_ctx_with_mdl(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
Expand Down
28 changes: 24 additions & 4 deletions wren-modeling-rs/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl};
use crate::mdl::manifest::{Column, Manifest, Model};
use datafusion::execution::context::SessionState;
use datafusion::prelude::SessionContext;
use datafusion::{error::Result, sql::unparser::plan_to_sql};
Expand All @@ -6,17 +9,16 @@ use manifest::Relationship;
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc};

use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::create_ctx_with_mdl;
use crate::mdl::manifest::{Column, Manifest, Model};

pub mod builder;
pub mod context;
pub(crate) mod dataset;
pub mod lineage;
pub mod manifest;
pub mod utils;

use crate::logical_plan::analyze::rule::{
ModelAnalyzeRule, ModelGenerationRule, RemoveWrenPrefixRule,
};
pub use dataset::Dataset;

pub type SessionStateRef = Arc<RwLock<SessionState>>;
Expand Down Expand Up @@ -210,6 +212,24 @@ pub async fn transform_sql_with_ctx(
}
}

/// Apply Wren Rules to a given session context with a WrenMDL
pub async fn apply_wren_rules(
ctx: &SessionContext,
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
) -> Result<()> {
ctx.add_analyzer_rule(Arc::new(ModelAnalyzeRule::new(
Arc::clone(&analyzed_wren_mdl),
ctx.state_ref(),
)));
ctx.add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
ctx.add_analyzer_rule(Arc::new(RemoveWrenPrefixRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
register_table_with_mdl(ctx, analyzed_wren_mdl.wren_mdl()).await
}

/// Analyze the decision point. It's same as the /v1/analysis/sql API in wren engine
pub fn decision_point_analyze(_wren_mdl: Arc<WrenMDL>, _sql: &str) {}

Expand Down
7 changes: 0 additions & 7 deletions wren-modeling-rs/sqllogictest/src/engine/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::prelude::SessionContext;
use log::info;
use sqllogictest::DBOutput;
use wren_core::mdl::transform_sql_with_ctx;

use super::{
error::Result,
Expand Down Expand Up @@ -55,12 +54,6 @@ impl sqllogictest::AsyncDB for DataFusion {
self.relative_path.display(),
sql
);
let sql = transform_sql_with_ctx(
self.ctx.session_ctx(),
Arc::clone(self.ctx.analyzed_wren_mdl()),
sql,
)
.await?;
run_query(self.ctx.session_ctx(), sql).await
}

Expand Down
63 changes: 26 additions & 37 deletions wren-modeling-rs/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ use datafusion::prelude::SessionConfig;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use log::info;
use tempfile::TempDir;

use wren_core::mdl::builder::{
ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder,
};
use wren_core::mdl::manifest::JoinType;
use wren_core::mdl::AnalyzedWrenMDL;
use wren_core::mdl::{apply_wren_rules, AnalyzedWrenMDL};

use crate::engine::utils::read_dir_recursive;

Expand Down Expand Up @@ -69,9 +68,8 @@ impl TestContext {
match file_name {
"view.slt" | "model.slt" => {
info!("Registering local temporary table");
Some(register_ecommerce_table(&ctx).await.ok()?)
Some(register_ecommerce_table(&ctx).await.unwrap())
}

_ => {
info!("Using default SessionContext");
None
Expand Down Expand Up @@ -161,20 +159,15 @@ async fn register_ecommerce_mdl(
)
.column(
ColumnBuilder::new_calculated("customer_state_cf", "varchar")
.expression("orders.customers_state")
.build(),
)
.column(
ColumnBuilder::new_calculated("customer_state_cf_concat", "varchar")
.expression("orders.customers_state || '-test'")
.build(),
)
.column(
ColumnBuilder::new("totalprice", "double")
.expression("sum(orders.totalprice)")
.calculated(true)
.expression("orders.customer_state")
.build(),
)
// TODO: duplicate `orders.customer_state`
// .column(
// ColumnBuilder::new_calculated("customer_state_cf_concat", "varchar")
// .expression("orders.customer_state || '-test'")
// .build(),
// )
// TODO: allow multiple calculation in an expression
// .column(
// ColumnBuilder::new("customer_location", "varchar")
Expand Down Expand Up @@ -213,11 +206,12 @@ async fn register_ecommerce_mdl(
.expression("customers.state")
.build(),
)
.column(
ColumnBuilder::new_calculated("customer_state_order_id", "varchar")
.expression("customers.state || ' ' || order_id")
.build(),
)
// TODO: fix calcaultion with non-relationship column
// .column(
// ColumnBuilder::new_calculated("customer_state_order_id", "varchar")
// .expression("customers.state || ' ' || order_id")
// .build(),
// )
.column(
ColumnBuilder::new_relationship(
"order_items",
Expand Down Expand Up @@ -255,12 +249,18 @@ async fn register_ecommerce_mdl(
.condition("orders.order_id = order_items.order_id")
.build(),
)
.view(ViewBuilder::new("orders_view")
.statement("select * from wrenai.public.orders")
.build())
.view(
ViewBuilder::new("customer_view")
.statement("select * from wrenai.public.customers")
.build(),
)
// TODO: support expression without alias inside view
// .view(ViewBuilder::new("revenue_orders").statement("select order_id, sum(price) from wrenai.public.order_items group by order_id").build())
.view(ViewBuilder::new("revenue_orders").statement("select order_id, sum(price) as totalprice from wrenai.public.order_items group by order_id").build())
// TODO: fix view with calculation
// .view(
// ViewBuilder::new("revenue_orders")
// .statement("select order_id, sum(price) as totalprice from wrenai.public.order_items group by order_id")
// .build())
.build();
let mut register_tables = HashMap::new();
register_tables.insert(
Expand Down Expand Up @@ -297,17 +297,6 @@ async fn register_ecommerce_mdl(
manifest,
register_tables,
)?);
// let new_state = ctx
// .state()
// .add_analyzer_rule(Arc::new(ModelAnalyzeRule::
//
// new(Arc::clone(&analyzed_mdl))))
// .add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
// &analyzed_mdl,
// ))))
// // TODO: disable optimize_projections rule
// // There are some conflict with the optimize rule, [datafusion::optimizer::optimize_projections::OptimizeProjections]
// .with_optimizer_rules(vec![]);
// let ctx = SessionContext::new_with_state(new_state);
apply_wren_rules(ctx, Arc::clone(&analyzed_mdl)).await?;
Ok((ctx.to_owned(), analyzed_mdl))
}
10 changes: 5 additions & 5 deletions wren-modeling-rs/sqllogictest/test_sql_files/view.slt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
statement ok
SELECT * FROM wrenai.public.orders_view
SELECT * FROM wrenai.public.customer_view

query TR
SELECT * FROM wrenai.public.revenue_orders where order_id = '76754c0e642c8f99a8c3fcb8a14ac700'
----
76754c0e642c8f99a8c3fcb8a14ac700 287.4
#query TR
#SELECT totalprice FROM wrenai.public.revenue_orders where order_id = '76754c0e642c8f99a8c3fcb8a14ac700'
#----
#76754c0e642c8f99a8c3fcb8a14ac700 287.4
21 changes: 5 additions & 16 deletions wren-modeling-rs/wren-example/examples/to-many-calculation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use wren_core::mdl::builder::{
ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder,
};
use wren_core::mdl::manifest::{JoinType, Manifest};
use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL};
use wren_core::mdl::{apply_wren_rules, AnalyzedWrenMDL};

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -75,22 +75,11 @@ async fn main() -> Result<()> {
]);
let analyzed_mdl =
Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?);

let transformed = match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
"select totalprice from wrenai.public.customers",
)
.await
apply_wren_rules(&ctx, analyzed_mdl).await?;
let df = match ctx
.sql("select totalprice from wrenai.public.customers")
.await
{
Ok(sql) => sql,
Err(e) => {
eprintln!("Error transforming SQL: {}", e);
return Ok(());
}
};
println!("Transformed SQL: {}", transformed);
let df = match ctx.sql(&transformed).await {
Ok(df) => df,
Err(e) => {
eprintln!("Error executing SQL: {}", e);
Expand Down
Loading