Skip to content

Commit 6b76a35

Browse files
Lordwormsalamb
andauthored
consider volatile function in simply_expression (apache#13128)
* consider volatile function in simply_expression * refactor and fix bugs * fix clippy * refactor * refactor * format * fix clippy * Resolve logical conflict * simplify more --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent b7f4db4 commit 6b76a35

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
862862
right,
863863
}) if has_common_conjunction(&left, &right) => {
864864
let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
865-
let (common, rhs): (Vec<_>, Vec<_>) =
866-
iter_conjunction_owned(*right).partition(|e| lhs.contains(e));
865+
let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right)
866+
.partition(|e| lhs.contains(e) && !e.is_volatile());
867867

868868
let new_rhs = rhs.into_iter().reduce(and);
869869
let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
@@ -1682,8 +1682,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
16821682
}
16831683

16841684
fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
1685-
let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect();
1686-
iter_conjunction(rhs).any(|e| lhs.contains(&e))
1685+
let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
1686+
iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
16871687
}
16881688

16891689
// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
@@ -3978,4 +3978,69 @@ mod tests {
39783978
unimplemented!("not needed for tests")
39793979
}
39803980
}
3981+
#[derive(Debug)]
3982+
struct VolatileUdf {
3983+
signature: Signature,
3984+
}
3985+
3986+
impl VolatileUdf {
3987+
pub fn new() -> Self {
3988+
Self {
3989+
signature: Signature::exact(vec![], Volatility::Volatile),
3990+
}
3991+
}
3992+
}
3993+
impl ScalarUDFImpl for VolatileUdf {
3994+
fn as_any(&self) -> &dyn std::any::Any {
3995+
self
3996+
}
3997+
3998+
fn name(&self) -> &str {
3999+
"VolatileUdf"
4000+
}
4001+
4002+
fn signature(&self) -> &Signature {
4003+
&self.signature
4004+
}
4005+
4006+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
4007+
Ok(DataType::Int16)
4008+
}
4009+
}
4010+
#[test]
4011+
fn test_optimize_volatile_conditions() {
4012+
let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
4013+
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
4014+
{
4015+
let expr = rand
4016+
.clone()
4017+
.eq(lit(0))
4018+
.or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4019+
4020+
assert_eq!(simplify(expr.clone()), expr);
4021+
}
4022+
4023+
{
4024+
let expr = col("column1")
4025+
.eq(lit(2))
4026+
.or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
4027+
4028+
assert_eq!(simplify(expr), col("column1").eq(lit(2)));
4029+
}
4030+
4031+
{
4032+
let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
4033+
"column1",
4034+
)
4035+
.eq(lit(2))
4036+
.and(rand.clone().eq(lit(0))));
4037+
4038+
assert_eq!(
4039+
simplify(expr),
4040+
col("column1")
4041+
.eq(lit(2))
4042+
.and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
4043+
);
4044+
}
4045+
}
39814046
}

datafusion/optimizer/src/simplify_expressions/utils.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [
6767

6868
/// returns true if `needle` is found in a chain of search_op
6969
/// expressions. Such as: (A AND B) AND C
70-
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
70+
fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
7171
match expr {
7272
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
73-
expr_contains(left, needle, search_op)
74-
|| expr_contains(right, needle, search_op)
73+
expr_contains_inner(left, needle, search_op)
74+
|| expr_contains_inner(right, needle, search_op)
7575
}
7676
_ => expr == needle,
7777
}
7878
}
7979

80+
/// check volatile calls and return if expr contains needle
81+
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
82+
expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
83+
}
84+
8085
/// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor
8186
/// expressions. Such as: A ^ (A ^ (B ^ A))
8287
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
@@ -206,7 +211,7 @@ pub fn is_false(expr: &Expr) -> bool {
206211

207212
/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
208213
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
209-
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()))
214+
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile())
210215
}
211216

212217
/// returns true if `not_expr` is !`expr` (not)

0 commit comments

Comments
 (0)