-
Notifications
You must be signed in to change notification settings - Fork 359
Reuse input buffer in lowering to krnl #2939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
1e991e4
4d728e8
4d4e938
c911e40
fa390ab
05b6aa0
4136c6f
e9941b2
3600425
c59908c
d954e6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,86 @@ using namespace mlir; | |
|
||
namespace onnx_mlir { | ||
|
||
// Check the input, x, can be reused as the output buffer | ||
bool isBufferReusable(Value x, MemRefType outputType) { | ||
if (!x.hasOneUse()) | ||
return false; | ||
|
||
Type xType = x.getType(); | ||
auto inputType = dyn_cast<ShapedType>(xType); | ||
if (!inputType) | ||
return false; | ||
// Currently, only static shape could be reused. | ||
// ToFix: use DimAnalysis to handle dynamic shape. | ||
if (!hasStaticShape(inputType)) | ||
return false; | ||
if (!hasStaticShape(outputType)) | ||
return false; | ||
|
||
// Currently reuse requires that the shape has to be the same. | ||
// ToFix: If the shape is not the same, memref.cast can be used. | ||
if (getRank(inputType) != getRank(outputType)) | ||
return false; | ||
for (int64_t i = 0; i < getRank(inputType); i++) { | ||
if (inputType.getShape()[i] != outputType.getShape()[i]) | ||
return false; | ||
} | ||
|
||
// ToFix: The simd padding is not checked | ||
// We did not record whether the memref is padded or not. | ||
// The padding added to the memref the as an attribute, or not needed. | ||
return true; | ||
} | ||
|
||
// Traverse the operands to find the candidate for buffer reuse. | ||
// Return -1, if no candidate is found. | ||
int whichBufferToReuse(ValueRange values, MemRefType outputType) { | ||
for (size_t i = 0; i < values.size(); i++) { | ||
if (isBufferReusable(values[i], outputType)) | ||
return i; | ||
} | ||
return -1; | ||
} | ||
|
||
// Allocate memref (as before) if no input buffer can be reused. | ||
// Default VL=0 is used for non SIMD allocation | ||
Value allocOrReuse(MemRefBuilder &create, Operation *op, | ||
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, | ||
int64_t alignment, int64_t VL = 0); | ||
|
||
Value allocOrReuse(MemRefBuilder &create, Operation *op, | ||
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, | ||
int64_t alignment, int64_t VL) { | ||
|
||
// By default, disableKrnlBufferReuse is true. Simply allocate a memref. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code could be simplified as follows if (!disableKrnlBufferReuse) {
int indexToReuse = xxx
if (indexToReuse != -1) return xxx
}
// no reuse, alloc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. Thanks for the suggestion. |
||
if (disableKrnlBufferReuse) { | ||
if (VL == 0) | ||
return create.alignedAlloc(outputMemRefType, dims, alignment); | ||
else | ||
return create.alignedAllocWithSimdPadding( | ||
outputMemRefType, dims, VL, alignment); | ||
} | ||
|
||
// Be aware to use the op->getOperands() to check the number of uses. | ||
// After buffer reuse, the number of uses of the transformed Value, | ||
// generatedOperands, will increase. | ||
int indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType); | ||
if (indexToReuse != -1) { | ||
int size = getSizeInBytes(outputMemRefType); | ||
LLVM_DEBUG({ | ||
llvm::dbgs() << " malloc_size " << size << "\n"; | ||
op->dump(); | ||
}); | ||
return generatedOperands[indexToReuse]; | ||
} else { | ||
if (VL == 0) | ||
return create.alignedAlloc(outputMemRefType, dims, alignment); | ||
else | ||
return create.alignedAllocWithSimdPadding( | ||
outputMemRefType, dims, VL, alignment); | ||
} | ||
} | ||
|
||
// ============================================================================= | ||
|
||
/// Emit post-processing for variadic element-wise ops. | ||
|
@@ -1323,14 +1403,14 @@ static LogicalResult getPartiallyFlattenedSimdCode( | |
IndexExprScope allocScope(create.vec, shapeHelper->getScope()); | ||
DimsExpr outputDims; | ||
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims); | ||
// Alloc memory with padding for SIMD. | ||
// Reuse the buffer from the input, or Alloc memory with padding for SIMD. | ||
// For the moment, its ok to go here; if we truly have partial flattening of | ||
// the simd code, then we only do it with static memref size that are | ||
// multiples of VL * unrollVL, so there should be no padding anyway. This | ||
// will change if we do partial flattening with non-multiple of VL * | ||
// unrollVL. | ||
Value alloc = create.mem.alignedAllocWithSimdPadding( | ||
outputMemRefType, outputDims, VL, alignment); | ||
Value alloc = allocOrReuse( | ||
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL); | ||
// Create flat inputs in the last innerDinNum dims. | ||
llvm::SmallVector<Value, 4> flatOperands; | ||
for (Value oper : operands) { | ||
|
@@ -1975,8 +2055,9 @@ struct ONNXElementwiseUnaryOpLowering | |
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); | ||
|
||
// Insert an allocation for the result of this operation. | ||
Value alloc = create.mem.alignedAlloc( | ||
outputMemRefType, shapeHelper.getOutputDims(), alignment); | ||
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, | ||
shapeHelper.getOutputDims(), alignment); | ||
; | ||
|
||
// Only create krnl.iterate if one of the operands is not scalar tensor. | ||
if (!isScalar) { | ||
|
@@ -2156,8 +2237,9 @@ struct ONNXElementwiseBinaryOpLowering | |
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); | ||
|
||
// Insert an allocation and deallocation for the result of this operation. | ||
Value alloc = create.mem.alignedAlloc( | ||
outputMemRefType, shapeHelper.getOutputDims(), alignment); | ||
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, | ||
shapeHelper.getOutputDims(), alignment); | ||
; | ||
|
||
// Only create krnl.iterate if one of the operands is not scalar tensor. | ||
if (!isScalar) { | ||
|
@@ -2331,8 +2413,9 @@ struct ONNXElementwiseVariadicOpLowering | |
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); | ||
|
||
// Insert an allocation and deallocation for the result of this operation. | ||
Value alloc = create.mem.alignedAlloc( | ||
outputMemRefType, shapeHelper.getOutputDims(), alignment); | ||
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, | ||
shapeHelper.getOutputDims(), alignment); | ||
; | ||
|
||
// Only create krnl.iterate if one of the operands is not scalar tensor. | ||
if (!isScalar) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --disable-krnl-buffer-reuse=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s | ||
|
||
// ----- | ||
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice concise test, we should all aspire to do that! |
||
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32> | ||
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32> | ||
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32> | ||
return %2 : tensor<1024xf32> | ||
} | ||
// CHECK-LABEL: func.func @test_reuse | ||
// CHECK-NOT: memref.alloc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: generally "disable" is for a function that is default on, "enable" is for one that is default off. I think you want "enable" here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed