Skip to content

Commit 3fbd9fe

Browse files
Merge branch 'main' into simd-framwork-v1
2 parents 9cdfc55 + 97d497f commit 3fbd9fe

File tree

4 files changed

+109
-9
lines changed

4 files changed

+109
-9
lines changed

src/Compiler/CompilerOptions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
45+
bool enableKrnlBufferReuse; // common for both
4546
bool disableMemRefPrefetch; // common for both
4647
EmissionTargetType emissionTarget; // onnx-mlir only
4748
bool invokeOnnxVersionConverter; // onnx-mlir only
@@ -212,6 +213,14 @@ static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
212213
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
213214
llvm::cl::cat(OnnxMlirCommonOptions));
214215

216+
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
217+
"enable-krnl-buffer-reuse",
218+
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
219+
"(default=false)\n"
220+
"Set to 'true' if you want to enable buffer reuse."),
221+
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
222+
llvm::cl::cat(OnnxMlirCommonOptions));
223+
215224
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
216225
"disable-memref-prefetch",
217226
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"

src/Compiler/CompilerOptions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
8787
extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
90+
extern bool enableKrnlBufferReuse; // common for both
9091
extern bool disableMemRefPrefetch; // common for both
9192
extern EmissionTargetType emissionTarget; // onnx-mlir only
9293
extern bool invokeOnnxVersionConverter; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,82 @@ using namespace mlir;
2828

2929
namespace onnx_mlir {
3030

31+
// Check the input, x, can be reused as the output buffer
32+
bool isBufferReusable(Value x, MemRefType outputType) {
33+
if (!x.hasOneUse())
34+
return false;
35+
36+
Type xType = x.getType();
37+
auto inputType = dyn_cast<ShapedType>(xType);
38+
if (!inputType)
39+
return false;
40+
// Currently, only static shape could be reused.
41+
// ToFix: use DimAnalysis to handle dynamic shape.
42+
if (!hasStaticShape(inputType))
43+
return false;
44+
if (!hasStaticShape(outputType))
45+
return false;
46+
47+
// Currently reuse requires that the shape has to be the same.
48+
// ToFix: If the shape is not the same, memref.cast can be used.
49+
if (getRank(inputType) != getRank(outputType))
50+
return false;
51+
for (int64_t i = 0; i < getRank(inputType); i++) {
52+
if (inputType.getShape()[i] != outputType.getShape()[i])
53+
return false;
54+
}
55+
56+
// ToFix: The simd padding is not checked
57+
// We did not record whether the memref is padded or not.
58+
// The padding added to the memref the as an attribute, or not needed.
59+
return true;
60+
}
61+
62+
// Traverse the operands to find the candidate for buffer reuse.
63+
// Return -1, if no candidate is found.
64+
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
65+
for (size_t i = 0; i < values.size(); i++) {
66+
if (isBufferReusable(values[i], outputType))
67+
return i;
68+
}
69+
return -1;
70+
}
71+
72+
// Allocate memref (as before) if no input buffer can be reused.
73+
// Default VL=0 is used for non SIMD allocation
74+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
75+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76+
int64_t alignment, int64_t VL = 0);
77+
78+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
79+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80+
int64_t alignment, int64_t VL) {
81+
82+
int indexToReuse = -1;
83+
// By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84+
if (enableKrnlBufferReuse) {
85+
// Be aware to use the op->getOperands() to check the number of uses.
86+
// After buffer reuse, the number of uses of the transformed Value,
87+
// generatedOperands, will increase.
88+
indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
89+
}
90+
91+
if (indexToReuse != -1) {
92+
int size = getSizeInBytes(outputMemRefType);
93+
LLVM_DEBUG({
94+
llvm::dbgs() << " malloc_size " << size << "\n";
95+
op->dump();
96+
});
97+
return generatedOperands[indexToReuse];
98+
} else {
99+
if (VL == 0)
100+
return create.alignedAlloc(outputMemRefType, dims, alignment);
101+
else
102+
return create.alignedAllocWithSimdPadding(
103+
outputMemRefType, dims, VL, alignment);
104+
}
105+
}
106+
31107
// =============================================================================
32108

33109
/// Emit post-processing for variadic element-wise ops.
@@ -1323,14 +1399,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
13231399
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
13241400
DimsExpr outputDims;
13251401
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
1326-
// Alloc memory with padding for SIMD.
1402+
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
13271403
// For the moment, its ok to go here; if we truly have partial flattening of
13281404
// the simd code, then we only do it with static memref size that are
13291405
// multiples of VL * unrollVL, so there should be no padding anyway. This
13301406
// will change if we do partial flattening with non-multiple of VL *
13311407
// unrollVL.
1332-
Value alloc = create.mem.alignedAllocWithSimdPadding(
1333-
outputMemRefType, outputDims, VL, alignment);
1408+
Value alloc = allocOrReuse(
1409+
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
13341410
// Create flat inputs in the last innerDinNum dims.
13351411
llvm::SmallVector<Value, 4> flatOperands;
13361412
for (Value oper : operands) {
@@ -1974,8 +2050,9 @@ struct ONNXElementwiseUnaryOpLowering
19742050
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
19752051

19762052
// Insert an allocation for the result of this operation.
1977-
Value alloc = create.mem.alignedAlloc(
1978-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2053+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2054+
shapeHelper.getOutputDims(), alignment);
2055+
;
19792056

19802057
// Only create krnl.iterate if one of the operands is not scalar tensor.
19812058
if (!isScalar) {
@@ -2155,8 +2232,9 @@ struct ONNXElementwiseBinaryOpLowering
21552232
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
21562233

21572234
// Insert an allocation and deallocation for the result of this operation.
2158-
Value alloc = create.mem.alignedAlloc(
2159-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2235+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2236+
shapeHelper.getOutputDims(), alignment);
2237+
;
21602238

21612239
// Only create krnl.iterate if one of the operands is not scalar tensor.
21622240
if (!isScalar) {
@@ -2330,8 +2408,9 @@ struct ONNXElementwiseVariadicOpLowering
23302408
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
23312409

23322410
// Insert an allocation and deallocation for the result of this operation.
2333-
Value alloc = create.mem.alignedAlloc(
2334-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2411+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2412+
shapeHelper.getOutputDims(), alignment);
2413+
;
23352414

23362415
// Only create krnl.iterate if one of the operands is not scalar tensor.
23372416
if (!isScalar) {
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
2+
3+
// -----
4+
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
5+
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
6+
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32>
7+
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32>
8+
return %2 : tensor<1024xf32>
9+
}
10+
// CHECK-LABEL: func.func @test_reuse
11+
// CHECK-NOT: memref.alloc

0 commit comments

Comments
 (0)