Skip to content

Commit 860b9ca

Browse files
fix(trino,pyspark): improve null handling in array filter (#10448)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent 85f0693 commit 860b9ca

File tree

3 files changed

+87
-16
lines changed

3 files changed

+87
-16
lines changed

ibis/backends/sql/compilers/pyspark.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,8 @@ def visit_ArrayFilter(self, op, *, arg, body, param, index):
397397
if index is not None:
398398
expressions.append(index)
399399

400-
func = sge.Lambda(this=self.if_(body, param, NULL), expressions=expressions)
401-
transform = self.f.transform(arg, func)
402-
403-
func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=expressions)
404-
return self.f.filter(transform, func)
400+
lamduh = sge.Lambda(this=body, expressions=expressions)
401+
return self.f.filter(arg, lamduh)
405402

406403
def visit_ArrayIndex(self, op, *, arg, index):
407404
return self.f.element_at(arg, index + 1)

ibis/backends/sql/compilers/trino.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,21 +177,50 @@ def visit_ArrayFilter(self, op, *, arg, param, body, index):
177177
else:
178178
placeholder = sg.to_identifier("__trino_filter__")
179179
index = sg.to_identifier(index)
180-
return self.f.filter(
181-
self.f.zip_with(
182-
arg,
183-
# users are limited to 10_000 elements here because it
184-
# seems like trino won't ever actually address the limit
185-
self.f.sequence(0, self.f.cardinality(arg) - 1),
186-
sge.Lambda(
187-
# semantics are: arg if predicate(arg, index) else null
188-
this=self.if_(body, param, NULL),
189-
expressions=[param, index],
180+
keep, value = map(sg.to_identifier, ("keep", "value"))
181+
182+
# first, zip the array with the index and call the user's function,
183+
# returning a struct of {"keep": value-of-predicate, "value": array-element}
184+
zipped = self.f.zip_with(
185+
arg,
186+
# users are limited to 10_000 elements here because it
187+
# seems like trino won't ever actually address the limit
188+
self.f.sequence(0, self.f.cardinality(arg) - 1),
189+
sge.Lambda(
190+
this=self.cast(
191+
sge.Struct(
192+
expressions=[
193+
sge.PropertyEQ(this=keep, expression=body),
194+
sge.PropertyEQ(this=value, expression=param),
195+
]
196+
),
197+
dt.Struct(
198+
{
199+
"keep": dt.boolean,
200+
"value": op.arg.dtype.value_type,
201+
}
202+
),
190203
),
204+
expressions=[param, index],
191205
),
206+
)
207+
208+
# second, keep only the elements whose predicate returned true
209+
filtered = self.f.filter(
192210
# then, filter out elements that are null
211+
zipped,
212+
sge.Lambda(
213+
this=sge.Dot(this=placeholder, expression=keep),
214+
expressions=[placeholder],
215+
),
216+
)
217+
218+
# finally, extract the "value" field from the struct
219+
return self.f.transform(
220+
filtered,
193221
sge.Lambda(
194-
this=placeholder.is_(sg.not_(NULL)), expressions=[placeholder]
222+
this=sge.Dot(this=placeholder, expression=value),
223+
expressions=[placeholder],
195224
),
196225
)
197226

ibis/backends/tests/test_array.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,51 @@ def test_array_filter_with_index(con, input, output, predicate):
662662
)
663663

664664

665+
@builtin_array
666+
@pytest.mark.notimpl(
667+
["datafusion", "flink", "polars"], raises=com.OperationNotDefinedError
668+
)
669+
@pytest.mark.notimpl(
670+
["sqlite"], raises=com.UnsupportedBackendType, reason="Unsupported type: Array..."
671+
)
672+
@pytest.mark.parametrize(
673+
("input", "output"),
674+
[
675+
param(
676+
{"a": [[1, None, None], [4]]},
677+
{"a": [[1, None], [4]]},
678+
id="nulls",
679+
marks=[
680+
pytest.mark.notyet(
681+
["bigquery"],
682+
raises=GoogleBadRequest,
683+
reason="NULLs are not allowed as array elements",
684+
)
685+
],
686+
),
687+
param({"a": [[1, 2], [1]]}, {"a": [[1], [1]]}, id="no_nulls"),
688+
],
689+
)
690+
@pytest.mark.notyet(
691+
"risingwave",
692+
raises=PsycoPg2InternalError,
693+
reason="no support for not null column constraint",
694+
)
695+
@pytest.mark.parametrize(
696+
"predicate",
697+
[lambda x, i: i % 2 == 0, partial(lambda x, y, i: i % 2 == 0, y=1)],
698+
ids=["lambda", "partial"],
699+
)
700+
def test_array_filter_with_index_lambda(con, input, output, predicate):
701+
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
702+
703+
expr = t.select(a=t.a.filter(predicate))
704+
result = con.to_pyarrow(expr.a)
705+
assert frozenset(map(tuple, result.to_pylist())) == frozenset(
706+
map(tuple, output["a"])
707+
)
708+
709+
665710
@builtin_array
666711
@pytest.mark.parametrize(
667712
("col", "value"),

0 commit comments

Comments
 (0)