@@ -172,57 +172,58 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
172
172
)
173
173
174
174
def visit_ArrayFilter (self , op , * , arg , param , body , index ):
175
+ # no index, life is simpler
175
176
if index is None :
176
177
return self .f .filter (arg , sge .Lambda (this = body , expressions = [param ]))
177
- else :
178
- placeholder = sg .to_identifier ("__trino_filter__" )
179
- index = sg .to_identifier (index )
180
- keep , value = map (sg .to_identifier , ("keep" , "value" ))
181
178
182
- # first, zip the array with the index and call the user's function,
183
- # returning a struct of {"keep": value-of-predicate, "value": array-element}
184
- zipped = self .f .zip_with (
185
- arg ,
186
- # users are limited to 10_000 elements here because it
187
- # seems like trino won't ever actually address the limit
188
- self .f .sequence (0 , self .f .cardinality (arg ) - 1 ),
189
- sge .Lambda (
190
- this = self .cast (
191
- sge .Struct (
192
- expressions = [
193
- sge .PropertyEQ (this = keep , expression = body ),
194
- sge .PropertyEQ (this = value , expression = param ),
195
- ]
196
- ),
197
- dt .Struct (
198
- {
199
- "keep" : dt .boolean ,
200
- "value" : op .arg .dtype .value_type ,
201
- }
202
- ),
179
+ placeholder = sg .to_identifier ("__trino_filter__" )
180
+ index = sg .to_identifier (index )
181
+ keep , value = map (sg .to_identifier , ("keep" , "value" ))
182
+
183
+ # first, zip the array with the index and call the user's function,
184
+ # returning a struct of {"keep": value-of-predicate, "value": array-element}
185
+ zipped = self .f .zip_with (
186
+ arg ,
187
+ # users are limited to 10_000 elements here because it
188
+ # seems like trino won't ever actually address the limit
189
+ self .f .sequence (0 , self .f .cardinality (arg ) - 1 ),
190
+ sge .Lambda (
191
+ this = self .cast (
192
+ sge .Struct (
193
+ expressions = [
194
+ sge .PropertyEQ (this = keep , expression = body ),
195
+ sge .PropertyEQ (this = value , expression = param ),
196
+ ]
197
+ ),
198
+ dt .Struct (
199
+ {
200
+ "keep" : dt .boolean ,
201
+ "value" : op .arg .dtype .value_type ,
202
+ }
203
203
),
204
- expressions = [param , index ],
205
204
),
206
- )
205
+ expressions = [param , index ],
206
+ ),
207
+ )
207
208
208
- # second, keep only the elements whose predicate returned true
209
- filtered = self .f .filter (
210
- # then, filter out elements that are null
211
- zipped ,
212
- sge .Lambda (
213
- this = sge .Dot (this = placeholder , expression = keep ),
214
- expressions = [placeholder ],
215
- ),
216
- )
209
+ # second, keep only the elements whose predicate returned true
210
+ filtered = self .f .filter (
211
+ # then, filter out elements that are null
212
+ zipped ,
213
+ sge .Lambda (
214
+ this = sge .Dot (this = placeholder , expression = keep ),
215
+ expressions = [placeholder ],
216
+ ),
217
+ )
217
218
218
- # finally, extract the "value" field from the struct
219
- return self .f .transform (
220
- filtered ,
221
- sge .Lambda (
222
- this = sge .Dot (this = placeholder , expression = value ),
223
- expressions = [placeholder ],
224
- ),
225
- )
219
+ # finally, extract the "value" field from the struct
220
+ return self .f .transform (
221
+ filtered ,
222
+ sge .Lambda (
223
+ this = sge .Dot (this = placeholder , expression = value ),
224
+ expressions = [placeholder ],
225
+ ),
226
+ )
226
227
227
228
def visit_ArrayContains (self , op , * , arg , other ):
228
229
return self .if_ (
0 commit comments