Skip to content

Commit f97bb77

Browse files
authored
fix(polars): use elementwise flatten to flatten nested arrays (#10168)
Fixes #10135.
1 parent 3877d6d commit f97bb77

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

ibis/backends/polars/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,11 @@ def array_flatten(op, **kw):
10211021
.then(None)
10221022
.when(result.list.len() == 0)
10231023
.then([])
1024-
.otherwise(result.flatten())
1024+
# polars doesn't have an efficient API (yet?) for removing one level of
1025+
# nesting from an array so we use elementwise evaluation
1026+
#
1027+
# https://github.com/ibis-project/ibis/issues/10135
1028+
.otherwise(result.list.eval(pl.element().flatten()))
10251029
)
10261030

10271031

ibis/backends/tests/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
TrinoUserError = None
104104

105105
try:
106+
from psycopg2.errors import ArraySubscriptError as PsycoPg2ArraySubscriptError
106107
from psycopg2.errors import DivisionByZero as PsycoPg2DivisionByZero
107108
from psycopg2.errors import IndeterminateDatatype as PsycoPg2IndeterminateDatatype
108109
from psycopg2.errors import InternalError_ as PsycoPg2InternalError
@@ -118,7 +119,7 @@
118119
PsycoPg2InvalidTextRepresentation
119120
) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = (
120121
PsycoPg2OperationalError
121-
) = PsycoPg2UndefinedObject = None
122+
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = None
122123

123124
try:
124125
from MySQLdb import NotSupportedError as MySQLNotSupportedError

ibis/backends/tests/test_array.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
GoogleBadRequest,
2222
MySQLOperationalError,
2323
PolarsComputeError,
24+
PsycoPg2ArraySubscriptError,
2425
PsycoPg2IndeterminateDatatype,
2526
PsycoPg2InternalError,
2627
PsycoPg2ProgrammingError,
@@ -1006,6 +1007,11 @@ def flatten_data():
10061007
reason="Arrays are never nullable",
10071008
raises=AssertionError,
10081009
),
1010+
pytest.mark.notyet(
1011+
["polars"],
1012+
reason="flattened empty arrays incorrectly insert a null",
1013+
raises=AssertionError,
1014+
),
10091015
],
10101016
),
10111017
],
@@ -1557,3 +1563,19 @@ def test_array_agg_bool(con, data, agg, baseline_func):
15571563
result = [x if pd.notna(x) else None for x in result]
15581564
expected = [baseline_func(x) for x in df.x]
15591565
assert result == expected
1566+
1567+
1568+
@pytest.mark.notyet(
1569+
["postgres"],
1570+
raises=PsycoPg2ArraySubscriptError,
1571+
reason="all dimensions must match in size",
1572+
)
1573+
@pytest.mark.notimpl(["risingwave", "flink"], raises=com.OperationNotDefinedError)
1574+
def test_flatten(con):
1575+
t = ibis.memtable(
1576+
[{"arr": [[1, 5, 7], [3, 4]]}], schema={"arr": "array<array<int64>>"}
1577+
)
1578+
expr = t.arr.flatten().name("result")
1579+
result = con.execute(expr)
1580+
expected = pd.Series([[1, 5, 7, 3, 4]], name="result")
1581+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)