Skip to content

Commit 860bfac

Browse files
committed
[MLIR][LLVM] Always print variadic callee type
This commit updates the LLVM dialect CallOp and InvokeOp to always print the variadic callee type (previously callee type) if present. An additional verifier checks that only variadic calls have a non-null variadic callee type, and the builders are adapted accordingly to only set the variadic callee type for variadic calls. The motivation of this change is that CallOp and InvokeOp don't have hidden state that is not pretty printed, but used during the export to LLVM IR. Previously, it could happen that a call looked correct in MLIR, but the return type changed after exporting to LLVM IR (since it has been taken from the hidden callee type attribute). After landing this change, this is not possible anymore since the variadic callee type is always printed if present.
1 parent b330d80 commit 860bfac

File tree

3 files changed

+94
-55
lines changed

3 files changed

+94
-55
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

+8-7
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
560560
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
561561
Terminator]> {
562562
let arguments = (ins
563-
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
563+
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
564564
OptionalAttr<FlatSymbolRefAttr>:$callee,
565565
Variadic<LLVM_Type>:$callee_operands,
566566
Variadic<LLVM_Type>:$normalDestOperands,
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
617617
start with a function name (`@`-prefixed) and indirect calls start with an
618618
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
619619
function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
620-
and is stored as the first value in `callee_operands`. If the callee is a variadic
621-
function, then the `callee_type` attribute must carry the function type. The
622-
trailing type list contains the optional indirect callee type and the MLIR
623-
function type, which differs from the LLVM function type that uses a explicit
624-
void type to model functions that do not return a value.
620+
and is stored as the first value in `callee_operands`. If and only if the
621+
callee is a variadic function, then the `var_callee_type` attribute must
622+
carry the variadic LLVM function type. The trailing type list contains the
623+
optional indirect callee type and the MLIR function type, which differs from
624+
the LLVM function type that uses an explicit void type to model functions
625+
that do not return a value.
625626

626627
Examples:
627628

@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
644645
```
645646
}];
646647

647-
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
648+
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
648649
OptionalAttr<FlatSymbolRefAttr>:$callee,
649650
Variadic<LLVM_Type>:$callee_operands,
650651
DefaultValuedAttr<LLVM_FastmathFlagsAttr,

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+59-44
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948948
return results;
949949
}
950950

951+
/// Gets the variadic callee type for a LLVMFunctionType.
952+
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
953+
return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
954+
}
955+
951956
/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
952957
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
953958
ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974979
FlatSymbolRefAttr callee, ValueRange args) {
975980
assert(callee && "expected non-null callee in direct call builder");
976981
build(builder, state, results,
977-
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
978-
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
982+
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
983+
/*branch_weights=*/nullptr,
979984
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
980985
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
981986
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
9971002
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
9981003
ValueRange args) {
9991004
build(builder, state, getCallOpResultTypes(calleeType),
1000-
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
1005+
getCallOpVarCalleeType(calleeType), callee, args,
1006+
/*fastmathFlags=*/nullptr,
10011007
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
10021008
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
10031009
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10061012
void CallOp::build(OpBuilder &builder, OperationState &state,
10071013
LLVMFunctionType calleeType, ValueRange args) {
10081014
build(builder, state, getCallOpResultTypes(calleeType),
1009-
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
1015+
getCallOpVarCalleeType(calleeType),
1016+
/*callee=*/nullptr, args,
10101017
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10111018
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10121019
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10171024
ValueRange args) {
10181025
auto calleeType = func.getFunctionType();
10191026
build(builder, state, getCallOpResultTypes(calleeType),
1020-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
1027+
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
10211028
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10221029
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10231030
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1087,12 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10801087
if (getNumResults() > 1)
10811088
return emitOpError("must have 0 or 1 result");
10821089

1090+
// Verify the variadic callee type is a variadic function type.
1091+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1092+
if (!varCalleeType->isVarArg())
1093+
return emitOpError(
1094+
"expected var_callee_type to be a variadic function type");
1095+
10831096
// Type for the callee, we'll get it differently depending if it is a direct
10841097
// or indirect call.
10851098
Type fnType;
@@ -1120,7 +1133,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
11201133
if (!funcType)
11211134
return emitOpError("callee does not have a functional type: ") << fnType;
11221135

1123-
if (funcType.isVarArg() && !getCalleeType())
1136+
if (funcType.isVarArg() && !getVarCalleeType())
11241137
return emitOpError() << "missing callee type attribute for vararg call";
11251138

11261139
// Verify that the operand and result types match the callee.
@@ -1168,14 +1181,6 @@ void CallOp::print(OpAsmPrinter &p) {
11681181
auto callee = getCallee();
11691182
bool isDirect = callee.has_value();
11701183

1171-
LLVMFunctionType calleeType;
1172-
bool isVarArg = false;
1173-
1174-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1175-
calleeType = *optionalCalleeType;
1176-
isVarArg = calleeType.isVarArg();
1177-
}
1178-
11791184
p << ' ';
11801185

11811186
// Print calling convention.
@@ -1195,11 +1200,13 @@ void CallOp::print(OpAsmPrinter &p) {
11951200
auto args = getOperands().drop_front(isDirect ? 0 : 1);
11961201
p << '(' << args << ')';
11971202

1198-
if (isVarArg)
1199-
p << " vararg(" << calleeType << ")";
1203+
// Print the variadic callee type if the call is variadic.
1204+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1205+
p << " vararg(" << *varCalleeType << ")";
12001206

12011207
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1202-
{getCConvAttrName(), "callee", "callee_type",
1208+
{getCConvAttrName(), "callee",
1209+
getVarCalleeTypeAttrName(),
12031210
getTailCallKindAttrName()});
12041211

12051212
p << " : ";
@@ -1270,11 +1277,11 @@ static ParseResult parseOptionalCallFuncPtr(
12701277

12711278
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
12721279
// `(` ssa-use-list `)`
1273-
// ( `vararg(` var-arg-func-type `)` )?
1280+
// ( `vararg(` var-callee-type `)` )?
12741281
// attribute-dict? `:` (type `,`)? function-type
12751282
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
12761283
SymbolRefAttr funcAttr;
1277-
TypeAttr calleeType;
1284+
TypeAttr varCalleeType;
12781285
SmallVector<OpAsmParser::UnresolvedOperand> operands;
12791286

12801287
// Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1312,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13051312

13061313
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
13071314
if (isVarArg) {
1315+
StringAttr varCalleeTypeAttrName =
1316+
CallOp::getVarCalleeTypeAttrName(result.name);
13081317
if (parser.parseLParen().failed() ||
1309-
parser.parseAttribute(calleeType, "callee_type", result.attributes)
1318+
parser
1319+
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
1320+
result.attributes)
13101321
.failed() ||
13111322
parser.parseRParen().failed())
13121323
return failure();
@@ -1320,8 +1331,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13201331
}
13211332

13221333
LLVMFunctionType CallOp::getCalleeFunctionType() {
1323-
if (getCalleeType())
1324-
return *getCalleeType();
1334+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1335+
return *varCalleeType;
13251336
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
13261337
}
13271338

@@ -1334,26 +1345,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
13341345
Block *unwind, ValueRange unwindOps) {
13351346
auto calleeType = func.getFunctionType();
13361347
build(builder, state, getCallOpResultTypes(calleeType),
1337-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1338-
unwindOps, nullptr, nullptr, normal, unwind);
1348+
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1349+
normalOps, unwindOps, nullptr, nullptr, normal, unwind);
13391350
}
13401351

13411352
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
13421353
FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
13431354
ValueRange normalOps, Block *unwind,
13441355
ValueRange unwindOps) {
13451356
build(builder, state, tys,
1346-
TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1347-
ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
1357+
/*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
1358+
nullptr, normal, unwind);
13481359
}
13491360

13501361
void InvokeOp::build(OpBuilder &builder, OperationState &state,
13511362
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
13521363
ValueRange ops, Block *normal, ValueRange normalOps,
13531364
Block *unwind, ValueRange unwindOps) {
13541365
build(builder, state, getCallOpResultTypes(calleeType),
1355-
TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1356-
nullptr, normal, unwind);
1366+
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
1367+
nullptr, nullptr, normal, unwind);
13571368
}
13581369

13591370
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1404,12 @@ LogicalResult InvokeOp::verify() {
13931404
if (getNumResults() > 1)
13941405
return emitOpError("must have 0 or 1 result");
13951406

1407+
// Verify the variadic callee type is a variadic function type.
1408+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1409+
if (!varCalleeType->isVarArg())
1410+
return emitOpError(
1411+
"expected var_callee_type to be a variadic function type");
1412+
13961413
Block *unwindDest = getUnwindDest();
13971414
if (unwindDest->empty())
13981415
return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1426,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
14091426
auto callee = getCallee();
14101427
bool isDirect = callee.has_value();
14111428

1412-
LLVMFunctionType calleeType;
1413-
bool isVarArg = false;
1414-
1415-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1416-
calleeType = *optionalCalleeType;
1417-
isVarArg = calleeType.isVarArg();
1418-
}
1419-
14201429
p << ' ';
14211430

14221431
// Print calling convention.
@@ -1435,12 +1444,14 @@ void InvokeOp::print(OpAsmPrinter &p) {
14351444
p << " unwind ";
14361445
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
14371446

1438-
if (isVarArg)
1439-
p << " vararg(" << calleeType << ")";
1447+
// Print the variadic callee type if the invoke is variadic.
1448+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1449+
p << " vararg(" << *varCalleeType << ")";
14401450

14411451
p.printOptionalAttrDict((*this)->getAttrs(),
14421452
{InvokeOp::getOperandSegmentSizeAttr(), "callee",
1443-
"callee_type", InvokeOp::getCConvAttrName()});
1453+
InvokeOp::getVarCalleeTypeAttrName(),
1454+
InvokeOp::getCConvAttrName()});
14441455

14451456
p << " : ";
14461457
if (!isDirect)
@@ -1453,12 +1464,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
14531464
// `(` ssa-use-list `)`
14541465
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
14551466
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1456-
// ( `vararg(` var-arg-func-type `)` )?
1467+
// ( `vararg(` var-callee-type `)` )?
14571468
// attribute-dict? `:` (type `,`)? function-type
14581469
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
14591470
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
14601471
SymbolRefAttr funcAttr;
1461-
TypeAttr calleeType;
1472+
TypeAttr varCalleeType;
14621473
Block *normalDest, *unwindDest;
14631474
SmallVector<Value, 4> normalOperands, unwindOperands;
14641475
Builder &builder = parser.getBuilder();
@@ -1488,8 +1499,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
14881499

14891500
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
14901501
if (isVarArg) {
1502+
StringAttr varCalleeTypeAttrName =
1503+
InvokeOp::getVarCalleeTypeAttrName(result.name);
14911504
if (parser.parseLParen().failed() ||
1492-
parser.parseAttribute(calleeType, "callee_type", result.attributes)
1505+
parser
1506+
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
1507+
result.attributes)
14931508
.failed() ||
14941509
parser.parseRParen().failed())
14951510
return failure();
@@ -1515,8 +1530,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
15151530
}
15161531

15171532
LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1518-
if (getCalleeType())
1519-
return *getCalleeType();
1533+
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1534+
return *varCalleeType;
15201535
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
15211536
}
15221537

mlir/test/Dialect/LLVMIR/invalid.mlir

+27-4
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>) {
14151415

14161416
// -----
14171417

1418+
llvm.func @non_variadic(%arg: i32)
1419+
1420+
llvm.func @invalid_callee_type(%arg: i32) {
1421+
// expected-error@below {{expected var_callee_type to be a variadic function type}}
1422+
llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
1423+
llvm.return
1424+
}
1425+
1426+
// -----
1427+
1428+
llvm.func @non_variadic(%arg: i32)
1429+
1430+
llvm.func @invalid_callee_type(%arg: i32) {
1431+
// expected-error@below {{expected var_callee_type to be a variadic function type}}
1432+
llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
1433+
^bb1:
1434+
llvm.return
1435+
^bb2:
1436+
llvm.return
1437+
}
1438+
1439+
// -----
1440+
14181441
llvm.func @variadic(...)
14191442

14201443
llvm.func @invalid_variadic_call(%arg: i32) {
@@ -1445,22 +1468,22 @@ llvm.func @foo(%arg: !llvm.ptr) {
14451468

14461469
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14471470
// expected-error@+1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
1448-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
1471+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
14491472
return
14501473
}
14511474
// -----
14521475

14531476
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14541477
// expected-error@+1 {{im2col offsets must be 2 less than number of coordinates}}
1455-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
1478+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
14561479
return
14571480
}
14581481

14591482
// -----
14601483

14611484
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14621485
// expected-error@+1 {{expects coordinates between 1 to 5 dimension}}
1463-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
1486+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
14641487
return
14651488
}
14661489

@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
14691492

14701493
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14711494
// expected-error@+1 {{expects coordinates between 1 to 5 dimension}}
1472-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
1495+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
14731496
return
14741497
}
14751498

0 commit comments

Comments
 (0)