Skip to content

Commit 305c864

Browse files
authored
feat(frontend): support infer param in binder (risingwavelabs#8453)
1 parent 9d5ff78 commit 305c864

File tree

10 files changed

+212
-11
lines changed

10 files changed

+212
-11
lines changed

src/frontend/src/binder/expr/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use risingwave_sqlparser::ast::{
2323
};
2424

2525
use crate::binder::Binder;
26-
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, SubqueryKind};
26+
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Parameter, SubqueryKind};
2727

2828
mod binary_op;
2929
mod column;
@@ -123,6 +123,7 @@ impl Binder {
123123
start,
124124
count,
125125
} => self.bind_overlay(*expr, *new_substring, *start, count),
126+
Expr::Parameter { index } => self.bind_parameter(index),
126127
_ => Err(ErrorCode::NotImplemented(
127128
format!("unsupported expression {:?}", expr),
128129
112.into(),
@@ -297,6 +298,10 @@ impl Binder {
297298
FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
298299
}
299300

301+
fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
302+
Ok(Parameter::new(index, self.param_types.clone()).into())
303+
}
304+
300305
/// Bind `expr (not) between low and high`
301306
pub(super) fn bind_between(
302307
&mut self,

src/frontend/src/binder/mod.rs

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16-
use std::sync::Arc;
16+
use std::sync::{Arc, RwLock};
1717

18+
use itertools::Itertools;
1819
use risingwave_common::error::Result;
1920
use risingwave_common::session_config::SearchPath;
21+
use risingwave_common::types::DataType;
22+
use risingwave_common::util::iter_util::ZipEqDebug;
2023
use risingwave_sqlparser::ast::Statement;
2124

2225
mod bind_context;
@@ -90,10 +93,94 @@ pub struct Binder {
9093

9194
/// `ShareId`s identifying shared views.
9295
shared_views: HashMap<ViewId, ShareId>,
96+
97+
param_types: ParameterTypes,
98+
}
99+
100+
/// `ParameterTypes` is used to record the types of the parameters during binding. It works
101+
/// following the rules:
102+
/// 1. At the beginning, it contains the user specified parameters type.
103+
/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
104+
/// if it didn't exist in `ParameterTypes`.
105+
/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
106+
/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
107+
/// parameter has been inferred, the cast function will act as a normal cast.
108+
/// 4. After bind finished:
109+
/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
110+
/// occur in the query. `export` will return error if there is a kind of
111+
/// parameter. This rule is compatible with PostgreSQL
112+
/// (b) parameter is None means that it's a unknown type. The user didn't specify it
113+
/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
114+
/// compatible with PostgreSQL.
115+
/// (c) parameter is Some means that it's a known type.
116+
#[derive(Clone, Debug)]
117+
pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
118+
119+
impl ParameterTypes {
120+
pub fn new(specified_param_types: Vec<DataType>) -> Self {
121+
let map = specified_param_types
122+
.into_iter()
123+
.enumerate()
124+
.map(|(index, data_type)| ((index + 1) as u64, Some(data_type)))
125+
.collect::<HashMap<u64, Option<DataType>>>();
126+
Self(Arc::new(RwLock::new(map)))
127+
}
128+
129+
pub fn has_infer(&self, index: u64) -> bool {
130+
self.0.read().unwrap().get(&index).unwrap().is_some()
131+
}
132+
133+
pub fn read_type(&self, index: u64) -> Option<DataType> {
134+
self.0.read().unwrap().get(&index).unwrap().clone()
135+
}
136+
137+
pub fn record_new_param(&mut self, index: u64) {
138+
self.0.write().unwrap().entry(index).or_insert(None);
139+
}
140+
141+
pub fn record_infer_type(&mut self, index: u64, data_type: DataType) {
142+
assert!(
143+
!self.has_infer(index),
144+
"The parameter has been inferred, should not be inferred again."
145+
);
146+
self.0
147+
.write()
148+
.unwrap()
149+
.get_mut(&index)
150+
.unwrap()
151+
.replace(data_type);
152+
}
153+
154+
pub fn export(&self) -> Result<Vec<DataType>> {
155+
let types = self
156+
.0
157+
.read()
158+
.unwrap()
159+
.clone()
160+
.into_iter()
161+
.sorted_by_key(|(index, _)| *index)
162+
.collect::<Vec<_>>();
163+
164+
// Check if all the parameters have been inferred.
165+
for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
166+
if *index != expect_index {
167+
return Err(ErrorCode::InvalidInputSyntax(format!(
168+
"Cannot infer the type of the parameter {}.",
169+
expect_index
170+
))
171+
.into());
172+
}
173+
}
174+
175+
Ok(types
176+
.into_iter()
177+
.map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
178+
.collect::<Vec<_>>())
179+
}
93180
}
94181

95182
impl Binder {
96-
fn new_inner(session: &SessionImpl, in_create_mv: bool) -> Binder {
183+
fn new_inner(session: &SessionImpl, in_create_mv: bool, param_types: Vec<DataType>) -> Binder {
97184
let now_ms = session
98185
.env()
99186
.hummock_snapshot_manager()
@@ -114,22 +201,27 @@ impl Binder {
114201
search_path: session.config().get_search_path(),
115202
in_create_mv,
116203
shared_views: HashMap::new(),
204+
param_types: ParameterTypes::new(param_types),
117205
}
118206
}
119207

120208
pub fn new(session: &SessionImpl) -> Binder {
121-
Self::new_inner(session, false)
209+
Self::new_inner(session, false, vec![])
122210
}
123211

124212
pub fn new_for_stream(session: &SessionImpl) -> Binder {
125-
Self::new_inner(session, true)
213+
Self::new_inner(session, true, vec![])
126214
}
127215

128216
/// Bind a [`Statement`].
129217
pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
130218
self.bind_statement(stmt)
131219
}
132220

221+
pub fn export_param_types(&self) -> Result<Vec<DataType>> {
222+
self.param_types.export()
223+
}
224+
133225
fn push_context(&mut self) {
134226
let new_context = std::mem::take(&mut self.context);
135227
self.context.cte_to_relation = new_context.cte_to_relation.clone();

src/frontend/src/expr/expr_mutator.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
use super::{
16-
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
16+
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
1717
TableFunction, UserDefinedFunction, WindowFunction,
1818
};
1919

@@ -30,6 +30,7 @@ pub trait ExprMutator {
3030
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
3131
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
3232
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
33+
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
3334
}
3435
}
3536
fn visit_function_call(&mut self, func_call: &mut FunctionCall) {
@@ -47,6 +48,7 @@ pub trait ExprMutator {
4748
agg_call.filter_mut().visit_expr_mut(self);
4849
}
4950
fn visit_literal(&mut self, _: &mut Literal) {}
51+
fn visit_parameter(&mut self, _: &mut Parameter) {}
5052
fn visit_input_ref(&mut self, _: &mut InputRef) {}
5153
fn visit_subquery(&mut self, _: &mut Subquery) {}
5254
fn visit_correlated_input_ref(&mut self, _: &mut CorrelatedInputRef) {}

src/frontend/src/expr/expr_rewriter.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
use super::{
16-
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
16+
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
1717
TableFunction, UserDefinedFunction, WindowFunction,
1818
};
1919

@@ -32,6 +32,7 @@ pub trait ExprRewriter {
3232
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
3333
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
3434
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
35+
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
3536
}
3637
}
3738
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
@@ -54,6 +55,9 @@ pub trait ExprRewriter {
5455
.unwrap()
5556
.into()
5657
}
58+
fn rewrite_parameter(&mut self, parameter: Parameter) -> ExprImpl {
59+
parameter.into()
60+
}
5761
fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
5862
literal.into()
5963
}

src/frontend/src/expr/expr_visitor.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
use super::{
16-
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
16+
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
1717
TableFunction, UserDefinedFunction, WindowFunction,
1818
};
1919

@@ -42,6 +42,7 @@ pub trait ExprVisitor<R: Default> {
4242
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
4343
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
4444
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
45+
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
4546
}
4647
}
4748
fn visit_function_call(&mut self, func_call: &FunctionCall) -> R {
@@ -63,6 +64,9 @@ pub trait ExprVisitor<R: Default> {
6364
r = Self::merge(r, agg_call.filter().visit_expr(self));
6465
r
6566
}
67+
fn visit_parameter(&mut self, _: &Parameter) -> R {
68+
R::default()
69+
}
6670
fn visit_literal(&mut self, _: &Literal) -> R {
6771
R::default()
6872
}

src/frontend/src/expr/function_call.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,14 @@ impl FunctionCall {
111111

112112
/// Create a cast expr over `child` to `target` type in `allows` context.
113113
pub fn new_cast(
114-
child: ExprImpl,
114+
mut child: ExprImpl,
115115
target: DataType,
116116
allows: CastContext,
117117
) -> Result<ExprImpl, CastError> {
118+
if let ExprImpl::Parameter(expr) = &mut child && !expr.has_infer() {
119+
expr.cast_infer_type(target);
120+
return Ok(child);
121+
}
118122
if is_row_function(&child) {
119123
// Row function will have empty fields in Datatype::Struct at this point. Therefore,
120124
// we will need to take some special care to generate the cast types. For normal struct

src/frontend/src/expr/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ mod correlated_input_ref;
2727
mod function_call;
2828
mod input_ref;
2929
mod literal;
30+
mod parameter;
3031
mod subquery;
3132
mod table_function;
3233
mod user_defined_function;
@@ -50,6 +51,7 @@ pub use expr_visitor::ExprVisitor;
5051
pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
5152
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
5253
pub use literal::Literal;
54+
pub use parameter::Parameter;
5355
pub use risingwave_pb::expr::expr_node::Type as ExprType;
5456
pub use session_timezone::SessionTimezone;
5557
pub use subquery::{Subquery, SubqueryKind};
@@ -96,7 +98,8 @@ impl_expr_impl!(
9698
Subquery,
9799
TableFunction,
98100
WindowFunction,
99-
UserDefinedFunction
101+
UserDefinedFunction,
102+
Parameter
100103
);
101104

102105
impl ExprImpl {
@@ -174,6 +177,7 @@ impl ExprImpl {
174177
/// Check whether self is a literal NULL or literal string.
175178
pub fn is_unknown(&self) -> bool {
176179
matches!(self, ExprImpl::Literal(literal) if literal.return_type() == DataType::Varchar)
180+
|| matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
177181
}
178182

179183
/// Shorthand to create cast expr to `target` type in implicit context.
@@ -761,6 +765,7 @@ impl Expr for ExprImpl {
761765
ExprImpl::TableFunction(expr) => expr.return_type(),
762766
ExprImpl::WindowFunction(expr) => expr.return_type(),
763767
ExprImpl::UserDefinedFunction(expr) => expr.return_type(),
768+
ExprImpl::Parameter(expr) => expr.return_type(),
764769
}
765770
}
766771

@@ -779,6 +784,7 @@ impl Expr for ExprImpl {
779784
unreachable!("Window function should not be converted to ExprNode")
780785
}
781786
ExprImpl::UserDefinedFunction(e) => e.to_expr_proto(),
787+
ExprImpl::Parameter(e) => e.to_expr_proto(),
782788
}
783789
}
784790
}
@@ -813,6 +819,7 @@ impl std::fmt::Debug for ExprImpl {
813819
Self::UserDefinedFunction(arg0) => {
814820
f.debug_tuple("UserDefinedFunction").field(arg0).finish()
815821
}
822+
Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
816823
};
817824
}
818825
match self {
@@ -825,6 +832,7 @@ impl std::fmt::Debug for ExprImpl {
825832
Self::TableFunction(x) => write!(f, "{:?}", x),
826833
Self::WindowFunction(x) => write!(f, "{:?}", x),
827834
Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
835+
Self::Parameter(x) => write!(f, "{:?}", x),
828836
}
829837
}
830838
}
@@ -867,6 +875,7 @@ impl std::fmt::Debug for ExprDisplay<'_> {
867875
write!(f, "{:?}", x)
868876
}
869877
ExprImpl::UserDefinedFunction(x) => write!(f, "{:?}", x),
878+
ExprImpl::Parameter(x) => write!(f, "{:?}", x),
870879
}
871880
}
872881
}

0 commit comments

Comments
 (0)