@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948
948
return results;
949
949
}
950
950
951
+ // / Gets the variadic callee type for a LLVMFunctionType.
952
+ static TypeAttr getCallOpVarCalleeType (LLVMFunctionType calleeType) {
953
+ return calleeType.isVarArg () ? TypeAttr::get (calleeType) : nullptr ;
954
+ }
955
+
951
956
// / Constructs a LLVMFunctionType from MLIR `results` and `args`.
952
957
static LLVMFunctionType getLLVMFuncType (MLIRContext *context, TypeRange results,
953
958
ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974
979
FlatSymbolRefAttr callee, ValueRange args) {
975
980
assert (callee && " expected non-null callee in direct call builder" );
976
981
build (builder, state, results,
977
- TypeAttr::get ( getLLVMFuncType (builder. getContext (), results , args)) ,
978
- callee, args, /* fastmathFlags= */ nullptr , /* branch_weights=*/ nullptr ,
982
+ /* var_callee_type= */ nullptr , callee , args, /* fastmathFlags= */ nullptr ,
983
+ /* branch_weights=*/ nullptr ,
979
984
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
980
985
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
981
986
/* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
997
1002
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
998
1003
ValueRange args) {
999
1004
build (builder, state, getCallOpResultTypes (calleeType),
1000
- TypeAttr::get (calleeType), callee, args, /* fastmathFlags=*/ nullptr ,
1005
+ getCallOpVarCalleeType (calleeType), callee, args,
1006
+ /* fastmathFlags=*/ nullptr ,
1001
1007
/* branch_weights=*/ nullptr , /* CConv=*/ nullptr ,
1002
1008
/* TailCallKind=*/ nullptr , /* access_groups=*/ nullptr ,
1003
1009
/* alias_scopes=*/ nullptr , /* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
1006
1012
void CallOp::build (OpBuilder &builder, OperationState &state,
1007
1013
LLVMFunctionType calleeType, ValueRange args) {
1008
1014
build (builder, state, getCallOpResultTypes (calleeType),
1009
- TypeAttr::get (calleeType), /* callee=*/ nullptr , args,
1015
+ getCallOpVarCalleeType (calleeType),
1016
+ /* callee=*/ nullptr , args,
1010
1017
/* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
1011
1018
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
1012
1019
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1017
1024
ValueRange args) {
1018
1025
auto calleeType = func.getFunctionType ();
1019
1026
build (builder, state, getCallOpResultTypes (calleeType),
1020
- TypeAttr::get (calleeType), SymbolRefAttr::get (func), args,
1027
+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), args,
1021
1028
/* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
1022
1029
/* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
1023
1030
/* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1080,6 +1087,12 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1080
1087
if (getNumResults () > 1 )
1081
1088
return emitOpError (" must have 0 or 1 result" );
1082
1089
1090
+ // Verify the variadic callee type is a variadic function type.
1091
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1092
+ if (!varCalleeType->isVarArg ())
1093
+ return emitOpError (
1094
+ " expected var_callee_type to be a variadic function type" );
1095
+
1083
1096
// Type for the callee, we'll get it differently depending if it is a direct
1084
1097
// or indirect call.
1085
1098
Type fnType;
@@ -1120,7 +1133,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1120
1133
if (!funcType)
1121
1134
return emitOpError (" callee does not have a functional type: " ) << fnType;
1122
1135
1123
- if (funcType.isVarArg () && !getCalleeType ())
1136
+ if (funcType.isVarArg () && !getVarCalleeType ())
1124
1137
return emitOpError () << " missing callee type attribute for vararg call" ;
1125
1138
1126
1139
// Verify that the operand and result types match the callee.
@@ -1168,14 +1181,6 @@ void CallOp::print(OpAsmPrinter &p) {
1168
1181
auto callee = getCallee ();
1169
1182
bool isDirect = callee.has_value ();
1170
1183
1171
- LLVMFunctionType calleeType;
1172
- bool isVarArg = false ;
1173
-
1174
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1175
- calleeType = *optionalCalleeType;
1176
- isVarArg = calleeType.isVarArg ();
1177
- }
1178
-
1179
1184
p << ' ' ;
1180
1185
1181
1186
// Print calling convention.
@@ -1195,11 +1200,13 @@ void CallOp::print(OpAsmPrinter &p) {
1195
1200
auto args = getOperands ().drop_front (isDirect ? 0 : 1 );
1196
1201
p << ' (' << args << ' )' ;
1197
1202
1198
- if (isVarArg)
1199
- p << " vararg(" << calleeType << " )" ;
1203
+ // Print the variadic callee type if the call is variadic.
1204
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1205
+ p << " vararg(" << *varCalleeType << " )" ;
1200
1206
1201
1207
p.printOptionalAttrDict (processFMFAttr ((*this )->getAttrs ()),
1202
- {getCConvAttrName (), " callee" , " callee_type" ,
1208
+ {getCConvAttrName (), " callee" ,
1209
+ getVarCalleeTypeAttrName (),
1203
1210
getTailCallKindAttrName ()});
1204
1211
1205
1212
p << " : " ;
@@ -1270,11 +1277,11 @@ static ParseResult parseOptionalCallFuncPtr(
1270
1277
1271
1278
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1272
1279
// `(` ssa-use-list `)`
1273
- // ( `vararg(` var-arg-func -type `)` )?
1280
+ // ( `vararg(` var-callee -type `)` )?
1274
1281
// attribute-dict? `:` (type `,`)? function-type
1275
1282
ParseResult CallOp::parse (OpAsmParser &parser, OperationState &result) {
1276
1283
SymbolRefAttr funcAttr;
1277
- TypeAttr calleeType ;
1284
+ TypeAttr varCalleeType ;
1278
1285
SmallVector<OpAsmParser::UnresolvedOperand> operands;
1279
1286
1280
1287
// Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1312,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1305
1312
1306
1313
bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
1307
1314
if (isVarArg) {
1315
+ StringAttr varCalleeTypeAttrName =
1316
+ CallOp::getVarCalleeTypeAttrName (result.name );
1308
1317
if (parser.parseLParen ().failed () ||
1309
- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1318
+ parser
1319
+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1320
+ result.attributes )
1310
1321
.failed () ||
1311
1322
parser.parseRParen ().failed ())
1312
1323
return failure ();
@@ -1320,8 +1331,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1320
1331
}
1321
1332
1322
1333
LLVMFunctionType CallOp::getCalleeFunctionType () {
1323
- if (getCalleeType ())
1324
- return *getCalleeType () ;
1334
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1335
+ return *varCalleeType ;
1325
1336
return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
1326
1337
}
1327
1338
@@ -1334,26 +1345,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1334
1345
Block *unwind, ValueRange unwindOps) {
1335
1346
auto calleeType = func.getFunctionType ();
1336
1347
build (builder, state, getCallOpResultTypes (calleeType),
1337
- TypeAttr::get (calleeType), SymbolRefAttr::get (func), ops, normalOps ,
1338
- unwindOps, nullptr , nullptr , normal , unwind);
1348
+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), ops,
1349
+ normalOps, unwindOps, nullptr , nullptr , normal , unwind);
1339
1350
}
1340
1351
1341
1352
void InvokeOp::build (OpBuilder &builder, OperationState &state, TypeRange tys,
1342
1353
FlatSymbolRefAttr callee, ValueRange ops, Block *normal ,
1343
1354
ValueRange normalOps, Block *unwind,
1344
1355
ValueRange unwindOps) {
1345
1356
build (builder, state, tys,
1346
- TypeAttr::get ( getLLVMFuncType (builder. getContext (), tys , ops)), callee ,
1347
- ops, normalOps, unwindOps, nullptr , nullptr , normal , unwind);
1357
+ /* var_callee_type= */ nullptr , callee , ops, normalOps, unwindOps, nullptr ,
1358
+ nullptr , normal , unwind);
1348
1359
}
1349
1360
1350
1361
void InvokeOp::build (OpBuilder &builder, OperationState &state,
1351
1362
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1352
1363
ValueRange ops, Block *normal , ValueRange normalOps,
1353
1364
Block *unwind, ValueRange unwindOps) {
1354
1365
build (builder, state, getCallOpResultTypes (calleeType),
1355
- TypeAttr::get (calleeType), callee, ops, normalOps, unwindOps, nullptr ,
1356
- nullptr , normal , unwind);
1366
+ getCallOpVarCalleeType (calleeType), callee, ops, normalOps, unwindOps,
1367
+ nullptr , nullptr , normal , unwind);
1357
1368
}
1358
1369
1359
1370
SuccessorOperands InvokeOp::getSuccessorOperands (unsigned index) {
@@ -1393,6 +1404,12 @@ LogicalResult InvokeOp::verify() {
1393
1404
if (getNumResults () > 1 )
1394
1405
return emitOpError (" must have 0 or 1 result" );
1395
1406
1407
+ // Verify the variadic callee type is a variadic function type.
1408
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1409
+ if (!varCalleeType->isVarArg ())
1410
+ return emitOpError (
1411
+ " expected var_callee_type to be a variadic function type" );
1412
+
1396
1413
Block *unwindDest = getUnwindDest ();
1397
1414
if (unwindDest->empty ())
1398
1415
return emitError (" must have at least one operation in unwind destination" );
@@ -1409,14 +1426,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
1409
1426
auto callee = getCallee ();
1410
1427
bool isDirect = callee.has_value ();
1411
1428
1412
- LLVMFunctionType calleeType;
1413
- bool isVarArg = false ;
1414
-
1415
- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1416
- calleeType = *optionalCalleeType;
1417
- isVarArg = calleeType.isVarArg ();
1418
- }
1419
-
1420
1429
p << ' ' ;
1421
1430
1422
1431
// Print calling convention.
@@ -1435,12 +1444,14 @@ void InvokeOp::print(OpAsmPrinter &p) {
1435
1444
p << " unwind " ;
1436
1445
p.printSuccessorAndUseList (getUnwindDest (), getUnwindDestOperands ());
1437
1446
1438
- if (isVarArg)
1439
- p << " vararg(" << calleeType << " )" ;
1447
+ // Print the variadic callee type if the invoke is variadic.
1448
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1449
+ p << " vararg(" << *varCalleeType << " )" ;
1440
1450
1441
1451
p.printOptionalAttrDict ((*this )->getAttrs (),
1442
1452
{InvokeOp::getOperandSegmentSizeAttr (), " callee" ,
1443
- " callee_type" , InvokeOp::getCConvAttrName ()});
1453
+ InvokeOp::getVarCalleeTypeAttrName (),
1454
+ InvokeOp::getCConvAttrName ()});
1444
1455
1445
1456
p << " : " ;
1446
1457
if (!isDirect)
@@ -1453,12 +1464,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
1453
1464
// `(` ssa-use-list `)`
1454
1465
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1455
1466
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1456
- // ( `vararg(` var-arg-func -type `)` )?
1467
+ // ( `vararg(` var-callee -type `)` )?
1457
1468
// attribute-dict? `:` (type `,`)? function-type
1458
1469
ParseResult InvokeOp::parse (OpAsmParser &parser, OperationState &result) {
1459
1470
SmallVector<OpAsmParser::UnresolvedOperand, 8 > operands;
1460
1471
SymbolRefAttr funcAttr;
1461
- TypeAttr calleeType ;
1472
+ TypeAttr varCalleeType ;
1462
1473
Block *normalDest, *unwindDest;
1463
1474
SmallVector<Value, 4 > normalOperands, unwindOperands;
1464
1475
Builder &builder = parser.getBuilder ();
@@ -1488,8 +1499,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1488
1499
1489
1500
bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
1490
1501
if (isVarArg) {
1502
+ StringAttr varCalleeTypeAttrName =
1503
+ InvokeOp::getVarCalleeTypeAttrName (result.name );
1491
1504
if (parser.parseLParen ().failed () ||
1492
- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1505
+ parser
1506
+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1507
+ result.attributes )
1493
1508
.failed () ||
1494
1509
parser.parseRParen ().failed ())
1495
1510
return failure ();
@@ -1515,8 +1530,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1515
1530
}
1516
1531
1517
1532
LLVMFunctionType InvokeOp::getCalleeFunctionType () {
1518
- if (getCalleeType ())
1519
- return *getCalleeType () ;
1533
+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1534
+ return *varCalleeType ;
1520
1535
return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
1521
1536
}
1522
1537
0 commit comments