Skip to content

feat(frontend): support infer param in binder #8453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use risingwave_sqlparser::ast::{
};

use crate::binder::Binder;
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, SubqueryKind};
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Parameter, SubqueryKind};

mod binary_op;
mod column;
Expand Down Expand Up @@ -123,6 +123,7 @@ impl Binder {
start,
count,
} => self.bind_overlay(*expr, *new_substring, *start, count),
Expr::Parameter { index } => self.bind_parameter(index),
_ => Err(ErrorCode::NotImplemented(
format!("unsupported expression {:?}", expr),
112.into(),
Expand Down Expand Up @@ -297,6 +298,10 @@ impl Binder {
FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
}

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
Ok(Parameter::new(index, self.param_types.clone()).into())
}

/// Bind `expr (not) between low and high`
pub(super) fn bind_between(
&mut self,
Expand Down
100 changes: 96 additions & 4 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};

use itertools::Itertools;
use risingwave_common::error::Result;
use risingwave_common::session_config::SearchPath;
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;

mod bind_context;
Expand Down Expand Up @@ -90,10 +93,94 @@ pub struct Binder {

/// `ShareId`s identifying shared views.
shared_views: HashMap<ViewId, ShareId>,

param_types: ParameterTypes,
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
/// following the rules:
/// 1. At the beginning, it contains the user specified parameters type.
/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
/// if it didn't exist in `ParameterTypes`.
/// 3. When the binder encounters a cast on parameter, if it's a unknown type, the cast function
/// will record the target type as infer type for that parameter(call `record_infer_type`). If the
/// parameter has been inferred, the cast function will act as a normal cast.
/// 4. After bind finished:
/// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't
/// occur in the query. `export` will return error if there is a kind of
/// parameter. This rule is compatible with PostgreSQL
/// (b) parameter is None means that it's a unknown type. The user didn't specify it
/// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is
/// compatible with PostgreSQL.
/// (c) parameter is Some means that it's a known type.
#[derive(Clone, Debug)]
pub struct ParameterTypes(Arc<RwLock<HashMap<u64, Option<DataType>>>>);
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned before, let's add some doc comments on:

  • Why it exists (to collect inferred types when first new_cast, which does not have extra input / output)
  • What does it mean when (a) not in the HashMap, (b) contains a None, (c) contains Some(T)


impl ParameterTypes {
pub fn new(specified_param_types: Vec<DataType>) -> Self {
let map = specified_param_types
.into_iter()
.enumerate()
.map(|(index, data_type)| ((index + 1) as u64, Some(data_type)))
.collect::<HashMap<u64, Option<DataType>>>();
Self(Arc::new(RwLock::new(map)))
}

pub fn has_infer(&self, index: u64) -> bool {
self.0.read().unwrap().get(&index).unwrap().is_some()
}

pub fn read_type(&self, index: u64) -> Option<DataType> {
self.0.read().unwrap().get(&index).unwrap().clone()
}

pub fn record_new_param(&mut self, index: u64) {
self.0.write().unwrap().entry(index).or_insert(None);
}

pub fn record_infer_type(&mut self, index: u64, data_type: DataType) {
assert!(
!self.has_infer(index),
"The parameter has been inferred, should not be inferred again."
);
self.0
.write()
.unwrap()
.get_mut(&index)
.unwrap()
.replace(data_type);
}

pub fn export(&self) -> Result<Vec<DataType>> {
let types = self
.0
.read()
.unwrap()
.clone()
.into_iter()
.sorted_by_key(|(index, _)| *index)
.collect::<Vec<_>>();

// Check if all the parameters have been inferred.
for ((index, _), expect_index) in types.iter().zip_eq_debug(1_u64..=types.len() as u64) {
if *index != expect_index {
return Err(ErrorCode::InvalidInputSyntax(format!(
"Cannot infer the type of the parameter {}.",
expect_index
))
.into());
}
}

Ok(types
.into_iter()
.map(|(_, data_type)| data_type.unwrap_or(DataType::Varchar))
.collect::<Vec<_>>())
}
}

impl Binder {
fn new_inner(session: &SessionImpl, in_create_mv: bool) -> Binder {
fn new_inner(session: &SessionImpl, in_create_mv: bool, param_types: Vec<DataType>) -> Binder {
let now_ms = session
.env()
.hummock_snapshot_manager()
Expand All @@ -114,22 +201,27 @@ impl Binder {
search_path: session.config().get_search_path(),
in_create_mv,
shared_views: HashMap::new(),
param_types: ParameterTypes::new(param_types),
}
}

pub fn new(session: &SessionImpl) -> Binder {
Self::new_inner(session, false)
Self::new_inner(session, false, vec![])
}

pub fn new_for_stream(session: &SessionImpl) -> Binder {
Self::new_inner(session, true)
Self::new_inner(session, true, vec![])
}

/// Bind a [`Statement`].
pub fn bind(&mut self, stmt: Statement) -> Result<BoundStatement> {
self.bind_statement(stmt)
}

pub fn export_param_types(&self) -> Result<Vec<DataType>> {
self.param_types.export()
}

fn push_context(&mut self) {
let new_context = std::mem::take(&mut self.context);
self.context.cte_to_relation = new_context.cte_to_relation.clone();
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/expr/expr_mutator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand All @@ -30,6 +30,7 @@ pub trait ExprMutator {
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
}
}
fn visit_function_call(&mut self, func_call: &mut FunctionCall) {
Expand All @@ -47,6 +48,7 @@ pub trait ExprMutator {
agg_call.filter_mut().visit_expr_mut(self);
}
fn visit_literal(&mut self, _: &mut Literal) {}
fn visit_parameter(&mut self, _: &mut Parameter) {}
fn visit_input_ref(&mut self, _: &mut InputRef) {}
fn visit_subquery(&mut self, _: &mut Subquery) {}
fn visit_correlated_input_ref(&mut self, _: &mut CorrelatedInputRef) {}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand All @@ -32,6 +32,7 @@ pub trait ExprRewriter {
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
}
}
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
Expand All @@ -54,6 +55,9 @@ pub trait ExprRewriter {
.unwrap()
.into()
}
fn rewrite_parameter(&mut self, parameter: Parameter) -> ExprImpl {
parameter.into()
}
fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
literal.into()
}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Subquery,
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, InputRef, Literal, Parameter, Subquery,
TableFunction, UserDefinedFunction, WindowFunction,
};

Expand Down Expand Up @@ -42,6 +42,7 @@ pub trait ExprVisitor<R: Default> {
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
}
}
fn visit_function_call(&mut self, func_call: &FunctionCall) -> R {
Expand All @@ -63,6 +64,9 @@ pub trait ExprVisitor<R: Default> {
r = Self::merge(r, agg_call.filter().visit_expr(self));
r
}
fn visit_parameter(&mut self, _: &Parameter) -> R {
R::default()
}
fn visit_literal(&mut self, _: &Literal) -> R {
R::default()
}
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,14 @@ impl FunctionCall {

/// Create a cast expr over `child` to `target` type in `allows` context.
pub fn new_cast(
child: ExprImpl,
mut child: ExprImpl,
target: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
if let ExprImpl::Parameter(expr) = &mut child && !expr.has_infer() {
expr.cast_infer_type(target);
return Ok(child);
}
if is_row_function(&child) {
// Row function will have empty fields in Datatype::Struct at this point. Therefore,
// we will need to take some special care to generate the cast types. For normal struct
Expand Down
11 changes: 10 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod correlated_input_ref;
mod function_call;
mod input_ref;
mod literal;
mod parameter;
mod subquery;
mod table_function;
mod user_defined_function;
Expand All @@ -50,6 +51,7 @@ pub use expr_visitor::ExprVisitor;
pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
pub use literal::Literal;
pub use parameter::Parameter;
pub use risingwave_pb::expr::expr_node::Type as ExprType;
pub use session_timezone::SessionTimezone;
pub use subquery::{Subquery, SubqueryKind};
Expand Down Expand Up @@ -96,7 +98,8 @@ impl_expr_impl!(
Subquery,
TableFunction,
WindowFunction,
UserDefinedFunction
UserDefinedFunction,
Parameter
);

impl ExprImpl {
Expand Down Expand Up @@ -174,6 +177,7 @@ impl ExprImpl {
/// Check whether self is a literal NULL or literal string.
pub fn is_unknown(&self) -> bool {
matches!(self, ExprImpl::Literal(literal) if literal.return_type() == DataType::Varchar)
|| matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
}

/// Shorthand to create cast expr to `target` type in implicit context.
Expand Down Expand Up @@ -761,6 +765,7 @@ impl Expr for ExprImpl {
ExprImpl::TableFunction(expr) => expr.return_type(),
ExprImpl::WindowFunction(expr) => expr.return_type(),
ExprImpl::UserDefinedFunction(expr) => expr.return_type(),
ExprImpl::Parameter(expr) => expr.return_type(),
}
}

Expand All @@ -779,6 +784,7 @@ impl Expr for ExprImpl {
unreachable!("Window function should not be converted to ExprNode")
}
ExprImpl::UserDefinedFunction(e) => e.to_expr_proto(),
ExprImpl::Parameter(e) => e.to_expr_proto(),
}
}
}
Expand Down Expand Up @@ -813,6 +819,7 @@ impl std::fmt::Debug for ExprImpl {
Self::UserDefinedFunction(arg0) => {
f.debug_tuple("UserDefinedFunction").field(arg0).finish()
}
Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
};
}
match self {
Expand All @@ -825,6 +832,7 @@ impl std::fmt::Debug for ExprImpl {
Self::TableFunction(x) => write!(f, "{:?}", x),
Self::WindowFunction(x) => write!(f, "{:?}", x),
Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
Self::Parameter(x) => write!(f, "{:?}", x),
}
}
}
Expand Down Expand Up @@ -867,6 +875,7 @@ impl std::fmt::Debug for ExprDisplay<'_> {
write!(f, "{:?}", x)
}
ExprImpl::UserDefinedFunction(x) => write!(f, "{:?}", x),
ExprImpl::Parameter(x) => write!(f, "{:?}", x),
}
}
}
Expand Down
Loading