Skip to content

[GIE Compiler] Support Case When Expression in Logical and Physical Plan #2918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 3, 2023
23 changes: 23 additions & 0 deletions interactive_engine/compiler/src/main/antlr4/CypherGS.g4
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,34 @@ oC_Atom
| oC_FunctionInvocation
| oC_CountAny
| oC_Parameter
| oC_CaseExpression
;

oC_Parameter
: '$' ( oC_SymbolicName ) ;

oC_CaseExpression
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_InputExpression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_ElseExpression )? SP? END ;

oC_InputExpression
: oC_Expression ;

oC_ElseExpression
: oC_Expression ;

CASE : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

ELSE : ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

END : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ;

oC_CaseAlternative
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;

WHEN : ( 'W' | 'w' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

THEN : ( 'T' | 't' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

oC_CountAny
: ( COUNT SP? '(' SP? '*' SP? ')' )
;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2020 Alibaba Group Holding Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.graphscope.common.ir.rex.operator;

import static org.apache.calcite.util.Static.RESOURCE;

import static java.util.Objects.requireNonNull;

import com.alibaba.graphscope.common.ir.rex.RexCallBinding;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlTypeUtil;

import java.util.ArrayList;
import java.util.List;

public class CaseOperator extends SqlOperator {

public CaseOperator(SqlOperandTypeInference operandTypeInference) {
super("CASE", SqlKind.CASE, MDX_PRECEDENCE, true, null, operandTypeInference, null);
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
Preconditions.checkArgument(callBinding instanceof RexCallBinding);
boolean foundNotNull = false;
int operandCount = callBinding.getOperandCount();
for (int i = 0; i < operandCount - 1; ++i) {
RelDataType type = callBinding.getOperandType(i);
if ((i & 1) == 0) { // when expression, should be boolean
if (!SqlTypeUtil.inBooleanFamily(type)) {
if (throwOnFailure) {
throw new IllegalArgumentException(
"Expected a boolean type at operand idx = " + i);
}
return false;
}
} else { // then expression
if (!callBinding.isOperandNull(i, false)) {
foundNotNull = true;
}
}
}

if (operandCount > 2 && !callBinding.isOperandNull(operandCount - 1, false)) {
foundNotNull = true;
}

if (!foundNotNull) {
// according to the sql standard we can not have all of the THEN
// statements and the ELSE returning null
if (throwOnFailure && !callBinding.isTypeCoercionEnabled()) {
throw callBinding.newValidationError(RESOURCE.mustNotNullInElse());
}
return false;
}
return true;
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
return inferTypeFromOperands(opBinding);
}

private static RelDataType inferTypeFromOperands(SqlOperatorBinding opBinding) {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final List<RelDataType> argTypes = opBinding.collectOperandTypes();
assert (argTypes.size() % 2) == 1 : "odd number of arguments expected: " + argTypes.size();
assert argTypes.size() > 1
: "CASE must have more than 1 argument. Given " + argTypes.size() + ", " + argTypes;
List<RelDataType> thenTypes = new ArrayList<>();
for (int j = 1; j < (argTypes.size() - 1); j += 2) {
RelDataType argType = argTypes.get(j);
thenTypes.add(argType);
}

thenTypes.add(Iterables.getLast(argTypes));
return requireNonNull(
typeFactory.leastRestrictive(thenTypes),
() -> "Can't find leastRestrictive type for " + thenTypes);
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return SqlOperandCountRanges.any();
}

@Override
public SqlSyntax getSyntax() {
return SqlSyntax.SPECIAL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import com.alibaba.graphscope.gaia.proto.Common;
import com.alibaba.graphscope.gaia.proto.DataType;
import com.alibaba.graphscope.gaia.proto.OuterExpression;
import com.google.common.base.Preconditions;

import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;

/**
Expand All @@ -41,8 +43,38 @@ public OuterExpression.Expression visitCall(RexCall call) {
if (!this.deep) {
return null;
}
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
SqlOperator operator = call.getOperator();
if (operator.getKind() == SqlKind.CASE) {
return visitCase(call);
} else {
return visitOperator(call);
}
}

private OuterExpression.Expression visitCase(RexCall call) {
OuterExpression.Case.Builder caseBuilder = OuterExpression.Case.newBuilder();
int operandCount = call.getOperands().size();
Preconditions.checkArgument(operandCount > 2 && (operandCount & 1) == 1);
for (int i = 1; i < operandCount - 1; i += 2) {
RexNode whenNode = call.getOperands().get(i - 1);
RexNode thenNode = call.getOperands().get(i);
caseBuilder.addWhenThenExpressions(
OuterExpression.Case.WhenThen.newBuilder()
.setWhenExpression(whenNode.accept(this))
.setThenResultExpression(thenNode.accept(this)));
}
caseBuilder.setElseResultExpression(call.getOperands().get(operandCount - 1).accept(this));
return OuterExpression.Expression.newBuilder()
.addOperators(
OuterExpression.ExprOpr.newBuilder()
.setCase(caseBuilder)
.setNodeType(Utils.protoIrDataType(call.getType(), isColumnId)))
.build();
}

private OuterExpression.Expression visitOperator(RexCall call) {
SqlOperator operator = call.getOperator();
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
// left-associative
if (operator.getLeftPrec() <= operator.getRightPrec()) {
for (int i = 0; i < call.getOperands().size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
public abstract class Utils {
public static final Common.Value protoValue(RexLiteral literal) {
switch (literal.getType().getSqlTypeName()) {
case NULL:
return Common.Value.newBuilder().setNone(Common.None.newBuilder().build()).build();
case BOOLEAN:
return Common.Value.newBuilder().setBoolean((Boolean) literal.getValue()).build();
case INTEGER:
Expand Down Expand Up @@ -176,6 +178,8 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator)

public static final Common.DataType protoBasicDataType(RelDataType basicType) {
switch (basicType.getSqlTypeName()) {
case NULL:
return Common.DataType.NONE;
case BOOLEAN:
return Common.DataType.BOOLEAN;
case INTEGER:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,27 +464,29 @@ private RexNode call_(SqlOperator operator, List<RexNode> operandList) {
throw new UnsupportedOperationException(
"operator " + operator.getKind().name() + " not supported");
}
operandList = inferOperandTypes(operator, operandList);
RexCallBinding callBinding =
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
// check count of operands, if fail throw exceptions
operator.validRexOperands(callBinding.getOperandCount(), Litmus.THROW);
// check type of each operand, if fail throw exceptions
operator.checkOperandTypes(callBinding, true);
// derive type
RelDataType type = operator.inferReturnType(callBinding);
// derive return type
RelDataType returnType = operator.inferReturnType(callBinding);
// derive unknown types of operands
operandList = inferOperandTypes(operator, returnType, operandList);
final RexBuilder builder = cluster.getRexBuilder();
return builder.makeCall(type, operator, operandList);
return builder.makeCall(returnType, operator, operandList);
}

private List<RexNode> inferOperandTypes(SqlOperator operator, List<RexNode> operandList) {
private List<RexNode> inferOperandTypes(
SqlOperator operator, RelDataType returnType, List<RexNode> operandList) {
if (operator.getOperandTypeInference() != null
&& operandList.stream()
.anyMatch((t) -> t.getType().getSqlTypeName() == SqlTypeName.UNKNOWN)) {
RexCallBinding callBinding =
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
RelDataType[] newTypes = callBinding.collectOperandTypes().toArray(new RelDataType[0]);
operator.getOperandTypeInference().inferOperandTypes(callBinding, null, newTypes);
operator.getOperandTypeInference().inferOperandTypes(callBinding, returnType, newTypes);
List<RexNode> typeInferredOperands = new ArrayList<>(operandList.size());
GraphRexBuilder rexBuilder = (GraphRexBuilder) this.getRexBuilder();
for (int i = 0; i < operandList.size(); ++i) {
Expand All @@ -507,7 +509,8 @@ private boolean isCurrentSupported(SqlOperator operator) {
|| sqlKind == SqlKind.OR
|| sqlKind == SqlKind.DESCENDING
|| (sqlKind == SqlKind.OTHER_FUNCTION && operator.getName().equals("POWER"))
|| (sqlKind == SqlKind.MINUS_PREFIX);
|| (sqlKind == SqlKind.MINUS_PREFIX)
|| sqlKind == SqlKind.CASE;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.alibaba.graphscope.common.ir.tools;

import com.alibaba.graphscope.common.ir.rex.operator.CaseOperator;

import org.apache.calcite.sql.*;
import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
Expand Down Expand Up @@ -175,4 +177,6 @@ public class GraphStdOperatorTable extends SqlStdOperatorTable {
ReturnTypes.BOOLEAN_NULLABLE,
GraphInferTypes.FIRST_KNOWN,
OperandTypes.COMPARABLE_ORDERED_COMPARABLE_ORDERED);

public static final SqlOperator CASE = new CaseOperator(GraphInferTypes.RETURN_TYPE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import com.alibaba.graphscope.cypher.antlr4.visitor.type.ExprVisitorResult;
import com.alibaba.graphscope.grammar.CypherGSBaseVisitor;
import com.alibaba.graphscope.grammar.CypherGSParser;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
Expand Down Expand Up @@ -255,6 +257,35 @@ public ExprVisitorResult visitOC_CountAny(CypherGSParser.OC_CountAnyContext ctx)
RexTmpVariable.of(alias, ((GraphAggCall) aggCall).getType()));
}

@Override
public ExprVisitorResult visitOC_CaseExpression(CypherGSParser.OC_CaseExpressionContext ctx) {
ExprVisitorResult inputExpr =
ctx.oC_InputExpression() == null
? null
: visitOC_InputExpression(ctx.oC_InputExpression());
List<RexNode> operands = Lists.newArrayList();
for (CypherGSParser.OC_CaseAlternativeContext whenThen : ctx.oC_CaseAlternative()) {
Preconditions.checkArgument(
whenThen.oC_Expression().size() == 2,
"whenThen expression should have 2 parts");
ExprVisitorResult whenExpr = visitOC_Expression(whenThen.oC_Expression(0));
if (inputExpr != null) {
operands.add(builder.equals(inputExpr.getExpr(), whenExpr.getExpr()));
} else {
operands.add(whenExpr.getExpr());
}
ExprVisitorResult thenExpr = visitOC_Expression(whenThen.oC_Expression(1));
operands.add(thenExpr.getExpr());
}
// if else expression is omitted, the default value is null
ExprVisitorResult elseExpr =
ctx.oC_ElseExpression() == null
? new ExprVisitorResult(builder.literal(null))
: visitOC_ElseExpression(ctx.oC_ElseExpression());
operands.add(elseExpr.getExpr());
return new ExprVisitorResult(builder.call(GraphStdOperatorTable.CASE, operands));
}

private ExprVisitorResult binaryCall(
List<SqlOperator> operators, List<ExprVisitorResult> operands) {
ObjectUtils.requireNonEmpty(operands, "operands count should not be 0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,22 @@ private GraphInferTypes() {}
}
Arrays.fill(operandTypes, knownType);
};

/**
* Operand type-inference strategy where an unknown operand type is derived
* from the call's return type. If the return type is a record, it must have
* the same number of fields as the number of operands.
*/
public static final SqlOperandTypeInference RETURN_TYPE =
(callBinding, returnType, operandTypes) -> {
RelDataType unknownType = callBinding.getTypeFactory().createUnknownType();
for (int i = 0; i < operandTypes.length; ++i) {
if (operandTypes[i].equals(unknownType)) {
operandTypes[i] =
returnType.isStruct()
? returnType.getFieldList().get(i).getType()
: returnType;
}
}
};
}
Loading