Skip to content

Commit bad487b

Browse files
authored
fix(snowflake): make semantics of array filtering match everything else (#10469)
1 parent 23c0e81 commit bad487b

File tree

2 files changed

+83
-63
lines changed

2 files changed

+83
-63
lines changed

ibis/backends/sql/compilers/snowflake.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -800,31 +800,50 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
800800
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
801801

802802
def visit_ArrayFilter(self, op, *, arg, param, body, index):
803-
if index is not None:
804-
arg = self.f.arrays_zip(
805-
arg, self.f.array_generate_range(0, self.f.array_size(arg))
806-
)
807-
null_filter_arg = self.f.get(param, "$1")
808-
# extract the field we care about
809-
placeholder = sg.to_identifier("__ibis_snowflake_arg__")
810-
post_process = lambda arg: self.f.transform(
803+
if index is None:
804+
return self.f.filter(
811805
arg,
806+
# nulls are considered false when they are returned from a
807+
# `filter` predicate
808+
#
809+
# we're using is_null_value here because snowflake
810+
# automatically converts embedded SQL NULLs to JSON nulls in
811+
# higher order functions
812812
sge.Lambda(
813-
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
813+
this=sg.and_(sg.not_(self.f.is_null_value(param)), body),
814+
expressions=[param],
814815
),
815816
)
816-
else:
817-
null_filter_arg = param
818-
post_process = lambda arg: arg
819817

820-
# null_filter is necessary otherwise null values are treated as JSON
821-
# nulls instead of SQL NULLs
822-
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))
818+
zipped = self.f.arrays_zip(
819+
arg, self.f.array_generate_range(0, self.f.array_size(arg))
820+
)
821+
# extract the field we care about
822+
keeps = self.f.transform(
823+
zipped,
824+
sge.Lambda(
825+
this=self.f.object_construct_keep_null(
826+
"keep", body, "value", self.f.get(param, "$1")
827+
),
828+
expressions=[param],
829+
),
830+
)
823831

824-
return post_process(
825-
self.f.filter(
826-
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
827-
)
832+
# then, filter out elements that are null
833+
placeholder1 = sg.to_identifier("__f1__")
834+
placeholder2 = sg.to_identifier("__f2__")
835+
filtered = self.f.filter(
836+
keeps,
837+
sge.Lambda(
838+
this=self.cast(self.f.get(placeholder1, "keep"), dt.boolean),
839+
expressions=[placeholder1],
840+
),
841+
)
842+
return self.f.transform(
843+
filtered,
844+
sge.Lambda(
845+
this=self.f.get(placeholder2, "value"), expressions=[placeholder2]
846+
),
828847
)
829848

830849
def visit_JoinLink(self, op, *, how, table, predicates):

ibis/backends/sql/compilers/trino.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -172,57 +172,58 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
172172
)
173173

174174
def visit_ArrayFilter(self, op, *, arg, param, body, index):
175+
# no index, life is simpler
175176
if index is None:
176177
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
177-
else:
178-
placeholder = sg.to_identifier("__trino_filter__")
179-
index = sg.to_identifier(index)
180-
keep, value = map(sg.to_identifier, ("keep", "value"))
181178

182-
# first, zip the array with the index and call the user's function,
183-
# returning a struct of {"keep": value-of-predicate, "value": array-element}
184-
zipped = self.f.zip_with(
185-
arg,
186-
# users are limited to 10_000 elements here because it
187-
# seems like trino won't ever actually address the limit
188-
self.f.sequence(0, self.f.cardinality(arg) - 1),
189-
sge.Lambda(
190-
this=self.cast(
191-
sge.Struct(
192-
expressions=[
193-
sge.PropertyEQ(this=keep, expression=body),
194-
sge.PropertyEQ(this=value, expression=param),
195-
]
196-
),
197-
dt.Struct(
198-
{
199-
"keep": dt.boolean,
200-
"value": op.arg.dtype.value_type,
201-
}
202-
),
179+
placeholder = sg.to_identifier("__trino_filter__")
180+
index = sg.to_identifier(index)
181+
keep, value = map(sg.to_identifier, ("keep", "value"))
182+
183+
# first, zip the array with the index and call the user's function,
184+
# returning a struct of {"keep": value-of-predicate, "value": array-element}
185+
zipped = self.f.zip_with(
186+
arg,
187+
# users are limited to 10_000 elements here because it
188+
# seems like trino won't ever actually address the limit
189+
self.f.sequence(0, self.f.cardinality(arg) - 1),
190+
sge.Lambda(
191+
this=self.cast(
192+
sge.Struct(
193+
expressions=[
194+
sge.PropertyEQ(this=keep, expression=body),
195+
sge.PropertyEQ(this=value, expression=param),
196+
]
197+
),
198+
dt.Struct(
199+
{
200+
"keep": dt.boolean,
201+
"value": op.arg.dtype.value_type,
202+
}
203203
),
204-
expressions=[param, index],
205204
),
206-
)
205+
expressions=[param, index],
206+
),
207+
)
207208

208-
# second, keep only the elements whose predicate returned true
209-
filtered = self.f.filter(
210-
# then, filter out elements that are null
211-
zipped,
212-
sge.Lambda(
213-
this=sge.Dot(this=placeholder, expression=keep),
214-
expressions=[placeholder],
215-
),
216-
)
209+
# second, keep only the elements whose predicate returned true
210+
filtered = self.f.filter(
211+
# then, filter out elements that are null
212+
zipped,
213+
sge.Lambda(
214+
this=sge.Dot(this=placeholder, expression=keep),
215+
expressions=[placeholder],
216+
),
217+
)
217218

218-
# finally, extract the "value" field from the struct
219-
return self.f.transform(
220-
filtered,
221-
sge.Lambda(
222-
this=sge.Dot(this=placeholder, expression=value),
223-
expressions=[placeholder],
224-
),
225-
)
219+
# finally, extract the "value" field from the struct
220+
return self.f.transform(
221+
filtered,
222+
sge.Lambda(
223+
this=sge.Dot(this=placeholder, expression=value),
224+
expressions=[placeholder],
225+
),
226+
)
226227

227228
def visit_ArrayContains(self, op, *, arg, other):
228229
return self.if_(

0 commit comments

Comments
 (0)