Skip to content

Commit 98ea76a

Browse files
kwannoellmatz
andauthored
feat(optimizer): Add StreamProjectMergeRule (risingwavelabs#8753)
Co-authored-by: lmatz <[email protected]>
1 parent 380e104 commit 98ea76a

File tree

10 files changed

+405
-286
lines changed

10 files changed

+405
-286
lines changed

src/frontend/planner_test/src/lib.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
mod resolve_id;
2121

22-
use std::collections::BTreeMap;
22+
use std::collections::{BTreeMap, HashMap};
2323
use std::path::Path;
2424
use std::sync::Arc;
2525

@@ -33,7 +33,7 @@ use risingwave_frontend::session::SessionImpl;
3333
use risingwave_frontend::test_utils::{create_proto_file, get_explain_output, LocalFrontend};
3434
use risingwave_frontend::{
3535
build_graph, explain_stream_graph, Binder, Explain, FrontendOpts, OptimizerContext,
36-
OptimizerContextRef, PlanRef, Planner,
36+
OptimizerContextRef, PlanRef, Planner, WithOptions,
3737
};
3838
use risingwave_sqlparser::ast::{ExplainOptions, ObjectName, Statement};
3939
use risingwave_sqlparser::parser::Parser;
@@ -83,6 +83,10 @@ pub struct TestCase {
8383
/// Batch plan for local execution `.gen_batch_local_plan()`
8484
pub batch_local_plan: Option<String>,
8585

86+
/// Create sink plan (assumes blackhole sink)
87+
/// TODO: Other sinks
88+
pub sink_plan: Option<String>,
89+
8690
/// Create MV plan `.gen_create_mv_plan()`
8791
pub stream_plan: Option<String>,
8892

@@ -152,6 +156,9 @@ pub struct TestCaseResult {
152156
/// Batch plan for local execution `.gen_batch_local_plan()`
153157
pub batch_local_plan: Option<String>,
154158

159+
/// Generate sink plan
160+
pub sink_plan: Option<String>,
161+
155162
/// Create MV plan `.gen_create_mv_plan()`
156163
pub stream_plan: Option<String>,
157164

@@ -176,6 +183,9 @@ pub struct TestCaseResult {
176183
/// Error of `.gen_stream_plan()`
177184
pub stream_error: Option<String>,
178185

186+
/// Error of `.gen_sink_plan()`
187+
pub sink_error: Option<String>,
188+
179189
/// The result of an `EXPLAIN` statement.
180190
///
181191
/// This field is used when `sql` is an `EXPLAIN` statement.
@@ -209,6 +219,7 @@ impl TestCaseResult {
209219
batch_plan: self.batch_plan,
210220
batch_local_plan: self.batch_local_plan,
211221
stream_plan: self.stream_plan,
222+
sink_plan: self.sink_plan,
212223
batch_plan_proto: self.batch_plan_proto,
213224
planner_error: self.planner_error,
214225
optimizer_error: self.optimizer_error,
@@ -640,6 +651,30 @@ impl TestCase {
640651
}
641652
}
642653

654+
'sink: {
655+
if self.sink_plan.is_some() {
656+
let sink_name = "sink_test";
657+
let mut options = HashMap::new();
658+
options.insert("connector".to_string(), "blackhole".to_string());
659+
options.insert("type".to_string(), "append-only".to_string());
660+
let options = WithOptions::new(options);
661+
match logical_plan.gen_sink_plan(
662+
sink_name.to_string(),
663+
format!("CREATE SINK {sink_name} AS {}", stmt),
664+
options,
665+
) {
666+
Ok(sink_plan) => {
667+
ret.sink_plan = Some(explain_plan(&sink_plan.into()));
668+
break 'sink;
669+
}
670+
Err(err) => {
671+
ret.sink_error = Some(err.to_string());
672+
break 'sink;
673+
}
674+
}
675+
}
676+
}
677+
643678
Ok(ret)
644679
}
645680
}
@@ -696,7 +731,7 @@ fn check_result(expected: &TestCase, actual: &TestCaseResult) -> Result<()> {
696731
&expected.explain_output,
697732
&actual.explain_output,
698733
)?;
699-
734+
check_option_plan_eq("sink_plan", &expected.sink_plan, &actual.sink_plan)?;
700735
Ok(())
701736
}
702737

src/frontend/planner_test/tests/testdata/agg.yaml

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,8 @@
577577
└─StreamGlobalSimpleAgg { aggs: [max(max($expr1) filter((t.a < t.b) AND ((t.a + t.b) < 100:Int32) AND ((t.a * t.b) <> ((t.a + t.b) - 1:Int32)))), count] }
578578
└─StreamExchange { dist: Single }
579579
└─StreamHashAgg { group_key: [$expr2], aggs: [max($expr1) filter((t.a < t.b) AND ((t.a + t.b) < 100:Int32) AND ((t.a * t.b) <> ((t.a + t.b) - 1:Int32))), count] }
580-
└─StreamProject { exprs: [t.a, t.b, $expr1, t._row_id, Vnode(t._row_id) as $expr2] }
581-
└─StreamProject { exprs: [t.a, t.b, (t.a * t.b) as $expr1, t._row_id] }
582-
└─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
580+
└─StreamProject { exprs: [t.a, t.b, (t.a * t.b) as $expr1, t._row_id, Vnode(t._row_id) as $expr2] }
581+
└─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
583582
- name: avg filter clause + group by
584583
sql: |
585584
create table t(a int, b int);
@@ -1139,9 +1138,8 @@
11391138
└─StreamExchange { dist: HashShard(lineitem.l_commitdate) }
11401139
└─StreamHashAgg { group_key: [lineitem.l_commitdate, $expr1], aggs: [max(lineitem.l_commitdate), count] }
11411140
└─StreamProject { exprs: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate, Vnode(lineitem.l_orderkey) as $expr1] }
1142-
└─StreamProject { exprs: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate] }
1143-
└─StreamHashAgg { group_key: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate], aggs: [count] }
1144-
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_tax, lineitem.l_commitdate, lineitem.l_shipinstruct], pk: [lineitem.l_orderkey], dist: UpstreamHashShard(lineitem.l_orderkey) }
1141+
└─StreamHashAgg { group_key: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate], aggs: [count] }
1142+
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_tax, lineitem.l_commitdate, lineitem.l_shipinstruct], pk: [lineitem.l_orderkey], dist: UpstreamHashShard(lineitem.l_orderkey) }
11451143
- name: two phase agg on hop window input should use two phase agg
11461144
sql: |
11471145
SET QUERY_MODE TO DISTRIBUTED;
@@ -1180,14 +1178,13 @@
11801178
└─StreamExchange { dist: HashShard(window_start) }
11811179
└─StreamHashAgg { group_key: [window_start, $expr2], aggs: [max(sum0(count)), count] }
11821180
└─StreamProject { exprs: [bid.auction, window_start, sum0(count), Vnode(bid.auction, window_start) as $expr2] }
1183-
└─StreamProject { exprs: [bid.auction, window_start, sum0(count)] }
1184-
└─StreamHashAgg { group_key: [bid.auction, window_start], aggs: [sum0(count), count] }
1185-
└─StreamExchange { dist: HashShard(bid.auction, window_start) }
1186-
└─StreamHashAgg { group_key: [bid.auction, window_start, $expr1], aggs: [count] }
1187-
└─StreamProject { exprs: [bid.auction, window_start, bid._row_id, Vnode(bid._row_id) as $expr1] }
1188-
└─StreamHopWindow { time_col: bid.date_time, slide: 00:00:02, size: 00:00:10, output: [bid.auction, window_start, bid._row_id] }
1189-
└─StreamFilter { predicate: IsNotNull(bid.date_time) }
1190-
└─StreamTableScan { table: bid, columns: [bid.date_time, bid.auction, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) }
1181+
└─StreamHashAgg { group_key: [bid.auction, window_start], aggs: [sum0(count), count] }
1182+
└─StreamExchange { dist: HashShard(bid.auction, window_start) }
1183+
└─StreamHashAgg { group_key: [bid.auction, window_start, $expr1], aggs: [count] }
1184+
└─StreamProject { exprs: [bid.auction, window_start, bid._row_id, Vnode(bid._row_id) as $expr1] }
1185+
└─StreamHopWindow { time_col: bid.date_time, slide: 00:00:02, size: 00:00:10, output: [bid.auction, window_start, bid._row_id] }
1186+
└─StreamFilter { predicate: IsNotNull(bid.date_time) }
1187+
└─StreamTableScan { table: bid, columns: [bid.date_time, bid.auction, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) }
11911188
- name: two phase agg with stream SomeShard (via index) but pk satisfies output dist should use shuffle agg
11921189
sql: |
11931190
SET QUERY_MODE TO DISTRIBUTED;

0 commit comments

Comments
 (0)