@@ -28,6 +28,82 @@ using namespace mlir;
28
28
29
29
namespace onnx_mlir {
30
30
31
+ // Check the input, x, can be reused as the output buffer
32
+ bool isBufferReusable (Value x, MemRefType outputType) {
33
+ if (!x.hasOneUse ())
34
+ return false ;
35
+
36
+ Type xType = x.getType ();
37
+ auto inputType = dyn_cast<ShapedType>(xType);
38
+ if (!inputType)
39
+ return false ;
40
+ // Currently, only static shape could be reused.
41
+ // ToFix: use DimAnalysis to handle dynamic shape.
42
+ if (!hasStaticShape (inputType))
43
+ return false ;
44
+ if (!hasStaticShape (outputType))
45
+ return false ;
46
+
47
+ // Currently reuse requires that the shape has to be the same.
48
+ // ToFix: If the shape is not the same, memref.cast can be used.
49
+ if (getRank (inputType) != getRank (outputType))
50
+ return false ;
51
+ for (int64_t i = 0 ; i < getRank (inputType); i++) {
52
+ if (inputType.getShape ()[i] != outputType.getShape ()[i])
53
+ return false ;
54
+ }
55
+
56
+ // ToFix: The simd padding is not checked
57
+ // We did not record whether the memref is padded or not.
58
+ // The padding added to the memref the as an attribute, or not needed.
59
+ return true ;
60
+ }
61
+
62
+ // Traverse the operands to find the candidate for buffer reuse.
63
+ // Return -1, if no candidate is found.
64
+ int whichBufferToReuse (ValueRange values, MemRefType outputType) {
65
+ for (size_t i = 0 ; i < values.size (); i++) {
66
+ if (isBufferReusable (values[i], outputType))
67
+ return i;
68
+ }
69
+ return -1 ;
70
+ }
71
+
72
+ // Allocate memref (as before) if no input buffer can be reused.
73
+ // Default VL=0 is used for non SIMD allocation
74
+ Value allocOrReuse (MemRefBuilder &create, Operation *op,
75
+ ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76
+ int64_t alignment, int64_t VL = 0 );
77
+
78
+ Value allocOrReuse (MemRefBuilder &create, Operation *op,
79
+ ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80
+ int64_t alignment, int64_t VL) {
81
+
82
+ int indexToReuse = -1 ;
83
+ // By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84
+ if (enableKrnlBufferReuse) {
85
+ // Be aware to use the op->getOperands() to check the number of uses.
86
+ // After buffer reuse, the number of uses of the transformed Value,
87
+ // generatedOperands, will increase.
88
+ indexToReuse = whichBufferToReuse (op->getOperands (), outputMemRefType);
89
+ }
90
+
91
+ if (indexToReuse != -1 ) {
92
+ int size = getSizeInBytes (outputMemRefType);
93
+ LLVM_DEBUG ({
94
+ llvm::dbgs () << " malloc_size " << size << " \n " ;
95
+ op->dump ();
96
+ });
97
+ return generatedOperands[indexToReuse];
98
+ } else {
99
+ if (VL == 0 )
100
+ return create.alignedAlloc (outputMemRefType, dims, alignment);
101
+ else
102
+ return create.alignedAllocWithSimdPadding (
103
+ outputMemRefType, dims, VL, alignment);
104
+ }
105
+ }
106
+
31
107
// =============================================================================
32
108
33
109
// / Emit post-processing for variadic element-wise ops.
@@ -1329,14 +1405,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
1329
1405
IndexExprScope allocScope (create.vec , shapeHelper->getScope ());
1330
1406
DimsExpr outputDims;
1331
1407
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims (), outputDims);
1332
- // Alloc memory with padding for SIMD.
1408
+ // Reuse the buffer from the input, or Alloc memory with padding for SIMD.
1333
1409
// For the moment, its ok to go here; if we truly have partial flattening of
1334
1410
// the simd code, then we only do it with static memref size that are
1335
1411
// multiples of VL * unrollVL, so there should be no padding anyway. This
1336
1412
// will change if we do partial flattening with non-multiple of VL *
1337
1413
// unrollVL.
1338
- Value alloc = create. mem . alignedAllocWithSimdPadding (
1339
- outputMemRefType, outputDims, VL, alignment );
1414
+ Value alloc = allocOrReuse (
1415
+ create. mem , op, operands, outputMemRefType, outputDims, alignment, VL );
1340
1416
// Create flat inputs in the last innerDinNum dims.
1341
1417
llvm::SmallVector<Value, 4 > flatOperands;
1342
1418
for (Value oper : operands) {
@@ -1981,8 +2057,9 @@ struct ONNXElementwiseUnaryOpLowering
1981
2057
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
1982
2058
1983
2059
// Insert an allocation for the result of this operation.
1984
- Value alloc = create.mem .alignedAlloc (
1985
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2060
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2061
+ shapeHelper.getOutputDims (), alignment);
2062
+ ;
1986
2063
1987
2064
// Only create krnl.iterate if one of the operands is not scalar tensor.
1988
2065
if (!isScalar) {
@@ -2162,8 +2239,9 @@ struct ONNXElementwiseBinaryOpLowering
2162
2239
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2163
2240
2164
2241
// Insert an allocation and deallocation for the result of this operation.
2165
- Value alloc = create.mem .alignedAlloc (
2166
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2242
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2243
+ shapeHelper.getOutputDims (), alignment);
2244
+ ;
2167
2245
2168
2246
// Only create krnl.iterate if one of the operands is not scalar tensor.
2169
2247
if (!isScalar) {
@@ -2337,8 +2415,9 @@ struct ONNXElementwiseVariadicOpLowering
2337
2415
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2338
2416
2339
2417
// Insert an allocation and deallocation for the result of this operation.
2340
- Value alloc = create.mem .alignedAlloc (
2341
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2418
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2419
+ shapeHelper.getOutputDims (), alignment);
2420
+ ;
2342
2421
2343
2422
// Only create krnl.iterate if one of the operands is not scalar tensor.
2344
2423
if (!isScalar) {
0 commit comments