Skip to content

Commit 27c0869

Browse files
update
Signed-off-by: Alexandre Eichenberger <[email protected]>
2 parents 630a97e + 97d497f commit 27c0869

File tree

4 files changed

+109
-9
lines changed

4 files changed

+109
-9
lines changed

src/Compiler/CompilerOptions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
4545
bool disableQuantZeroPoint; // common for both
46+
bool enableKrnlBufferReuse; // common for both
4647
bool disableMemRefPrefetch; // common for both
4748
EmissionTargetType emissionTarget; // onnx-mlir only
4849
bool invokeOnnxVersionConverter; // onnx-mlir only
@@ -222,6 +223,14 @@ static llvm::cl::opt<bool, true> disable_quantization_zero_point(
222223
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
223224
llvm::cl::cat(OnnxMlirCommonOptions));
224225

226+
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
227+
"enable-krnl-buffer-reuse",
228+
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
229+
"(default=false)\n"
230+
"Set to 'true' if you want to enable buffer reuse."),
231+
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
232+
llvm::cl::cat(OnnxMlirCommonOptions));
233+
225234
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
226235
"disable-memref-prefetch",
227236
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"

src/Compiler/CompilerOptions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
9090
extern bool disableQuantZeroPoint; // common for both
91+
extern bool enableKrnlBufferReuse; // common for both
9192
extern bool disableMemRefPrefetch; // common for both
9293
extern EmissionTargetType emissionTarget; // onnx-mlir only
9394
extern bool invokeOnnxVersionConverter; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,82 @@ using namespace mlir;
2828

2929
namespace onnx_mlir {
3030

31+
// Check the input, x, can be reused as the output buffer
32+
bool isBufferReusable(Value x, MemRefType outputType) {
33+
if (!x.hasOneUse())
34+
return false;
35+
36+
Type xType = x.getType();
37+
auto inputType = dyn_cast<ShapedType>(xType);
38+
if (!inputType)
39+
return false;
40+
// Currently, only static shape could be reused.
41+
// ToFix: use DimAnalysis to handle dynamic shape.
42+
if (!hasStaticShape(inputType))
43+
return false;
44+
if (!hasStaticShape(outputType))
45+
return false;
46+
47+
// Currently reuse requires that the shape has to be the same.
48+
// ToFix: If the shape is not the same, memref.cast can be used.
49+
if (getRank(inputType) != getRank(outputType))
50+
return false;
51+
for (int64_t i = 0; i < getRank(inputType); i++) {
52+
if (inputType.getShape()[i] != outputType.getShape()[i])
53+
return false;
54+
}
55+
56+
// ToFix: The simd padding is not checked
57+
// We did not record whether the memref is padded or not.
58+
// The padding added to the memref the as an attribute, or not needed.
59+
return true;
60+
}
61+
62+
// Traverse the operands to find the candidate for buffer reuse.
63+
// Return -1, if no candidate is found.
64+
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
65+
for (size_t i = 0; i < values.size(); i++) {
66+
if (isBufferReusable(values[i], outputType))
67+
return i;
68+
}
69+
return -1;
70+
}
71+
72+
// Allocate memref (as before) if no input buffer can be reused.
73+
// Default VL=0 is used for non SIMD allocation
74+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
75+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76+
int64_t alignment, int64_t VL = 0);
77+
78+
Value allocOrReuse(MemRefBuilder &create, Operation *op,
79+
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80+
int64_t alignment, int64_t VL) {
81+
82+
int indexToReuse = -1;
83+
// By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84+
if (enableKrnlBufferReuse) {
85+
// Be aware to use the op->getOperands() to check the number of uses.
86+
// After buffer reuse, the number of uses of the transformed Value,
87+
// generatedOperands, will increase.
88+
indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
89+
}
90+
91+
if (indexToReuse != -1) {
92+
int size = getSizeInBytes(outputMemRefType);
93+
LLVM_DEBUG({
94+
llvm::dbgs() << " malloc_size " << size << "\n";
95+
op->dump();
96+
});
97+
return generatedOperands[indexToReuse];
98+
} else {
99+
if (VL == 0)
100+
return create.alignedAlloc(outputMemRefType, dims, alignment);
101+
else
102+
return create.alignedAllocWithSimdPadding(
103+
outputMemRefType, dims, VL, alignment);
104+
}
105+
}
106+
31107
// =============================================================================
32108

33109
/// Emit post-processing for variadic element-wise ops.
@@ -1329,14 +1405,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
13291405
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
13301406
DimsExpr outputDims;
13311407
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
1332-
// Alloc memory with padding for SIMD.
1408+
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
13331409
// For the moment, its ok to go here; if we truly have partial flattening of
13341410
// the simd code, then we only do it with static memref size that are
13351411
// multiples of VL * unrollVL, so there should be no padding anyway. This
13361412
// will change if we do partial flattening with non-multiple of VL *
13371413
// unrollVL.
1338-
Value alloc = create.mem.alignedAllocWithSimdPadding(
1339-
outputMemRefType, outputDims, VL, alignment);
1414+
Value alloc = allocOrReuse(
1415+
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
13401416
// Create flat inputs in the last innerDinNum dims.
13411417
llvm::SmallVector<Value, 4> flatOperands;
13421418
for (Value oper : operands) {
@@ -1981,8 +2057,9 @@ struct ONNXElementwiseUnaryOpLowering
19812057
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
19822058

19832059
// Insert an allocation for the result of this operation.
1984-
Value alloc = create.mem.alignedAlloc(
1985-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2060+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2061+
shapeHelper.getOutputDims(), alignment);
2062+
;
19862063

19872064
// Only create krnl.iterate if one of the operands is not scalar tensor.
19882065
if (!isScalar) {
@@ -2162,8 +2239,9 @@ struct ONNXElementwiseBinaryOpLowering
21622239
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
21632240

21642241
// Insert an allocation and deallocation for the result of this operation.
2165-
Value alloc = create.mem.alignedAlloc(
2166-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2242+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2243+
shapeHelper.getOutputDims(), alignment);
2244+
;
21672245

21682246
// Only create krnl.iterate if one of the operands is not scalar tensor.
21692247
if (!isScalar) {
@@ -2337,8 +2415,9 @@ struct ONNXElementwiseVariadicOpLowering
23372415
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
23382416

23392417
// Insert an allocation and deallocation for the result of this operation.
2340-
Value alloc = create.mem.alignedAlloc(
2341-
outputMemRefType, shapeHelper.getOutputDims(), alignment);
2418+
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2419+
shapeHelper.getOutputDims(), alignment);
2420+
;
23422421

23432422
// Only create krnl.iterate if one of the operands is not scalar tensor.
23442423
if (!isScalar) {
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
2+
3+
// -----
4+
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
5+
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
6+
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32>
7+
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32>
8+
return %2 : tensor<1024xf32>
9+
}
10+
// CHECK-LABEL: func.func @test_reuse
11+
// CHECK-NOT: memref.alloc

0 commit comments

Comments
 (0)