Skip to content

Commit 80cf33b

Browse files
authored
Merge branch 'main' into 2836-api-prototype
2 parents 8476cc9 + 4694963 commit 80cf33b

File tree

18 files changed

+665
-12
lines changed

18 files changed

+665
-12
lines changed

coordinator/requirements-dev.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
black>=23.3.0
22
flake8==4.0.1
33
isort==5.10.1
4+
# Avoid the `InvalidVersion` error of `setuptools`
5+
# ref: https://github.com/pypa/setuptools/issues/3772
6+
setuptools==65.7.0

coordinator/requirements.txt

-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,5 @@ PyYAML
77
vineyard==0.15.0;sys_platform!="win32"
88
vineyard-io==0.15.0;sys_platform!="win32"
99
prometheus-client>=0.14.1
10-
# Avoid the `InvalidVersion` error of `setuptools`
11-
# ref: https://github.com/pypa/setuptools/issues/3772
12-
setuptools==65.7.0
1310
packaging
1411
tqdm

interactive_engine/compiler/src/main/antlr4/CypherGS.g4

+23
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,34 @@ oC_Atom
250250
| oC_FunctionInvocation
251251
| oC_CountAny
252252
| oC_Parameter
253+
| oC_CaseExpression
253254
;
254255

255256
oC_Parameter
256257
: '$' ( oC_SymbolicName ) ;
257258

259+
oC_CaseExpression
260+
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_InputExpression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_ElseExpression )? SP? END ;
261+
262+
oC_InputExpression
263+
: oC_Expression ;
264+
265+
oC_ElseExpression
266+
: oC_Expression ;
267+
268+
CASE : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;
269+
270+
ELSE : ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;
271+
272+
END : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ;
273+
274+
oC_CaseAlternative
275+
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;
276+
277+
WHEN : ( 'W' | 'w' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;
278+
279+
THEN : ( 'T' | 't' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;
280+
258281
oC_CountAny
259282
: ( COUNT SP? '(' SP? '*' SP? ')' )
260283
;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright 2020 Alibaba Group Holding Limited.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.graphscope.common.ir.rex.operator;
18+
19+
import static org.apache.calcite.util.Static.RESOURCE;
20+
21+
import static java.util.Objects.requireNonNull;
22+
23+
import com.alibaba.graphscope.common.ir.rex.RexCallBinding;
24+
import com.google.common.base.Preconditions;
25+
import com.google.common.collect.Iterables;
26+
27+
import org.apache.calcite.rel.type.RelDataType;
28+
import org.apache.calcite.rel.type.RelDataTypeFactory;
29+
import org.apache.calcite.sql.*;
30+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
31+
import org.apache.calcite.sql.type.SqlOperandTypeInference;
32+
import org.apache.calcite.sql.type.SqlTypeUtil;
33+
34+
import java.util.ArrayList;
35+
import java.util.List;
36+
37+
public class CaseOperator extends SqlOperator {
38+
39+
public CaseOperator(SqlOperandTypeInference operandTypeInference) {
40+
super("CASE", SqlKind.CASE, MDX_PRECEDENCE, true, null, operandTypeInference, null);
41+
}
42+
43+
@Override
44+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
45+
Preconditions.checkArgument(callBinding instanceof RexCallBinding);
46+
boolean foundNotNull = false;
47+
int operandCount = callBinding.getOperandCount();
48+
for (int i = 0; i < operandCount - 1; ++i) {
49+
RelDataType type = callBinding.getOperandType(i);
50+
if ((i & 1) == 0) { // when expression, should be boolean
51+
if (!SqlTypeUtil.inBooleanFamily(type)) {
52+
if (throwOnFailure) {
53+
throw new IllegalArgumentException(
54+
"Expected a boolean type at operand idx = " + i);
55+
}
56+
return false;
57+
}
58+
} else { // then expression
59+
if (!callBinding.isOperandNull(i, false)) {
60+
foundNotNull = true;
61+
}
62+
}
63+
}
64+
65+
if (operandCount > 2 && !callBinding.isOperandNull(operandCount - 1, false)) {
66+
foundNotNull = true;
67+
}
68+
69+
if (!foundNotNull) {
70+
// according to the sql standard we can not have all of the THEN
71+
// statements and the ELSE returning null
72+
if (throwOnFailure && !callBinding.isTypeCoercionEnabled()) {
73+
throw callBinding.newValidationError(RESOURCE.mustNotNullInElse());
74+
}
75+
return false;
76+
}
77+
return true;
78+
}
79+
80+
@Override
81+
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
82+
return inferTypeFromOperands(opBinding);
83+
}
84+
85+
private static RelDataType inferTypeFromOperands(SqlOperatorBinding opBinding) {
86+
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
87+
final List<RelDataType> argTypes = opBinding.collectOperandTypes();
88+
assert (argTypes.size() % 2) == 1 : "odd number of arguments expected: " + argTypes.size();
89+
assert argTypes.size() > 1
90+
: "CASE must have more than 1 argument. Given " + argTypes.size() + ", " + argTypes;
91+
List<RelDataType> thenTypes = new ArrayList<>();
92+
for (int j = 1; j < (argTypes.size() - 1); j += 2) {
93+
RelDataType argType = argTypes.get(j);
94+
thenTypes.add(argType);
95+
}
96+
97+
thenTypes.add(Iterables.getLast(argTypes));
98+
return requireNonNull(
99+
typeFactory.leastRestrictive(thenTypes),
100+
() -> "Can't find leastRestrictive type for " + thenTypes);
101+
}
102+
103+
@Override
104+
public SqlOperandCountRange getOperandCountRange() {
105+
return SqlOperandCountRanges.any();
106+
}
107+
108+
@Override
109+
public SqlSyntax getSyntax() {
110+
return SqlSyntax.SPECIAL;
111+
}
112+
}

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java

+33-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import com.alibaba.graphscope.gaia.proto.Common;
2222
import com.alibaba.graphscope.gaia.proto.DataType;
2323
import com.alibaba.graphscope.gaia.proto.OuterExpression;
24+
import com.google.common.base.Preconditions;
2425

2526
import org.apache.calcite.rex.*;
27+
import org.apache.calcite.sql.SqlKind;
2628
import org.apache.calcite.sql.SqlOperator;
2729

2830
/**
@@ -41,8 +43,38 @@ public OuterExpression.Expression visitCall(RexCall call) {
4143
if (!this.deep) {
4244
return null;
4345
}
44-
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
4546
SqlOperator operator = call.getOperator();
47+
if (operator.getKind() == SqlKind.CASE) {
48+
return visitCase(call);
49+
} else {
50+
return visitOperator(call);
51+
}
52+
}
53+
54+
private OuterExpression.Expression visitCase(RexCall call) {
55+
OuterExpression.Case.Builder caseBuilder = OuterExpression.Case.newBuilder();
56+
int operandCount = call.getOperands().size();
57+
Preconditions.checkArgument(operandCount > 2 && (operandCount & 1) == 1);
58+
for (int i = 1; i < operandCount - 1; i += 2) {
59+
RexNode whenNode = call.getOperands().get(i - 1);
60+
RexNode thenNode = call.getOperands().get(i);
61+
caseBuilder.addWhenThenExpressions(
62+
OuterExpression.Case.WhenThen.newBuilder()
63+
.setWhenExpression(whenNode.accept(this))
64+
.setThenResultExpression(thenNode.accept(this)));
65+
}
66+
caseBuilder.setElseResultExpression(call.getOperands().get(operandCount - 1).accept(this));
67+
return OuterExpression.Expression.newBuilder()
68+
.addOperators(
69+
OuterExpression.ExprOpr.newBuilder()
70+
.setCase(caseBuilder)
71+
.setNodeType(Utils.protoIrDataType(call.getType(), isColumnId)))
72+
.build();
73+
}
74+
75+
private OuterExpression.Expression visitOperator(RexCall call) {
76+
SqlOperator operator = call.getOperator();
77+
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
4678
// left-associative
4779
if (operator.getLeftPrec() <= operator.getRightPrec()) {
4880
for (int i = 0; i < call.getOperands().size(); ++i) {

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
public abstract class Utils {
3939
public static final Common.Value protoValue(RexLiteral literal) {
4040
switch (literal.getType().getSqlTypeName()) {
41+
case NULL:
42+
return Common.Value.newBuilder().setNone(Common.None.newBuilder().build()).build();
4143
case BOOLEAN:
4244
return Common.Value.newBuilder().setBoolean((Boolean) literal.getValue()).build();
4345
case INTEGER:
@@ -176,6 +178,8 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator)
176178

177179
public static final Common.DataType protoBasicDataType(RelDataType basicType) {
178180
switch (basicType.getSqlTypeName()) {
181+
case NULL:
182+
return Common.DataType.NONE;
179183
case BOOLEAN:
180184
return Common.DataType.BOOLEAN;
181185
case INTEGER:

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java

+10-7
Original file line numberDiff line numberDiff line change
@@ -464,27 +464,29 @@ private RexNode call_(SqlOperator operator, List<RexNode> operandList) {
464464
throw new UnsupportedOperationException(
465465
"operator " + operator.getKind().name() + " not supported");
466466
}
467-
operandList = inferOperandTypes(operator, operandList);
468467
RexCallBinding callBinding =
469468
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
470469
// check count of operands, if fail throw exceptions
471470
operator.validRexOperands(callBinding.getOperandCount(), Litmus.THROW);
472471
// check type of each operand, if fail throw exceptions
473472
operator.checkOperandTypes(callBinding, true);
474-
// derive type
475-
RelDataType type = operator.inferReturnType(callBinding);
473+
// derive return type
474+
RelDataType returnType = operator.inferReturnType(callBinding);
475+
// derive unknown types of operands
476+
operandList = inferOperandTypes(operator, returnType, operandList);
476477
final RexBuilder builder = cluster.getRexBuilder();
477-
return builder.makeCall(type, operator, operandList);
478+
return builder.makeCall(returnType, operator, operandList);
478479
}
479480

480-
private List<RexNode> inferOperandTypes(SqlOperator operator, List<RexNode> operandList) {
481+
private List<RexNode> inferOperandTypes(
482+
SqlOperator operator, RelDataType returnType, List<RexNode> operandList) {
481483
if (operator.getOperandTypeInference() != null
482484
&& operandList.stream()
483485
.anyMatch((t) -> t.getType().getSqlTypeName() == SqlTypeName.UNKNOWN)) {
484486
RexCallBinding callBinding =
485487
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
486488
RelDataType[] newTypes = callBinding.collectOperandTypes().toArray(new RelDataType[0]);
487-
operator.getOperandTypeInference().inferOperandTypes(callBinding, null, newTypes);
489+
operator.getOperandTypeInference().inferOperandTypes(callBinding, returnType, newTypes);
488490
List<RexNode> typeInferredOperands = new ArrayList<>(operandList.size());
489491
GraphRexBuilder rexBuilder = (GraphRexBuilder) this.getRexBuilder();
490492
for (int i = 0; i < operandList.size(); ++i) {
@@ -507,7 +509,8 @@ private boolean isCurrentSupported(SqlOperator operator) {
507509
|| sqlKind == SqlKind.OR
508510
|| sqlKind == SqlKind.DESCENDING
509511
|| (sqlKind == SqlKind.OTHER_FUNCTION && operator.getName().equals("POWER"))
510-
|| (sqlKind == SqlKind.MINUS_PREFIX);
512+
|| (sqlKind == SqlKind.MINUS_PREFIX)
513+
|| sqlKind == SqlKind.CASE;
511514
}
512515

513516
@Override

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.alibaba.graphscope.common.ir.tools;
1818

19+
import com.alibaba.graphscope.common.ir.rex.operator.CaseOperator;
20+
1921
import org.apache.calcite.sql.*;
2022
import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator;
2123
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -175,4 +177,6 @@ public class GraphStdOperatorTable extends SqlStdOperatorTable {
175177
ReturnTypes.BOOLEAN_NULLABLE,
176178
GraphInferTypes.FIRST_KNOWN,
177179
OperandTypes.COMPARABLE_ORDERED_COMPARABLE_ORDERED);
180+
181+
public static final SqlOperator CASE = new CaseOperator(GraphInferTypes.RETURN_TYPE);
178182
}

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java

+31
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import com.alibaba.graphscope.cypher.antlr4.visitor.type.ExprVisitorResult;
2626
import com.alibaba.graphscope.grammar.CypherGSBaseVisitor;
2727
import com.alibaba.graphscope.grammar.CypherGSParser;
28+
import com.google.common.base.Preconditions;
2829
import com.google.common.collect.ImmutableList;
30+
import com.google.common.collect.Lists;
2931

3032
import org.apache.calcite.rex.RexNode;
3133
import org.apache.calcite.sql.SqlOperator;
@@ -255,6 +257,35 @@ public ExprVisitorResult visitOC_CountAny(CypherGSParser.OC_CountAnyContext ctx)
255257
RexTmpVariable.of(alias, ((GraphAggCall) aggCall).getType()));
256258
}
257259

260+
@Override
261+
public ExprVisitorResult visitOC_CaseExpression(CypherGSParser.OC_CaseExpressionContext ctx) {
262+
ExprVisitorResult inputExpr =
263+
ctx.oC_InputExpression() == null
264+
? null
265+
: visitOC_InputExpression(ctx.oC_InputExpression());
266+
List<RexNode> operands = Lists.newArrayList();
267+
for (CypherGSParser.OC_CaseAlternativeContext whenThen : ctx.oC_CaseAlternative()) {
268+
Preconditions.checkArgument(
269+
whenThen.oC_Expression().size() == 2,
270+
"whenThen expression should have 2 parts");
271+
ExprVisitorResult whenExpr = visitOC_Expression(whenThen.oC_Expression(0));
272+
if (inputExpr != null) {
273+
operands.add(builder.equals(inputExpr.getExpr(), whenExpr.getExpr()));
274+
} else {
275+
operands.add(whenExpr.getExpr());
276+
}
277+
ExprVisitorResult thenExpr = visitOC_Expression(whenThen.oC_Expression(1));
278+
operands.add(thenExpr.getExpr());
279+
}
280+
// if else expression is omitted, the default value is null
281+
ExprVisitorResult elseExpr =
282+
ctx.oC_ElseExpression() == null
283+
? new ExprVisitorResult(builder.literal(null))
284+
: visitOC_ElseExpression(ctx.oC_ElseExpression());
285+
operands.add(elseExpr.getExpr());
286+
return new ExprVisitorResult(builder.call(GraphStdOperatorTable.CASE, operands));
287+
}
288+
258289
private ExprVisitorResult binaryCall(
259290
List<SqlOperator> operators, List<ExprVisitorResult> operands) {
260291
ObjectUtils.requireNonEmpty(operands, "operands count should not be 0");

interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphInferTypes.java

+18
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,22 @@ private GraphInferTypes() {}
4444
}
4545
Arrays.fill(operandTypes, knownType);
4646
};
47+
48+
/**
49+
* Operand type-inference strategy where an unknown operand type is derived
50+
* from the call's return type. If the return type is a record, it must have
51+
* the same number of fields as the number of operands.
52+
*/
53+
public static final SqlOperandTypeInference RETURN_TYPE =
54+
(callBinding, returnType, operandTypes) -> {
55+
RelDataType unknownType = callBinding.getTypeFactory().createUnknownType();
56+
for (int i = 0; i < operandTypes.length; ++i) {
57+
if (operandTypes[i].equals(unknownType)) {
58+
operandTypes[i] =
59+
returnType.isStruct()
60+
? returnType.getFieldList().get(i).getType()
61+
: returnType;
62+
}
63+
}
64+
};
4765
}

0 commit comments

Comments
 (0)