Skip to content

Commit e951cde

Browse files
kamalesh0406xxchan
andauthored
feat: Add support for array_length function in psql (risingwavelabs#8636)
Co-authored-by: xxchan <[email protected]>
1 parent aeddef3 commit e951cde

File tree

6 files changed

+301
-0
lines changed

6 files changed

+301
-0
lines changed

proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ message ExprNode {
120120
ARRAY_PREPEND = 533;
121121
FORMAT_TYPE = 534;
122122
ARRAY_DISTINCT = 535;
123+
ARRAY_LENGTH = 536;
123124

124125
// Jsonb functions
125126

src/expr/src/expr/build_expr_from_prost.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ use super::expr_unary::{
5050
};
5151
use super::expr_vnode::VnodeExpression;
5252
use crate::expr::expr_array_distinct::ArrayDistinctExpression;
53+
use crate::expr::expr_array_length::ArrayLengthExpression;
5354
use crate::expr::expr_array_to_string::ArrayToStringExpression;
5455
use crate::expr::expr_binary_nonnull::new_tumble_start;
5556
use crate::expr::expr_ternary::new_tumble_start_offset;
@@ -117,6 +118,7 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
117118
}
118119
ArrayToString => ArrayToStringExpression::try_from(prost).map(Expression::boxed),
119120
ArrayDistinct => ArrayDistinctExpression::try_from(prost).map(Expression::boxed),
121+
ArrayLength => ArrayLengthExpression::try_from(prost).map(Expression::boxed),
120122
Vnode => VnodeExpression::try_from(prost).map(Expression::boxed),
121123
Now => build_now_expr(prost),
122124
Udf => UdfExpression::try_from(prost).map(Expression::boxed),
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright 2023 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::sync::Arc;
16+
17+
use risingwave_common::array::{ArrayRef, DataChunk};
18+
use risingwave_common::row::OwnedRow;
19+
use risingwave_common::types::{DataType, Datum, DatumRef, ScalarImpl, ScalarRefImpl, ToDatumRef};
20+
use risingwave_common::util::iter_util::ZipEqFast;
21+
use risingwave_pb::expr::expr_node::{RexNode, Type};
22+
use risingwave_pb::expr::ExprNode;
23+
24+
use crate::expr::{build_from_prost, BoxedExpression, Expression};
25+
use crate::{bail, ensure, ExprError, Result};
26+
27+
/// Returns the length of an array.
28+
///
29+
/// ```sql
30+
/// array_length ( array anyarray) → int64
31+
/// ```
32+
///
33+
/// Examples:
34+
///
35+
/// ```slt
36+
/// query T
37+
/// select array_length(null::int[]);
38+
/// ----
39+
/// NULL
40+
///
41+
/// query T
42+
/// select array_length(array[1,2,3]);
43+
/// ----
44+
/// 3
45+
///
46+
/// query T
47+
/// select array_length(array[1,2,3,4,1]);
48+
/// ----
49+
/// 5
50+
///
51+
/// query T
52+
/// select array_length(null::int[]);
53+
/// ----
54+
/// NULL
55+
///
56+
/// query T
57+
/// select array_length(array[array[1, 2, 3]]);
58+
/// ----
59+
/// 1
60+
///
61+
/// query T
62+
/// select array_length(array[NULL]);
63+
/// ----
64+
/// 1
65+
///
66+
/// query error unknown type
67+
/// select array_length(null);
68+
/// ```
69+
70+
#[derive(Debug)]
71+
pub struct ArrayLengthExpression {
72+
array: BoxedExpression,
73+
return_type: DataType,
74+
}
75+
76+
impl<'a> TryFrom<&'a ExprNode> for ArrayLengthExpression {
77+
type Error = ExprError;
78+
79+
fn try_from(prost: &'a ExprNode) -> Result<Self> {
80+
ensure!(prost.get_expr_type().unwrap() == Type::ArrayLength);
81+
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
82+
bail!("Expected RexNode:FunctionCall")
83+
};
84+
let children = func_call_node.get_children();
85+
ensure!(children.len() == 1);
86+
let array = build_from_prost(&children[0])?;
87+
let return_type = DataType::Int64;
88+
Ok(Self { array, return_type })
89+
}
90+
}
91+
92+
#[async_trait::async_trait]
93+
impl Expression for ArrayLengthExpression {
94+
fn return_type(&self) -> DataType {
95+
self.return_type.clone()
96+
}
97+
98+
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
99+
let array = self.array.eval_checked(input).await?;
100+
let mut builder = self.return_type.create_array_builder(array.len());
101+
102+
for (vis, input_array) in input.vis().iter().zip_eq_fast(array.iter()) {
103+
if vis {
104+
builder.append_datum(self.evaluate(input_array));
105+
} else {
106+
builder.append_null();
107+
}
108+
}
109+
110+
Ok(Arc::new(builder.finish()))
111+
}
112+
113+
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
114+
let array_data = self.array.eval_row(input).await?;
115+
Ok(self.evaluate(array_data.to_datum_ref()))
116+
}
117+
}
118+
119+
impl ArrayLengthExpression {
120+
fn evaluate(&self, array: DatumRef<'_>) -> Datum {
121+
match array {
122+
Some(ScalarRefImpl::List(array)) => Some(ScalarImpl::Int64(
123+
array.values_ref().len().try_into().unwrap(),
124+
)),
125+
None => None,
126+
_ => {
127+
panic!("The array should be a valid array");
128+
}
129+
}
130+
}
131+
}
132+
133+
#[cfg(test)]
134+
mod tests {
135+
use itertools::Itertools;
136+
use risingwave_common::array::{DataChunk, ListValue};
137+
use risingwave_common::types::{DataType, ScalarImpl};
138+
use risingwave_pb::data::Datum as ProstDatum;
139+
use risingwave_pb::expr::expr_node::{RexNode, Type as ProstType};
140+
use risingwave_pb::expr::{ExprNode, FunctionCall};
141+
142+
use crate::expr::expr_array_length::ArrayLengthExpression;
143+
use crate::expr::{BoxedExpression, Expression, LiteralExpression};
144+
145+
fn make_i64_expr_node(value: i64) -> ExprNode {
146+
ExprNode {
147+
expr_type: ProstType::ConstantValue as i32,
148+
return_type: Some(DataType::Int64.to_protobuf()),
149+
rex_node: Some(RexNode::Constant(ProstDatum {
150+
body: value.to_be_bytes().to_vec(),
151+
})),
152+
}
153+
}
154+
155+
fn make_i64_array_expr_node(values: Vec<i64>) -> ExprNode {
156+
ExprNode {
157+
expr_type: ProstType::Array as i32,
158+
return_type: Some(
159+
DataType::List {
160+
datatype: Box::new(DataType::Int64),
161+
}
162+
.to_protobuf(),
163+
),
164+
rex_node: Some(RexNode::FuncCall(FunctionCall {
165+
children: values.into_iter().map(make_i64_expr_node).collect(),
166+
})),
167+
}
168+
}
169+
170+
fn make_i64_array_array_expr_node(values: Vec<Vec<i64>>) -> ExprNode {
171+
ExprNode {
172+
expr_type: ProstType::Array as i32,
173+
return_type: Some(
174+
DataType::List {
175+
datatype: Box::new(DataType::List {
176+
datatype: Box::new(DataType::Int64),
177+
}),
178+
}
179+
.to_protobuf(),
180+
),
181+
rex_node: Some(RexNode::FuncCall(FunctionCall {
182+
children: values.into_iter().map(make_i64_array_expr_node).collect(),
183+
})),
184+
}
185+
}
186+
187+
#[test]
188+
fn test_array_length_try_from() {
189+
{
190+
let array = make_i64_expr_node(1);
191+
let expr = ExprNode {
192+
expr_type: ProstType::ArrayLength as i32,
193+
return_type: Some(
194+
DataType::List {
195+
datatype: Box::new(DataType::Int64),
196+
}
197+
.to_protobuf(),
198+
),
199+
rex_node: Some(RexNode::FuncCall(FunctionCall {
200+
children: vec![array],
201+
})),
202+
};
203+
204+
assert!(ArrayLengthExpression::try_from(&expr).is_ok());
205+
}
206+
207+
{
208+
let array = make_i64_array_expr_node(vec![1, 2, 3]);
209+
let expr = ExprNode {
210+
expr_type: ProstType::ArrayLength as i32,
211+
return_type: Some(
212+
DataType::List {
213+
datatype: Box::new(DataType::Int64),
214+
}
215+
.to_protobuf(),
216+
),
217+
rex_node: Some(RexNode::FuncCall(FunctionCall {
218+
children: vec![array],
219+
})),
220+
};
221+
222+
assert!(ArrayLengthExpression::try_from(&expr).is_ok());
223+
}
224+
225+
{
226+
let array = make_i64_array_array_expr_node(vec![vec![1, 2, 3]]);
227+
let expr = ExprNode {
228+
expr_type: ProstType::ArrayLength as i32,
229+
return_type: Some(
230+
DataType::List {
231+
datatype: Box::new(DataType::Int64),
232+
}
233+
.to_protobuf(),
234+
),
235+
rex_node: Some(RexNode::FuncCall(FunctionCall {
236+
children: vec![array],
237+
})),
238+
};
239+
240+
assert!(ArrayLengthExpression::try_from(&expr).is_ok());
241+
}
242+
}
243+
244+
fn make_i64_array_expr(values: Vec<i64>) -> BoxedExpression {
245+
LiteralExpression::new(
246+
DataType::List {
247+
datatype: Box::new(DataType::Int64),
248+
},
249+
Some(ListValue::new(values.into_iter().map(|x| Some(x.into())).collect()).into()),
250+
)
251+
.boxed()
252+
}
253+
254+
#[tokio::test]
255+
async fn test_array_length_of_primitives() {
256+
let array = make_i64_array_expr(vec![1, 2, 3]);
257+
let expr = ArrayLengthExpression {
258+
array,
259+
return_type: DataType::Int64,
260+
};
261+
262+
let chunk =
263+
DataChunk::new_dummy(3).with_visibility(([false, true, true]).into_iter().collect());
264+
let expected_length = Some(ScalarImpl::Int64(3));
265+
266+
let expected = vec![None, expected_length.clone(), expected_length];
267+
268+
let actual = expr
269+
.eval(&chunk)
270+
.await
271+
.unwrap()
272+
.iter()
273+
.map(|v| v.map(|s| s.into_scalar_impl()))
274+
.collect_vec();
275+
276+
assert_eq!(actual, expected);
277+
}
278+
}

src/expr/src/expr/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
// These modules define concrete expression structures.
3535
mod expr_array_concat;
3636
mod expr_array_distinct;
37+
mod expr_array_length;
3738
mod expr_array_to_string;
3839
mod expr_binary_bytes;
3940
mod expr_binary_nonnull;

src/frontend/src/binder/expr/function.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ impl Binder {
387387
("array_prepend", raw_call(ExprType::ArrayPrepend)),
388388
("array_to_string", raw_call(ExprType::ArrayToString)),
389389
("array_distinct", raw_call(ExprType::ArrayDistinct)),
390+
("array_length", raw_call(ExprType::ArrayLength)),
390391
// jsonb
391392
("jsonb_object_field", raw_call(ExprType::JsonbAccessInner)),
392393
("jsonb_array_element", raw_call(ExprType::JsonbAccessInner)),

src/frontend/src/expr/type_inference/func.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,24 @@ fn infer_type_for_special(
542542
_ => Ok(None),
543543
}
544544
}
545+
ExprType::ArrayLength => {
546+
ensure_arity!("array_length", | inputs | == 1);
547+
let return_type = inputs[0].return_type();
548+
549+
if inputs[0].is_unknown() {
550+
return Err(ErrorCode::BindError(
551+
"Cannot find length for unknown type".to_string(),
552+
)
553+
.into());
554+
}
555+
556+
match return_type {
557+
DataType::List {
558+
datatype: _list_elem_type,
559+
} => Ok(Some(DataType::Int64)),
560+
_ => Ok(None),
561+
}
562+
}
545563
ExprType::Vnode => {
546564
ensure_arity!("vnode", 1 <= | inputs |);
547565
Ok(Some(DataType::Int16))

0 commit comments

Comments
 (0)