Skip to content

Commit ca85ae2

Browse files
cpcloudjcrist
authored andcommitted
feat(flink): array sort
1 parent eb857e6 commit ca85ae2

File tree

5 files changed

+39
-17
lines changed

5 files changed

+39
-17
lines changed

ibis/backends/sql/compilers/flink.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class FlinkCompiler(SQLGlotCompiler):
7272
ops.ArgMax,
7373
ops.ArgMin,
7474
ops.ArrayFlatten,
75-
ops.ArraySort,
7675
ops.ArrayStringJoin,
7776
ops.Correlation,
7877
ops.CountDistinctStar,
@@ -102,6 +101,7 @@ class FlinkCompiler(SQLGlotCompiler):
102101
ops.ArrayLength: "cardinality",
103102
ops.ArrayPosition: "array_position",
104103
ops.ArrayRemove: "array_remove",
104+
ops.ArraySort: "array_sort",
105105
ops.ArrayUnion: "array_union",
106106
ops.ExtractDayOfYear: "dayofyear",
107107
ops.MapKeys: "map_keys",
@@ -576,10 +576,20 @@ def visit_StructColumn(self, op, *, names, values):
576576
return self.cast(sge.Struct(expressions=list(values)), op.dtype)
577577

578578
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
579+
if order_by:
580+
raise com.UnsupportedOperationError(
581+
"ordering of order-sensitive aggregations via `order_by` is "
582+
"not supported for this backend"
583+
)
584+
# the only way to get filtering *and* respecting nulls is to use
585+
# `FILTER` syntax, but it's broken in various ways for other aggregates
586+
out = self.f.array_agg(arg)
579587
if not include_null:
580588
cond = arg.is_(sg.not_(NULL, copy=False))
581589
where = cond if where is None else sge.And(this=cond, expression=where)
582-
return self.agg.array_agg(arg, where=where, order_by=order_by)
590+
if where is not None:
591+
out = sge.Filter(this=out, expression=sge.Where(this=where))
592+
return out
583593

584594

585595
compiler = FlinkCompiler()

ibis/backends/sql/dialects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class Generator(Hive.Generator):
212212
sge.ArrayConcat: rename_func("array_concat"),
213213
sge.ArraySize: rename_func("cardinality"),
214214
sge.ArrayAgg: rename_func("array_agg"),
215+
sge.ArraySort: rename_func("array_sort"),
215216
sge.Length: rename_func("char_length"),
216217
sge.TryCast: lambda self,
217218
e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})",

ibis/backends/tests/test_aggregation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,13 +1480,13 @@ def test_collect_ordered(alltypes, df, filtered):
14801480
def test_collect(alltypes, df, filtered, include_null):
14811481
ibis_cond = (_.id % 13 == 0) if filtered else None
14821482
pd_cond = (df.id % 13 == 0) if filtered else slice(None)
1483-
res = (
1483+
expr = (
14841484
alltypes.string_col.nullif("3")
14851485
.collect(where=ibis_cond, include_null=include_null)
14861486
.length()
1487-
.execute()
14881487
)
1489-
vals = df.string_col if include_null else df.string_col[(df.string_col != "3")]
1488+
res = expr.execute()
1489+
vals = df.string_col if include_null else df.string_col[df.string_col != "3"]
14901490
sol = len(vals[pd_cond])
14911491
assert res == sol
14921492

ibis/backends/tests/test_array.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,13 @@ def test_unnest_idempotent(backend):
316316
["scalar_column", array_types.x.cast("!array<int64>").unnest().name("x")]
317317
)
318318
.group_by("scalar_column")
319-
.aggregate(x=lambda t: t.x.collect())
319+
.aggregate(x=lambda t: t.x.collect().sort())
320320
.order_by("scalar_column")
321321
)
322322
result = expr.execute().reset_index(drop=True)
323323
expected = (
324324
df[["scalar_column", "x"]]
325-
.assign(x=df.x.map(lambda arr: [i for i in arr if not pd.isna(i)]))
325+
.assign(x=df.x.map(lambda arr: sorted(i for i in arr if not pd.isna(i))))
326326
.sort_values("scalar_column")
327327
.reset_index(drop=True)
328328
)
@@ -718,20 +718,34 @@ def test_array_unique(con, input, expected):
718718

719719

720720
@builtin_array
721-
@pytest.mark.notimpl(
722-
["flink", "polars"],
723-
raises=com.OperationNotDefinedError,
724-
)
721+
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
725722
@pytest.mark.notyet(
726723
["risingwave"],
727724
raises=AssertionError,
728725
reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735",
729726
)
730-
def test_array_sort(con):
731-
t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)})
727+
@pytest.mark.parametrize(
728+
"data",
729+
(
730+
param(
731+
[[3, 2], [], [42, 42], []],
732+
marks=[
733+
pytest.mark.notyet(
734+
["flink"],
735+
raises=Py4JJavaError,
736+
reason="flink cannot handle empty arrays",
737+
)
738+
],
739+
),
740+
[[3, 2], [42, 42]],
741+
),
742+
ids=["empty", "nonempty"],
743+
)
744+
def test_array_sort(con, data):
745+
t = ibis.memtable({"a": data, "id": range(len(data))})
732746
expr = t.mutate(a=t.a.sort()).order_by("id")
733747
result = con.execute(expr)
734-
expected = pd.Series([[2, 3], [], [42, 42], []], dtype="object")
748+
expected = pd.Series(list(map(sorted, data)), dtype="object")
735749

736750
assert frozenset(map(tuple, result["a"].values)) == frozenset(
737751
map(tuple, expected.values)

ibis/backends/tests/test_struct.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ def test_struct_column(alltypes, df):
116116

117117
@pytest.mark.notimpl(["postgres", "risingwave", "polars"])
118118
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="unsupported syntax")
119-
@pytest.mark.notyet(
120-
["flink"], reason="flink doesn't support creating struct columns from collect"
121-
)
122119
def test_collect_into_struct(alltypes):
123120
from ibis import _
124121

0 commit comments

Comments
 (0)