Skip to content

Commit 4c65ba2

Browse files
committed
lower to llvm
Signed-off-by: chentong319 <[email protected]>
1 parent 3c2ef1a commit 4c65ba2

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

src/Conversion/KrnlToLLVM/KrnlCall.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,27 @@ class KrnlCallOpLowering : public ConversionPattern {
6868
rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList);
6969
}
7070

71-
FlatSymbolRefAttr callRef =
72-
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
73-
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
74-
create.llvm.call({}, callRef, parameterList);
71+
ValueRange returns = op->getResults();
72+
if (returns.size() == 0) {
73+
// There is no return
74+
FlatSymbolRefAttr callRef =
75+
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
76+
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
77+
create.llvm.call({}, callRef, parameterList);
78+
79+
rewriter.eraseOp(op);
80+
} else {
81+
assert(returns.size() == 1 &&
82+
"Only one return value is allowed for krnl.call now");
83+
Type llvmReturnType =
84+
llvmTypeConverter->convertType(returns[0].getType());
85+
86+
FlatSymbolRefAttr callRef = create.llvm.getOrInsertSymbolRef(
87+
module, krnlCallOp.getFuncName(), llvmReturnType, parameterTypeList);
88+
auto llvmCall =
89+
create.llvm.call({llvmReturnType}, callRef, parameterList);
90+
rewriter.replaceOp(op, llvmCall.getDefiningOp()->getResults()[0]);
91+
}
7592

7693
// Destroy OMTensor wrappers of parameters.
7794
const auto &apiRegistry =
@@ -81,7 +98,6 @@ class KrnlCallOpLowering : public ConversionPattern {
8198
rewriter, loc, apiRegistry, RuntimeAPI::API::DESTROY_OMTENSOR, {omt});
8299
}
83100

84-
rewriter.eraseOp(op);
85101
return success();
86102
}
87103

src/Dialect/Krnl/Krnl.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def KrnlCallOp : Op<Krnl_Dialect, "call",
9494
// Only scalar type is supported now.
9595
// In future, return of memref can be supported with pointer of OMTensor.
9696
// The returned memref will be created inside the call.
97-
let results = (outs Variadic<AnyTypeOf<[F32, F64, I32, I64]>>:$returnValue);
97+
let results = (outs Variadic<AnyTypeOf<[AnyFloat, AnyInteger]>>:$returnValue);
9898

9999
// builders to build KrnlCallOp from op and operands, helping conversion from
100100
// onnx to krnl.

src/Dialect/Krnl/KrnlOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,12 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
156156

157157
void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
158158
std::string funcName, int64_t numOfOutput, ValueRange operands) {
159-
build(builder, odsState, {}, funcName, numOfOutput, operands);
159+
build(builder, odsState, {}, funcName, numOfOutput, operands);
160160
}
161161

162162
void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
163163
StringAttr funcName, IntegerAttr numOfOutput, ValueRange operands) {
164-
build(builder, odsState, {}, funcName, numOfOutput, operands);
164+
build(builder, odsState, {}, funcName, numOfOutput, operands);
165165
}
166166

167167
void KrnlCallOp::getEffects(

0 commit comments

Comments
 (0)