Skip to content

Commit 6d90f18

Browse files
committed
fix(snowflake): make semantics of array filtering match everything else
1 parent 983cd5d commit 6d90f18

File tree

1 file changed

+38
-19
lines changed

1 file changed

+38
-19
lines changed

ibis/backends/sql/compilers/snowflake.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -800,31 +800,50 @@ def visit_ArrayMap(self, op, *, arg, param, body, index):
800800
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
801801

802802
def visit_ArrayFilter(self, op, *, arg, param, body, index):
803-
if index is not None:
804-
arg = self.f.arrays_zip(
805-
arg, self.f.array_generate_range(0, self.f.array_size(arg))
806-
)
807-
null_filter_arg = self.f.get(param, "$1")
808-
# extract the field we care about
809-
placeholder = sg.to_identifier("__ibis_snowflake_arg__")
810-
post_process = lambda arg: self.f.transform(
803+
if index is None:
804+
return self.f.filter(
811805
arg,
806+
# nulls are considered false when they are returned from a
807+
# `filter` predicate
808+
#
809+
# we're using is_null_value here because snowflake
810+
# automatically converts embedded SQL NULLs to JSON nulls in
811+
# higher order functions
812812
sge.Lambda(
813-
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
813+
this=sg.and_(sg.not_(self.f.is_null_value(param)), body),
814+
expressions=[param],
814815
),
815816
)
816-
else:
817-
null_filter_arg = param
818-
post_process = lambda arg: arg
819817

820-
# null_filter is necessary otherwise null values are treated as JSON
821-
# nulls instead of SQL NULLs
822-
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))
818+
zipped = self.f.arrays_zip(
819+
arg, self.f.array_generate_range(0, self.f.array_size(arg))
820+
)
821+
# extract the field we care about
822+
keeps = self.f.transform(
823+
zipped,
824+
sge.Lambda(
825+
this=self.f.object_construct_keep_null(
826+
"keep", body, "value", self.f.get(param, "$1")
827+
),
828+
expressions=[param],
829+
),
830+
)
823831

824-
return post_process(
825-
self.f.filter(
826-
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
827-
)
832+
# then, filter out elements that are null
833+
placeholder1 = sg.to_identifier("__f1__")
834+
placeholder2 = sg.to_identifier("__f2__")
835+
filtered = self.f.filter(
836+
keeps,
837+
sge.Lambda(
838+
this=self.cast(self.f.get(placeholder1, "keep"), dt.boolean),
839+
expressions=[placeholder1],
840+
),
841+
)
842+
return self.f.transform(
843+
filtered,
844+
sge.Lambda(
845+
this=self.f.get(placeholder2, "value"), expressions=[placeholder2]
846+
),
828847
)
829848

830849
def visit_JoinLink(self, op, *, how, table, predicates):

0 commit comments

Comments
 (0)