Skip to content

Reuse input buffer in lowering to krnl #2939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableKrnlBufferReuse; // common for both
bool disableMemRefPrefetch; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down Expand Up @@ -212,6 +213,16 @@ static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableKrnlBufferReuseOpt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: generally "disable" is for a function that is default on, "enable" is for one that is default off. I think you want "enable" here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"disable-krnl-buffer-reuse",
llvm::cl::desc("disable buffer reuse within an op in onnx-to-krnl pass"
"(default=true)\n"
"Set to 'false' if you want to enable buffer reuse."
"Default value will be false when the functionality becomes"
"stable."),
llvm::cl::location(disableKrnlBufferReuse), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableKrnlBufferReuse; // common for both
extern bool disableMemRefPrefetch; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
extern bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down
101 changes: 92 additions & 9 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,86 @@ using namespace mlir;

namespace onnx_mlir {

// Check the input, x, can be reused as the output buffer
bool isBufferReusable(Value x, MemRefType outputType) {
if (!x.hasOneUse())
return false;

Type xType = x.getType();
auto inputType = dyn_cast<ShapedType>(xType);
if (!inputType)
return false;
// Currently, only static shape could be reused.
// ToFix: use DimAnalysis to handle dynamic shape.
if (!hasStaticShape(inputType))
return false;
if (!hasStaticShape(outputType))
return false;

// Currently reuse requires that the shape has to be the same.
// ToFix: If the shape is not the same, memref.cast can be used.
if (getRank(inputType) != getRank(outputType))
return false;
for (int64_t i = 0; i < getRank(inputType); i++) {
if (inputType.getShape()[i] != outputType.getShape()[i])
return false;
}

// ToFix: The simd padding is not checked
// We did not record whether the memref is padded or not.
// The padding added to the memref the as an attribute, or not needed.
return true;
}

// Traverse the operands to find the candidate for buffer reuse.
// Return -1, if no candidate is found.
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
for (size_t i = 0; i < values.size(); i++) {
if (isBufferReusable(values[i], outputType))
return i;
}
return -1;
}

// Allocate memref (as before) if no input buffer can be reused.
// Default VL=0 is used for non SIMD allocation
Value allocOrReuse(MemRefBuilder &create, Operation *op,
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
int64_t alignment, int64_t VL = 0);

Value allocOrReuse(MemRefBuilder &create, Operation *op,
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
int64_t alignment, int64_t VL) {

// By default, disableKrnlBufferReuse is true. Simply allocate a memref.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code could be simplified as follows

if (!disableKrnlBufferReuse) {
  int indexToReuse = xxx
  if (indexToReuse != -1) return xxx
}
// no reuse, alloc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks for the suggestion.

if (disableKrnlBufferReuse) {
if (VL == 0)
return create.alignedAlloc(outputMemRefType, dims, alignment);
else
return create.alignedAllocWithSimdPadding(
outputMemRefType, dims, VL, alignment);
}

// Be aware to use the op->getOperands() to check the number of uses.
// After buffer reuse, the number of uses of the transformed Value,
// generatedOperands, will increase.
int indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
if (indexToReuse != -1) {
int size = getSizeInBytes(outputMemRefType);
LLVM_DEBUG({
llvm::dbgs() << " malloc_size " << size << "\n";
op->dump();
});
return generatedOperands[indexToReuse];
} else {
if (VL == 0)
return create.alignedAlloc(outputMemRefType, dims, alignment);
else
return create.alignedAllocWithSimdPadding(
outputMemRefType, dims, VL, alignment);
}
}

// =============================================================================

/// Emit post-processing for variadic element-wise ops.
Expand Down Expand Up @@ -1323,14 +1403,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
DimsExpr outputDims;
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
// Alloc memory with padding for SIMD.
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
// For the moment, its ok to go here; if we truly have partial flattening of
// the simd code, then we only do it with static memref size that are
// multiples of VL * unrollVL, so there should be no padding anyway. This
// will change if we do partial flattening with non-multiple of VL *
// unrollVL.
Value alloc = create.mem.alignedAllocWithSimdPadding(
outputMemRefType, outputDims, VL, alignment);
Value alloc = allocOrReuse(
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
// Create flat inputs in the last innerDinNum dims.
llvm::SmallVector<Value, 4> flatOperands;
for (Value oper : operands) {
Expand Down Expand Up @@ -1975,8 +2055,9 @@ struct ONNXElementwiseUnaryOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down Expand Up @@ -2156,8 +2237,9 @@ struct ONNXElementwiseBinaryOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down Expand Up @@ -2331,8 +2413,9 @@ struct ONNXElementwiseVariadicOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down
11 changes: 11 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --disable-krnl-buffer-reuse=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s

// -----
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice concise test, we should all aspire to do that!

%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32>
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32>
return %2 : tensor<1024xf32>
}
// CHECK-LABEL: func.func @test_reuse
// CHECK-NOT: memref.alloc
Loading