Skip to content

Commit 97d497f

Browse files
authored
Reuse input buffer in lowering to krnl (#2939)
* first step Signed-off-by: chentong319 <[email protected]> * cpu Signed-off-by: chentong319 <[email protected]> * options Signed-off-by: chentong319 <[email protected]> * unify Signed-off-by: chentong319 <[email protected]> * simd Signed-off-by: chentong319 <[email protected]> * comments Signed-off-by: chentong319 <[email protected]> * lit test Signed-off-by: chentong319 <[email protected]> * fix test Signed-off-by: chentong319 <[email protected]> * format Signed-off-by: chentong319 <[email protected]> * response Signed-off-by: chentong319 <[email protected]> --------- Signed-off-by: chentong319 <[email protected]>
1 parent 02f45b0 commit 97d497f

File tree

4 files changed

+109
-9
lines changed

4 files changed

+109
-9
lines changed

src/Compiler/CompilerOptions.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
45+
bool enableKrnlBufferReuse; // common for both
4546
bool disableMemRefPrefetch; // common for both
4647
EmissionTargetType emissionTarget; // onnx-mlir only
4748
bool invokeOnnxVersionConverter; // onnx-mlir only
@@ -212,6 +213,14 @@ static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
212213
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
213214
llvm::cl::cat(OnnxMlirCommonOptions));
214215

216+
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
217+
"enable-krnl-buffer-reuse",
218+
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
219+
"(default=false)\n"
220+
"Set to 'true' if you want to enable buffer reuse."),
221+
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
222+
llvm::cl::cat(OnnxMlirCommonOptions));
223+
215224
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
216225
"disable-memref-prefetch",
217226
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"

src/Compiler/CompilerOptions.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
8787
extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
90+
extern bool enableKrnlBufferReuse; // common for both
9091
extern bool disableMemRefPrefetch; // common for both
9192
extern EmissionTargetType emissionTarget; // onnx-mlir only
9293
extern bool invokeOnnxVersionConverter; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+88-9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,82 @@ using namespace mlir;
2828

2929
namespace onnx_mlir {
3030

31+
// Check the input, x, can be reused as the output buffer
32+
bool isBufferReusable(Value x, MemRefType outputType) {
33+
if (!x.hasOneUse())
34+
return false;
35+
36+
Type xType = x.getType();
37+
auto inputType = dyn_cast<ShapedType>(xType);
38+
if (!inputType)
39+
return false;
40+
// Currently, only static shape could be reused.
41+
// ToFix: use DimAnalysis to handle dynamic shape.
42+
if (!hasStaticShape(inputType))
43+
return false;
44+
if (!hasStaticShape(outputType))
45+
return false;
46+
47+
// Currently reuse requires that the shape has to be the same.
48+
// ToFix: If the shape is not the same, memref.cast can be used.
49+
if (getRank(inputType) != getRank(outputType))
50+
return false;
51+
for (int64_t i = 0; i < getRank(inputType); i++) {
52+
if (inputType.getShape()[i] != outputType.getShape()[i])
53+
return false;
54+
}
55+
56+
// ToFix: The simd padding is not checked
57+
// We did not record whether the memref is padded or not.
58+
// The padding added to the memref the as an attribute, or not needed.
59+
return true;
60+
}
61+
62+
// Traverse the operands to find the candidate for buffer reuse.
63+
// Return -1, if no candidate is found.
64+
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
65+
for (size_t i = 0; i < values.size(); i++) {
66+
if (isBufferReusable(values[i], outputType))
67+
return i;
68+
}
69+
return -1;
70+
}
71+
72+
// Allocate memref (as before) if no input buffer can be reused.
73+
// Default VL=0 is used for non SIMD allocation
74+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
75+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76+
int64_t alignment, int64_t VL = 0);
77+
78+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
79+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80+
int64_t alignment, int64_t VL) {
81+
82+
int indexToReuse = -1;
83+
// By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84+
if (enableKrnlBufferReuse) {
85+
// Be aware to use the op->getOperands() to check the number of uses.
86+
// After buffer reuse, the number of uses of the transformed Value,
87+
// generatedOperands, will increase.
88+
indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
89+
}
90+
91+
if (indexToReuse != -1) {
92+
int size = getSizeInBytes(outputMemRefType);
93+
LLVM_DEBUG({
94+
llvm::dbgs() << " malloc_size " << size << "\n";
95+
op->dump();
96+
});
97+
return generatedOperands[indexToReuse];
98+
} else {
99+
if (VL == 0)
100+
return create.alignedAlloc(outputMemRefType, dims, alignment);
101+
else
102+
return create.alignedAllocWithSimdPadding(
103+
outputMemRefType, dims, VL, alignment);
104+
}
105+
}
106+
31107
// =============================================================================
32108

33109
/// Emit post-processing for variadic element-wise ops.
@@ -1323,14 +1399,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
13231399
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
13241400
DimsExpr outputDims;
13251401
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
1326-
// Alloc memory with padding for SIMD.
1402+
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
13271403
// For the moment, its ok to go here; if we truly have partial flattening of
13281404
// the simd code, then we only do it with static memref size that are
13291405
// multiples of VL * unrollVL, so there should be no padding anyway. This
13301406
// will change if we do partial flattening with non-multiple of VL *
13311407
// unrollVL.
1332-
Value alloc = create.mem.alignedAllocWithSimdPadding(
1333-
outputMemRefType, outputDims, VL, alignment);
1408+
Value alloc = allocOrReuse(
1409+
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
13341410
// Create flat inputs in the last innerDinNum dims.
13351411
llvm::SmallVector<Value, 4> flatOperands;
13361412
for (Value oper : operands) {
@@ -1975,8 +2051,9 @@ struct ONNXElementwiseUnaryOpLowering
19752051
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
19762052

19772053
// Insert an allocation for the result of this operation.
1978-
Value alloc = create.mem.alignedAlloc(
1979-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2054+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2055+
shapeHelper.getOutputDims(), alignment);
2056+
;
19802057

19812058
// Only create krnl.iterate if one of the operands is not scalar tensor.
19822059
if (!isScalar) {
@@ -2156,8 +2233,9 @@ struct ONNXElementwiseBinaryOpLowering
21562233
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
21572234

21582235
// Insert an allocation and deallocation for the result of this operation.
2159-
Value alloc = create.mem.alignedAlloc(
2160-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2236+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2237+
shapeHelper.getOutputDims(), alignment);
2238+
;
21612239

21622240
// Only create krnl.iterate if one of the operands is not scalar tensor.
21632241
if (!isScalar) {
@@ -2331,8 +2409,9 @@ struct ONNXElementwiseVariadicOpLowering
23312409
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
23322410

23332411
// Insert an allocation and deallocation for the result of this operation.
2334-
Value alloc = create.mem.alignedAlloc(
2335-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2412+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2413+
shapeHelper.getOutputDims(), alignment);
2414+
;
23362415

23372416
// Only create krnl.iterate if one of the operands is not scalar tensor.
23382417
if (!isScalar) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
2+
3+
// -----
4+
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
5+
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
6+
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32>
7+
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32>
8+
return %2 : tensor<1024xf32>
9+
}
10+
// CHECK-LABEL: func.func @test_reuse
11+
// CHECK-NOT: memref.alloc

0 commit comments

Comments
 (0)