Skip to content

Commit 9d46d4f

Browse files
margarit-hgithub-actions[bot]
authored andcommitted
Enable concat() string function to support multiple string arguments (#1279)
* Enable `concat()` string function to support multiple string arguments (#200) Signed-off-by: Margarit Hakobyan <[email protected]> (cherry picked from commit 45fc371)
1 parent 982d366 commit 9d46d4f

File tree

11 files changed

+159
-28
lines changed

11 files changed

+159
-28
lines changed

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ private FunctionBuilder getFunctionBuilder(
179179
List<ExprType> sourceTypes = functionSignature.getParamTypeList();
180180
List<ExprType> targetTypes = resolvedSignature.getKey().getParamTypeList();
181181
FunctionBuilder funcBuilder = resolvedSignature.getValue();
182-
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
182+
if (isCastFunction(functionName)
183+
|| FunctionSignature.isVarArgFunction(targetTypes)
184+
|| sourceTypes.equals(targetTypes)) {
183185
return funcBuilder;
184186
}
185187
return castArguments(sourceTypes,

core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,15 @@ public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unreso
5050
functionSignature));
5151
}
5252
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
53-
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) {
53+
if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())
54+
&& (unresolvedSignature.getParamTypeList().isEmpty()
55+
|| unresolvedSignature.getParamTypeList().size() > 9)) {
56+
throw new ExpressionEvaluationException(
57+
String.format("%s function expected 1-9 arguments, but got %d",
58+
functionName, unresolvedSignature.getParamTypeList().size()));
59+
}
60+
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())
61+
&& !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) {
5462
throw new ExpressionEvaluationException(
5563
String.format("%s function expected %s, but get %s", functionName,
5664
formatFunctions(functionBundle.keySet()),

core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.sql.expression.function;
77

8+
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
9+
810
import java.util.List;
911
import java.util.stream.Collectors;
1012
import lombok.EqualsAndHashCode;
@@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) {
3941
|| paramTypeList.size() != functionTypeList.size()) {
4042
return NOT_MATCH;
4143
}
44+
// TODO: improve to support regular and array type mixed, ex. func(int,string,array)
45+
if (isVarArgFunction(functionTypeList)) {
46+
return EXACTLY_MATCH;
47+
}
4248

4349
int matchDegree = EXACTLY_MATCH;
4450
for (int i = 0; i < paramTypeList.size(); i++) {
@@ -62,4 +68,11 @@ public String formatTypes() {
6268
.map(ExprType::typeName)
6369
.collect(Collectors.joining(",", "[", "]"));
6470
}
71+
72+
/**
73+
* util function - returns true if function has variable arguments.
74+
*/
75+
protected static boolean isVarArgFunction(List<ExprType> argTypes) {
76+
return argTypes.size() == 1 && argTypes.get(0) == ARRAY;
77+
}
6578
}

core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,34 @@
66

77
package org.opensearch.sql.expression.text;
88

9+
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
910
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
1011
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
1112
import static org.opensearch.sql.expression.function.FunctionDSL.define;
1213
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
1314
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;
1415

16+
import java.util.Collections;
17+
import java.util.List;
18+
import java.util.stream.Collectors;
1519
import lombok.experimental.UtilityClass;
20+
import org.apache.commons.lang3.tuple.Pair;
1621
import org.opensearch.sql.data.model.ExprIntegerValue;
1722
import org.opensearch.sql.data.model.ExprStringValue;
1823
import org.opensearch.sql.data.model.ExprValue;
24+
import org.opensearch.sql.data.model.ExprValueUtils;
25+
import org.opensearch.sql.data.type.ExprType;
26+
import org.opensearch.sql.expression.Expression;
27+
import org.opensearch.sql.expression.FunctionExpression;
28+
import org.opensearch.sql.expression.env.Environment;
1929
import org.opensearch.sql.expression.function.BuiltinFunctionName;
2030
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
2131
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
2232
import org.opensearch.sql.expression.function.FunctionName;
33+
import org.opensearch.sql.expression.function.FunctionSignature;
2334
import org.opensearch.sql.expression.function.SerializableBiFunction;
2435
import org.opensearch.sql.expression.function.SerializableTriFunction;
2536

26-
2737
/**
2838
* The definition of text functions.
2939
* 1) have the clear interface for function define.
@@ -141,16 +151,37 @@ private DefaultFunctionResolver upper() {
141151
}
142152

143153
/**
144-
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
145-
* Extend to accept variable argument amounts.
146154
* Concatenates a list of Strings.
147155
* Supports following signatures:
148-
* (STRING, STRING) -> STRING
156+
* (STRING, STRING, ...., STRING) -> STRING
149157
*/
150158
private DefaultFunctionResolver concat() {
151-
return define(BuiltinFunctionName.CONCAT.getName(),
152-
impl(nullMissingHandling((str1, str2) ->
153-
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
159+
FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName();
160+
return define(concatFuncName, funcName ->
161+
Pair.of(
162+
new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)),
163+
(funcProp, args) -> new FunctionExpression(funcName, args) {
164+
@Override
165+
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
166+
List<ExprValue> exprValues = args.stream()
167+
.map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList());
168+
if (exprValues.stream().anyMatch(ExprValue::isMissing)) {
169+
return ExprValueUtils.missingValue();
170+
}
171+
if (exprValues.stream().anyMatch(ExprValue::isNull)) {
172+
return ExprValueUtils.nullValue();
173+
}
174+
return new ExprStringValue(exprValues.stream()
175+
.map(ExprValue::stringValue)
176+
.collect(Collectors.joining()));
177+
}
178+
179+
@Override
180+
public ExprType type() {
181+
return STRING;
182+
}
183+
}
184+
));
154185
}
155186

156187
/**

core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
import static org.junit.jupiter.api.Assertions.assertEquals;
1010
import static org.junit.jupiter.api.Assertions.assertThrows;
1111
import static org.mockito.Mockito.when;
12+
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
13+
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
1214

15+
import com.google.common.collect.ImmutableList;
1316
import com.google.common.collect.ImmutableMap;
17+
import java.util.Collections;
1418
import org.junit.jupiter.api.DisplayNameGeneration;
1519
import org.junit.jupiter.api.DisplayNameGenerator;
1620
import org.junit.jupiter.api.Test;
@@ -76,4 +80,53 @@ void resolve_function_not_match() {
7680
assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]",
7781
exception.getMessage());
7882
}
83+
84+
@Test
85+
void resolve_varargs_function_signature_match() {
86+
functionName = FunctionName.of("concat");
87+
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
88+
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING));
89+
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
90+
91+
DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
92+
ImmutableMap.of(bestMatchFS, bestMatchBuilder));
93+
94+
assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue());
95+
}
96+
97+
@Test
98+
void resolve_varargs_no_args_function_signature_not_match() {
99+
functionName = FunctionName.of("concat");
100+
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
101+
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
102+
// Concat function with no arguments
103+
when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList());
104+
105+
DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
106+
ImmutableMap.of(bestMatchFS, bestMatchBuilder));
107+
108+
ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
109+
() -> resolver.resolve(functionSignature));
110+
assertEquals("concat function expected 1-9 arguments, but got 0",
111+
exception.getMessage());
112+
}
113+
114+
@Test
115+
void resolve_varargs_too_many_args_function_signature_not_match() {
116+
functionName = FunctionName.of("concat");
117+
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
118+
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
119+
// Concat function with more than 9 arguments
120+
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList
121+
.of(STRING, STRING, STRING, STRING, STRING,
122+
STRING, STRING, STRING, STRING, STRING));
123+
124+
DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
125+
ImmutableMap.of(bestMatchFS, bestMatchBuilder));
126+
127+
ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
128+
() -> resolver.resolve(functionSignature));
129+
assertEquals("concat function expected 1-9 arguments, but got 10",
130+
exception.getMessage());
131+
}
79132
}

core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase {
7272
private static List<List<String>> CONCAT_STRING_LISTS = ImmutableList.of(
7373
ImmutableList.of("hello", "world"),
7474
ImmutableList.of("123", "5325"));
75+
private static List<List<String>> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of(
76+
ImmutableList.of("he", "llo", "wo", "rld", "!"),
77+
ImmutableList.of("0", "123", "53", "25", "7"));
7578

7679
interface SubstrSubstring {
7780
FunctionExpression getFunction(SubstringInfo strInfo);
@@ -228,11 +231,13 @@ public void upper() {
228231
@Test
229232
void concat() {
230233
CONCAT_STRING_LISTS.forEach(this::testConcatString);
234+
CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString);
231235

232236
when(nullRef.type()).thenReturn(STRING);
233237
when(missingRef.type()).thenReturn(STRING);
234238
assertEquals(missingValue(), eval(
235239
DSL.concat(missingRef, DSL.literal("1"))));
240+
// If any of the expressions is a NULL value, it returns NULL.
236241
assertEquals(nullValue(), eval(
237242
DSL.concat(nullRef, DSL.literal("1"))));
238243
assertEquals(missingValue(), eval(
@@ -446,6 +451,22 @@ void testConcatString(List<String> strings, String delim) {
446451
assertEquals(expected, eval(expression).stringValue());
447452
}
448453

454+
void testConcatMultipleString(List<String> strings) {
455+
String expected = null;
456+
if (strings.stream().noneMatch(Objects::isNull)) {
457+
expected = String.join("", strings);
458+
}
459+
460+
FunctionExpression expression = DSL.concat(
461+
DSL.literal(strings.get(0)),
462+
DSL.literal(strings.get(1)),
463+
DSL.literal(strings.get(2)),
464+
DSL.literal(strings.get(3)),
465+
DSL.literal(strings.get(4)));
466+
assertEquals(STRING, expression.type());
467+
assertEquals(expected, eval(expression).stringValue());
468+
}
469+
449470
void testLengthString(String str) {
450471
FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str)));
451472
assertEquals(INTEGER, expression.type());

docs/user/dql/functions.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,21 +2614,21 @@ CONCAT
26142614
Description
26152615
>>>>>>>>>>>
26162616

2617-
Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
2617+
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL.
26182618

2619-
Argument type: STRING, STRING
2619+
Argument type: STRING, STRING, ...., STRING
26202620

26212621
Return type: STRING
26222622

26232623
Example::
26242624

2625-
os> SELECT CONCAT('hello', 'world')
2625+
os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null)
26262626
fetched rows / total rows = 1/1
2627-
+----------------------------+
2628-
| CONCAT('hello', 'world') |
2629-
|----------------------------|
2630-
| helloworld |
2631-
+----------------------------+
2627+
+--------------------------------------------+----------------------------+-------------------------+
2628+
| CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) |
2629+
|--------------------------------------------+----------------------------+-------------------------|
2630+
| hello whole world! | helloworld | null |
2631+
+--------------------------------------------+----------------------------+-------------------------+
26322632

26332633

26342634
CONCAT_WS

docs/user/ppl/functions/string.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ CONCAT
1414
Description
1515
>>>>>>>>>>>
1616

17-
Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
17+
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together.
1818

19-
Argument type: STRING, STRING
19+
Argument type: STRING, STRING, ...., STRING
2020

2121
Return type: STRING
2222

2323
Example::
2424

25-
os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')`
25+
os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')`
2626
fetched rows / total rows = 1/1
27-
+----------------------------+
28-
| CONCAT('hello', 'world') |
29-
|----------------------------|
30-
| helloworld |
31-
+----------------------------+
27+
+----------------------------+--------------------------------------------+
28+
| CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') |
29+
|----------------------------+--------------------------------------------|
30+
| helloworld | hello whole world! |
31+
+----------------------------+--------------------------------------------+
3232

3333

3434
CONCAT_WS

integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ public void testLtrim() throws IOException {
9999

100100
@Test
101101
public void testConcat() throws IOException {
102-
verifyQuery("concat", "", ", 'there'",
103-
"hellothere", "worldthere", "helloworldthere");
102+
verifyQuery("concat", "", ", 'there', 'all', '!'",
103+
"hellothereall!", "worldthereall!", "helloworldthereall!");
104104
}
105105

106106
@Test

integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public void testLtrim() throws IOException {
108108

109109
@Test
110110
public void testConcat() throws IOException {
111+
verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!");
111112
verifyQuery("concat('hello', 'world')", "keyword", "helloworld");
112113
verifyQuery("concat('', 'hello')", "keyword", "hello");
113114
}

integ-test/src/test/resources/correctness/expressions/text_functions.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column
1111
LOCATE('world', 'hello') as column
1212
LOCATE('world', 'helloworld', 7) as column
1313
REPLACE('helloworld', 'world', 'opensearch') as column
14-
REPLACE('hello', 'world', 'opensearch') as column
14+
REPLACE('hello', 'world', 'opensearch') as column
15+
CONCAT('hello', 'world') as column
16+
CONCAT('hello ', 'whole ', 'world', '!') as column

0 commit comments

Comments
 (0)