Skip to content

Commit 53261c5

Browse files
authored
feat: Bushy tree join ordering (risingwavelabs#8316)
Signed-off-by: Kevin Axel <[email protected]>
1 parent a3dc882 commit 53261c5

File tree

8 files changed

+1736
-14
lines changed

8 files changed

+1736
-14
lines changed

src/common/src/session_config/mod.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::util::epoch::Epoch;
3434

3535
// This is a hack, &'static str is not allowed as a const generics argument.
3636
// TODO: refine this using the adt_const_params feature.
37-
const CONFIG_KEYS: [&str; 21] = [
37+
const CONFIG_KEYS: [&str; 22] = [
3838
"RW_IMPLICIT_FLUSH",
3939
"CREATE_COMPACTION_GROUP_FOR_MV",
4040
"QUERY_MODE",
@@ -56,6 +56,7 @@ const CONFIG_KEYS: [&str; 21] = [
5656
"RW_ENABLE_SHARE_PLAN",
5757
"INTERVALSTYLE",
5858
"BATCH_PARALLELISM",
59+
"RW_STREAMING_ENABLE_BUSHY_JOIN",
5960
];
6061

6162
// MUST HAVE 1v1 relationship to CONFIG_KEYS. e.g. CONFIG_KEYS[IMPLICIT_FLUSH] =
@@ -81,6 +82,7 @@ const FORCE_TWO_PHASE_AGG: usize = 17;
8182
const RW_ENABLE_SHARE_PLAN: usize = 18;
8283
const INTERVAL_STYLE: usize = 19;
8384
const BATCH_PARALLELISM: usize = 20;
85+
const STREAMING_ENABLE_BUSHY_JOIN: usize = 21;
8486

8587
trait ConfigEntry: Default + for<'a> TryFrom<&'a [&'a str], Error = RwError> {
8688
fn entry_name() -> &'static str;
@@ -277,6 +279,7 @@ type QueryEpoch = ConfigU64<QUERY_EPOCH, 0>;
277279
type Timezone = ConfigString<TIMEZONE>;
278280
type StreamingParallelism = ConfigU64<STREAMING_PARALLELISM, 0>;
279281
type StreamingEnableDeltaJoin = ConfigBool<STREAMING_ENABLE_DELTA_JOIN, false>;
282+
type StreamingEnableBushyJoin = ConfigBool<STREAMING_ENABLE_BUSHY_JOIN, false>;
280283
type EnableTwoPhaseAgg = ConfigBool<ENABLE_TWO_PHASE_AGG, true>;
281284
type ForceTwoPhaseAgg = ConfigBool<FORCE_TWO_PHASE_AGG, false>;
282285
type EnableSharePlan = ConfigBool<RW_ENABLE_SHARE_PLAN, true>;
@@ -342,6 +345,9 @@ pub struct ConfigMap {
342345
/// Enable delta join in streaming query. Defaults to false.
343346
streaming_enable_delta_join: StreamingEnableDeltaJoin,
344347

348+
/// Enable bushy join in the streaming query. Defaults to false.
349+
streaming_enable_bushy_join: StreamingEnableBushyJoin,
350+
345351
/// Enable two phase agg optimization. Defaults to true.
346352
/// Setting this to true will always set `FORCE_TWO_PHASE_AGG` to false.
347353
enable_two_phase_agg: EnableTwoPhaseAgg,
@@ -402,6 +408,8 @@ impl ConfigMap {
402408
self.streaming_parallelism = val.as_slice().try_into()?;
403409
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
404410
self.streaming_enable_delta_join = val.as_slice().try_into()?;
411+
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
412+
self.streaming_enable_bushy_join = val.as_slice().try_into()?;
405413
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
406414
self.enable_two_phase_agg = val.as_slice().try_into()?;
407415
if !*self.enable_two_phase_agg {
@@ -458,6 +466,8 @@ impl ConfigMap {
458466
Ok(self.streaming_parallelism.to_string())
459467
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
460468
Ok(self.streaming_enable_delta_join.to_string())
469+
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
470+
Ok(self.streaming_enable_bushy_join.to_string())
461471
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
462472
Ok(self.enable_two_phase_agg.to_string())
463473
} else if key.eq_ignore_ascii_case(ForceTwoPhaseAgg::entry_name()) {
@@ -550,6 +560,11 @@ impl ConfigMap {
550560
setting : self.streaming_enable_delta_join.to_string(),
551561
description: String::from("Enable delta join in streaming query.")
552562
},
563+
VariableInfo{
564+
name : StreamingEnableBushyJoin::entry_name().to_lowercase(),
565+
setting : self.streaming_enable_bushy_join.to_string(),
566+
description: String::from("Enable bushy join in streaming query.")
567+
},
553568
VariableInfo{
554569
name : EnableTwoPhaseAgg::entry_name().to_lowercase(),
555570
setting : self.enable_two_phase_agg.to_string(),
@@ -648,6 +663,10 @@ impl ConfigMap {
648663
*self.streaming_enable_delta_join
649664
}
650665

666+
pub fn get_streaming_enable_bushy_join(&self) -> bool {
667+
*self.streaming_enable_bushy_join
668+
}
669+
651670
pub fn get_enable_two_phase_agg(&self) -> bool {
652671
*self.enable_two_phase_agg
653672
}

src/frontend/planner_test/tests/testdata/bushy_join.yaml

Lines changed: 1408 additions & 0 deletions
Large diffs are not rendered by default.

src/frontend/src/optimizer/logical_optimization.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,15 @@ lazy_static! {
159159
ApplyOrder::TopDown,
160160
);
161161

162-
static ref JOIN_REORDER: OptimizationStage = OptimizationStage::new(
162+
static ref LEFT_DEEP_JOIN_REORDER: OptimizationStage = OptimizationStage::new(
163163
"Join Reorder".to_string(),
164-
vec![ReorderMultiJoinRule::create()],
164+
vec![LeftDeepTreeJoinOrderingRule::create()],
165+
ApplyOrder::TopDown,
166+
);
167+
168+
static ref BUSHY_TREE_JOIN_REORDER: OptimizationStage = OptimizationStage::new(
169+
"Bushy tree join ordering Rule".to_string(),
170+
vec![BushyTreeJoinOrderingRule::create()],
165171
ApplyOrder::TopDown,
166172
);
167173

@@ -365,9 +371,17 @@ impl LogicalOptimizer {
365371
// their relevant joins.
366372
plan = plan.optimize_by_rules(&TO_MULTI_JOIN);
367373

368-
// Reorder multijoin into left-deep join tree.
369-
plan = plan.optimize_by_rules(&JOIN_REORDER);
370-
374+
// Reorder multijoin into join tree.
375+
if plan
376+
.ctx()
377+
.session_ctx()
378+
.config()
379+
.get_streaming_enable_bushy_join()
380+
{
381+
plan = plan.optimize_by_rules(&BUSHY_TREE_JOIN_REORDER);
382+
} else {
383+
plan = plan.optimize_by_rules(&LEFT_DEEP_JOIN_REORDER);
384+
}
371385
// Predicate Push-down: apply filter pushdown rules again since we pullup all join
372386
// conditions into a filter above the multijoin.
373387
plan = Self::predicate_pushdown(plan, explain_trace, &ctx);
@@ -438,7 +452,7 @@ impl LogicalOptimizer {
438452
plan = plan.optimize_by_rules(&TO_MULTI_JOIN);
439453

440454
// Reorder multijoin into left-deep join tree.
441-
plan = plan.optimize_by_rules(&JOIN_REORDER);
455+
plan = plan.optimize_by_rules(&LEFT_DEEP_JOIN_REORDER);
442456

443457
// Predicate Push-down: apply filter pushdown rules again since we pullup all join
444458
// conditions into a filter above the multijoin.

src/frontend/src/optimizer/plan_node/logical_multi_join.rs

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::cmp::Ordering;
16+
use std::collections::{BTreeMap, BTreeSet, VecDeque};
1517
use std::fmt;
1618

1719
use itertools::Itertools;
@@ -483,6 +485,243 @@ impl LogicalMultiJoin {
483485
Ok(join_ordering)
484486
}
485487

488+
pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
489+
// Join tree internal representation
490+
#[derive(Clone, Default, Debug)]
491+
struct JoinTreeNode {
492+
idx: Option<usize>,
493+
left: Option<Box<JoinTreeNode>>,
494+
right: Option<Box<JoinTreeNode>>,
495+
height: usize,
496+
}
497+
498+
// join graph internal representation
499+
#[derive(Clone, Debug)]
500+
struct GraphNode {
501+
id: usize,
502+
join_tree: JoinTreeNode,
503+
// use BTreeSet for deterministic
504+
relations: BTreeSet<usize>,
505+
}
506+
507+
let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
508+
.map(|idx| GraphNode {
509+
id: idx,
510+
relations: BTreeSet::new(),
511+
join_tree: JoinTreeNode {
512+
idx: Some(idx),
513+
left: None,
514+
right: None,
515+
height: 0,
516+
},
517+
})
518+
.enumerate()
519+
.collect();
520+
let (eq_join_conditions, _) = self
521+
.on
522+
.clone()
523+
.split_by_input_col_nums(&self.input_col_nums(), true);
524+
525+
for ((src, dst), _) in eq_join_conditions {
526+
nodes.get_mut(&src).unwrap().relations.insert(dst);
527+
nodes.get_mut(&dst).unwrap().relations.insert(src);
528+
}
529+
530+
// isolated nodes can be joined at any where.
531+
let iso_nodes = nodes
532+
.iter()
533+
.filter_map(|n| {
534+
if n.1.relations.is_empty() {
535+
Some(*n.0)
536+
} else {
537+
None
538+
}
539+
})
540+
.collect_vec();
541+
542+
for n in iso_nodes {
543+
for adj in 0..nodes.len() {
544+
if adj != n {
545+
nodes.get_mut(&n).unwrap().relations.insert(adj);
546+
nodes.get_mut(&adj).unwrap().relations.insert(n);
547+
}
548+
}
549+
}
550+
551+
let mut optimized_bushy_tree = None;
552+
let mut que = VecDeque::from([nodes]);
553+
let mut isolated = BTreeSet::new();
554+
555+
while let Some(mut nodes) = que.pop_front() {
556+
if nodes.len() == 1 {
557+
let node = nodes.into_values().next().unwrap();
558+
optimized_bushy_tree = Some(optimized_bushy_tree.map_or(
559+
node.clone(),
560+
|old_tree: GraphNode| {
561+
if node.join_tree.height < old_tree.join_tree.height {
562+
node
563+
} else {
564+
old_tree
565+
}
566+
},
567+
));
568+
continue;
569+
}
570+
571+
let (idx, _) = nodes
572+
.iter()
573+
.min_by(
574+
|(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
575+
Ordering::Less => Ordering::Less,
576+
Ordering::Greater => Ordering::Greater,
577+
Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
578+
},
579+
)
580+
.unwrap();
581+
let n = nodes.remove(&idx.clone()).unwrap();
582+
583+
if n.relations.is_empty() {
584+
isolated.insert(n.id);
585+
que.push_back(nodes);
586+
continue;
587+
}
588+
589+
for merge_node in &n.relations {
590+
let mut nodes = nodes.clone();
591+
for adjacent_node in &n.relations {
592+
if *adjacent_node != *merge_node {
593+
nodes
594+
.get_mut(adjacent_node)
595+
.unwrap()
596+
.relations
597+
.remove(&n.id);
598+
nodes
599+
.get_mut(adjacent_node)
600+
.unwrap()
601+
.relations
602+
.insert(*merge_node);
603+
nodes
604+
.get_mut(merge_node)
605+
.unwrap()
606+
.relations
607+
.insert(*adjacent_node);
608+
}
609+
}
610+
let mut merge_graph_node = nodes.get_mut(merge_node).unwrap();
611+
merge_graph_node.relations.remove(&n.id);
612+
let l_tree = n.join_tree.clone();
613+
let r_tree = std::mem::take(&mut merge_graph_node.join_tree);
614+
let new_height = usize::max(l_tree.height, r_tree.height) + 1;
615+
616+
if let Some(min_height) = optimized_bushy_tree.as_ref().map(|t| t.join_tree.height) && min_height < new_height {
617+
continue;
618+
}
619+
620+
merge_graph_node.join_tree = JoinTreeNode {
621+
idx: None,
622+
left: Some(Box::new(l_tree)),
623+
right: Some(Box::new(r_tree)),
624+
height: new_height,
625+
};
626+
que.push_back(nodes);
627+
}
628+
}
629+
630+
fn create_logical_join(
631+
s: &LogicalMultiJoin,
632+
mut join_tree: JoinTreeNode,
633+
join_ordering: &mut Vec<usize>,
634+
) -> Result<PlanRef> {
635+
Ok(match (join_tree.left.take(), join_tree.right.take()) {
636+
(Some(l), Some(r)) => LogicalJoin::new(
637+
create_logical_join(s, *l, join_ordering)?,
638+
create_logical_join(s, *r, join_ordering)?,
639+
JoinType::Inner,
640+
Condition::true_cond(),
641+
)
642+
.into(),
643+
(None, None) => {
644+
if let Some(idx) = join_tree.idx {
645+
join_ordering.push(idx);
646+
s.inputs[idx].clone()
647+
} else {
648+
return Err(RwError::from(ErrorCode::InternalError(
649+
"id of the leaf node not found in the join tree".into(),
650+
)));
651+
}
652+
}
653+
(_, _) => {
654+
return Err(RwError::from(ErrorCode::InternalError(
655+
"only leaf node can have None subtree".into(),
656+
)))
657+
}
658+
})
659+
}
660+
661+
let isolated = isolated.into_iter().collect_vec();
662+
let mut join_ordering = vec![];
663+
let mut output = if let Some(optimized_bushy_tree) = optimized_bushy_tree {
664+
let mut output =
665+
create_logical_join(self, optimized_bushy_tree.join_tree, &mut join_ordering)?;
666+
667+
output = isolated.into_iter().fold(output, |chain, n| {
668+
join_ordering.push(n);
669+
LogicalJoin::new(
670+
chain,
671+
self.inputs[n].clone(),
672+
JoinType::Inner,
673+
Condition::true_cond(),
674+
)
675+
.into()
676+
});
677+
output
678+
} else if !isolated.is_empty() {
679+
let base = isolated[0];
680+
join_ordering.push(isolated[0]);
681+
isolated[1..]
682+
.iter()
683+
.fold(self.inputs[base].clone(), |chain, n| {
684+
join_ordering.push(*n);
685+
LogicalJoin::new(
686+
chain,
687+
self.inputs[*n].clone(),
688+
JoinType::Inner,
689+
Condition::true_cond(),
690+
)
691+
.into()
692+
})
693+
} else {
694+
return Err(RwError::from(ErrorCode::InternalError(
695+
"no plan remain".into(),
696+
)));
697+
};
698+
let total_col_num = self.inner2output.source_size();
699+
let reorder_mapping = {
700+
let mut reorder_mapping = vec![None; total_col_num];
701+
702+
join_ordering
703+
.iter()
704+
.cloned()
705+
.flat_map(|input_idx| {
706+
(0..self.inputs[input_idx].schema().len())
707+
.map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
708+
})
709+
.enumerate()
710+
.for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
711+
reorder_mapping
712+
};
713+
output =
714+
LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
715+
.into();
716+
717+
// We will later push down all of the filters back to the individual joins via the
718+
// `FilterJoinRule`.
719+
output = LogicalFilter::create(output, self.on.clone());
720+
output =
721+
LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
722+
Ok(output)
723+
}
724+
486725
pub(crate) fn input_col_nums(&self) -> Vec<usize> {
487726
self.inputs.iter().map(|i| i.schema().len()).collect()
488727
}

0 commit comments

Comments
 (0)