Skip to content

Commit b640f0b

Browse files
authored
SQL: Fix unecessary evaluation for CASE/IIF (#57159) (#57264)
Previously, `CASE` and `IIF` when translated to painless scripts (used in GROUP BY, HAVING, WHERE) a custom `caseFunction` registered in the `InternalSqlScriptUtils` was used. This function received and array of arbitrary length: ```[condition1, result1, condition2, result2, ... elseResult]``` Painless doesn't know of the context and therefore is evaluating all conditions and results before invoking the `caseFunction` on them. As a consequence, erroneous result expressions (i.e. division by 0) where always evaluated despite of the guarding condition. Replace the `caseFunction` with painless `<cond> ? <res1> : <res2>` expressions to properly guard the result expressions and only evaluate the one for which its guarding condition evaluates to true (or of course the elseResult). As a bonus, this approach includes performance benefits since we avoid unnecessary evaluations of both conditions and result expressions. Fixes: #49672 (cherry picked from commit 9584b34)
1 parent 17e19ef commit b640f0b

File tree

6 files changed

+66
-32
lines changed

6 files changed

+66
-32
lines changed

x-pack/plugin/sql/qa/server/src/main/resources/conditionals.csv-spec

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,24 @@ END as lang_skills FROM test_emp GROUP BY lang_skills ORDER BY 2;
197197
10 | zero
198198
;
199199

200+
caseGroupByProtectedDivisionByZero
201+
schema::x:i
202+
SELECT CASE WHEN languages = 1 THEN NULL ELSE ( salary / (languages - 1) ) END AS x FROM test_emp GROUP BY 1 ORDER BY 1 LIMIT 10;
203+
204+
x
205+
---------------
206+
null
207+
6331
208+
6486
209+
7780
210+
7974
211+
8068
212+
8489
213+
8935
214+
9043
215+
9071
216+
;
217+
200218
caseGroupByAndHaving
201219
schema::count:l|gender:s|languages:byte
202220
SELECT count(*) AS count, gender, languages FROM test_emp
@@ -353,6 +371,28 @@ IIF(NVL(languages, 0) = 0, 'zero',
353371
10 |zero
354372
;
355373

374+
iifGroupByProtectedDivisionByZero
375+
schema::count:l|x:i
376+
SELECT count(*) AS count,
377+
IIF(languages - 1 = 0, 0,
378+
IIF(languages - 1 = 1, (salary / 10000) / (languages - 1),
379+
IIF(languages - 1 = 2, (salary / 10000) / languages,
380+
IIF(languages - 1 = 3, (salary / 10000) / (languages + 1),
381+
(salary / 10000) / (languages + 2))))) as x FROM test_emp GROUP BY x ORDER BY 2;
382+
383+
count | x
384+
---------------+---------------
385+
10 |null
386+
50 |0
387+
14 |1
388+
8 |2
389+
4 |3
390+
5 |4
391+
6 |5
392+
2 |6
393+
1 |7
394+
;
395+
356396
iifGroupByAndHaving
357397
schema::count:l|gender:s|languages:byte
358398
SELECT count(*) AS count, gender, languages FROM test_emp

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
3737
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalDayTime;
3838
import org.elasticsearch.xpack.sql.expression.literal.interval.IntervalYearMonth;
39-
import org.elasticsearch.xpack.sql.expression.predicate.conditional.CaseProcessor;
4039
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
4140
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIfProcessor;
4241
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.SqlBinaryArithmeticOperation;
@@ -65,10 +64,6 @@ public class InternalSqlScriptUtils extends InternalQlScriptUtils {
6564
//
6665
// Conditional
6766
//
68-
public static Object caseFunction(List<Object> expressions) {
69-
return CaseProcessor.apply(expressions);
70-
}
71-
7267
public static Object coalesce(List<Object> expressions) {
7368
return ConditionalOperation.COALESCE.apply(expressions);
7469
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/Case.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
1111
import org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder;
1212
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
13+
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
1314
import org.elasticsearch.xpack.ql.tree.NodeInfo;
1415
import org.elasticsearch.xpack.ql.tree.Source;
1516
import org.elasticsearch.xpack.ql.type.DataType;
@@ -19,7 +20,6 @@
1920

2021
import java.util.ArrayList;
2122
import java.util.List;
22-
import java.util.StringJoiner;
2323

2424
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2525
import static org.elasticsearch.xpack.ql.expression.gen.script.ParamsBuilder.paramsBuilder;
@@ -161,14 +161,26 @@ public ScriptTemplate asScript() {
161161
}
162162
templates.add(asScript(elseResult));
163163

164-
StringJoiner template = new StringJoiner(",", "{sql}.caseFunction([", "])");
164+
// Use painless ?: expressions to prevent evaluation of return expression
165+
// if the condition which guards it evaluates to false (e.g. division by 0)
166+
StringBuilder sb = new StringBuilder();
165167
ParamsBuilder params = paramsBuilder();
166-
167-
for (ScriptTemplate scriptTemplate : templates) {
168-
template.add(scriptTemplate.template());
168+
for (int i = 0; i < templates.size(); i++) {
169+
ScriptTemplate scriptTemplate = templates.get(i);
170+
if (i < templates.size() - 1) {
171+
if (i % 2 == 0) {
172+
// painless ? : operator expects primitive boolean, thus we use nullSafeFilter
173+
// to convert object Boolean to primitive boolean (null => false)
174+
sb.append(Scripts.nullSafeFilter(scriptTemplate).template()).append(" ? ");
175+
} else {
176+
sb.append(scriptTemplate.template()).append(" : ");
177+
}
178+
} else {
179+
sb.append(scriptTemplate.template());
180+
}
169181
params.script(scriptTemplate.params());
170182
}
171183

172-
return new ScriptTemplate(formatTemplate(template.toString()), params.build(), dataType());
184+
return new ScriptTemplate(formatTemplate(sb.toString()), params.build(), dataType());
173185
}
174186
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/CaseProcessor.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,6 @@ public Object process(Object input) {
5050
return processors.get(processors.size() - 1).process(input);
5151
}
5252

53-
public static Object apply(List<Object> objects) {
54-
// Check every condition in sequence and if it evaluates to TRUE,
55-
// evaluate and return the result associated with that condition.
56-
for (int i = 0; i < objects.size() - 2; i += 2) {
57-
if (objects.get(i) == Boolean.TRUE) {
58-
return objects.get(i + 1);
59-
}
60-
}
61-
// resort to default value
62-
return objects.get(objects.size() - 1);
63-
}
64-
6553
@Override
6654
public boolean equals(Object o) {
6755
if (this == o) {

x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
6666
#
6767
# Conditional
6868
#
69-
def caseFunction(java.util.List)
7069
def coalesce(java.util.List)
7170
def greatest(java.util.List)
7271
def least(java.util.List)

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,9 @@ public void testLikeRLikeAsPainlessScripts() {
626626
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
627627
assertNotNull(groupingContext);
628628
ScriptTemplate scriptTemplate = groupingContext.tail.script();
629-
assertEquals("InternalSqlScriptUtils.caseFunction([InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue("
630-
+ "doc,params.v0),params.v1),params.v2,InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(" +
631-
"doc,params.v3),params.v4),params.v5,params.v6])",
629+
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils.docValue(doc,params.v0)," +
630+
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.regex(InternalQlScriptUtils." +
631+
"docValue(doc,params.v3),params.v4)) ? params.v5 : params.v6",
632632
scriptTemplate.toString());
633633
assertEquals("[{v=keyword}, {v=^.*foo.*$}, {v=1}, {v=keyword}, {v=.*bar.*}, {v=2}, {v=3}]",
634634
scriptTemplate.params().toString());
@@ -1086,9 +1086,9 @@ public void testTranslateCase_GroupBy_Painless() {
10861086
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
10871087
assertNotNull(groupingContext);
10881088
ScriptTemplate scriptTemplate = groupingContext.tail.script();
1089-
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" + ""
1090-
+ "doc,params.v0),params.v1),params.v2,InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v3)," +
1091-
"params.v4),params.v5,params.v6])",
1089+
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
1090+
"params.v1)) ? params.v2 : InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(" +
1091+
"doc,params.v3),params.v4)) ? params.v5 : params.v6",
10921092
scriptTemplate.toString());
10931093
assertEquals("[{v=int}, {v=10}, {v=foo}, {v=int}, {v=20}, {v=bar}, {v=default}]", scriptTemplate.params().toString());
10941094
}
@@ -1101,8 +1101,8 @@ public void testTranslateIif_GroupBy_Painless() {
11011101
GroupingContext groupingContext = QueryFolder.FoldAggregate.groupBy(((Aggregate) p).groupings());
11021102
assertNotNull(groupingContext);
11031103
ScriptTemplate scriptTemplate = groupingContext.tail.script();
1104-
assertEquals("InternalSqlScriptUtils.caseFunction([InternalQlScriptUtils.gt(" +
1105-
"InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2,params.v3])",
1104+
assertEquals("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(InternalQlScriptUtils.docValue(doc,params.v0)," +
1105+
"params.v1)) ? params.v2 : params.v3",
11061106
scriptTemplate.toString());
11071107
assertEquals("[{v=int}, {v=20}, {v=foo}, {v=bar}]", scriptTemplate.params().toString());
11081108
}

0 commit comments

Comments
 (0)