Skip to content

Commit cce7b77

Browse files
authored
feat(expr): implement array_flatten (#21640)
Signed-off-by: Richard Chien <[email protected]>
1 parent 4790072 commit cce7b77

File tree

7 files changed

+117
-0
lines changed

7 files changed

+117
-0
lines changed

proto/expr.proto

+1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ message ExprNode {
236236
ARRAY_SORT = 549;
237237
ARRAY_CONTAINS = 550;
238238
ARRAY_CONTAINED = 551;
239+
ARRAY_FLATTEN = 552;
239240

240241
// Int256 functions
241242
HEX_TO_INT256 = 560;

src/common/src/types/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,21 @@ impl DataType {
510510
/// # Panics
511511
///
512512
/// Panics if the type is not a list type.
513+
/// TODO(rc): rename to `as_list_element_type`
513514
pub fn as_list(&self) -> &DataType {
514515
match self {
515516
DataType::List(t) => t,
516517
t => panic!("expect list type, got {t}"),
517518
}
518519
}
519520

521+
pub fn into_list_element_type(self) -> DataType {
522+
match self {
523+
DataType::List(t) => *t,
524+
t => panic!("expect list type, got {t}"),
525+
}
526+
}
527+
520528
/// Return a new type that removes the outer list, and get the innermost element type.
521529
///
522530
/// Use [`DataType::as_list`] if you only want the element type of a list.
+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright 2025 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use risingwave_common::array::{ListRef, ListValue};
16+
use risingwave_expr::expr::Context;
17+
use risingwave_expr::{ExprError, Result, function};
18+
19+
/// Flattens a nested array by concatenating the inner arrays into a single array.
20+
/// Only the outermost level of nesting is removed. For deeper nested arrays, call
21+
/// `array_flatten` multiple times.
22+
///
23+
/// Examples:
24+
///
25+
/// ```slt
26+
/// query T
27+
/// select array_flatten(array[array[1, 2], array[3, 4]]);
28+
/// ----
29+
/// {1,2,3,4}
30+
///
31+
/// query T
32+
/// select array_flatten(array[array[1, 2], array[]::int[], array[3, 4]]);
33+
/// ----
34+
/// {1,2,3,4}
35+
///
36+
/// query T
37+
/// select array_flatten(array[array[1, 2], null, array[3, 4]]);
38+
/// ----
39+
/// {1,2,3,4}
40+
///
41+
/// query T
42+
/// select array_flatten(array[array[array[1], array[2, null]], array[array[3, 4], null::int[]]]);
43+
/// ----
44+
/// {{1},{2,NULL},{3,4},NULL}
45+
///
46+
/// query T
47+
/// select array_flatten(array[[]]::int[][]);
48+
/// ----
49+
/// {}
50+
///
51+
/// query T
52+
/// select array_flatten(array[[null, 1]]::int[][]);
53+
/// ----
54+
/// {NULL,1}
55+
///
56+
/// query T
57+
/// select array_flatten(array[]::int[][]);
58+
/// ----
59+
/// {}
60+
///
61+
/// query T
62+
/// select array_flatten(null::int[][]);
63+
/// ----
64+
/// NULL
65+
/// ```
66+
#[function("array_flatten(anyarray) -> anyarray")]
67+
fn array_flatten(array: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
68+
// The elements of the array must be arrays themselves
69+
let outer_type = &ctx.arg_types[0];
70+
let inner_type = if outer_type.is_array() {
71+
outer_type.as_list()
72+
} else {
73+
return Err(ExprError::InvalidParam {
74+
name: "array_flatten",
75+
reason: Box::from("expected the argument to be an array of arrays"),
76+
});
77+
};
78+
if !inner_type.is_array() {
79+
return Err(ExprError::InvalidParam {
80+
name: "array_flatten",
81+
reason: Box::from("expected the argument to be an array of arrays"),
82+
});
83+
}
84+
let inner_elem_type = inner_type.as_list();
85+
86+
// Collect all inner array elements and flatten them into a single array
87+
Ok(ListValue::from_datum_iter(
88+
inner_elem_type,
89+
array
90+
.iter()
91+
// Filter out NULL inner arrays
92+
.flatten()
93+
// Flatten all inner arrays
94+
.flat_map(|inner_array| inner_array.into_list().iter()),
95+
))
96+
}

src/expr/impl/src/scalar/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod array_access;
1818
mod array_concat;
1919
mod array_contain;
2020
mod array_distinct;
21+
mod array_flatten;
2122
mod array_length;
2223
mod array_min_max;
2324
mod array_positions;

src/frontend/src/binder/expr/function/builtin_scalar.rs

+9
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,15 @@ impl Binder {
344344
("arraycontains", raw_call(ExprType::ArrayContains)),
345345
("array_contained", raw_call(ExprType::ArrayContained)),
346346
("arraycontained", raw_call(ExprType::ArrayContained)),
347+
("array_flatten", guard_by_len(1, raw(|_binder, inputs| {
348+
inputs[0].ensure_array_type().map_err(|_| ErrorCode::BindError("array_flatten expects `any[][]` input".into()))?;
349+
let return_type = inputs[0].return_type().into_list_element_type();
350+
if !return_type.is_array() {
351+
return Err(ErrorCode::BindError("array_flatten expects `any[][]` input".into()).into());
352+
353+
}
354+
Ok(FunctionCall::new_unchecked(ExprType::ArrayFlatten, inputs, return_type).into())
355+
}))),
347356
("trim_array", raw_call(ExprType::TrimArray)),
348357
(
349358
"array_ndims",

src/frontend/src/expr/pure.rs

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ impl ExprVisitor for ImpureAnalyzer {
183183
| Type::ArrayPosition
184184
| Type::ArrayContains
185185
| Type::ArrayContained
186+
| Type::ArrayFlatten
186187
| Type::HexToInt256
187188
| Type::JsonbConcat
188189
| Type::JsonbAccess

src/frontend/src/optimizer/plan_expr_visitor/strong.rs

+1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ impl Strong {
264264
| ExprType::ArraySort
265265
| ExprType::ArrayContains
266266
| ExprType::ArrayContained
267+
| ExprType::ArrayFlatten
267268
| ExprType::HexToInt256
268269
| ExprType::JsonbAccess
269270
| ExprType::JsonbAccessStr

0 commit comments

Comments
 (0)