Skip to content

[MLIR][LLVM] Always print variadic callee type #99293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -560,14 +560,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
let arguments = (ins
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
let results = (outs Variadic<LLVM_Type>);
let results = (outs Optional<LLVM_Type>:$result);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);

Expand Down Expand Up @@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
start with a function name (`@`-prefixed) and indirect calls start with an
SSA value (`%`-prefixed). The direct callee, if present, is stored as a
function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
and is stored as the first value in `callee_operands`. If the callee is a variadic
function, then the `callee_type` attribute must carry the function type. The
trailing type list contains the optional indirect callee type and the MLIR
function type, which differs from the LLVM function type that uses a explicit
void type to model functions that do not return a value.
and is stored as the first value in `callee_operands`. If and only if the
callee is a variadic function, the `var_callee_type` attribute must carry
the variadic LLVM function type. The trailing type list contains the
optional indirect callee type and the MLIR function type, which differs from
the LLVM function type that uses an explicit void type to model functions
that do not return a value.

Examples:

Expand All @@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
```
}];

dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
Expand Down
143 changes: 92 additions & 51 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
return results;
}

/// Gets the variadic callee type for a LLVMFunctionType.
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
}

/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
ValueRange args) {
Expand All @@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
Expand All @@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
getCallOpVarCalleeType(calleeType), callee, args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
Expand All @@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
getCallOpVarCalleeType(calleeType),
/*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
Expand All @@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
Expand Down Expand Up @@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
return success();
}

/// Verify that the parameter and return types of the variadic callee type match
/// the `callOp` argument and result types.
template <typename OpTy>
LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
if (!varCalleeType)
return success();

// Verify the variadic callee type is a variadic function type.
if (!varCalleeType->isVarArg())
return callOp.emitOpError(
"expected var_callee_type to be a variadic function type");

// Verify the variadic callee type has at most as many parameters as the call
// has argument operands.
if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
return callOp.emitOpError("expected var_callee_type to have at most ")
<< callOp.getArgOperands().size() << " parameters";

// Verify the variadic callee type matches the call argument types.
for (auto [paramType, operand] :
llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
if (paramType != operand.getType())
return callOp.emitOpError()
<< "var_callee_type parameter type mismatch: " << paramType
<< " != " << operand.getType();

// Verify the variadic callee type matches the call result type.
if (!callOp.getNumResults()) {
if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
return callOp.emitOpError("expected var_callee_type to return void");
} else {
if (callOp.getResult().getType() != varCalleeType->getReturnType())
return callOp.emitOpError("var_callee_type return type mismatch: ")
<< varCalleeType->getReturnType()
<< " != " << callOp.getResult().getType();
}
return success();
}

LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();

// Type for the callee, we'll get it differently depending if it is a direct
// or indirect call.
Expand Down Expand Up @@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (!funcType)
return emitOpError("callee does not have a functional type: ") << fnType;

if (funcType.isVarArg() && !getCalleeType())
return emitOpError() << "missing callee type attribute for vararg call";
if (funcType.isVarArg() && !getVarCalleeType())
return emitOpError() << "missing var_callee_type attribute for vararg call";

// Verify that the operand and result types match the callee.

Expand Down Expand Up @@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();

LLVMFunctionType calleeType;
bool isVarArg = false;

if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
calleeType = *optionalCalleeType;
isVarArg = calleeType.isVarArg();
}

p << ' ';

// Print calling convention.
Expand All @@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';

if (isVarArg)
p << " vararg(" << calleeType << ")";
// Print the variadic callee type if the call is variadic.
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";

p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getCConvAttrName(), "callee", "callee_type",
getTailCallKindAttrName()});
{getCalleeAttrName(), getTailCallKindAttrName(),
getVarCalleeTypeAttrName(), getCConvAttrName()});

p << " : ";
if (!isDirect)
Expand Down Expand Up @@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(

// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
// `(` ssa-use-list `)`
// ( `vararg(` var-arg-func-type `)` )?
// ( `vararg(` var-callee-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
TypeAttr calleeType;
TypeAttr varCalleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;

// Default to C Calling Convention if no keyword is provided.
Expand Down Expand Up @@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {

bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
StringAttr varCalleeTypeAttrName =
CallOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
parser.parseAttribute(calleeType, "callee_type", result.attributes)
parser
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
Expand All @@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
}

LLVMFunctionType CallOp::getCalleeFunctionType() {
if (getCalleeType())
return *getCalleeType();
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}

Expand All @@ -1334,26 +1378,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
Block *unwind, ValueRange unwindOps) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
unwindOps, nullptr, nullptr, normal, unwind);
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
normalOps, unwindOps, nullptr, nullptr, normal, unwind);
}

void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
ValueRange normalOps, Block *unwind,
ValueRange unwindOps) {
build(builder, state, tys,
TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
/*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
nullptr, normal, unwind);
}

void InvokeOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange ops, Block *normal, ValueRange normalOps,
Block *unwind, ValueRange unwindOps) {
build(builder, state, getCallOpResultTypes(calleeType),
TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
nullptr, normal, unwind);
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
nullptr, nullptr, normal, unwind);
}

SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
Expand Down Expand Up @@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
}

LogicalResult InvokeOp::verify() {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();

Block *unwindDest = getUnwindDest();
if (unwindDest->empty())
Expand All @@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();

LLVMFunctionType calleeType;
bool isVarArg = false;

if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
calleeType = *optionalCalleeType;
isVarArg = calleeType.isVarArg();
}

p << ' ';

// Print calling convention.
Expand All @@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
p << " unwind ";
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());

if (isVarArg)
p << " vararg(" << calleeType << ")";
// Print the variadic callee type if the invoke is variadic.
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";

p.printOptionalAttrDict((*this)->getAttrs(),
{InvokeOp::getOperandSegmentSizeAttr(), "callee",
"callee_type", InvokeOp::getCConvAttrName()});
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
getCConvAttrName(), getVarCalleeTypeAttrName()});

p << " : ";
if (!isDirect)
Expand All @@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
// `(` ssa-use-list `)`
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
// ( `vararg(` var-arg-func-type `)` )?
// ( `vararg(` var-callee-type `)` )?
// attribute-dict? `:` (type `,`)? function-type
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
TypeAttr calleeType;
TypeAttr varCalleeType;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
Expand Down Expand Up @@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {

bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
StringAttr varCalleeTypeAttrName =
InvokeOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
parser.parseAttribute(calleeType, "callee_type", result.attributes)
parser
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
Expand All @@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
}

LLVMFunctionType InvokeOp::getCalleeFunctionType() {
if (getCalleeType())
return *getCalleeType();
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}

Expand Down
Loading
Loading