Skip to content

[MLIR][LLVM] Always print variadic callee type #99293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024

Conversation

gysit
Copy link
Contributor

@gysit gysit commented Jul 17, 2024

This commit updates the LLVM dialect CallOp and InvokeOp to always print the variadic callee type (previously callee type) if present. An additional verifier checks that only variadic calls have a non-null variadic callee type, and the builders are adapted accordingly to set the variadic callee type for variadic calls only. Finally, the CallOp and InvokeOp verifiers are strengthened to check that the variadic callee type matches the call argument and result types.

The motivation of this change is that CallOp and InvokeOp don't have hidden state that is not pretty printed, but used during the export to LLVM IR. Previously, it could happen that a call looked correct in MLIR, but the return type changed after exporting to LLVM IR (since it has been taken from the hidden callee type attribute). After landing this change, this is not possible anymore since the variadic callee type is always printed if present.

@gysit gysit force-pushed the avoid-hidden-callee-type branch from 7e25f4d to 5ff39fb Compare July 17, 2024 12:42
@yoni-lavi
Copy link
Contributor

when we verify, shouldn't we still need to check the return type + all non-variadic args in the call agree with varCalleeType if it is non-null?

@gysit gysit force-pushed the avoid-hidden-callee-type branch 2 times, most recently from 860bfac to fc8bae8 Compare July 18, 2024 08:57
@gysit
Copy link
Contributor Author

gysit commented Jul 18, 2024

when we verify, shouldn't we still need to check the return type + all non-variadic args in the call agree with varCalleeType if it is non-null?

I added a verifier for the variadic callee type that compares with the call result and arguments.

@gysit gysit marked this pull request as ready for review July 18, 2024 18:09
@gysit gysit requested a review from Dinistro July 18, 2024 18:09
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Tobias Gysi (gysit)

Changes

This commit updates the LLVM dialect CallOp and InvokeOp to always print the variadic callee type (previously callee type) if present. An additional verifier checks that only variadic calls have a non-null variadic callee type, and the builders are adapted accordingly to set the variadic callee type for variadic calls only. Finally, the CallOp and InvokeOp verifiers are strengthened to check that the variadic callee type matches the call argument and result types.

The motivation of this change is that CallOp and InvokeOp don't have hidden state that is not pretty printed, but used during the export to LLVM IR. Previously, it could happen that a call looked correct in MLIR, but the return type changed after exporting to LLVM IR (since it has been taken from the hidden callee type attribute). After landing this change, this is not possible anymore since the variadic callee type is always printed if present.


Patch is 24.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99293.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+9-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+92-51)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+159-6)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 06656c791c594..d2d1fbaf304b2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -560,14 +560,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       DeclareOpInterfaceMethods<BranchWeightOpInterface>,
                       Terminator]> {
   let arguments = (ins
-                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                    OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
                    DefaultValuedAttr<CConv, "CConv::C">:$CConv);
-  let results = (outs Variadic<LLVM_Type>);
+  let results = (outs Optional<LLVM_Type>:$result);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
 
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     start with a function name (`@`-prefixed) and indirect calls start with an
     SSA value (`%`-prefixed). The direct callee, if present, is stored as a
     function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
-    and is stored as the first value in `callee_operands`. If the callee is a variadic
-    function, then the `callee_type` attribute must carry the function type. The
-    trailing type list contains the optional indirect callee type and the MLIR
-    function type, which differs from the LLVM function type that uses a explicit
-    void type to model functions that do not return a value.
+    and is stored as the first value in `callee_operands`. If and only if the
+    callee is a variadic function, the `var_callee_type` attribute must carry
+    the variadic LLVM function type. The trailing type list contains the
+    optional indirect callee type and the MLIR function type, which differs from
+    the LLVM function type that uses an explicit void type to model functions
+    that do not return a value.
 
     Examples:
 
@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..b572b79d089a6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
   return results;
 }
 
+/// Gets the variadic callee type for a LLVMFunctionType.
+static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
+  return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+}
+
 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
                                         ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
-        callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+        getCallOpVarCalleeType(calleeType), callee, args,
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+        getCallOpVarCalleeType(calleeType),
+        /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
   return success();
 }
 
+/// Verify that the parameter and return types of the variadic callee type match
+/// the `callOp` argument and result types.
+template <typename OpTy>
+LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
+  std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
+  if (!varCalleeType)
+    return success();
+
+  // Verify the variadic callee type is a variadic function type.
+  if (!varCalleeType->isVarArg())
+    return callOp.emitOpError(
+        "expected var_callee_type to be a variadic function type");
+
+  // Verify the variadic callee type has at most as many parameters as the call
+  // has argument operands.
+  if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
+    return callOp.emitOpError("expected var_callee_type to have at most ")
+           << callOp.getArgOperands().size() << " parameters";
+
+  // Verify the variadic callee type matches the call argument types.
+  for (auto [paramType, operand] :
+       llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
+    if (paramType != operand.getType())
+      return callOp.emitOpError()
+             << "var_callee_type parameter type mismatch: " << paramType
+             << " != " << operand.getType();
+
+  // Verify the variadic callee type matches the call result type.
+  if (!callOp.getNumResults()) {
+    if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
+      return callOp.emitOpError("expected var_callee_type to return void");
+  } else {
+    if (callOp.getResult().getType() != varCalleeType->getReturnType())
+      return callOp.emitOpError("var_callee_type return type mismatch: ")
+             << varCalleeType->getReturnType()
+             << " != " << callOp.getResult().getType();
+  }
+  return success();
+}
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  if (getNumResults() > 1)
-    return emitOpError("must have 0 or 1 result");
+  if (failed(verifyCallOpVarCalleeType(*this)))
+    return failure();
 
   // Type for the callee, we'll get it differently depending if it is a direct
   // or indirect call.
@@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (!funcType)
     return emitOpError("callee does not have a functional type: ") << fnType;
 
-  if (funcType.isVarArg() && !getCalleeType())
-    return emitOpError() << "missing callee type attribute for vararg call";
+  if (funcType.isVarArg() && !getVarCalleeType())
+    return emitOpError() << "missing var_callee_type attribute for vararg call";
 
   // Verify that the operand and result types match the callee.
 
@@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the call is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {getCConvAttrName(), "callee", "callee_type",
-                           getTailCallKindAttrName()});
+                          {getCalleeAttrName(), getTailCallKindAttrName(),
+                           getVarCalleeTypeAttrName(), getCConvAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(
 
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
-//                             ( `vararg(` var-arg-func-type `)` )?
+//                             ( `vararg(` var-callee-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        CallOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType CallOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
@@ -1334,8 +1378,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                      Block *unwind, ValueRange unwindOps) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
+        normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,8 +1387,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+        nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1352,8 +1396,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
+        nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
 }
 
 LogicalResult InvokeOp::verify() {
-  if (getNumResults() > 1)
-    return emitOpError("must have 0 or 1 result");
+  if (failed(verifyCallOpVarCalleeType(*this)))
+    return failure();
 
   Block *unwindDest = getUnwindDest();
   if (unwindDest->empty())
@@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the invoke is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict((*this)->getAttrs(),
-                          {InvokeOp::getOperandSegmentSizeAttr(), "callee",
-                           "callee_type", InvokeOp::getCConvAttrName()});
+                          {getCalleeAttrName(), getOperandSegmentSizeAttr(),
+                           getCConvAttrName(), getVarCalleeTypeAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
-//                  ( `vararg(` var-arg-func-type `)` )?
+//                  ( `vararg(` var-callee-type `)` )?
 //                  attribute-dict? `:` (type `,`)? function-type
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        InvokeOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..fe288dab973f5 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,10 +1415,163 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
 
 // -----
 
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_var_callee_type(%arg: i32)  {
+  // expected-error@below {{expected var_callee_type to be a variadic function type}}
+  llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_num_parameters(%arg: i32)  {
+  // expected-error@below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<void (i32, i64, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_num_parameters_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error@below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.call %callee(%arg) vararg(!llvm.func<void (i32, i64, ...)>) : !llvm.ptr, (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch(%arg: i32)  {
+  // expected-error@below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<void (i64, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error@below {{var_callee_type parameter type mismatch: 'i64' != 'i32'}}
+  llvm.call %callee(%arg) vararg(!llvm.func<void (i64, ...)>) : !llvm.ptr, (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_non_void(%arg: i32)  {
+  // expected-error@below {{expected var_callee_type to return void}}
+  llvm.call @variadic(%arg) vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...) -> i32
+
+llvm.func @invalid_var_callee_type_return_type_mismatch(%arg: i32)  {
+  // expected-error@below {{var_callee_type return type mismatch: 'i8' != 'i32'}}
+  %0 = llvm.call @variadic(%arg) vararg(!llvm.func<i8 (i32, ...)>) : (i32) -> (i32)
+  llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_var_callee_type(%arg: i32)  {
+  // expected-error@below {{expected var_callee_type to be a variadic function type}}
+  llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_num_parameters(%arg: i32)  {
+  // expected-error@below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.invoke @variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32, i64, ...)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_var_callee_type_num_parameters_indirect(%callee : !llvm.ptr, %arg: i32)  {
+  // expected-error@below {{expected var_callee_type to have at most 1 parameters}}
+  llvm.invoke %callee(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32, i64, ...)>) : !llvm.ptr, (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
+llvm.func @variadic(%arg: i32, ...)
+
+llvm.func @invalid_var_callee_type_parameter_type_mismatch(%arg: i32)  {
+  // expected-error@below {{va...
[truncated]

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix.

This commit updates the LLVM dialect CallOp and InvokeOp to always print
the variadic callee type (previously callee type) if present. An
additional verifier checks that only variadic calls have a non-null
variadic callee type, and the builders are adapted accordingly to
set the variadic callee type for variadic calls only. Finally, the
CallOp and InvokeOp verifiers are strengthened to check that the
variadic callee type matches the call argument and result types.

The motivation of this change is that CallOp and InvokeOp don't have
hidden state that is not pretty printed, but used during the export to
LLVM IR. Previously, it could happen that a call looked correct in MLIR,
but the return type changed after exporting to LLVM IR (since it has
been taken from the hidden callee type attribute). After landing this
change, this is not possible anymore since the variadic callee type is
always printed if present.
@gysit gysit force-pushed the avoid-hidden-callee-type branch from fc8bae8 to 6b774e0 Compare July 23, 2024 05:55
@gysit gysit merged commit 5da4310 into llvm:main Jul 23, 2024
7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
This commit updates the LLVM dialect CallOp and InvokeOp to always print
the variadic callee type (previously callee type) if present. An
additional verifier checks that only variadic calls have a non-null
variadic callee type, and the builders are adapted accordingly to set
the variadic callee type for variadic calls only. Finally, the CallOp
and InvokeOp verifiers are strengthened to check that the variadic
callee type matches the call argument and result types.

The motivation of this change is that CallOp and InvokeOp don't have
hidden state that is not pretty printed, but used during the export to
LLVM IR. Previously, it could happen that a call looked correct in MLIR,
but the return type changed after exporting to LLVM IR (since it has
been taken from the hidden callee type attribute). After landing this
change, this is not possible anymore since the variadic callee type is
always printed if present.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251037
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants