@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948
948
return results;
949
949
}
950
950
951
+ // / Gets the variadic callee type for a LLVMFunctionType.
952
+ static TypeAttr getCallOpVarCalleeType (LLVMFunctionType calleeType) {
953
+ return calleeType.isVarArg () ? TypeAttr::get (calleeType) : nullptr ;
954
+ }
955
+
951
956
// / Constructs a LLVMFunctionType from MLIR `results` and `args`.
952
957
static LLVMFunctionType getLLVMFuncType (MLIRContext *context, TypeRange results,
953
958
ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974
979
FlatSymbolRefAttr callee, ValueRange args) {
975
980
assert (callee && " expected non-null callee in direct call builder" );
976
981
build (builder, state, results,
977
- TypeAttr::get ( getLLVMFuncType (builder. getContext (), results , args)) ,
978
- callee, args, /* fastmathFlags= */ nullptr , /* branch_weights=*/ nullptr ,
982
+ /* var_callee_type= */ nullptr , callee , args, /* fastmathFlags= */ nullptr ,
983
+ /* branch_weights=*/ nullptr ,
979
984
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
980
985
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
981
986
/* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
997
1002
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
998
1003
ValueRange args) {
999
1004
build (builder, state, getCallOpResultTypes (calleeType),
1000
- TypeAttr::get (calleeType), callee, args, /* fastmathFlags=*/ nullptr ,
1005
+ getCallOpVarCalleeType (calleeType), callee, args,
1006
+ /* fastmathFlags=*/ nullptr ,
1001
1007
/* branch_weights=*/ nullptr , /* CConv=*/ nullptr ,
1002
1008
/* TailCallKind=*/ nullptr , /* access_groups=*/ nullptr ,
1003
1009
/* alias_scopes=*/ nullptr , /* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
1006
1012
void CallOp::build (OpBuilder &builder, OperationState &state,
1007
1013
LLVMFunctionType calleeType, ValueRange args) {
1008
1014
build (builder, state, getCallOpResultTypes (calleeType),
1009
- TypeAttr::get (calleeType), /* callee=*/ nullptr , args,
1015
+ getCallOpVarCalleeType (calleeType),
1016
+ /* callee=*/ nullptr , args,
1010
1017
/* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
1011
1018
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
1012
1019
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1017
1024
ValueRange args) {
1018
1025
auto calleeType = func.getFunctionType ();
1019
1026
build (builder, state, getCallOpResultTypes (calleeType),
1020
- TypeAttr::get (calleeType), SymbolRefAttr::get (func), args,
1027
+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), args,
1021
1028
/* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
1022
1029
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
1023
1030
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1076
1083
return success ();
1077
1084
}
1078
1085
1086
+ // / Verify that the parameter and return types of the variadic callee type match
1087
+ // / the `callOp` argument and result types.
1088
+ template <typename OpTy>
1089
+ LogicalResult verifyCallOpVarCalleeType (OpTy callOp) {
1090
+ std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType ();
1091
+ if (!varCalleeType)
1092
+ return success ();
1093
+
1094
+ // Verify the variadic callee type is a variadic function type.
1095
+ if (!varCalleeType->isVarArg ())
1096
+ return callOp.emitOpError (
1097
+ " expected var_callee_type to be a variadic function type" );
1098
+
1099
+ // Verify the variadic callee type has at most as many parameters as the call
1100
+ // has argument operands.
1101
+ if (varCalleeType->getNumParams () > callOp.getArgOperands ().size ())
1102
+ return callOp.emitOpError (" expected var_callee_type to have at most " )
1103
+ << callOp.getArgOperands ().size () << " parameters" ;
1104
+
1105
+ // Verify the variadic callee type matches the call argument types.
1106
+ for (auto [paramType, operand] :
1107
+ llvm::zip (varCalleeType->getParams (), callOp.getArgOperands ()))
1108
+ if (paramType != operand.getType ())
1109
+ return callOp.emitOpError ()
1110
+ << " var_callee_type parameter type mismatch: " << paramType
1111
+ << " != " << operand.getType ();
1112
+
1113
+ // Verify the variadic callee type matches the call result type.
1114
+ if (!callOp.getNumResults ()) {
1115
+ if (!isa<LLVMVoidType>(varCalleeType->getReturnType ()))
1116
+ return callOp.emitOpError (" expected var_callee_type to return void" );
1117
+ } else {
1118
+ if (callOp.getResult ().getType () != varCalleeType->getReturnType ())
1119
+ return callOp.emitOpError (" var_callee_type return type mismatch: " )
1120
+ << varCalleeType->getReturnType ()
1121
+ << " != " << callOp.getResult ().getType ();
1122
+ }
1123
+ return success ();
1124
+ }
1125
+
1079
1126
LogicalResult CallOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
1080
- if (getNumResults () > 1 )
1081
- return emitOpError ( " must have 0 or 1 result " );
1127
+ if (failed ( verifyCallOpVarCalleeType (* this )) )
1128
+ return failure ( );
1082
1129
1083
1130
// Type for the callee, we'll get it differently depending if it is a direct
1084
1131
// or indirect call.
@@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1120
1167
if (!funcType)
1121
1168
return emitOpError (" callee does not have a functional type: " ) << fnType;
1122
1169
1123
- if (funcType.isVarArg () && !getCalleeType ())
1124
- return emitOpError () << " missing callee type attribute for vararg call" ;
1170
+ if (funcType.isVarArg () && !getVarCalleeType ())
1171
+ return emitOpError () << " missing var_callee_type attribute for vararg call" ;
1125
1172
1126
1173
// Verify that the operand and result types match the callee.
1127
1174
@@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
1168
1215
auto callee = getCallee ();
1169
1216
bool isDirect = callee.has_value ();
1170
1217
1171
- LLVMFunctionType calleeType;
1172
- bool isVarArg = false ;
1173
-
1174
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1175
- calleeType = *optionalCalleeType;
1176
- isVarArg = calleeType.isVarArg ();
1177
- }
1178
-
1179
1218
p << ' ' ;
1180
1219
1181
1220
// Print calling convention.
@@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
1195
1234
auto args = getOperands ().drop_front (isDirect ? 0 : 1 );
1196
1235
p << ' (' << args << ' )' ;
1197
1236
1198
- if (isVarArg)
1199
- p << " vararg(" << calleeType << " )" ;
1237
+ // Print the variadic callee type if the call is variadic.
1238
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1239
+ p << " vararg(" << *varCalleeType << " )" ;
1200
1240
1201
1241
p.printOptionalAttrDict (processFMFAttr ((*this )->getAttrs ()),
1202
- {getCConvAttrName (), " callee " , " callee_type " ,
1203
- getTailCallKindAttrName ()});
1242
+ {getCalleeAttrName (), getTailCallKindAttrName () ,
1243
+ getVarCalleeTypeAttrName (), getCConvAttrName ()});
1204
1244
1205
1245
p << " : " ;
1206
1246
if (!isDirect)
@@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(
1270
1310
1271
1311
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1272
1312
// `(` ssa-use-list `)`
1273
- // ( `vararg(` var-arg-func -type `)` )?
1313
+ // ( `vararg(` var-callee -type `)` )?
1274
1314
// attribute-dict? `:` (type `,`)? function-type
1275
1315
ParseResult CallOp::parse (OpAsmParser &parser, OperationState &result) {
1276
1316
SymbolRefAttr funcAttr;
1277
- TypeAttr calleeType ;
1317
+ TypeAttr varCalleeType ;
1278
1318
SmallVector<OpAsmParser::UnresolvedOperand> operands;
1279
1319
1280
1320
// Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1305
1345
1306
1346
bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
1307
1347
if (isVarArg) {
1348
+ StringAttr varCalleeTypeAttrName =
1349
+ CallOp::getVarCalleeTypeAttrName (result.name );
1308
1350
if (parser.parseLParen ().failed () ||
1309
- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1351
+ parser
1352
+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1353
+ result.attributes )
1310
1354
.failed () ||
1311
1355
parser.parseRParen ().failed ())
1312
1356
return failure ();
@@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1320
1364
}
1321
1365
1322
1366
LLVMFunctionType CallOp::getCalleeFunctionType () {
1323
- if (getCalleeType ())
1324
- return *getCalleeType () ;
1367
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1368
+ return *varCalleeType ;
1325
1369
return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
1326
1370
}
1327
1371
@@ -1334,26 +1378,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1334
1378
Block *unwind, ValueRange unwindOps) {
1335
1379
auto calleeType = func.getFunctionType ();
1336
1380
build (builder, state, getCallOpResultTypes (calleeType),
1337
- TypeAttr::get (calleeType), SymbolRefAttr::get (func), ops, normalOps ,
1338
- unwindOps, nullptr , nullptr , normal , unwind);
1381
+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), ops,
1382
+ normalOps, unwindOps, nullptr , nullptr , normal , unwind);
1339
1383
}
1340
1384
1341
1385
void InvokeOp::build (OpBuilder &builder, OperationState &state, TypeRange tys,
1342
1386
FlatSymbolRefAttr callee, ValueRange ops, Block *normal ,
1343
1387
ValueRange normalOps, Block *unwind,
1344
1388
ValueRange unwindOps) {
1345
1389
build (builder, state, tys,
1346
- TypeAttr::get ( getLLVMFuncType (builder. getContext (), tys , ops)), callee ,
1347
- ops, normalOps, unwindOps, nullptr , nullptr , normal , unwind);
1390
+ /* var_callee_type= */ nullptr , callee , ops, normalOps, unwindOps, nullptr ,
1391
+ nullptr , normal , unwind);
1348
1392
}
1349
1393
1350
1394
void InvokeOp::build (OpBuilder &builder, OperationState &state,
1351
1395
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1352
1396
ValueRange ops, Block *normal , ValueRange normalOps,
1353
1397
Block *unwind, ValueRange unwindOps) {
1354
1398
build (builder, state, getCallOpResultTypes (calleeType),
1355
- TypeAttr::get (calleeType), callee, ops, normalOps, unwindOps, nullptr ,
1356
- nullptr , normal , unwind);
1399
+ getCallOpVarCalleeType (calleeType), callee, ops, normalOps, unwindOps,
1400
+ nullptr , nullptr , normal , unwind);
1357
1401
}
1358
1402
1359
1403
SuccessorOperands InvokeOp::getSuccessorOperands (unsigned index) {
@@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
1390
1434
}
1391
1435
1392
1436
LogicalResult InvokeOp::verify () {
1393
- if (getNumResults () > 1 )
1394
- return emitOpError ( " must have 0 or 1 result " );
1437
+ if (failed ( verifyCallOpVarCalleeType (* this )) )
1438
+ return failure ( );
1395
1439
1396
1440
Block *unwindDest = getUnwindDest ();
1397
1441
if (unwindDest->empty ())
@@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
1409
1453
auto callee = getCallee ();
1410
1454
bool isDirect = callee.has_value ();
1411
1455
1412
- LLVMFunctionType calleeType;
1413
- bool isVarArg = false ;
1414
-
1415
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1416
- calleeType = *optionalCalleeType;
1417
- isVarArg = calleeType.isVarArg ();
1418
- }
1419
-
1420
1456
p << ' ' ;
1421
1457
1422
1458
// Print calling convention.
@@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
1435
1471
p << " unwind " ;
1436
1472
p.printSuccessorAndUseList (getUnwindDest (), getUnwindDestOperands ());
1437
1473
1438
- if (isVarArg)
1439
- p << " vararg(" << calleeType << " )" ;
1474
+ // Print the variadic callee type if the invoke is variadic.
1475
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1476
+ p << " vararg(" << *varCalleeType << " )" ;
1440
1477
1441
1478
p.printOptionalAttrDict ((*this )->getAttrs (),
1442
- {InvokeOp::getOperandSegmentSizeAttr (), " callee " ,
1443
- " callee_type " , InvokeOp::getCConvAttrName ()});
1479
+ {getCalleeAttrName (), getOperandSegmentSizeAttr () ,
1480
+ getCConvAttrName (), getVarCalleeTypeAttrName ()});
1444
1481
1445
1482
p << " : " ;
1446
1483
if (!isDirect)
@@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
1453
1490
// `(` ssa-use-list `)`
1454
1491
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1455
1492
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1456
- // ( `vararg(` var-arg-func -type `)` )?
1493
+ // ( `vararg(` var-callee -type `)` )?
1457
1494
// attribute-dict? `:` (type `,`)? function-type
1458
1495
ParseResult InvokeOp::parse (OpAsmParser &parser, OperationState &result) {
1459
1496
SmallVector<OpAsmParser::UnresolvedOperand, 8 > operands;
1460
1497
SymbolRefAttr funcAttr;
1461
- TypeAttr calleeType ;
1498
+ TypeAttr varCalleeType ;
1462
1499
Block *normalDest, *unwindDest;
1463
1500
SmallVector<Value, 4 > normalOperands, unwindOperands;
1464
1501
Builder &builder = parser.getBuilder ();
@@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1488
1525
1489
1526
bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
1490
1527
if (isVarArg) {
1528
+ StringAttr varCalleeTypeAttrName =
1529
+ InvokeOp::getVarCalleeTypeAttrName (result.name );
1491
1530
if (parser.parseLParen ().failed () ||
1492
- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1531
+ parser
1532
+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1533
+ result.attributes )
1493
1534
.failed () ||
1494
1535
parser.parseRParen ().failed ())
1495
1536
return failure ();
@@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1515
1556
}
1516
1557
1517
1558
LLVMFunctionType InvokeOp::getCalleeFunctionType () {
1518
- if (getCalleeType ())
1519
- return *getCalleeType () ;
1559
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1560
+ return *varCalleeType ;
1520
1561
return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
1521
1562
}
1522
1563
0 commit comments