@@ -862,8 +862,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
862
862
right,
863
863
} ) if has_common_conjunction ( & left, & right) => {
864
864
let lhs: IndexSet < Expr > = iter_conjunction_owned ( * left) . collect ( ) ;
865
- let ( common, rhs) : ( Vec < _ > , Vec < _ > ) =
866
- iter_conjunction_owned ( * right ) . partition ( |e| lhs. contains ( e) ) ;
865
+ let ( common, rhs) : ( Vec < _ > , Vec < _ > ) = iter_conjunction_owned ( * right )
866
+ . partition ( |e| lhs. contains ( e) && !e . is_volatile ( ) ) ;
867
867
868
868
let new_rhs = rhs. into_iter ( ) . reduce ( and) ;
869
869
let new_lhs = lhs. into_iter ( ) . filter ( |e| !common. contains ( e) ) . reduce ( and) ;
@@ -1682,8 +1682,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
1682
1682
}
1683
1683
1684
1684
fn has_common_conjunction ( lhs : & Expr , rhs : & Expr ) -> bool {
1685
- let lhs : HashSet < & Expr > = iter_conjunction ( lhs) . collect ( ) ;
1686
- iter_conjunction ( rhs) . any ( |e| lhs . contains ( & e) )
1685
+ let lhs_set : HashSet < & Expr > = iter_conjunction ( lhs) . collect ( ) ;
1686
+ iter_conjunction ( rhs) . any ( |e| lhs_set . contains ( & e) && !e . is_volatile ( ) )
1687
1687
}
1688
1688
1689
1689
// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
@@ -3978,4 +3978,69 @@ mod tests {
3978
3978
unimplemented ! ( "not needed for tests" )
3979
3979
}
3980
3980
}
3981
+ #[ derive( Debug ) ]
3982
+ struct VolatileUdf {
3983
+ signature : Signature ,
3984
+ }
3985
+
3986
+ impl VolatileUdf {
3987
+ pub fn new ( ) -> Self {
3988
+ Self {
3989
+ signature : Signature :: exact ( vec ! [ ] , Volatility :: Volatile ) ,
3990
+ }
3991
+ }
3992
+ }
3993
+ impl ScalarUDFImpl for VolatileUdf {
3994
+ fn as_any ( & self ) -> & dyn std:: any:: Any {
3995
+ self
3996
+ }
3997
+
3998
+ fn name ( & self ) -> & str {
3999
+ "VolatileUdf"
4000
+ }
4001
+
4002
+ fn signature ( & self ) -> & Signature {
4003
+ & self . signature
4004
+ }
4005
+
4006
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
4007
+ Ok ( DataType :: Int16 )
4008
+ }
4009
+ }
4010
+ #[ test]
4011
+ fn test_optimize_volatile_conditions ( ) {
4012
+ let fun = Arc :: new ( ScalarUDF :: new_from_impl ( VolatileUdf :: new ( ) ) ) ;
4013
+ let rand = Expr :: ScalarFunction ( ScalarFunction :: new_udf ( fun, vec ! [ ] ) ) ;
4014
+ {
4015
+ let expr = rand
4016
+ . clone ( )
4017
+ . eq ( lit ( 0 ) )
4018
+ . or ( col ( "column1" ) . eq ( lit ( 2 ) ) . and ( rand. clone ( ) . eq ( lit ( 0 ) ) ) ) ;
4019
+
4020
+ assert_eq ! ( simplify( expr. clone( ) ) , expr) ;
4021
+ }
4022
+
4023
+ {
4024
+ let expr = col ( "column1" )
4025
+ . eq ( lit ( 2 ) )
4026
+ . or ( col ( "column1" ) . eq ( lit ( 2 ) ) . and ( rand. clone ( ) . eq ( lit ( 0 ) ) ) ) ;
4027
+
4028
+ assert_eq ! ( simplify( expr) , col( "column1" ) . eq( lit( 2 ) ) ) ;
4029
+ }
4030
+
4031
+ {
4032
+ let expr = ( col ( "column1" ) . eq ( lit ( 2 ) ) . and ( rand. clone ( ) . eq ( lit ( 0 ) ) ) ) . or ( col (
4033
+ "column1" ,
4034
+ )
4035
+ . eq ( lit ( 2 ) )
4036
+ . and ( rand. clone ( ) . eq ( lit ( 0 ) ) ) ) ;
4037
+
4038
+ assert_eq ! (
4039
+ simplify( expr) ,
4040
+ col( "column1" )
4041
+ . eq( lit( 2 ) )
4042
+ . and( ( rand. clone( ) . eq( lit( 0 ) ) ) . or( rand. clone( ) . eq( lit( 0 ) ) ) )
4043
+ ) ;
4044
+ }
4045
+ }
3981
4046
}
0 commit comments