Skip to content

Commit 4915c66

Browse files
committed
Revert "Reuse input buffer in lowering to krnl (onnx#2939)"
This reverts commit 97d497f.
1 parent 4ceeabc commit 4915c66

File tree

4 files changed

+11
-113
lines changed

4 files changed

+11
-113
lines changed

src/Compiler/CompilerOptions.cpp

+2-13
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@ bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
45-
<<<<<<< HEAD
46-
=======
47-
bool enableKrnlBufferReuse; // common for both
48-
bool disableMemRefPrefetch; // common for both
49-
>>>>>>> 97d497fa09e4cfa8a570d820aa01a76b8cda8728
45+
bool enableKrnlBufferReuse; // common for both
46+
bool disableMemRefPrefetch; // common for both
5047
EmissionTargetType emissionTarget; // onnx-mlir only
5148
bool invokeOnnxVersionConverter; // onnx-mlir only
5249
bool preserveLocations; // onnx-mlir only
@@ -217,14 +214,6 @@ static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
217214
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
218215
llvm::cl::cat(OnnxMlirCommonOptions));
219216

220-
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
221-
"enable-krnl-buffer-reuse",
222-
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
223-
"(default=false)\n"
224-
"Set to 'true' if you want to enable buffer reuse."),
225-
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
226-
llvm::cl::cat(OnnxMlirCommonOptions));
227-
228217
static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
229218
llvm::cl::desc("Disable recomposition of ONNX operations."),
230219
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),

src/Compiler/CompilerOptions.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ extern bool enableONNXHybridPass; // common for both
8787
extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
90-
extern bool enableKrnlBufferReuse; // common for both
9190
extern EmissionTargetType emissionTarget; // onnx-mlir only
9291
extern bool invokeOnnxVersionConverter; // onnx-mlir only
9392
extern bool preserveLocations; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+9-88
Original file line numberDiff line numberDiff line change
@@ -28,82 +28,6 @@ using namespace mlir;
2828

2929
namespace onnx_mlir {
3030

31-
// Check the input, x, can be reused as the output buffer
32-
bool isBufferReusable(Value x, MemRefType outputType) {
33-
if (!x.hasOneUse())
34-
return false;
35-
36-
Type xType = x.getType();
37-
auto inputType = dyn_cast<ShapedType>(xType);
38-
if (!inputType)
39-
return false;
40-
// Currently, only static shape could be reused.
41-
// ToFix: use DimAnalysis to handle dynamic shape.
42-
if (!hasStaticShape(inputType))
43-
return false;
44-
if (!hasStaticShape(outputType))
45-
return false;
46-
47-
// Currently reuse requires that the shape has to be the same.
48-
// ToFix: If the shape is not the same, memref.cast can be used.
49-
if (getRank(inputType) != getRank(outputType))
50-
return false;
51-
for (int64_t i = 0; i < getRank(inputType); i++) {
52-
if (inputType.getShape()[i] != outputType.getShape()[i])
53-
return false;
54-
}
55-
56-
// ToFix: The simd padding is not checked
57-
// We did not record whether the memref is padded or not.
58-
// The padding added to the memref the as an attribute, or not needed.
59-
return true;
60-
}
61-
62-
// Traverse the operands to find the candidate for buffer reuse.
63-
// Return -1, if no candidate is found.
64-
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
65-
for (size_t i = 0; i < values.size(); i++) {
66-
if (isBufferReusable(values[i], outputType))
67-
return i;
68-
}
69-
return -1;
70-
}
71-
72-
// Allocate memref (as before) if no input buffer can be reused.
73-
// Default VL=0 is used for non SIMD allocation
74-
Value allocOrReuse(MemRefBuilder &create, Operation *op,
75-
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76-
int64_t alignment, int64_t VL = 0);
77-
78-
Value allocOrReuse(MemRefBuilder &create, Operation *op,
79-
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80-
int64_t alignment, int64_t VL) {
81-
82-
int indexToReuse = -1;
83-
// By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84-
if (enableKrnlBufferReuse) {
85-
// Be aware to use the op->getOperands() to check the number of uses.
86-
// After buffer reuse, the number of uses of the transformed Value,
87-
// generatedOperands, will increase.
88-
indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
89-
}
90-
91-
if (indexToReuse != -1) {
92-
int size = getSizeInBytes(outputMemRefType);
93-
LLVM_DEBUG({
94-
llvm::dbgs() << " malloc_size " << size << "\n";
95-
op->dump();
96-
});
97-
return generatedOperands[indexToReuse];
98-
} else {
99-
if (VL == 0)
100-
return create.alignedAlloc(outputMemRefType, dims, alignment);
101-
else
102-
return create.alignedAllocWithSimdPadding(
103-
outputMemRefType, dims, VL, alignment);
104-
}
105-
}
106-
10731
// =============================================================================
10832

10933
/// Emit post-processing for variadic element-wise ops.
@@ -1399,14 +1323,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
13991323
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
14001324
DimsExpr outputDims;
14011325
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
1402-
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
1326+
// Alloc memory with padding for SIMD.
14031327
// For the moment, its ok to go here; if we truly have partial flattening of
14041328
// the simd code, then we only do it with static memref size that are
14051329
// multiples of VL * unrollVL, so there should be no padding anyway. This
14061330
// will change if we do partial flattening with non-multiple of VL *
14071331
// unrollVL.
1408-
Value alloc = allocOrReuse(
1409-
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
1332+
Value alloc = create.mem.alignedAllocWithSimdPadding(
1333+
outputMemRefType, outputDims, VL, alignment);
14101334
// Create flat inputs in the last innerDinNum dims.
14111335
llvm::SmallVector<Value, 4> flatOperands;
14121336
for (Value oper : operands) {
@@ -2051,9 +1975,8 @@ struct ONNXElementwiseUnaryOpLowering
20511975
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
20521976

20531977
// Insert an allocation for the result of this operation.
2054-
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2055-
shapeHelper.getOutputDims(), alignment);
2056-
;
1978+
Value alloc = create.mem.alignedAlloc(
1979+
outputMemRefType, shapeHelper.getOutputDims(), alignment);
20571980

20581981
// Only create krnl.iterate if one of the operands is not scalar tensor.
20591982
if (!isScalar) {
@@ -2233,9 +2156,8 @@ struct ONNXElementwiseBinaryOpLowering
22332156
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
22342157

22352158
// Insert an allocation and deallocation for the result of this operation.
2236-
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2237-
shapeHelper.getOutputDims(), alignment);
2238-
;
2159+
Value alloc = create.mem.alignedAlloc(
2160+
outputMemRefType, shapeHelper.getOutputDims(), alignment);
22392161

22402162
// Only create krnl.iterate if one of the operands is not scalar tensor.
22412163
if (!isScalar) {
@@ -2409,9 +2331,8 @@ struct ONNXElementwiseVariadicOpLowering
24092331
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);
24102332

24112333
// Insert an allocation and deallocation for the result of this operation.
2412-
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
2413-
shapeHelper.getOutputDims(), alignment);
2414-
;
2334+
Value alloc = create.mem.alignedAlloc(
2335+
outputMemRefType, shapeHelper.getOutputDims(), alignment);
24152336

24162337
// Only create krnl.iterate if one of the operands is not scalar tensor.
24172338
if (!isScalar) {

test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir

-11
This file was deleted.

0 commit comments

Comments
 (0)