Skip to content

Commit df21b19

Browse files
gysityuxuanchen1997
authored andcommitted
[MLIR][LLVM] Always print variadic callee type (#99293)
Summary: This commit updates the LLVM dialect CallOp and InvokeOp to always print the variadic callee type (previously callee type) if present. An additional verifier checks that only variadic calls have a non-null variadic callee type, and the builders are adapted accordingly to set the variadic callee type for variadic calls only. Finally, the CallOp and InvokeOp verifiers are strengthened to check that the variadic callee type matches the call argument and result types. The motivation of this change is that CallOp and InvokeOp don't have hidden state that is not pretty printed, but used during the export to LLVM IR. Previously, it could happen that a call looked correct in MLIR, but the return type changed after exporting to LLVM IR (since it has been taken from the hidden callee type attribute). After landing this change, this is not possible anymore since the variadic callee type is always printed if present. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251037
1 parent 73e9e3a commit df21b19

File tree

3 files changed

+260
-65
lines changed

3 files changed

+260
-65
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

+9-8
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
560560
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
561561
Terminator]> {
562562
let arguments = (ins
563-
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
563+
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
564564
OptionalAttr<FlatSymbolRefAttr>:$callee,
565565
Variadic<LLVM_Type>:$callee_operands,
566566
Variadic<LLVM_Type>:$normalDestOperands,
567567
Variadic<LLVM_Type>:$unwindDestOperands,
568568
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
569569
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
570-
let results = (outs Variadic<LLVM_Type>);
570+
let results = (outs Optional<LLVM_Type>:$result);
571571
let successors = (successor AnySuccessor:$normalDest,
572572
AnySuccessor:$unwindDest);
573573

@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
617617
start with a function name (`@`-prefixed) and indirect calls start with an
618618
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
619619
function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
620-
and is stored as the first value in `callee_operands`. If the callee is a variadic
621-
function, then the `callee_type` attribute must carry the function type. The
622-
trailing type list contains the optional indirect callee type and the MLIR
623-
function type, which differs from the LLVM function type that uses a explicit
624-
void type to model functions that do not return a value.
620+
and is stored as the first value in `callee_operands`. If and only if the
621+
callee is a variadic function, the `var_callee_type` attribute must carry
622+
the variadic LLVM function type. The trailing type list contains the
623+
optional indirect callee type and the MLIR function type, which differs from
624+
the LLVM function type that uses an explicit void type to model functions
625+
that do not return a value.
625626

626627
Examples:
627628

@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
644645
```
645646
}];
646647

647-
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
648+
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
648649
OptionalAttr<FlatSymbolRefAttr>:$callee,
649650
Variadic<LLVM_Type>:$callee_operands,
650651
DefaultValuedAttr<LLVM_FastmathFlagsAttr,

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+92-51
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948948
return results;
949949
}
950950

951+
/// Gets the variadic callee type for a LLVMFunctionType.
952+
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
953+
return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
954+
}
955+
951956
/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
952957
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
953958
ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974979
FlatSymbolRefAttr callee, ValueRange args) {
975980
assert(callee && "expected non-null callee in direct call builder");
976981
build(builder, state, results,
977-
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
978-
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
982+
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
983+
/*branch_weights=*/nullptr,
979984
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
980985
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
981986
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
9971002
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
9981003
ValueRange args) {
9991004
build(builder, state, getCallOpResultTypes(calleeType),
1000-
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
1005+
getCallOpVarCalleeType(calleeType), callee, args,
1006+
/*fastmathFlags=*/nullptr,
10011007
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
10021008
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
10031009
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10061012
void CallOp::build(OpBuilder &builder, OperationState &state,
10071013
LLVMFunctionType calleeType, ValueRange args) {
10081014
build(builder, state, getCallOpResultTypes(calleeType),
1009-
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
1015+
getCallOpVarCalleeType(calleeType),
1016+
/*callee=*/nullptr, args,
10101017
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10111018
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10121019
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10171024
ValueRange args) {
10181025
auto calleeType = func.getFunctionType();
10191026
build(builder, state, getCallOpResultTypes(calleeType),
1020-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
1027+
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
10211028
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10221029
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10231030
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
10761083
return success();
10771084
}
10781085

1086+
/// Verify that the parameter and return types of the variadic callee type match
1087+
/// the `callOp` argument and result types.
1088+
template <typename OpTy>
1089+
LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1090+
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1091+
if (!varCalleeType)
1092+
return success();
1093+
1094+
// Verify the variadic callee type is a variadic function type.
1095+
if (!varCalleeType->isVarArg())
1096+
return callOp.emitOpError(
1097+
"expected var_callee_type to be a variadic function type");
1098+
1099+
// Verify the variadic callee type has at most as many parameters as the call
1100+
// has argument operands.
1101+
if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1102+
return callOp.emitOpError("expected var_callee_type to have at most ")
1103+
<< callOp.getArgOperands().size() << " parameters";
1104+
1105+
// Verify the variadic callee type matches the call argument types.
1106+
for (auto [paramType, operand] :
1107+
llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1108+
if (paramType != operand.getType())
1109+
return callOp.emitOpError()
1110+
<< "var_callee_type parameter type mismatch: " << paramType
1111+
<< " != " << operand.getType();
1112+
1113+
// Verify the variadic callee type matches the call result type.
1114+
if (!callOp.getNumResults()) {
1115+
if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1116+
return callOp.emitOpError("expected var_callee_type to return void");
1117+
} else {
1118+
if (callOp.getResult().getType() != varCalleeType->getReturnType())
1119+
return callOp.emitOpError("var_callee_type return type mismatch: ")
1120+
<< varCalleeType->getReturnType()
1121+
<< " != " << callOp.getResult().getType();
1122+
}
1123+
return success();
1124+
}
1125+
10791126
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1080-
if (getNumResults() > 1)
1081-
return emitOpError("must have 0 or 1 result");
1127+
if (failed(verifyCallOpVarCalleeType(*this)))
1128+
return failure();
10821129

10831130
// Type for the callee, we'll get it differently depending if it is a direct
10841131
// or indirect call.
@@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
11201167
if (!funcType)
11211168
return emitOpError("callee does not have a functional type: ") << fnType;
11221169

1123-
if (funcType.isVarArg() && !getCalleeType())
1124-
return emitOpError() << "missing callee type attribute for vararg call";
1170+
if (funcType.isVarArg() && !getVarCalleeType())
1171+
return emitOpError() << "missing var_callee_type attribute for vararg call";
11251172

11261173
// Verify that the operand and result types match the callee.
11271174

@@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
11681215
auto callee = getCallee();
11691216
bool isDirect = callee.has_value();
11701217

1171-
LLVMFunctionType calleeType;
1172-
bool isVarArg = false;
1173-
1174-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1175-
calleeType = *optionalCalleeType;
1176-
isVarArg = calleeType.isVarArg();
1177-
}
1178-
11791218
p << ' ';
11801219

11811220
// Print calling convention.
@@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
11951234
auto args = getOperands().drop_front(isDirect ? 0 : 1);
11961235
p << '(' << args << ')';
11971236

1198-
if (isVarArg)
1199-
p << " vararg(" << calleeType << ")";
1237+
// Print the variadic callee type if the call is variadic.
1238+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1239+
p << " vararg(" << *varCalleeType << ")";
12001240

12011241
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1202-
{getCConvAttrName(), "callee", "callee_type",
1203-
getTailCallKindAttrName()});
1242+
{getCalleeAttrName(), getTailCallKindAttrName(),
1243+
getVarCalleeTypeAttrName(), getCConvAttrName()});
12041244

12051245
p << " : ";
12061246
if (!isDirect)
@@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(
12701310

12711311
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
12721312
// `(` ssa-use-list `)`
1273-
// ( `vararg(` var-arg-func-type `)` )?
1313+
// ( `vararg(` var-callee-type `)` )?
12741314
// attribute-dict? `:` (type `,`)? function-type
12751315
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
12761316
SymbolRefAttr funcAttr;
1277-
TypeAttr calleeType;
1317+
TypeAttr varCalleeType;
12781318
SmallVector<OpAsmParser::UnresolvedOperand> operands;
12791319

12801320
// Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13051345

13061346
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
13071347
if (isVarArg) {
1348+
StringAttr varCalleeTypeAttrName =
1349+
CallOp::getVarCalleeTypeAttrName(result.name);
13081350
if (parser.parseLParen().failed() ||
1309-
parser.parseAttribute(calleeType, "callee_type", result.attributes)
1351+
parser
1352+
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
1353+
result.attributes)
13101354
.failed() ||
13111355
parser.parseRParen().failed())
13121356
return failure();
@@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13201364
}
13211365

13221366
LLVMFunctionType CallOp::getCalleeFunctionType() {
1323-
if (getCalleeType())
1324-
return *getCalleeType();
1367+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1368+
return *varCalleeType;
13251369
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
13261370
}
13271371

@@ -1334,26 +1378,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
13341378
Block *unwind, ValueRange unwindOps) {
13351379
auto calleeType = func.getFunctionType();
13361380
build(builder, state, getCallOpResultTypes(calleeType),
1337-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1338-
unwindOps, nullptr, nullptr, normal, unwind);
1381+
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1382+
normalOps, unwindOps, nullptr, nullptr, normal, unwind);
13391383
}
13401384

13411385
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
13421386
FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
13431387
ValueRange normalOps, Block *unwind,
13441388
ValueRange unwindOps) {
13451389
build(builder, state, tys,
1346-
TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1347-
ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
1390+
/*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
1391+
nullptr, normal, unwind);
13481392
}
13491393

13501394
void InvokeOp::build(OpBuilder &builder, OperationState &state,
13511395
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
13521396
ValueRange ops, Block *normal, ValueRange normalOps,
13531397
Block *unwind, ValueRange unwindOps) {
13541398
build(builder, state, getCallOpResultTypes(calleeType),
1355-
TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1356-
nullptr, normal, unwind);
1399+
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
1400+
nullptr, nullptr, normal, unwind);
13571401
}
13581402

13591403
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
13901434
}
13911435

13921436
LogicalResult InvokeOp::verify() {
1393-
if (getNumResults() > 1)
1394-
return emitOpError("must have 0 or 1 result");
1437+
if (failed(verifyCallOpVarCalleeType(*this)))
1438+
return failure();
13951439

13961440
Block *unwindDest = getUnwindDest();
13971441
if (unwindDest->empty())
@@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
14091453
auto callee = getCallee();
14101454
bool isDirect = callee.has_value();
14111455

1412-
LLVMFunctionType calleeType;
1413-
bool isVarArg = false;
1414-
1415-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1416-
calleeType = *optionalCalleeType;
1417-
isVarArg = calleeType.isVarArg();
1418-
}
1419-
14201456
p << ' ';
14211457

14221458
// Print calling convention.
@@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
14351471
p << " unwind ";
14361472
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
14371473

1438-
if (isVarArg)
1439-
p << " vararg(" << calleeType << ")";
1474+
// Print the variadic callee type if the invoke is variadic.
1475+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1476+
p << " vararg(" << *varCalleeType << ")";
14401477

14411478
p.printOptionalAttrDict((*this)->getAttrs(),
1442-
{InvokeOp::getOperandSegmentSizeAttr(), "callee",
1443-
"callee_type", InvokeOp::getCConvAttrName()});
1479+
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
1480+
getCConvAttrName(), getVarCalleeTypeAttrName()});
14441481

14451482
p << " : ";
14461483
if (!isDirect)
@@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
14531490
// `(` ssa-use-list `)`
14541491
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
14551492
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1456-
// ( `vararg(` var-arg-func-type `)` )?
1493+
// ( `vararg(` var-callee-type `)` )?
14571494
// attribute-dict? `:` (type `,`)? function-type
14581495
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
14591496
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
14601497
SymbolRefAttr funcAttr;
1461-
TypeAttr calleeType;
1498+
TypeAttr varCalleeType;
14621499
Block *normalDest, *unwindDest;
14631500
SmallVector<Value, 4> normalOperands, unwindOperands;
14641501
Builder &builder = parser.getBuilder();
@@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
14881525

14891526
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
14901527
if (isVarArg) {
1528+
StringAttr varCalleeTypeAttrName =
1529+
InvokeOp::getVarCalleeTypeAttrName(result.name);
14911530
if (parser.parseLParen().failed() ||
1492-
parser.parseAttribute(calleeType, "callee_type", result.attributes)
1531+
parser
1532+
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
1533+
result.attributes)
14931534
.failed() ||
14941535
parser.parseRParen().failed())
14951536
return failure();
@@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
15151556
}
15161557

15171558
LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1518-
if (getCalleeType())
1519-
return *getCalleeType();
1559+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1560+
return *varCalleeType;
15201561
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
15211562
}
15221563

0 commit comments

Comments
 (0)