Skip to content

Commit 7e25f4d

Browse files
committed
[MLIR][LLVM] Always print variadic callee type
This commit updates the LLVM dialect CallOp and InvokeOp to always print the calleeType if present. An additional verifier checks that only variadic calls have a non-null calleeType, and the builders are adapted accordingly to only set the calleeType for variadic calls. The motivation for this change is to avoid that CallOp and InvokeOp have hidden state that is not pretty printed but that is used for example during the export to LLVM IR. This triggered downstream bugs where a call looked correct in MLIR, but had a completely different result type after exporting to LLVM IR. This change ensures the calleeType is only present when necessary, reducing the amount of redundant state, and always printed if present, avoiding any kind of hidden state.
1 parent b330d80 commit 7e25f4d

File tree

3 files changed

+74
-45
lines changed

3 files changed

+74
-45
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

+6-5
Original file line numberDiff line numberDiff line change
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
617617
start with a function name (`@`-prefixed) and indirect calls start with an
618618
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
619619
function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
620-
and is stored as the first value in `callee_operands`. If the callee is a variadic
621-
function, then the `callee_type` attribute must carry the function type. The
622-
trailing type list contains the optional indirect callee type and the MLIR
623-
function type, which differs from the LLVM function type that uses a explicit
624-
void type to model functions that do not return a value.
620+
and is stored as the first value in `callee_operands`. If and only if the
621+
callee is a variadic function, then the `callee_type` attribute must carry
622+
the variadic LLVM function type. The trailing type list contains the
623+
optional indirect callee type and the MLIR function type, which differs from
624+
the LLVM function type that uses an explicit void type to model functions
625+
that do not return a value.
625626

626627
Examples:
627628

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+41-36
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974974
FlatSymbolRefAttr callee, ValueRange args) {
975975
assert(callee && "expected non-null callee in direct call builder");
976976
build(builder, state, results,
977-
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
978-
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
977+
/*callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
978+
/*branch_weights=*/nullptr,
979979
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
980980
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
981981
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -996,17 +996,21 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
996996
void CallOp::build(OpBuilder &builder, OperationState &state,
997997
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
998998
ValueRange args) {
999-
build(builder, state, getCallOpResultTypes(calleeType),
1000-
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
999+
auto varArgCalleeType =
1000+
calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1001+
build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
1002+
callee, args, /*fastmathFlags=*/nullptr,
10011003
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
10021004
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
10031005
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10041006
}
10051007

10061008
void CallOp::build(OpBuilder &builder, OperationState &state,
10071009
LLVMFunctionType calleeType, ValueRange args) {
1008-
build(builder, state, getCallOpResultTypes(calleeType),
1009-
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
1010+
auto varArgCalleeType =
1011+
calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1012+
build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
1013+
/*callee=*/nullptr, args,
10101014
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10111015
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10121016
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1016,8 +1020,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10161020
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10171021
ValueRange args) {
10181022
auto calleeType = func.getFunctionType();
1019-
build(builder, state, getCallOpResultTypes(calleeType),
1020-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
1023+
auto varArgCalleeType =
1024+
calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1025+
build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
1026+
SymbolRefAttr::get(func), args,
10211027
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
10221028
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10231029
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1086,11 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10801086
if (getNumResults() > 1)
10811087
return emitOpError("must have 0 or 1 result");
10821088

1089+
// If the callee type attribute is present, it must be variadic.
1090+
if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
1091+
if (!calleeType->isVarArg())
1092+
return emitOpError("expected variadic callee type attribute");
1093+
10831094
// Type for the callee, we'll get it differently depending if it is a direct
10841095
// or indirect call.
10851096
Type fnType;
@@ -1168,14 +1179,6 @@ void CallOp::print(OpAsmPrinter &p) {
11681179
auto callee = getCallee();
11691180
bool isDirect = callee.has_value();
11701181

1171-
LLVMFunctionType calleeType;
1172-
bool isVarArg = false;
1173-
1174-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1175-
calleeType = *optionalCalleeType;
1176-
isVarArg = calleeType.isVarArg();
1177-
}
1178-
11791182
p << ' ';
11801183

11811184
// Print calling convention.
@@ -1195,8 +1198,9 @@ void CallOp::print(OpAsmPrinter &p) {
11951198
auto args = getOperands().drop_front(isDirect ? 0 : 1);
11961199
p << '(' << args << ')';
11971200

1198-
if (isVarArg)
1199-
p << " vararg(" << calleeType << ")";
1201+
// Print the callee type if the call is variadic.
1202+
if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
1203+
p << " vararg(" << *calleeType << ")";
12001204

12011205
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
12021206
{getCConvAttrName(), "callee", "callee_type",
@@ -1333,27 +1337,30 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
13331337
ValueRange ops, Block *normal, ValueRange normalOps,
13341338
Block *unwind, ValueRange unwindOps) {
13351339
auto calleeType = func.getFunctionType();
1336-
build(builder, state, getCallOpResultTypes(calleeType),
1337-
TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1338-
unwindOps, nullptr, nullptr, normal, unwind);
1340+
auto varArgCalleeType =
1341+
calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1342+
build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
1343+
SymbolRefAttr::get(func), ops, normalOps, unwindOps, nullptr, nullptr,
1344+
normal, unwind);
13391345
}
13401346

13411347
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
13421348
FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
13431349
ValueRange normalOps, Block *unwind,
13441350
ValueRange unwindOps) {
13451351
build(builder, state, tys,
1346-
TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1347-
ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
1352+
/*callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
1353+
nullptr, normal, unwind);
13481354
}
13491355

13501356
void InvokeOp::build(OpBuilder &builder, OperationState &state,
13511357
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
13521358
ValueRange ops, Block *normal, ValueRange normalOps,
13531359
Block *unwind, ValueRange unwindOps) {
1354-
build(builder, state, getCallOpResultTypes(calleeType),
1355-
TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1356-
nullptr, normal, unwind);
1360+
auto varArgCalleeType =
1361+
calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
1362+
build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
1363+
callee, ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
13571364
}
13581365

13591366
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1400,11 @@ LogicalResult InvokeOp::verify() {
13931400
if (getNumResults() > 1)
13941401
return emitOpError("must have 0 or 1 result");
13951402

1403+
// If the callee type attribute is present, it must be variadic.
1404+
if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
1405+
if (!calleeType->isVarArg())
1406+
return emitOpError("expected variadic callee type attribute");
1407+
13961408
Block *unwindDest = getUnwindDest();
13971409
if (unwindDest->empty())
13981410
return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1421,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
14091421
auto callee = getCallee();
14101422
bool isDirect = callee.has_value();
14111423

1412-
LLVMFunctionType calleeType;
1413-
bool isVarArg = false;
1414-
1415-
if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1416-
calleeType = *optionalCalleeType;
1417-
isVarArg = calleeType.isVarArg();
1418-
}
1419-
14201424
p << ' ';
14211425

14221426
// Print calling convention.
@@ -1435,8 +1439,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
14351439
p << " unwind ";
14361440
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
14371441

1438-
if (isVarArg)
1439-
p << " vararg(" << calleeType << ")";
1442+
// Print the callee type if the invoke is variadic.
1443+
if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
1444+
p << " vararg(" << *calleeType << ")";
14401445

14411446
p.printOptionalAttrDict((*this)->getAttrs(),
14421447
{InvokeOp::getOperandSegmentSizeAttr(), "callee",

mlir/test/Dialect/LLVMIR/invalid.mlir

+27-4
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>) {
14151415

14161416
// -----
14171417

1418+
llvm.func @non_variadic(%arg: i32)
1419+
1420+
llvm.func @invalid_callee_type(%arg: i32) {
1421+
// expected-error@below {{expected variadic callee type attribute}}
1422+
llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
1423+
llvm.return
1424+
}
1425+
1426+
// -----
1427+
1428+
llvm.func @non_variadic(%arg: i32)
1429+
1430+
llvm.func @invalid_callee_type(%arg: i32) {
1431+
// expected-error@below {{expected variadic callee type attribute}}
1432+
llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
1433+
^bb1:
1434+
llvm.return
1435+
^bb2:
1436+
llvm.return
1437+
}
1438+
1439+
// -----
1440+
14181441
llvm.func @variadic(...)
14191442

14201443
llvm.func @invalid_variadic_call(%arg: i32) {
@@ -1445,22 +1468,22 @@ llvm.func @foo(%arg: !llvm.ptr) {
14451468

14461469
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14471470
// expected-error@+1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
1448-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
1471+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
14491472
return
14501473
}
14511474
// -----
14521475

14531476
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14541477
// expected-error@+1 {{im2col offsets must be 2 less than number of coordinates}}
1455-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
1478+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
14561479
return
14571480
}
14581481

14591482
// -----
14601483

14611484
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14621485
// expected-error@+1 {{expects coordinates between 1 to 5 dimension}}
1463-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
1486+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
14641487
return
14651488
}
14661489

@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
14691492

14701493
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
14711494
// expected-error@+1 {{expects coordinates between 1 to 5 dimension}}
1472-
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
1495+
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
14731496
return
14741497
}
14751498

0 commit comments

Comments
 (0)