Skip to content
This repository was archived by the owner on Aug 2, 2022. It is now read-only.

Commit 64c7bd6

Browse files
authored
Fix CASE clause pushdown issue (#895)
* Support case when pushdown * Add more comparison test * Relax type check for null * Prepare PR * Prepare PR * Fix Literal.toString() NPE issue
1 parent 8a44305 commit 64c7bd6

File tree

12 files changed

+260
-15
lines changed

12 files changed

+260
-15
lines changed

core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java

+29
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
2323
import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression;
2424
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator;
25+
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause;
26+
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause;
2527
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
2628
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalAggregation;
2729
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalPlan;
@@ -87,6 +89,33 @@ public Expression visitAggregator(Aggregator<?> node, AnalysisContext context) {
8789
return expressionMap.getOrDefault(node, node);
8890
}
8991

92+
/**
93+
* Implement this because Case/When is not registered in function repository.
94+
*/
95+
@Override
96+
public Expression visitCase(CaseClause node, AnalysisContext context) {
97+
if (expressionMap.containsKey(node)) {
98+
return expressionMap.get(node);
99+
}
100+
101+
List<WhenClause> whenClauses = node.getWhenClauses()
102+
.stream()
103+
.map(expr -> (WhenClause) expr.accept(this, context))
104+
.collect(Collectors.toList());
105+
Expression defaultResult = null;
106+
if (node.getDefaultResult() != null) {
107+
defaultResult = node.getDefaultResult().accept(this, context);
108+
}
109+
return new CaseClause(whenClauses, defaultResult);
110+
}
111+
112+
@Override
113+
public Expression visitWhen(WhenClause node, AnalysisContext context) {
114+
return new WhenClause(
115+
node.getCondition().accept(this, context),
116+
node.getResult().accept(this, context));
117+
}
118+
90119

91120
/**
92121
* Expression Map Builder.

core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/Literal.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import lombok.EqualsAndHashCode;
2222
import lombok.Getter;
2323
import lombok.RequiredArgsConstructor;
24-
import lombok.ToString;
2524

2625
/**
2726
* Expression node of literal type
@@ -48,6 +47,6 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
4847

4948
@Override
5049
public String toString() {
51-
return value.toString();
50+
return String.valueOf(value);
5251
}
5352
}

core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/ExpressionNodeVisitor.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,18 @@ public T visitNamedAggregator(NamedAggregator node, C context) {
8181
return visitChildren(node, context);
8282
}
8383

84+
/**
85+
* Call visitFunction() by default rather than visitChildren().
86+
* This makes CASE/WHEN able to be handled:
87+
* 1) by visitFunction() if not overwritten: ex. FilterQueryBuilder
88+
* 2) by visitCase/When() otherwise if any special logic: ex. ExprReferenceOptimizer
89+
*/
8490
public T visitCase(CaseClause node, C context) {
85-
return visitNode(node, context);
91+
return visitFunction(node, context);
8692
}
8793

8894
public T visitWhen(WhenClause node, C context) {
89-
return visitNode(node, context);
95+
return visitFunction(node, context);
9096
}
9197

9298
}

core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/conditional/cases/CaseClause.java

+36-6
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616

1717
package com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases;
1818

19+
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.UNKNOWN;
20+
1921
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprNullValue;
2022
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
2123
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
2224
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
2325
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
26+
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
2427
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
28+
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
29+
import com.google.common.collect.ImmutableList;
2530
import java.util.List;
2631
import java.util.stream.Collectors;
27-
import lombok.AllArgsConstructor;
2832
import lombok.EqualsAndHashCode;
2933
import lombok.Getter;
3034
import lombok.ToString;
@@ -33,11 +37,10 @@
3337
* A CASE clause is very different from a regular function. Functions have well-defined signature,
3438
* though CASE clause is more like a function implementation which requires type check "manually".
3539
*/
36-
@AllArgsConstructor
37-
@EqualsAndHashCode
40+
@EqualsAndHashCode(callSuper = false)
3841
@Getter
3942
@ToString
40-
public class CaseClause implements Expression {
43+
public class CaseClause extends FunctionExpression {
4144

4245
/**
4346
* List of WHEN clauses.
@@ -49,6 +52,15 @@ public class CaseClause implements Expression {
4952
*/
5053
private final Expression defaultResult;
5154

55+
/**
56+
* Initialize case clause.
57+
*/
58+
public CaseClause(List<WhenClause> whenClauses, Expression defaultResult) {
59+
super(FunctionName.of("case"), concatArgs(whenClauses, defaultResult));
60+
this.whenClauses = whenClauses;
61+
this.defaultResult = defaultResult;
62+
}
63+
5264
@Override
5365
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
5466
for (WhenClause when : whenClauses) {
@@ -61,7 +73,10 @@ public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
6173

6274
@Override
6375
public ExprType type() {
64-
return whenClauses.get(0).type();
76+
List<ExprType> types = allResultTypes();
77+
78+
// Return unknown if all WHEN/ELSE return NULL
79+
return types.isEmpty() ? UNKNOWN : types.get(0);
6580
}
6681

6782
@Override
@@ -71,7 +86,9 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
7186

7287
/**
7388
* Get types of each result in WHEN clause and ELSE clause.
74-
* @return all result types
89+
* Exclude UNKNOWN type from NULL literal which means NULL in THEN or ELSE clause
90+
* is not included in result.
91+
* @return all result types. Use list so caller can generate friendly error message.
7592
*/
7693
public List<ExprType> allResultTypes() {
7794
List<ExprType> types = whenClauses.stream()
@@ -80,7 +97,20 @@ public List<ExprType> allResultTypes() {
8097
if (defaultResult != null) {
8198
types.add(defaultResult.type());
8299
}
100+
101+
types.removeIf(type -> (type == UNKNOWN));
83102
return types;
84103
}
85104

105+
private static List<Expression> concatArgs(List<WhenClause> whenClauses,
106+
Expression defaultResult) {
107+
ImmutableList.Builder<Expression> args = ImmutableList.builder();
108+
whenClauses.forEach(args::add);
109+
110+
if (defaultResult != null) {
111+
args.add(defaultResult);
112+
}
113+
return args.build();
114+
}
115+
86116
}

core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/conditional/cases/WhenClause.java

+24-5
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
2121
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
2222
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
23+
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
2324
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
25+
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
26+
import com.google.common.collect.ImmutableList;
2427
import lombok.EqualsAndHashCode;
2528
import lombok.Getter;
26-
import lombok.RequiredArgsConstructor;
2729
import lombok.ToString;
2830

2931
/**
3032
* WHEN clause that consists of a condition and a result corresponding.
3133
*/
32-
@EqualsAndHashCode
34+
@EqualsAndHashCode(callSuper = false)
3335
@Getter
34-
@RequiredArgsConstructor
3536
@ToString
36-
public class WhenClause implements Expression {
37+
public class WhenClause extends FunctionExpression {
3738

3839
/**
3940
* Condition that must be a predicate.
@@ -45,8 +46,26 @@ public class WhenClause implements Expression {
4546
*/
4647
private final Expression result;
4748

49+
/**
50+
* Initialize when clause.
51+
*/
52+
public WhenClause(Expression condition, Expression result) {
53+
super(FunctionName.of("when"), ImmutableList.of(condition, result));
54+
this.condition = condition;
55+
this.result = result;
56+
}
57+
58+
/**
59+
* Evaluate when condition.
60+
* @param valueEnv value env
61+
* @return is condition satisfied
62+
*/
4863
public boolean isTrue(Environment<Expression, ExprValue> valueEnv) {
49-
return condition.valueOf(valueEnv).booleanValue();
64+
ExprValue result = condition.valueOf(valueEnv);
65+
if (result.isMissing() || result.isNull()) {
66+
return false;
67+
}
68+
return result.booleanValue();
5069
}
5170

5271
@Override

core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java

+19
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
3232
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
3333
import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxCheckException;
34+
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
3435
import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException;
3536
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
3637
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
38+
import com.amazon.opendistroforelasticsearch.sql.expression.LiteralExpression;
3739
import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig;
3840
import org.junit.jupiter.api.Test;
3941
import org.junit.jupiter.api.extension.ExtendWith;
@@ -185,6 +187,23 @@ public void all_fields() {
185187
AllFields.of());
186188
}
187189

190+
@Test
191+
public void case_clause() {
192+
assertAnalyzeEqual(
193+
DSL.cases(
194+
DSL.literal(ExprValueUtils.nullValue()),
195+
DSL.when(
196+
dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(30)),
197+
DSL.literal("test"))),
198+
AstDSL.caseWhen(
199+
AstDSL.nullLiteral(),
200+
AstDSL.when(
201+
AstDSL.function("=",
202+
AstDSL.qualifiedName("integer_value"),
203+
AstDSL.intLiteral(30)),
204+
AstDSL.stringLiteral("test"))));
205+
}
206+
188207
@Test
189208
public void skip_struct_data_type() {
190209
SyntaxCheckException exception =

core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizerTest.java

+71
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE;
2121
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
22+
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
2223
import static java.util.Collections.emptyList;
2324
import static org.junit.jupiter.api.Assertions.assertEquals;
2425

@@ -72,6 +73,76 @@ void aggregation_in_expression_should_be_replaced() {
7273
);
7374
}
7475

76+
@Test
77+
void case_clause_should_be_replaced() {
78+
Expression caseClause = DSL.cases(
79+
null,
80+
DSL.when(
81+
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
82+
DSL.literal("true")));
83+
84+
LogicalPlan logicalPlan =
85+
LogicalPlanDSL.aggregation(
86+
LogicalPlanDSL.relation("test"),
87+
emptyList(),
88+
ImmutableList.of(DSL.named(
89+
"CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")],"
90+
+ " defaultResult=null)",
91+
caseClause)));
92+
93+
assertEquals(
94+
DSL.ref(
95+
"CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")],"
96+
+ " defaultResult=null)", STRING),
97+
optimize(caseClause, logicalPlan));
98+
}
99+
100+
@Test
101+
void aggregation_in_case_when_clause_should_be_replaced() {
102+
Expression caseClause = DSL.cases(
103+
null,
104+
DSL.when(
105+
dsl.equal(dsl.avg(DSL.ref("age", INTEGER)), DSL.literal(30)),
106+
DSL.literal("true")));
107+
108+
LogicalPlan logicalPlan =
109+
LogicalPlanDSL.aggregation(
110+
LogicalPlanDSL.relation("test"),
111+
ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))),
112+
ImmutableList.of(DSL.named("name", DSL.ref("name", STRING))));
113+
114+
assertEquals(
115+
DSL.cases(
116+
null,
117+
DSL.when(
118+
dsl.equal(DSL.ref("AVG(age)", DOUBLE), DSL.literal(30)),
119+
DSL.literal("true"))),
120+
optimize(caseClause, logicalPlan));
121+
}
122+
123+
@Test
124+
void aggregation_in_case_else_clause_should_be_replaced() {
125+
Expression caseClause = DSL.cases(
126+
dsl.avg(DSL.ref("age", INTEGER)),
127+
DSL.when(
128+
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
129+
DSL.literal("true")));
130+
131+
LogicalPlan logicalPlan =
132+
LogicalPlanDSL.aggregation(
133+
LogicalPlanDSL.relation("test"),
134+
ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))),
135+
ImmutableList.of(DSL.named("name", DSL.ref("name", STRING))));
136+
137+
assertEquals(
138+
DSL.cases(
139+
DSL.ref("AVG(age)", DOUBLE),
140+
DSL.when(
141+
dsl.equal(DSL.ref("age", INTEGER), DSL.literal(30)),
142+
DSL.literal("true"))),
143+
optimize(caseClause, logicalPlan));
144+
}
145+
75146
@Test
76147
void window_expression_should_be_replaced() {
77148
LogicalPlan logicalPlan =

core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/conditional/cases/CaseClauseTest.java

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,26 @@ void should_use_type_of_when_clause() {
7575
assertEquals(ExprCoreType.INTEGER, caseClause.type());
7676
}
7777

78+
@Test
79+
void should_use_type_of_nonnull_when_or_else_clause() {
80+
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
81+
Expression defaultResult = mock(Expression.class);
82+
when(defaultResult.type()).thenReturn(ExprCoreType.STRING);
83+
84+
CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
85+
assertEquals(ExprCoreType.STRING, caseClause.type());
86+
}
87+
88+
@Test
89+
void should_use_unknown_type_of_if_all_when_and_else_return_null() {
90+
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
91+
Expression defaultResult = mock(Expression.class);
92+
when(defaultResult.type()).thenReturn(ExprCoreType.UNKNOWN);
93+
94+
CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
95+
assertEquals(ExprCoreType.UNKNOWN, caseClause.type());
96+
}
97+
7898
@Test
7999
void should_return_all_result_types_including_default() {
80100
when(whenClause.type()).thenReturn(ExprCoreType.INTEGER);
@@ -87,4 +107,16 @@ void should_return_all_result_types_including_default() {
87107
caseClause.allResultTypes());
88108
}
89109

110+
@Test
111+
void should_return_all_result_types_excluding_null_result() {
112+
when(whenClause.type()).thenReturn(ExprCoreType.UNKNOWN);
113+
Expression defaultResult = mock(Expression.class);
114+
when(defaultResult.type()).thenReturn(ExprCoreType.UNKNOWN);
115+
116+
CaseClause caseClause = new CaseClause(ImmutableList.of(whenClause), defaultResult);
117+
assertEquals(
118+
ImmutableList.of(),
119+
caseClause.allResultTypes());
120+
}
121+
90122
}

0 commit comments

Comments
 (0)