Skip to content

Commit 1a11c3f

Browse files
authored
feat(frontend): support bind paramater (risingwavelabs#8543)
1 parent f92d7f6 commit 1a11c3f

16 files changed

+432
-1
lines changed

src/frontend/src/binder/bind_param.rs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// Copyright 2023 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use bytes::Bytes;
16+
use pgwire::types::Format;
17+
use risingwave_common::error::{Result, RwError};
18+
use risingwave_common::types::ScalarImpl;
19+
20+
use super::statement::RewriteExprsRecursive;
21+
use super::BoundStatement;
22+
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
23+
24+
/// Rewrites parameter expressions to literals.
25+
pub(crate) struct ParamRewriter {
26+
pub(crate) params: Vec<Bytes>,
27+
pub(crate) param_formats: Vec<Format>,
28+
pub(crate) error: Option<RwError>,
29+
}
30+
31+
impl ExprRewriter for ParamRewriter {
32+
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
33+
if self.error.is_some() {
34+
return expr;
35+
}
36+
match expr {
37+
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
38+
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
39+
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
40+
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
41+
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
42+
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
43+
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
44+
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
45+
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
46+
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
47+
}
48+
}
49+
50+
fn rewrite_parameter(&mut self, parameter: crate::expr::Parameter) -> ExprImpl {
51+
let data_type = parameter.return_type();
52+
53+
// original parameter.index is 1-based.
54+
let parameter_index = (parameter.index - 1) as usize;
55+
56+
let format = self.param_formats[parameter_index];
57+
let scalar = {
58+
let res = match format {
59+
Format::Text => {
60+
let value = self.params[parameter_index].clone();
61+
ScalarImpl::from_text(&value, &data_type)
62+
}
63+
Format::Binary => {
64+
let value = self.params[parameter_index].clone();
65+
ScalarImpl::from_binary(&value, &data_type)
66+
}
67+
};
68+
69+
match res {
70+
Ok(datum) => datum,
71+
Err(e) => {
72+
self.error = Some(e);
73+
return parameter.into();
74+
}
75+
}
76+
};
77+
Literal::new(Some(scalar), data_type).into()
78+
}
79+
}
80+
81+
impl BoundStatement {
82+
pub fn bind_parameter(
83+
mut self,
84+
params: Vec<Bytes>,
85+
param_formats: Vec<Format>,
86+
) -> Result<BoundStatement> {
87+
let mut rewriter = ParamRewriter {
88+
params,
89+
param_formats,
90+
error: None,
91+
};
92+
93+
self.rewrite_exprs_recursive(&mut rewriter);
94+
95+
if let Some(err) = rewriter.error {
96+
return Err(err);
97+
}
98+
99+
Ok(self)
100+
}
101+
}
102+
103+
#[cfg(test)]
104+
mod test {
105+
use bytes::Bytes;
106+
use pgwire::types::Format;
107+
use risingwave_common::types::DataType;
108+
use risingwave_sqlparser::test_utils::parse_sql_statements;
109+
110+
use crate::binder::test_utils::{mock_binder, mock_binder_with_param_types};
111+
use crate::binder::BoundStatement;
112+
113+
fn create_expect_bound(sql: &str) -> BoundStatement {
114+
let mut binder = mock_binder();
115+
let stmt = parse_sql_statements(sql).unwrap().remove(0);
116+
binder.bind(stmt).unwrap()
117+
}
118+
119+
fn create_actual_bound(
120+
sql: &str,
121+
param_types: Vec<DataType>,
122+
params: Vec<Bytes>,
123+
param_formats: Vec<Format>,
124+
) -> BoundStatement {
125+
let mut binder = mock_binder_with_param_types(param_types);
126+
let stmt = parse_sql_statements(sql).unwrap().remove(0);
127+
let bound = binder.bind(stmt).unwrap();
128+
bound.bind_parameter(params, param_formats).unwrap()
129+
}
130+
131+
fn expect_actual_eq(expect: BoundStatement, actual: BoundStatement) {
132+
// Use debug format to compare. May modify in future.
133+
assert!(format!("{:?}", expect) == format!("{:?}", actual));
134+
}
135+
136+
#[tokio::test]
137+
async fn basic_select() {
138+
expect_actual_eq(
139+
create_expect_bound("select 1::int4"),
140+
create_actual_bound(
141+
"select $1::int4",
142+
vec![],
143+
vec!["1".into()],
144+
vec![Format::Text],
145+
),
146+
);
147+
}
148+
149+
#[tokio::test]
150+
async fn basic_value() {
151+
expect_actual_eq(
152+
create_expect_bound("values(1::int4)"),
153+
create_actual_bound(
154+
"values($1::int4)",
155+
vec![],
156+
vec!["1".into()],
157+
vec![Format::Text],
158+
),
159+
);
160+
}
161+
162+
#[tokio::test]
163+
async fn default_type() {
164+
expect_actual_eq(
165+
create_expect_bound("select '1'"),
166+
create_actual_bound("select $1", vec![], vec!["1".into()], vec![Format::Text]),
167+
);
168+
}
169+
170+
#[tokio::test]
171+
async fn cast_after_specific() {
172+
expect_actual_eq(
173+
create_expect_bound("select 1::varchar"),
174+
create_actual_bound(
175+
"select $1::varchar",
176+
vec![DataType::Int32],
177+
vec!["1".into()],
178+
vec![Format::Text],
179+
),
180+
);
181+
}
182+
183+
#[tokio::test]
184+
async fn infer_case() {
185+
expect_actual_eq(
186+
create_expect_bound("select 1,1::INT4"),
187+
create_actual_bound(
188+
"select $1,$1::INT4",
189+
vec![],
190+
vec!["1".into()],
191+
vec![Format::Text],
192+
),
193+
);
194+
}
195+
}

src/frontend/src/binder/delete.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use risingwave_common::catalog::{Schema, TableVersionId};
1616
use risingwave_common::error::Result;
1717
use risingwave_sqlparser::ast::{Expr, ObjectName, SelectItem};
1818

19+
use super::statement::RewriteExprsRecursive;
1920
use super::{Binder, BoundBaseTable};
2021
use crate::catalog::TableId;
2122
use crate::expr::ExprImpl;
@@ -48,6 +49,19 @@ pub struct BoundDelete {
4849
pub returning_schema: Option<Schema>,
4950
}
5051

52+
impl RewriteExprsRecursive for BoundDelete {
53+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
54+
self.selection =
55+
std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
56+
57+
let new_returning_list = std::mem::take(&mut self.returning_list)
58+
.into_iter()
59+
.map(|expr| rewriter.rewrite_expr(expr))
60+
.collect::<Vec<_>>();
61+
self.returning_list = new_returning_list;
62+
}
63+
}
64+
5165
impl Binder {
5266
pub(super) fn bind_delete(
5367
&mut self,

src/frontend/src/binder/insert.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use risingwave_common::types::DataType;
2121
use risingwave_common::util::iter_util::ZipEqFast;
2222
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr};
2323

24+
use super::statement::RewriteExprsRecursive;
2425
use super::{BoundQuery, BoundSetExpr};
2526
use crate::binder::Binder;
2627
use crate::catalog::TableId;
@@ -66,6 +67,24 @@ pub struct BoundInsert {
6667
pub returning_schema: Option<Schema>,
6768
}
6869

70+
impl RewriteExprsRecursive for BoundInsert {
71+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
72+
self.source.rewrite_exprs_recursive(rewriter);
73+
74+
let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
75+
.into_iter()
76+
.map(|expr| rewriter.rewrite_expr(expr))
77+
.collect::<Vec<_>>();
78+
self.cast_exprs = new_cast_exprs;
79+
80+
let new_returning_list = std::mem::take(&mut self.returning_list)
81+
.into_iter()
82+
.map(|expr| rewriter.rewrite_expr(expr))
83+
.collect::<Vec<_>>();
84+
self.returning_list = new_returning_list;
85+
}
86+
}
87+
6988
impl Binder {
7089
pub(super) fn bind_insert(
7190
&mut self,

src/frontend/src/binder/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use risingwave_common::util::iter_util::ZipEqDebug;
2323
use risingwave_sqlparser::ast::Statement;
2424

2525
mod bind_context;
26+
mod bind_param;
2627
mod create;
2728
mod delete;
2829
mod expr;
@@ -209,6 +210,10 @@ impl Binder {
209210
Self::new_inner(session, false, vec![])
210211
}
211212

213+
pub fn new_with_param_types(session: &SessionImpl, param_types: Vec<DataType>) -> Binder {
214+
Self::new_inner(session, false, param_types)
215+
}
216+
212217
pub fn new_for_stream(session: &SessionImpl) -> Binder {
213218
Self::new_inner(session, true, vec![])
214219
}
@@ -295,13 +300,20 @@ impl Binder {
295300

296301
#[cfg(test)]
297302
pub mod test_utils {
303+
use risingwave_common::types::DataType;
304+
298305
use super::Binder;
299306
use crate::session::SessionImpl;
300307

301308
#[cfg(test)]
302309
pub fn mock_binder() -> Binder {
303310
Binder::new(&SessionImpl::mock())
304311
}
312+
313+
#[cfg(test)]
314+
pub fn mock_binder_with_param_types(param_types: Vec<DataType>) -> Binder {
315+
Binder::new_with_param_types(&SessionImpl::mock(), param_types)
316+
}
305317
}
306318

307319
/// The column name stored in [`BindContext`] for a column without an alias.

src/frontend/src/binder/query.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ use risingwave_common::types::DataType;
2121
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
2222
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};
2323

24+
use super::statement::RewriteExprsRecursive;
2425
use crate::binder::{Binder, BoundSetExpr};
25-
use crate::expr::{CorrelatedId, Depth, ExprImpl};
26+
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};
2627

2728
/// A validated sql query, including order and union.
2829
/// An example of its relationship with `BoundSetExpr` and `BoundSelect` can be found here: <https://bit.ly/3GQwgPz>
@@ -96,6 +97,18 @@ impl BoundQuery {
9697
}
9798
}
9899

100+
impl RewriteExprsRecursive for BoundQuery {
101+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
102+
let new_extra_order_exprs = std::mem::take(&mut self.extra_order_exprs)
103+
.into_iter()
104+
.map(|expr| rewriter.rewrite_expr(expr))
105+
.collect::<Vec<_>>();
106+
self.extra_order_exprs = new_extra_order_exprs;
107+
108+
self.body.rewrite_exprs_recursive(rewriter);
109+
}
110+
}
111+
99112
impl Binder {
100113
/// Bind a [`Query`].
101114
///

src/frontend/src/binder/relation/join.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use risingwave_sqlparser::ast::{
1919
};
2020

2121
use crate::binder::bind_context::BindContext;
22+
use crate::binder::statement::RewriteExprsRecursive;
2223
use crate::binder::{Binder, Relation, COLUMN_GROUP_PREFIX};
2324
use crate::expr::ExprImpl;
2425

@@ -30,6 +31,15 @@ pub struct BoundJoin {
3031
pub cond: ExprImpl,
3132
}
3233

34+
impl RewriteExprsRecursive for BoundJoin {
35+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
36+
self.left.rewrite_exprs_recursive(rewriter);
37+
self.right.rewrite_exprs_recursive(rewriter);
38+
let dummy = ExprImpl::literal_bool(true);
39+
self.cond = rewriter.rewrite_expr(std::mem::replace(&mut self.cond, dummy));
40+
}
41+
}
42+
3343
impl Binder {
3444
pub(crate) fn bind_vec_table_with_joins(
3545
&mut self,

src/frontend/src/binder/relation/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use risingwave_sqlparser::ast::{
2727

2828
use self::watermark::is_watermark_func;
2929
use super::bind_context::ColumnBinding;
30+
use super::statement::RewriteExprsRecursive;
3031
use crate::binder::{Binder, BoundSetExpr};
3132
use crate::catalog::system_catalog::pg_catalog::{
3233
PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME,
@@ -64,6 +65,26 @@ pub enum Relation {
6465
Share(Box<BoundShare>),
6566
}
6667

68+
impl RewriteExprsRecursive for Relation {
69+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
70+
match self {
71+
Relation::Subquery(inner) => inner.rewrite_exprs_recursive(rewriter),
72+
Relation::Join(inner) => inner.rewrite_exprs_recursive(rewriter),
73+
Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter),
74+
Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter),
75+
Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter),
76+
Relation::TableFunction(inner) => {
77+
let new_args = std::mem::take(&mut inner.args)
78+
.into_iter()
79+
.map(|expr| rewriter.rewrite_expr(expr))
80+
.collect();
81+
inner.args = new_args;
82+
}
83+
_ => {}
84+
}
85+
}
86+
}
87+
6788
impl Relation {
6889
pub fn contains_sys_table(&self) -> bool {
6990
match self {

0 commit comments

Comments
 (0)