Skip to content

Commit 572780b

Browse files
authored
feat(expr): support array_position and array_replace for 1d scenario (#10166)
1 parent ac2085d commit 572780b

File tree

8 files changed

+266
-23
lines changed

8 files changed

+266
-23
lines changed

proto/expr.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ message ExprNode {
179179
ARRAY_POSITIONS = 539;
180180
TRIM_ARRAY = 540;
181181
STRING_TO_ARRAY = 541;
182+
ARRAY_POSITION = 542;
183+
ARRAY_REPLACE = 543;
182184

183185
// Int256 functions
184186
HEX_TO_INT256 = 560;

src/expr/src/vector_op/array_positions.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,121 @@ use risingwave_expr_macro::function;
1919
use crate::error::ExprError;
2020
use crate::Result;
2121

22+
/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
23+
/// it's not present.
24+
///
25+
/// Examples:
26+
///
27+
/// ```slt
28+
/// query I
29+
/// select array_position(array[1, null, 2, null], null);
30+
/// ----
31+
/// 2
32+
///
33+
/// query I
34+
/// select array_position(array[3, 4, 5], 2);
35+
/// ----
36+
/// NULL
37+
///
38+
/// query I
39+
/// select array_position(null, 4);
40+
/// ----
41+
/// NULL
42+
///
43+
/// query I
44+
/// select array_position(null, null);
45+
/// ----
46+
/// NULL
47+
///
48+
/// query I
49+
/// select array_position('{yes}', true);
50+
/// ----
51+
/// 1
52+
///
53+
/// # Like in PostgreSQL, searching `int` in multidimensional array is disallowed.
54+
/// statement error
55+
/// select array_position(array[array[1, 2], array[3, 4]], 1);
56+
///
57+
/// # Unlike in PostgreSQL, it is okay to search `int[]` inside `int[][]`.
58+
/// query I
59+
/// select array_position(array[array[1, 2], array[3, 4]], array[3, 4]);
60+
/// ----
61+
/// 2
62+
///
63+
/// statement error
64+
/// select array_position(array[3, 4], true);
65+
///
66+
/// query I
67+
/// select array_position(array[3, 4], 4.0);
68+
/// ----
69+
/// 2
70+
/// ```
71+
#[function("array_position(list, *) -> int32")]
72+
fn array_position<'a, T: ScalarRef<'a>>(
73+
array: Option<ListRef<'_>>,
74+
element: Option<T>,
75+
) -> Result<Option<i32>> {
76+
array_position_common(array, element, 0)
77+
}
78+
79+
/// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if
80+
/// it's not present. The search begins at the third argument.
81+
///
82+
/// Examples:
83+
///
84+
/// ```slt
85+
/// statement error
86+
/// select array_position(array[1, null, 2, null], null, false);
87+
///
88+
/// statement error
89+
/// select array_position(array[1, null, 2, null], null, null::int);
90+
///
91+
/// query II
92+
/// select v, array_position(array[1, null, 2, null], null, v) from generate_series(-1, 5) as t(v);
93+
/// ----
94+
/// -1 2
95+
/// 0 2
96+
/// 1 2
97+
/// 2 2
98+
/// 3 4
99+
/// 4 4
100+
/// 5 NULL
101+
/// ```
102+
#[function("array_position(list, *, int32) -> int32")]
103+
fn array_position_start<'a, T: ScalarRef<'a>>(
104+
array: Option<ListRef<'_>>,
105+
element: Option<T>,
106+
start: Option<i32>,
107+
) -> Result<Option<i32>> {
108+
let start = match start {
109+
None => {
110+
return Err(ExprError::InvalidParam {
111+
name: "start",
112+
reason: "initial position must not be null".into(),
113+
})
114+
}
115+
Some(start) => (start.max(1) - 1) as usize,
116+
};
117+
array_position_common(array, element, start)
118+
}
119+
120+
fn array_position_common<'a, T: ScalarRef<'a>>(
121+
array: Option<ListRef<'_>>,
122+
element: Option<T>,
123+
skip: usize,
124+
) -> Result<Option<i32>> {
125+
let Some(left) = array else { return Ok(None) };
126+
if i32::try_from(left.len()).is_err() {
127+
return Err(ExprError::CastOutOfRange("invalid array length"));
128+
}
129+
130+
Ok(left
131+
.iter()
132+
.skip(skip)
133+
.position(|item| item == element.map(Into::into))
134+
.map(|idx| (idx + 1 + skip) as _))
135+
}
136+
22137
/// Returns an array of the subscripts of all occurrences of the second argument in the array
23138
/// given as first argument. Note the behavior is slightly different from PG.
24139
///
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright 2023 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use risingwave_common::array::{ListRef, ListValue};
16+
use risingwave_common::types::{ScalarRef, ToOwnedDatum};
17+
use risingwave_expr_macro::function;
18+
19+
/// Replaces each array element equal to the second argument with the third argument.
20+
///
21+
/// Examples:
22+
///
23+
/// ```slt
24+
/// query T
25+
/// select array_replace(array[7, null, 8, null], null, 0.5);
26+
/// ----
27+
/// {7,0.5,8,0.5}
28+
///
29+
/// query T
30+
/// select array_replace(null, 1, 5);
31+
/// ----
32+
/// NULL
33+
///
34+
/// query T
35+
/// select array_replace(null, null, null);
36+
/// ----
37+
/// NULL
38+
///
39+
/// statement error
40+
/// select array_replace(array[3, null, 4], true, false);
41+
///
42+
/// # Replacing `int` in multidimensional array is not supported yet. (OK in PostgreSQL)
43+
/// statement error
44+
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], 3, 9);
45+
///
46+
/// # Unlike PostgreSQL, it is okay to replace `int[][]` inside `int[][][]`.
47+
/// query T
48+
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[array[4, 5], array[6, 7]], array[array[2, 3], array[4, 5]]);
49+
/// ----
50+
/// {{{0,1},{2,3}},{{2,3},{4,5}}}
51+
///
52+
/// # Replacing `int[]` inside `int[][][]` is not supported by either PostgreSQL or RisingWave.
53+
/// # This may or may not be supported later, whichever makes the `int` support above simpler.
54+
/// statement error
55+
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[4, 5], array[8, 9]);
56+
/// ```
57+
#[function("array_replace(list, boolean, boolean) -> list")]
58+
#[function("array_replace(list, int16, int16) -> list")]
59+
#[function("array_replace(list, int32, int32) -> list")]
60+
#[function("array_replace(list, int64, int64) -> list")]
61+
#[function("array_replace(list, decimal, decimal) -> list")]
62+
#[function("array_replace(list, float32, float32) -> list")]
63+
#[function("array_replace(list, float64, float64) -> list")]
64+
#[function("array_replace(list, varchar, varchar) -> list")]
65+
#[function("array_replace(list, bytea, bytea) -> list")]
66+
#[function("array_replace(list, time, time) -> list")]
67+
#[function("array_replace(list, interval, interval) -> list")]
68+
#[function("array_replace(list, date, date) -> list")]
69+
#[function("array_replace(list, timestamp, timestamp) -> list")]
70+
#[function("array_replace(list, timestamptz, timestamptz) -> list")]
71+
#[function("array_replace(list, list, list) -> list")]
72+
#[function("array_replace(list, struct, struct) -> list")]
73+
#[function("array_replace(list, jsonb, jsonb) -> list")]
74+
#[function("array_replace(list, int256, int256) -> list")]
75+
fn array_replace<'a, T: ScalarRef<'a>>(
76+
arr: Option<ListRef<'_>>,
77+
elem_from: Option<T>,
78+
elem_to: Option<T>,
79+
) -> Option<ListValue> {
80+
arr.map(|arr| {
81+
ListValue::new(
82+
arr.iter()
83+
.map(|x| match x == elem_from.map(Into::into) {
84+
true => elem_to.map(Into::into).to_owned_datum(),
85+
false => x.to_owned_datum(),
86+
})
87+
.collect(),
88+
)
89+
})
90+
}

src/expr/src/vector_op/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod array_length;
1919
pub mod array_positions;
2020
pub mod array_range_access;
2121
pub mod array_remove;
22+
pub mod array_replace;
2223
pub mod ascii;
2324
pub mod bitwise_op;
2425
pub mod cardinality;

src/frontend/src/binder/expr/function.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ impl Binder {
524524
("array_length", raw_call(ExprType::ArrayLength)),
525525
("cardinality", raw_call(ExprType::Cardinality)),
526526
("array_remove", raw_call(ExprType::ArrayRemove)),
527+
("array_replace", raw_call(ExprType::ArrayReplace)),
528+
("array_position", raw_call(ExprType::ArrayPosition)),
527529
("array_positions", raw_call(ExprType::ArrayPositions)),
528530
("trim_array", raw_call(ExprType::TrimArray)),
529531
// int256

src/frontend/src/expr/pure.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ impl ExprVisitor<bool> for ImpureAnalyzer {
148148
| expr_node::Type::Cardinality
149149
| expr_node::Type::TrimArray
150150
| expr_node::Type::ArrayRemove
151+
| expr_node::Type::ArrayReplace
152+
| expr_node::Type::ArrayPosition
151153
| expr_node::Type::HexToInt256
152154
| expr_node::Type::JsonbAccessInner
153155
| expr_node::Type::JsonbAccessStr

src/frontend/src/expr/type_inference/func.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,37 @@ fn infer_type_for_special(
544544
.into()),
545545
}
546546
}
547+
ExprType::ArrayReplace => {
548+
ensure_arity!("array_replace", | inputs | == 3);
549+
let common_type = align_array_and_element(0, &[1, 2], inputs);
550+
match common_type {
551+
Ok(casted) => Ok(Some(casted)),
552+
Err(_) => Err(ErrorCode::BindError(format!(
553+
"Cannot replace {} with {} in {}",
554+
inputs[1].return_type(),
555+
inputs[2].return_type(),
556+
inputs[0].return_type(),
557+
))
558+
.into()),
559+
}
560+
}
561+
ExprType::ArrayPosition => {
562+
ensure_arity!("array_position", 2 <= | inputs | <= 3);
563+
if let Some(start) = inputs.get_mut(2) {
564+
let owned = std::mem::replace(start, ExprImpl::literal_bool(false));
565+
*start = owned.cast_implicit(DataType::Int32)?;
566+
}
567+
let common_type = align_array_and_element(0, &[1], inputs);
568+
match common_type {
569+
Ok(_) => Ok(Some(DataType::Int32)),
570+
Err(_) => Err(ErrorCode::BindError(format!(
571+
"Cannot get position of {} in {}",
572+
inputs[1].return_type(),
573+
inputs[0].return_type()
574+
))
575+
.into()),
576+
}
577+
}
547578
ExprType::ArrayPositions => {
548579
ensure_arity!("array_positions", | inputs | == 2);
549580
let common_type = align_array_and_element(0, &[1], inputs);

src/tests/regress/data/sql/arrays.sql

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -261,14 +261,14 @@ SELECT array_cat(ARRAY[1,2], ARRAY[3,4]) AS "{1,2,3,4}";
261261
SELECT array_cat(ARRAY[1,2], ARRAY[[3,4],[5,6]]) AS "{{1,2},{3,4},{5,6}}";
262262
SELECT array_cat(ARRAY[[3,4],[5,6]], ARRAY[1,2]) AS "{{3,4},{5,6},{1,2}}";
263263

264-
--@ SELECT array_position(ARRAY[1,2,3,4,5], 4);
265-
--@ SELECT array_position(ARRAY[5,3,4,2,1], 4);
266-
--@ SELECT array_position(ARRAY[[1,2],[3,4]], 3);
267-
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon');
268-
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat');
269-
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL);
270-
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL);
271-
--@ SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat');
264+
SELECT array_position(ARRAY[1,2,3,4,5], 4);
265+
SELECT array_position(ARRAY[5,3,4,2,1], 4);
266+
SELECT array_position(ARRAY[[1,2],[3,4]], 3);
267+
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'mon');
268+
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], 'sat');
269+
SELECT array_position(ARRAY['sun','mon','tue','wed','thu','fri','sat'], NULL);
270+
SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], NULL);
271+
SELECT array_position(ARRAY['sun','mon','tue','wed','thu',NULL,'fri','sat'], 'sat');
272272

273273
SELECT array_positions(NULL, 10);
274274
SELECT array_positions(NULL, NULL::int);
@@ -296,15 +296,15 @@ SELECT array_positions(ARRAY[1,2,3,NULL,5,6,1,2,3,NULL,5,6], NULL);
296296

297297
--@ SELECT array_position('[2:4]={1,2,3}'::int[], 1);
298298
--@ SELECT array_positions('[2:4]={1,2,3}'::int[], 1);
299-
--@
300-
--@ SELECT
301-
--@ array_position(ids, (1, 1)),
302-
--@ array_positions(ids, (1, 1))
303-
--@ FROM
304-
--@ (VALUES
305-
--@ (ARRAY[(0, 0), (1, 1)]),
306-
--@ (ARRAY[(1, 1)])
307-
--@ ) AS f (ids);
299+
300+
SELECT
301+
array_position(ids, (1, 1)),
302+
array_positions(ids, (1, 1))
303+
FROM
304+
(VALUES
305+
(ARRAY[(0, 0), (1, 1)]),
306+
(ARRAY[(1, 1)])
307+
) AS f (ids);
308308

309309
-- operators
310310
--@ SELECT a FROM arrtest WHERE b = ARRAY[[[113,142],[1,147]]];
@@ -625,12 +625,12 @@ select array_remove(array['A','CC','D','C','RR'], 'RR');
625625
select array_remove(array[1.0, 2.1, 3.3], 1);
626626
select array_remove('{{1,2,2},{1,4,3}}', 2); -- not allowed
627627
select array_remove(array['X','X','X'], 'X') = '{}';
628-
--@ select array_replace(array[1,2,5,4],5,3);
629-
--@ select array_replace(array[1,2,5,4],5,NULL);
630-
--@ select array_replace(array[1,2,NULL,4,NULL],NULL,5);
631-
--@ select array_replace(array['A','B','DD','B'],'B','CC');
632-
--@ select array_replace(array[1,NULL,3],NULL,NULL);
633-
--@ select array_replace(array['AB',NULL,'CDE'],NULL,'12');
628+
select array_replace(array[1,2,5,4],5,3);
629+
select array_replace(array[1,2,5,4],5,NULL);
630+
select array_replace(array[1,2,NULL,4,NULL],NULL,5);
631+
select array_replace(array['A','B','DD','B'],'B','CC');
632+
select array_replace(array[1,NULL,3],NULL,NULL);
633+
select array_replace(array['AB',NULL,'CDE'],NULL,'12');
634634

635635
-- array(select array-value ...)
636636
--@ select array(select array[i,i/2] from generate_series(1,5) i);

0 commit comments

Comments
 (0)