Skip to content

Commit 5ea61d9

Browse files
authored
Merge branch 'main' into mem_reduction_stickified
2 parents 1d4ed1b + 087f069 commit 5ea61d9

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

src/Conversion/KrnlToLLVM/KrnlCall.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,27 @@ class KrnlCallOpLowering : public ConversionPattern {
6868
rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList);
6969
}
7070

71-
FlatSymbolRefAttr callRef =
72-
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
73-
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
74-
create.llvm.call({}, callRef, parameterList);
71+
ValueRange returns = op->getResults();
72+
if (returns.size() == 0) {
73+
// There is no return
74+
FlatSymbolRefAttr callRef =
75+
create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(),
76+
LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList);
77+
create.llvm.call({}, callRef, parameterList);
78+
79+
rewriter.eraseOp(op);
80+
} else {
81+
assert(returns.size() == 1 &&
82+
"Only one return value is allowed for krnl.call now");
83+
Type llvmReturnType =
84+
llvmTypeConverter->convertType(returns[0].getType());
85+
86+
FlatSymbolRefAttr callRef = create.llvm.getOrInsertSymbolRef(
87+
module, krnlCallOp.getFuncName(), llvmReturnType, parameterTypeList);
88+
auto llvmCall =
89+
create.llvm.call({llvmReturnType}, callRef, parameterList);
90+
rewriter.replaceOp(op, llvmCall.getDefiningOp()->getResults()[0]);
91+
}
7592

7693
// Destroy OMTensor wrappers of parameters.
7794
const auto &apiRegistry =
@@ -81,7 +98,6 @@ class KrnlCallOpLowering : public ConversionPattern {
8198
rewriter, loc, apiRegistry, RuntimeAPI::API::DESTROY_OMTENSOR, {omt});
8299
}
83100

84-
rewriter.eraseOp(op);
85101
return success();
86102
}
87103

src/Dialect/Krnl/Krnl.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,22 @@ def KrnlCallOp : Op<Krnl_Dialect, "call",
9090
DefaultValuedAttr<SI64Attr, "1">:$numOfOutput,
9191
Variadic<AnyType>:$parameters);
9292

93+
// Return Value for the Call.
94+
// No return if the type is NoneType (void in llvm)
95+
// Only scalar type is supported now.
96+
// In future, return of memref can be supported with pointer of OMTensor.
97+
// The returned memref will be created inside the call.
98+
let results = (outs Variadic<AnyTypeOf<[AnyFloat, AnyInteger]>>:$returnValue);
99+
93100
// builders to build KrnlCallOp from op and operands, helping conversion from
94101
// onnx to krnl.
95102
// The name of function can be determined by the op name and elemnt type of
96103
// the return, or given to builder if the simple rule does not work.
97104
// Attributes of the op will be propagated to KrnlCallOp if the copyAttrs is
98105
// true. Or the attribute names can be specified.
99106
let builders = [
107+
OpBuilder<(ins "std::string":$funcNameStr, "int64_t":$numOfOutput, "mlir::ValueRange":$operands)>,
108+
OpBuilder<(ins "mlir::StringAttr":$funcNameStr, "IntegerAttr":$numOfOutput, "mlir::ValueRange":$operands)>,
100109
OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector<std::string>":$attributeNames)>,
101110
OpBuilder<(ins "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>,
102111
OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector<std::string>":$attributeNames)>,

src/Dialect/Krnl/KrnlOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
156156
build(builder, odsState, funcNameStr, resultVals, op, operands, copyAttrs);
157157
}
158158

159+
void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
160+
std::string funcName, int64_t numOfOutput, ValueRange operands) {
161+
build(builder, odsState, {}, funcName, numOfOutput, operands);
162+
}
163+
164+
void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState,
165+
StringAttr funcName, IntegerAttr numOfOutput, ValueRange operands) {
166+
build(builder, odsState, {}, funcName, numOfOutput, operands);
167+
}
168+
159169
void KrnlCallOp::getEffects(
160170
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
161171
&effects) {
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: onnx-mlir-opt --convert-krnl-to-llvm %s -split-input-file | FileCheck %s
2+
3+
func.func private @test_krnl_call_with_return(%arg0: memref<2x3xi32>) -> i32 {
4+
%1 = "krnl.call"() {funcName = "get_omp_num_thread", numOfOutput = 0 : si64} : () -> (i32)
5+
func.return %1: i32
6+
// CHECK: llvm.func @get_omp_num_thread() -> i32
7+
// CHECK: llvm.func @test_krnl_call_with_return
8+
// CHECK: [[VAR_0_:%.+]] = llvm.call @get_omp_num_thread() : () -> i32
9+
// CHECK: llvm.return [[VAR_0_]] : i32
10+
}

0 commit comments

Comments
 (0)