Skip to content

Commit 79b499c

Browse files
authored
refactor(optimizer): move some methods into core struct && refactor the join's predicate push down (risingwavelabs#8455)
1 parent 64d80d2 commit 79b499c

File tree

6 files changed

+302
-209
lines changed

6 files changed

+302
-209
lines changed

src/frontend/src/optimizer/plan_node/generic/hop_window.rs

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
// limitations under the License.
1414

1515
use std::fmt;
16+
use std::num::NonZeroUsize;
1617

1718
use itertools::Itertools;
1819
use risingwave_common::catalog::{Field, Schema};
20+
use risingwave_common::error::Result;
1921
use risingwave_common::types::{DataType, IntervalUnit, IntervalUnitDisplay};
22+
use risingwave_common::util::column_index_mapping::ColIndexMapping;
23+
use risingwave_expr::ExprError;
2024

2125
use super::super::utils::IndicesDisplay;
2226
use super::{GenericPlanNode, GenericPlanRef};
23-
use crate::expr::{InputRef, InputRefDisplay};
27+
use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, InputRefDisplay, Literal};
2428
use crate::optimizer::optimizer_context::OptimizerContextRef;
2529

2630
/// [`HopWindow`] implements Hop Table Function.
@@ -104,6 +108,141 @@ impl<PlanRef: GenericPlanRef> HopWindow<PlanRef> {
104108
)
105109
}
106110

111+
pub fn internal_window_start_col_idx(&self) -> usize {
112+
self.input.schema().len()
113+
}
114+
115+
pub fn internal_window_end_col_idx(&self) -> usize {
116+
self.internal_window_start_col_idx() + 1
117+
}
118+
119+
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
120+
self.output2internal_col_mapping()
121+
.composite(&self.internal2input_col_mapping())
122+
}
123+
124+
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
125+
self.input2internal_col_mapping()
126+
.composite(&self.internal2output_col_mapping())
127+
}
128+
129+
pub fn internal_column_num(&self) -> usize {
130+
self.internal_window_start_col_idx() + 2
131+
}
132+
133+
pub fn output2internal_col_mapping(&self) -> ColIndexMapping {
134+
self.internal2output_col_mapping().inverse()
135+
}
136+
137+
pub fn internal2output_col_mapping(&self) -> ColIndexMapping {
138+
ColIndexMapping::with_remaining_columns(&self.output_indices, self.internal_column_num())
139+
}
140+
141+
pub fn input2internal_col_mapping(&self) -> ColIndexMapping {
142+
ColIndexMapping::identity_or_none(
143+
self.internal_window_start_col_idx(),
144+
self.internal_column_num(),
145+
)
146+
}
147+
148+
pub fn internal2input_col_mapping(&self) -> ColIndexMapping {
149+
ColIndexMapping::identity_or_none(
150+
self.internal_column_num(),
151+
self.internal_window_start_col_idx(),
152+
)
153+
}
154+
155+
pub fn derive_window_start_and_end_exprs(&self) -> Result<(Vec<ExprImpl>, Vec<ExprImpl>)> {
156+
let Self {
157+
window_size,
158+
window_slide,
159+
time_col,
160+
..
161+
} = &self;
162+
let units = window_size
163+
.exact_div(window_slide)
164+
.and_then(|x| NonZeroUsize::new(usize::try_from(x).ok()?))
165+
.ok_or_else(|| ExprError::InvalidParam {
166+
name: "window",
167+
reason: format!(
168+
"window_size {} cannot be divided by window_slide {}",
169+
window_size, window_slide
170+
),
171+
})?
172+
.get();
173+
let window_size_expr = Literal::new(Some((*window_size).into()), DataType::Interval).into();
174+
let window_slide_expr: ExprImpl =
175+
Literal::new(Some((*window_slide).into()), DataType::Interval).into();
176+
let window_size_sub_slide = FunctionCall::new(
177+
ExprType::Subtract,
178+
vec![window_size_expr, window_slide_expr.clone()],
179+
)?
180+
.into();
181+
182+
let time_col_shifted = FunctionCall::new(
183+
ExprType::Subtract,
184+
vec![
185+
ExprImpl::InputRef(Box::new(time_col.clone())),
186+
window_size_sub_slide,
187+
],
188+
)?
189+
.into();
190+
191+
let hop_start: ExprImpl = FunctionCall::new(
192+
ExprType::TumbleStart,
193+
vec![time_col_shifted, window_slide_expr],
194+
)?
195+
.into();
196+
197+
let mut window_start_exprs = Vec::with_capacity(units);
198+
let mut window_end_exprs = Vec::with_capacity(units);
199+
for i in 0..units {
200+
{
201+
let window_start_offset =
202+
window_slide
203+
.checked_mul_int(i)
204+
.ok_or_else(|| ExprError::InvalidParam {
205+
name: "window",
206+
reason: format!(
207+
"window_slide {} cannot be multiplied by {}",
208+
window_slide, i
209+
),
210+
})?;
211+
let window_start_offset_expr =
212+
Literal::new(Some(window_start_offset.into()), DataType::Interval).into();
213+
let window_start_expr = FunctionCall::new(
214+
ExprType::Add,
215+
vec![hop_start.clone(), window_start_offset_expr],
216+
)?
217+
.into();
218+
window_start_exprs.push(window_start_expr);
219+
}
220+
{
221+
let window_end_offset =
222+
window_slide.checked_mul_int(i + units).ok_or_else(|| {
223+
ExprError::InvalidParam {
224+
name: "window",
225+
reason: format!(
226+
"window_slide {} cannot be multiplied by {}",
227+
window_slide,
228+
i + units
229+
),
230+
}
231+
})?;
232+
let window_end_offset_expr =
233+
Literal::new(Some(window_end_offset.into()), DataType::Interval).into();
234+
let window_end_expr = FunctionCall::new(
235+
ExprType::Add,
236+
vec![hop_start.clone(), window_end_offset_expr],
237+
)?
238+
.into();
239+
window_end_exprs.push(window_end_expr);
240+
}
241+
}
242+
assert_eq!(window_start_exprs.len(), window_end_exprs.len());
243+
Ok((window_start_exprs, window_end_exprs))
244+
}
245+
107246
pub fn fmt_fields_with_builder(&self, builder: &mut fmt::DebugStruct<'_, '_>) {
108247
let output_type = DataType::window_of(&self.time_col.data_type).unwrap();
109248
builder.field(

src/frontend/src/optimizer/plan_node/generic/join.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,132 @@ impl<PlanRef: GenericPlanRef> Join<PlanRef> {
252252
}
253253
}
254254
}
255+
256+
/// Try to split and pushdown `predicate` into a into a join condition and into the inputs of the
257+
/// join. Returns the pushed predicates. The pushed part will be removed from the original
258+
/// predicate.
259+
///
260+
/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
261+
262+
pub fn push_down_into_join(
263+
predicate: &mut Condition,
264+
left_col_num: usize,
265+
right_col_num: usize,
266+
ty: JoinType,
267+
) -> (Condition, Condition, Condition) {
268+
let (left, right) = push_down_to_inputs(
269+
predicate,
270+
left_col_num,
271+
right_col_num,
272+
can_push_left_from_filter(ty),
273+
can_push_right_from_filter(ty),
274+
);
275+
276+
let on = if can_push_on_from_filter(ty) {
277+
let mut conjunctions = std::mem::take(&mut predicate.conjunctions);
278+
279+
// Do not push now on to the on, it will be pulled up into a filter instead.
280+
let on = Condition {
281+
conjunctions: conjunctions
282+
.drain_filter(|expr| expr.count_nows() == 0)
283+
.collect(),
284+
};
285+
predicate.conjunctions = conjunctions;
286+
on
287+
} else {
288+
Condition::true_cond()
289+
};
290+
(left, right, on)
291+
}
292+
293+
/// Try to pushes parts of the join condition to its inputs. Returns the pushed predicates. The
294+
/// pushed part will be removed from the original join predicate.
295+
///
296+
/// `InputRef`s in the right pushed condition are indexed by the right child's output schema.
297+
298+
pub fn push_down_join_condition(
299+
on_condition: &mut Condition,
300+
left_col_num: usize,
301+
right_col_num: usize,
302+
ty: JoinType,
303+
) -> (Condition, Condition) {
304+
push_down_to_inputs(
305+
on_condition,
306+
left_col_num,
307+
right_col_num,
308+
can_push_left_from_on(ty),
309+
can_push_right_from_on(ty),
310+
)
311+
}
312+
313+
/// Try to split and pushdown `predicate` into a join's left/right child.
314+
/// Returns the pushed predicates. The pushed part will be removed from the original predicate.
315+
///
316+
/// `InputRef`s in the right `Condition` are shifted by `-left_col_num`.
317+
fn push_down_to_inputs(
318+
predicate: &mut Condition,
319+
left_col_num: usize,
320+
right_col_num: usize,
321+
push_left: bool,
322+
push_right: bool,
323+
) -> (Condition, Condition) {
324+
let conjunctions = std::mem::take(&mut predicate.conjunctions);
325+
326+
let (mut left, right, mut others) =
327+
Condition { conjunctions }.split(left_col_num, right_col_num);
328+
329+
if !push_left {
330+
others.conjunctions.extend(left);
331+
left = Condition::true_cond();
332+
};
333+
334+
let right = if push_right {
335+
let mut mapping = ColIndexMapping::with_shift_offset(
336+
left_col_num + right_col_num,
337+
-(left_col_num as isize),
338+
);
339+
right.rewrite_expr(&mut mapping)
340+
} else {
341+
others.conjunctions.extend(right);
342+
Condition::true_cond()
343+
};
344+
345+
predicate.conjunctions = others.conjunctions;
346+
347+
(left, right)
348+
}
349+
350+
pub fn can_push_left_from_filter(ty: JoinType) -> bool {
351+
matches!(
352+
ty,
353+
JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti
354+
)
355+
}
356+
357+
pub fn can_push_right_from_filter(ty: JoinType) -> bool {
358+
matches!(
359+
ty,
360+
JoinType::Inner | JoinType::RightOuter | JoinType::RightSemi | JoinType::RightAnti
361+
)
362+
}
363+
364+
pub fn can_push_on_from_filter(ty: JoinType) -> bool {
365+
matches!(
366+
ty,
367+
JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi
368+
)
369+
}
370+
371+
pub fn can_push_left_from_on(ty: JoinType) -> bool {
372+
matches!(
373+
ty,
374+
JoinType::Inner | JoinType::RightOuter | JoinType::LeftSemi
375+
)
376+
}
377+
378+
pub fn can_push_right_from_on(ty: JoinType) -> bool {
379+
matches!(
380+
ty,
381+
JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi
382+
)
383+
}

src/frontend/src/optimizer/plan_node/logical_apply.rs

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use risingwave_common::catalog::Schema;
1818
use risingwave_common::error::{ErrorCode, Result, RwError};
1919
use risingwave_pb::plan_common::JoinType;
2020

21-
use super::generic::{self, GenericPlanNode};
21+
use super::generic::{self, push_down_into_join, push_down_join_condition, GenericPlanNode};
2222
use super::{
2323
ColPrunable, LogicalJoin, LogicalProject, PlanBase, PlanRef, PlanTreeNodeBinary,
2424
PredicatePushdown, ToBatch, ToStream,
@@ -318,28 +318,12 @@ impl PredicatePushdown for LogicalApply {
318318
let right_col_num = self.right().schema().len();
319319
let join_type = self.join_type();
320320

321-
let (left_from_filter, right_from_filter, on) = LogicalJoin::push_down(
322-
&mut predicate,
323-
left_col_num,
324-
right_col_num,
325-
LogicalJoin::can_push_left_from_filter(join_type),
326-
LogicalJoin::can_push_right_from_filter(join_type),
327-
LogicalJoin::can_push_on_from_filter(join_type),
328-
);
321+
let (left_from_filter, right_from_filter, on) =
322+
push_down_into_join(&mut predicate, left_col_num, right_col_num, join_type);
329323

330324
let mut new_on = self.on.clone().and(on);
331-
let (left_from_on, right_from_on, on) = LogicalJoin::push_down(
332-
&mut new_on,
333-
left_col_num,
334-
right_col_num,
335-
LogicalJoin::can_push_left_from_on(join_type),
336-
LogicalJoin::can_push_right_from_on(join_type),
337-
false,
338-
);
339-
assert!(
340-
on.always_true(),
341-
"On-clause should not be pushed to on-clause."
342-
);
325+
let (left_from_on, right_from_on) =
326+
push_down_join_condition(&mut new_on, left_col_num, right_col_num, join_type);
343327

344328
let left_predicate = left_from_filter.and(left_from_on);
345329
let right_predicate = right_from_filter.and(right_from_on);

0 commit comments

Comments
 (0)