@@ -3,7 +3,7 @@ use itertools::Itertools;
3
3
use sqlparser:: ast:: { Expr , OrderByExpr } ;
4
4
use std:: collections:: HashSet ;
5
5
6
- use crate :: binder:: { BindError , InputRefType } ;
6
+ use crate :: binder:: BindError ;
7
7
use crate :: planner:: LogicalPlan ;
8
8
use crate :: storage:: Transaction ;
9
9
use crate :: {
@@ -28,7 +28,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
28
28
select_items : & mut [ ScalarExpression ] ,
29
29
) -> Result < ( ) , BindError > {
30
30
for column in select_items {
31
- self . visit_column_agg_expr ( column, true ) ?;
31
+ self . visit_column_agg_expr ( column) ?;
32
32
}
33
33
Ok ( ( ) )
34
34
}
@@ -55,7 +55,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
55
55
// Extract having expression.
56
56
let return_having = if let Some ( having) = having {
57
57
let mut having = self . bind_expr ( having) ?;
58
- self . visit_column_agg_expr ( & mut having, false ) ?;
58
+ self . visit_column_agg_expr ( & mut having) ?;
59
59
60
60
Some ( having)
61
61
} else {
@@ -72,7 +72,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
72
72
nulls_first,
73
73
} = orderby;
74
74
let mut expr = self . bind_expr ( expr) ?;
75
- self . visit_column_agg_expr ( & mut expr, false ) ?;
75
+ self . visit_column_agg_expr ( & mut expr) ?;
76
76
77
77
return_orderby. push ( SortField :: new (
78
78
expr,
@@ -87,77 +87,30 @@ impl<'a, T: Transaction> Binder<'a, T> {
87
87
Ok ( ( return_having, return_orderby) )
88
88
}
89
89
90
- fn visit_column_agg_expr (
91
- & mut self ,
92
- expr : & mut ScalarExpression ,
93
- is_select : bool ,
94
- ) -> Result < ( ) , BindError > {
95
- let ref_columns = expr. referenced_columns ( ) ;
96
-
90
+ fn visit_column_agg_expr ( & mut self , expr : & mut ScalarExpression ) -> Result < ( ) , BindError > {
97
91
match expr {
98
- ScalarExpression :: AggCall {
99
- ty : return_type, ..
100
- } => {
101
- let ty = return_type. clone ( ) ;
102
- if is_select {
103
- let index = self . context . input_ref_index ( InputRefType :: AggCall ) ;
104
- let input_ref = ScalarExpression :: InputRef {
105
- index,
106
- ty,
107
- ref_columns,
108
- } ;
109
- match std:: mem:: replace ( expr, input_ref) {
110
- ScalarExpression :: AggCall {
111
- kind,
112
- args,
113
- ty,
114
- distinct,
115
- } => {
116
- self . context . agg_calls . push ( ScalarExpression :: AggCall {
117
- distinct,
118
- kind,
119
- args,
120
- ty,
121
- } ) ;
122
- }
123
- _ => unreachable ! ( ) ,
124
- }
125
- } else {
126
- let ( index, _) = self
127
- . context
128
- . agg_calls
129
- . iter ( )
130
- . find_position ( |agg_expr| agg_expr == & expr)
131
- . ok_or_else ( || BindError :: AggMiss ( format ! ( "{:?}" , expr) ) ) ?;
132
-
133
- let _ = std:: mem:: replace (
134
- expr,
135
- ScalarExpression :: InputRef {
136
- index,
137
- ty,
138
- ref_columns,
139
- } ,
140
- ) ;
141
- }
142
- }
143
-
144
- ScalarExpression :: TypeCast { expr, .. } => {
145
- self . visit_column_agg_expr ( expr, is_select) ?
92
+ ScalarExpression :: AggCall { .. } => {
93
+ self . context . agg_calls . push ( expr. clone ( ) ) ;
146
94
}
147
- ScalarExpression :: IsNull { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
148
- ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
149
- ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr, is_select) ?,
95
+ ScalarExpression :: TypeCast { expr, .. } => self . visit_column_agg_expr ( expr) ?,
96
+ ScalarExpression :: IsNull { expr, .. } => self . visit_column_agg_expr ( expr) ?,
97
+ ScalarExpression :: Unary { expr, .. } => self . visit_column_agg_expr ( expr) ?,
98
+ ScalarExpression :: Alias { expr, .. } => self . visit_column_agg_expr ( expr) ?,
150
99
ScalarExpression :: Binary {
151
100
left_expr,
152
101
right_expr,
153
102
..
154
103
} => {
155
- self . visit_column_agg_expr ( left_expr, is_select ) ?;
156
- self . visit_column_agg_expr ( right_expr, is_select ) ?;
104
+ self . visit_column_agg_expr ( left_expr) ?;
105
+ self . visit_column_agg_expr ( right_expr) ?;
157
106
}
158
- ScalarExpression :: Constant ( _)
159
- | ScalarExpression :: ColumnRef { .. }
160
- | ScalarExpression :: InputRef { .. } => { }
107
+ ScalarExpression :: In { expr, args, .. } => {
108
+ self . visit_column_agg_expr ( expr) ?;
109
+ for arg in args {
110
+ self . visit_column_agg_expr ( arg) ?;
111
+ }
112
+ }
113
+ ScalarExpression :: Constant ( _) | ScalarExpression :: ColumnRef { .. } => { }
161
114
}
162
115
163
116
Ok ( ( ) )
@@ -239,44 +192,13 @@ impl<'a, T: Transaction> Binder<'a, T> {
239
192
false
240
193
}
241
194
} ) {
242
- let index = self . context . input_ref_index ( InputRefType :: GroupBy ) ;
243
- let mut select_item = & mut select_list[ i] ;
244
- let ref_columns = select_item. referenced_columns ( ) ;
245
- let return_type = select_item. return_type ( ) ;
246
-
247
- self . context . group_by_exprs . push ( std:: mem:: replace (
248
- & mut select_item,
249
- ScalarExpression :: InputRef {
250
- index,
251
- ty : return_type,
252
- ref_columns,
253
- } ,
254
- ) ) ;
195
+ self . context . group_by_exprs . push ( select_list[ i] . clone ( ) ) ;
255
196
return ;
256
197
}
257
198
}
258
199
259
200
if let Some ( i) = select_list. iter ( ) . position ( |column| column == expr) {
260
- let expr = & mut select_list[ i] ;
261
- let ref_columns = expr. referenced_columns ( ) ;
262
-
263
- match expr {
264
- ScalarExpression :: Constant ( _) | ScalarExpression :: ColumnRef { .. } => {
265
- self . context . group_by_exprs . push ( expr. clone ( ) )
266
- }
267
- _ => {
268
- let index = self . context . input_ref_index ( InputRefType :: GroupBy ) ;
269
-
270
- self . context . group_by_exprs . push ( std:: mem:: replace (
271
- expr,
272
- ScalarExpression :: InputRef {
273
- index,
274
- ty : expr. return_type ( ) ,
275
- ref_columns,
276
- } ,
277
- ) )
278
- }
279
- }
201
+ self . context . group_by_exprs . push ( select_list[ i] . clone ( ) )
280
202
}
281
203
}
282
204
@@ -320,6 +242,13 @@ impl<'a, T: Transaction> Binder<'a, T> {
320
242
ScalarExpression :: TypeCast { expr, .. } => self . validate_having_orderby ( expr) ,
321
243
ScalarExpression :: IsNull { expr, .. } => self . validate_having_orderby ( expr) ,
322
244
ScalarExpression :: Unary { expr, .. } => self . validate_having_orderby ( expr) ,
245
+ ScalarExpression :: In { expr, args, .. } => {
246
+ self . validate_having_orderby ( expr) ?;
247
+ for arg in args {
248
+ self . validate_having_orderby ( arg) ?;
249
+ }
250
+ Ok ( ( ) )
251
+ }
323
252
ScalarExpression :: Binary {
324
253
left_expr,
325
254
right_expr,
@@ -330,7 +259,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
330
259
Ok ( ( ) )
331
260
}
332
261
333
- ScalarExpression :: Constant ( _) | ScalarExpression :: InputRef { .. } => Ok ( ( ) ) ,
262
+ ScalarExpression :: Constant ( _) => Ok ( ( ) ) ,
334
263
}
335
264
}
336
265
}
0 commit comments