Skip to content

Commit b6e2402

Browse files
committed
* add Parameter expr in binder
* infer parameter type in bind * add interface to bind param for BoundStatement * add related test
1 parent 953e4b2 commit b6e2402

24 files changed

+535
-11
lines changed

src/frontend/src/binder/bind_param.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright 2023 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use bytes::Bytes;
16+
use pgwire::types::Format;
17+
use risingwave_common::error::{Result, RwError};
18+
use risingwave_common::types::ScalarImpl;
19+
20+
use super::statement::RewriteExprsRecursive;
21+
use super::BoundStatement;
22+
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
23+
24+
/// Rewrites parameter expressions to literals.
25+
pub(crate) struct ParamRewriter {
26+
pub(crate) params: Vec<Bytes>,
27+
pub(crate) param_formats: Vec<Format>,
28+
pub(crate) error: Option<RwError>,
29+
}
30+
31+
impl ExprRewriter for ParamRewriter {
32+
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
33+
if self.error.is_some() {
34+
return expr;
35+
}
36+
match expr {
37+
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
38+
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
39+
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
40+
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
41+
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
42+
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
43+
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
44+
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
45+
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
46+
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
47+
}
48+
}
49+
50+
fn rewrite_parameter(&mut self, parameter: crate::expr::Parameter) -> ExprImpl {
51+
let data_type = parameter.return_type();
52+
53+
// original parameter.index is 1-based.
54+
let parameter_index = (parameter.index - 1) as usize;
55+
56+
let format = self.param_formats[parameter_index];
57+
let scalar = {
58+
let res = match format {
59+
Format::Text => {
60+
let value = self.params[parameter_index].clone();
61+
ScalarImpl::from_text(&value, &data_type)
62+
}
63+
Format::Binary => {
64+
let value = self.params[parameter_index].clone();
65+
ScalarImpl::from_binary(&value, &data_type)
66+
}
67+
};
68+
69+
match res {
70+
Ok(datum) => datum,
71+
Err(e) => {
72+
self.error = Some(e);
73+
return parameter.into();
74+
}
75+
}
76+
};
77+
Literal::new(Some(scalar), data_type).into()
78+
}
79+
}
80+
81+
impl BoundStatement {
82+
pub fn bind_parameter(
83+
mut self,
84+
params: Vec<Bytes>,
85+
param_formats: Vec<Format>,
86+
) -> Result<BoundStatement> {
87+
let mut rewriter = ParamRewriter {
88+
params,
89+
param_formats,
90+
error: None,
91+
};
92+
93+
self.rewrite_exprs_recursive(&mut rewriter);
94+
95+
if let Some(err) = rewriter.error {
96+
return Err(err);
97+
}
98+
99+
Ok(self)
100+
}
101+
}
102+
103+
#[cfg(test)]
104+
mod test {
105+
use bytes::Bytes;
106+
use pgwire::types::Format;
107+
use risingwave_sqlparser::test_utils::parse_sql_statements;
108+
109+
use crate::binder::test_utils::mock_binder;
110+
use crate::binder::BoundStatement;
111+
112+
fn create_expect_bound(sql: &str) -> BoundStatement {
113+
let mut binder = mock_binder();
114+
let stmt = parse_sql_statements(sql).unwrap().remove(0);
115+
binder.bind(stmt).unwrap()
116+
}
117+
118+
fn create_actual_bound(
119+
sql: &str,
120+
params: Vec<Bytes>,
121+
param_formats: Vec<Format>,
122+
) -> BoundStatement {
123+
let mut binder = mock_binder();
124+
let stmt = parse_sql_statements(sql).unwrap().remove(0);
125+
let bound = binder.bind(stmt).unwrap();
126+
bound.bind_parameter(params, param_formats).unwrap()
127+
}
128+
129+
fn expect_actual_eq(expect: BoundStatement, actual: BoundStatement) {
130+
println!("expect: {:?}", expect);
131+
println!("actual: {:?}", actual);
132+
// Use debug format to compare. May modify in future.
133+
assert!(format!("{:?}", expect) == format!("{:?}", actual));
134+
}
135+
136+
#[tokio::test]
137+
async fn bind_basic_select() {
138+
expect_actual_eq(
139+
create_expect_bound("select 1::int4"),
140+
create_actual_bound("select $1::int4", vec!["1".into()], vec![Format::Text]),
141+
);
142+
}
143+
144+
#[tokio::test]
145+
async fn bind_basic_value() {
146+
expect_actual_eq(
147+
create_expect_bound("values(1::int4)"),
148+
create_actual_bound("values($1::int4)", vec!["1".into()], vec![Format::Text]),
149+
);
150+
}
151+
}

src/frontend/src/binder/delete.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use risingwave_common::catalog::{Schema, TableVersionId};
1616
use risingwave_common::error::Result;
1717
use risingwave_sqlparser::ast::{Expr, ObjectName, SelectItem};
1818

19+
use super::statement::RewriteExprsRecursive;
1920
use super::{Binder, BoundBaseTable};
2021
use crate::catalog::TableId;
2122
use crate::expr::ExprImpl;
@@ -48,6 +49,19 @@ pub struct BoundDelete {
4849
pub returning_schema: Option<Schema>,
4950
}
5051

52+
impl RewriteExprsRecursive for BoundDelete {
53+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
54+
self.selection =
55+
std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));
56+
57+
let new_returning_list = std::mem::take(&mut self.returning_list)
58+
.into_iter()
59+
.map(|expr| rewriter.rewrite_expr(expr))
60+
.collect::<Vec<_>>();
61+
self.returning_list = new_returning_list;
62+
}
63+
}
64+
5165
impl Binder {
5266
pub(super) fn bind_delete(
5367
&mut self,

src/frontend/src/binder/expr/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use risingwave_sqlparser::ast::{
2323
};
2424

2525
use crate::binder::Binder;
26-
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, SubqueryKind};
26+
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Parameter, SubqueryKind};
2727

2828
mod binary_op;
2929
mod column;
@@ -123,6 +123,7 @@ impl Binder {
123123
start,
124124
count,
125125
} => self.bind_overlay(*expr, *new_substring, *start, count),
126+
Expr::Parameter { index } => self.bind_parameter(index),
126127
_ => Err(ErrorCode::NotImplemented(
127128
format!("unsupported expression {:?}", expr),
128129
112.into(),
@@ -297,6 +298,10 @@ impl Binder {
297298
FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
298299
}
299300

301+
fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
302+
Ok(Parameter::new(index, self.param_types.clone()).into())
303+
}
304+
300305
/// Bind `expr (not) between low and high`
301306
pub(super) fn bind_between(
302307
&mut self,

src/frontend/src/binder/insert.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use risingwave_common::types::DataType;
2121
use risingwave_common::util::iter_util::ZipEqFast;
2222
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr};
2323

24+
use super::statement::RewriteExprsRecursive;
2425
use super::{BoundQuery, BoundSetExpr};
2526
use crate::binder::Binder;
2627
use crate::catalog::TableId;
@@ -66,6 +67,24 @@ pub struct BoundInsert {
6667
pub returning_schema: Option<Schema>,
6768
}
6869

70+
impl RewriteExprsRecursive for BoundInsert {
71+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
72+
self.source.rewrite_exprs_recursive(rewriter);
73+
74+
let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
75+
.into_iter()
76+
.map(|expr| rewriter.rewrite_expr(expr))
77+
.collect::<Vec<_>>();
78+
self.cast_exprs = new_cast_exprs;
79+
80+
let new_returning_list = std::mem::take(&mut self.returning_list)
81+
.into_iter()
82+
.map(|expr| rewriter.rewrite_expr(expr))
83+
.collect::<Vec<_>>();
84+
self.returning_list = new_returning_list;
85+
}
86+
}
87+
6988
impl Binder {
7089
pub(super) fn bind_insert(
7190
&mut self,

src/frontend/src/binder/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16-
use std::sync::Arc;
16+
use std::sync::{Arc, RwLock};
1717

1818
use risingwave_common::error::Result;
1919
use risingwave_common::session_config::SearchPath;
20+
use risingwave_common::types::DataType;
2021
use risingwave_sqlparser::ast::Statement;
2122

2223
mod bind_context;
24+
mod bind_param;
2325
mod create;
2426
mod delete;
2527
mod expr;
@@ -90,6 +92,8 @@ pub struct Binder {
9092

9193
/// `ShareId`s identifying shared views.
9294
shared_views: HashMap<ViewId, ShareId>,
95+
96+
param_types: Arc<RwLock<HashMap<u64, Option<DataType>>>>,
9397
}
9498

9599
impl Binder {
@@ -114,6 +118,7 @@ impl Binder {
114118
search_path: session.config().get_search_path(),
115119
in_create_mv,
116120
shared_views: HashMap::new(),
121+
param_types: Arc::new(RwLock::new(HashMap::new())),
117122
}
118123
}
119124

@@ -130,6 +135,10 @@ impl Binder {
130135
self.bind_statement(stmt)
131136
}
132137

138+
pub fn export_param_types(&self) -> HashMap<u64, Option<DataType>> {
139+
self.param_types.read().unwrap().clone()
140+
}
141+
133142
fn push_context(&mut self) {
134143
let new_context = std::mem::take(&mut self.context);
135144
self.context.cte_to_relation = new_context.cte_to_relation.clone();

src/frontend/src/binder/query.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ use risingwave_common::error::{ErrorCode, Result};
2020
use risingwave_common::types::DataType;
2121
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};
2222

23+
use super::statement::RewriteExprsRecursive;
2324
use crate::binder::{Binder, BoundSetExpr};
24-
use crate::expr::{CorrelatedId, Depth, ExprImpl};
25+
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};
2526
use crate::optimizer::property::{Direction, FieldOrder};
2627

2728
/// A validated sql query, including order and union.
@@ -96,6 +97,18 @@ impl BoundQuery {
9697
}
9798
}
9899

100+
impl RewriteExprsRecursive for BoundQuery {
101+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
102+
let new_extra_order_exprs = std::mem::take(&mut self.extra_order_exprs)
103+
.into_iter()
104+
.map(|expr| rewriter.rewrite_expr(expr))
105+
.collect::<Vec<_>>();
106+
self.extra_order_exprs = new_extra_order_exprs;
107+
108+
self.body.rewrite_exprs_recursive(rewriter);
109+
}
110+
}
111+
99112
impl Binder {
100113
/// Bind a [`Query`].
101114
///

src/frontend/src/binder/relation/join.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use risingwave_sqlparser::ast::{
1919
};
2020

2121
use crate::binder::bind_context::BindContext;
22+
use crate::binder::statement::RewriteExprsRecursive;
2223
use crate::binder::{Binder, Relation, COLUMN_GROUP_PREFIX};
2324
use crate::expr::ExprImpl;
2425

@@ -30,6 +31,14 @@ pub struct BoundJoin {
3031
pub cond: ExprImpl,
3132
}
3233

34+
impl RewriteExprsRecursive for BoundJoin {
35+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
36+
self.left.rewrite_exprs_recursive(rewriter);
37+
self.right.rewrite_exprs_recursive(rewriter);
38+
self.cond = rewriter.rewrite_expr(std::mem::take(&mut self.cond));
39+
}
40+
}
41+
3342
impl Binder {
3443
pub(crate) fn bind_vec_table_with_joins(
3544
&mut self,

src/frontend/src/binder/relation/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use risingwave_sqlparser::ast::{
2727

2828
use self::watermark::is_watermark_func;
2929
use super::bind_context::ColumnBinding;
30+
use super::statement::RewriteExprsRecursive;
3031
use crate::binder::{Binder, BoundSetExpr};
3132
use crate::catalog::system_catalog::pg_catalog::{
3233
PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME,
@@ -64,6 +65,26 @@ pub enum Relation {
6465
Share(Box<BoundShare>),
6566
}
6667

68+
impl RewriteExprsRecursive for Relation {
69+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
70+
match self {
71+
Relation::Subquery(inner) => inner.rewrite_exprs_recursive(rewriter),
72+
Relation::Join(inner) => inner.rewrite_exprs_recursive(rewriter),
73+
Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter),
74+
Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter),
75+
Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter),
76+
Relation::TableFunction(inner) => {
77+
let new_args = std::mem::take(&mut inner.args)
78+
.into_iter()
79+
.map(|expr| rewriter.rewrite_expr(expr))
80+
.collect();
81+
inner.args = new_args;
82+
}
83+
_ => {}
84+
}
85+
}
86+
}
87+
6788
impl Relation {
6889
pub fn contains_sys_table(&self) -> bool {
6990
match self {

src/frontend/src/binder/relation/share.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use crate::binder::statement::RewriteExprsRecursive;
1516
use crate::binder::{Relation, ShareId};
1617

1718
/// Share a relation during binding and planning.
@@ -21,3 +22,9 @@ pub struct BoundShare {
2122
pub(crate) share_id: ShareId,
2223
pub(crate) input: Relation,
2324
}
25+
26+
impl RewriteExprsRecursive for BoundShare {
27+
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
28+
self.input.rewrite_exprs_recursive(rewriter);
29+
}
30+
}

0 commit comments

Comments
 (0)