Skip to content

Commit 32df9cb

Browse files
authored
feat(frontend): fearless recursion on deep plans (#16279)
Signed-off-by: Bugen Zhao <[email protected]>
1 parent 4ba80c9 commit 32df9cb

File tree

6 files changed

+448
-216
lines changed

6 files changed

+448
-216
lines changed

Cargo.lock

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/common/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ serde_json = "1"
9393
serde_with = "3"
9494
smallbitset = "0.7.1"
9595
speedate = "0.14.0"
96+
stacker = "0.1"
9697
static_assertions = "1"
9798
strum = "0.26"
9899
strum_macros = "0.26"

src/common/src/util/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub mod pretty_bytes;
3030
pub mod prost;
3131
pub mod query_log;
3232
pub use rw_resource_util as resource_util;
33+
pub mod recursive;
3334
pub mod row_id;
3435
pub mod row_serde;
3536
pub mod runtime;

src/common/src/util/recursive.rs

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright 2024 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//! Track the recursion and grow the stack when necessary to enable fearless recursion.
16+
17+
use std::cell::RefCell;
18+
19+
// See documentation of `stacker` for the meaning of these constants.
20+
// TODO: determine good values or make them configurable
21+
const RED_ZONE: usize = 128 * 1024; // 128KiB
22+
const STACK_SIZE: usize = 16 * RED_ZONE; // 2MiB
23+
24+
/// Recursion depth.
25+
struct Depth {
26+
/// The current depth.
27+
current: usize,
28+
/// The max depth reached so far, not considering the current depth.
29+
last_max: usize,
30+
}
31+
32+
impl Depth {
33+
const fn new() -> Self {
34+
Self {
35+
current: 0,
36+
last_max: 0,
37+
}
38+
}
39+
40+
fn reset(&mut self) {
41+
*self = Self::new();
42+
}
43+
}
44+
45+
/// The tracker for a recursive function.
46+
pub struct Tracker {
47+
depth: RefCell<Depth>,
48+
}
49+
50+
impl Tracker {
51+
/// Create a new tracker.
52+
pub const fn new() -> Self {
53+
Self {
54+
depth: RefCell::new(Depth::new()),
55+
}
56+
}
57+
58+
/// Retrieve the current depth of the recursion. Starts from 1 once the
59+
/// recursive function is called.
60+
pub fn depth(&self) -> usize {
61+
self.depth.borrow().current
62+
}
63+
64+
/// Check if the current depth reaches the given depth **for the first time**.
65+
///
66+
/// This is useful for logging without any duplication.
67+
pub fn depth_reaches(&self, depth: usize) -> bool {
68+
let d = self.depth.borrow();
69+
d.current == depth && d.current > d.last_max
70+
}
71+
72+
/// Run a recursive function. Grow the stack if necessary.
73+
fn recurse<T>(&self, f: impl FnOnce() -> T) -> T {
74+
struct DepthGuard<'a> {
75+
depth: &'a RefCell<Depth>,
76+
}
77+
78+
impl<'a> DepthGuard<'a> {
79+
fn new(depth: &'a RefCell<Depth>) -> Self {
80+
depth.borrow_mut().current += 1;
81+
Self { depth }
82+
}
83+
}
84+
85+
impl<'a> Drop for DepthGuard<'a> {
86+
fn drop(&mut self) {
87+
let mut d = self.depth.borrow_mut();
88+
d.last_max = d.last_max.max(d.current); // update the last max depth
89+
d.current -= 1; // restore the current depth
90+
if d.current == 0 {
91+
d.reset(); // reset state if the recursion is finished
92+
}
93+
}
94+
}
95+
96+
let _guard = DepthGuard::new(&self.depth);
97+
98+
if cfg!(madsim) {
99+
f() // madsim does not support stack growth
100+
} else {
101+
stacker::maybe_grow(RED_ZONE, STACK_SIZE, f)
102+
}
103+
}
104+
}
105+
106+
/// The extension trait for a thread-local tracker to run a recursive function.
107+
#[easy_ext::ext(Recurse)]
108+
impl std::thread::LocalKey<Tracker> {
109+
/// Run the given recursive function. Grow the stack if necessary.
110+
///
111+
/// # Fearless Recursion
112+
///
113+
/// This enables fearless recursion in most cases as long as a single frame
114+
/// does not exceed the [`RED_ZONE`] size. That is, the caller can recurse
115+
/// as much as it wants without worrying about stack overflow.
116+
///
117+
/// # Tracker
118+
///
119+
/// The caller can retrieve the [`Tracker`] of the current recursion from
120+
/// the closure argument. This can be useful for checking the depth of the
121+
/// recursion, logging or throwing an error gracefully if it's too deep.
122+
///
123+
/// Note that different trackers defined in different functions are
124+
/// independent of each other. If there's a cross-function recursion, the
125+
/// tracker retrieved from the closure argument only represents the current
126+
/// function's state.
127+
///
128+
/// # Example
129+
///
130+
/// Define the tracker with [`tracker!`] and call this method on it to run
131+
/// a recursive function.
132+
///
133+
/// ```ignore
134+
/// #[inline(never)]
135+
/// fn sum(x: u64) -> u64 {
136+
/// tracker!().recurse(|t| {
137+
/// if t.depth() % 100000 == 0 {
138+
/// eprintln!("too deep!");
139+
/// }
140+
/// if x == 0 {
141+
/// return 0;
142+
/// }
143+
/// x + sum(x - 1)
144+
/// })
145+
/// }
146+
/// ```
147+
pub fn recurse<T>(&'static self, f: impl FnOnce(&Tracker) -> T) -> T {
148+
self.with(|t| t.recurse(|| f(t)))
149+
}
150+
}
151+
152+
/// Define the tracker for recursion and return it.
153+
///
154+
/// Call [`Recurse::recurse`] on it to run a recursive function. See
155+
/// documentation there for usage.
156+
#[macro_export]
157+
macro_rules! __recursive_tracker {
158+
() => {{
159+
use $crate::util::recursive::Tracker;
160+
std::thread_local! {
161+
static __TRACKER: Tracker = const { Tracker::new() };
162+
}
163+
__TRACKER
164+
}};
165+
}
166+
pub use __recursive_tracker as tracker;
167+
168+
#[cfg(all(test, not(madsim)))]
169+
mod tests {
170+
use super::*;
171+
172+
#[test]
173+
fn test_fearless_recursion() {
174+
const X: u64 = 1919810;
175+
const EXPECTED: u64 = 1842836177955;
176+
177+
#[inline(never)]
178+
fn sum(x: u64) -> u64 {
179+
tracker!().recurse(|t| {
180+
if x == 0 {
181+
assert_eq!(t.depth(), X as usize + 1);
182+
return 0;
183+
}
184+
x + sum(x - 1)
185+
})
186+
}
187+
188+
assert_eq!(sum(X), EXPECTED);
189+
}
190+
}

src/frontend/src/optimizer/plan_node/mod.rs

+65-47
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use itertools::Itertools;
3939
use paste::paste;
4040
use pretty_xmlish::{Pretty, PrettyConfig};
4141
use risingwave_common::catalog::Schema;
42+
use risingwave_common::util::recursive::{self, Recurse};
4243
use risingwave_pb::batch_plan::PlanNode as BatchPlanPb;
4344
use risingwave_pb::stream_plan::StreamNode as StreamPlanPb;
4445
use serde::Serialize;
@@ -51,6 +52,7 @@ use self::utils::Distill;
5152
use super::property::{Distribution, FunctionalDependencySet, Order};
5253
use crate::error::{ErrorCode, Result};
5354
use crate::optimizer::ExpressionSimplifyRewriter;
55+
use crate::session::current::notice_to_user;
5456

5557
/// A marker trait for different conventions, used for enforcing type safety.
5658
///
@@ -694,6 +696,10 @@ impl dyn PlanNode {
694696
}
695697
}
696698

699+
const PLAN_DEPTH_THRESHOLD: usize = 30;
700+
const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
701+
Consider simplifying or splitting the query if you encounter any issues.";
702+
697703
impl dyn PlanNode {
698704
/// Serialize the plan node and its children to a stream plan proto.
699705
///
@@ -703,41 +709,47 @@ impl dyn PlanNode {
703709
&self,
704710
state: &mut BuildFragmentGraphState,
705711
) -> SchedulerResult<StreamPlanPb> {
706-
use stream::prelude::*;
712+
recursive::tracker!().recurse(|t| {
713+
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
714+
notice_to_user(PLAN_TOO_DEEP_NOTICE);
715+
}
707716

708-
if let Some(stream_table_scan) = self.as_stream_table_scan() {
709-
return stream_table_scan.adhoc_to_stream_prost(state);
710-
}
711-
if let Some(stream_cdc_table_scan) = self.as_stream_cdc_table_scan() {
712-
return stream_cdc_table_scan.adhoc_to_stream_prost(state);
713-
}
714-
if let Some(stream_source_scan) = self.as_stream_source_scan() {
715-
return stream_source_scan.adhoc_to_stream_prost(state);
716-
}
717-
if let Some(stream_share) = self.as_stream_share() {
718-
return stream_share.adhoc_to_stream_prost(state);
719-
}
717+
use stream::prelude::*;
720718

721-
let node = Some(self.try_to_stream_prost_body(state)?);
722-
let input = self
723-
.inputs()
724-
.into_iter()
725-
.map(|plan| plan.to_stream_prost(state))
726-
.try_collect()?;
727-
// TODO: support pk_indices and operator_id
728-
Ok(StreamPlanPb {
729-
input,
730-
identity: self.explain_myself_to_string(),
731-
node_body: node,
732-
operator_id: self.id().0 as _,
733-
stream_key: self
734-
.stream_key()
735-
.unwrap_or_default()
736-
.iter()
737-
.map(|x| *x as u32)
738-
.collect(),
739-
fields: self.schema().to_prost(),
740-
append_only: self.plan_base().append_only(),
719+
if let Some(stream_table_scan) = self.as_stream_table_scan() {
720+
return stream_table_scan.adhoc_to_stream_prost(state);
721+
}
722+
if let Some(stream_cdc_table_scan) = self.as_stream_cdc_table_scan() {
723+
return stream_cdc_table_scan.adhoc_to_stream_prost(state);
724+
}
725+
if let Some(stream_source_scan) = self.as_stream_source_scan() {
726+
return stream_source_scan.adhoc_to_stream_prost(state);
727+
}
728+
if let Some(stream_share) = self.as_stream_share() {
729+
return stream_share.adhoc_to_stream_prost(state);
730+
}
731+
732+
let node = Some(self.try_to_stream_prost_body(state)?);
733+
let input = self
734+
.inputs()
735+
.into_iter()
736+
.map(|plan| plan.to_stream_prost(state))
737+
.try_collect()?;
738+
// TODO: support pk_indices and operator_id
739+
Ok(StreamPlanPb {
740+
input,
741+
identity: self.explain_myself_to_string(),
742+
node_body: node,
743+
operator_id: self.id().0 as _,
744+
stream_key: self
745+
.stream_key()
746+
.unwrap_or_default()
747+
.iter()
748+
.map(|x| *x as u32)
749+
.collect(),
750+
fields: self.schema().to_prost(),
751+
append_only: self.plan_base().append_only(),
752+
})
741753
})
742754
}
743755

@@ -749,20 +761,26 @@ impl dyn PlanNode {
749761
/// Serialize the plan node and its children to a batch plan proto without the identity field
750762
/// (for testing).
751763
pub fn to_batch_prost_identity(&self, identity: bool) -> SchedulerResult<BatchPlanPb> {
752-
let node_body = Some(self.try_to_batch_prost_body()?);
753-
let children = self
754-
.inputs()
755-
.into_iter()
756-
.map(|plan| plan.to_batch_prost_identity(identity))
757-
.try_collect()?;
758-
Ok(BatchPlanPb {
759-
children,
760-
identity: if identity {
761-
self.explain_myself_to_string()
762-
} else {
763-
"".into()
764-
},
765-
node_body,
764+
recursive::tracker!().recurse(|t| {
765+
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
766+
notice_to_user(PLAN_TOO_DEEP_NOTICE);
767+
}
768+
769+
let node_body = Some(self.try_to_batch_prost_body()?);
770+
let children = self
771+
.inputs()
772+
.into_iter()
773+
.map(|plan| plan.to_batch_prost_identity(identity))
774+
.try_collect()?;
775+
Ok(BatchPlanPb {
776+
children,
777+
identity: if identity {
778+
self.explain_myself_to_string()
779+
} else {
780+
"".into()
781+
},
782+
node_body,
783+
})
766784
})
767785
}
768786

0 commit comments

Comments
 (0)