@@ -28,6 +28,82 @@ using namespace mlir;
28
28
29
29
namespace onnx_mlir {
30
30
31
+ // Check the input, x, can be reused as the output buffer
32
+ bool isBufferReusable (Value x, MemRefType outputType) {
33
+ if (!x.hasOneUse ())
34
+ return false ;
35
+
36
+ Type xType = x.getType ();
37
+ auto inputType = dyn_cast<ShapedType>(xType);
38
+ if (!inputType)
39
+ return false ;
40
+ // Currently, only static shape could be reused.
41
+ // ToFix: use DimAnalysis to handle dynamic shape.
42
+ if (!hasStaticShape (inputType))
43
+ return false ;
44
+ if (!hasStaticShape (outputType))
45
+ return false ;
46
+
47
+ // Currently reuse requires that the shape has to be the same.
48
+ // ToFix: If the shape is not the same, memref.cast can be used.
49
+ if (getRank (inputType) != getRank (outputType))
50
+ return false ;
51
+ for (int64_t i = 0 ; i < getRank (inputType); i++) {
52
+ if (inputType.getShape ()[i] != outputType.getShape ()[i])
53
+ return false ;
54
+ }
55
+
56
+ // ToFix: The simd padding is not checked
57
+ // We did not record whether the memref is padded or not.
58
+ // The padding added to the memref the as an attribute, or not needed.
59
+ return true ;
60
+ }
61
+
62
+ // Traverse the operands to find the candidate for buffer reuse.
63
+ // Return -1, if no candidate is found.
64
+ int whichBufferToReuse (ValueRange values, MemRefType outputType) {
65
+ for (size_t i = 0 ; i < values.size (); i++) {
66
+ if (isBufferReusable (values[i], outputType))
67
+ return i;
68
+ }
69
+ return -1 ;
70
+ }
71
+
72
+ // Allocate memref (as before) if no input buffer can be reused.
73
+ // Default VL=0 is used for non SIMD allocation
74
+ Value allocOrReuse (MemRefBuilder &create, Operation *op,
75
+ ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76
+ int64_t alignment, int64_t VL = 0 );
77
+
78
+ Value allocOrReuse (MemRefBuilder &create, Operation *op,
79
+ ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80
+ int64_t alignment, int64_t VL) {
81
+
82
+ int indexToReuse = -1 ;
83
+ // By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84
+ if (enableKrnlBufferReuse) {
85
+ // Be aware to use the op->getOperands() to check the number of uses.
86
+ // After buffer reuse, the number of uses of the transformed Value,
87
+ // generatedOperands, will increase.
88
+ indexToReuse = whichBufferToReuse (op->getOperands (), outputMemRefType);
89
+ }
90
+
91
+ if (indexToReuse != -1 ) {
92
+ int size = getSizeInBytes (outputMemRefType);
93
+ LLVM_DEBUG ({
94
+ llvm::dbgs () << " malloc_size " << size << " \n " ;
95
+ op->dump ();
96
+ });
97
+ return generatedOperands[indexToReuse];
98
+ } else {
99
+ if (VL == 0 )
100
+ return create.alignedAlloc (outputMemRefType, dims, alignment);
101
+ else
102
+ return create.alignedAllocWithSimdPadding (
103
+ outputMemRefType, dims, VL, alignment);
104
+ }
105
+ }
106
+
31
107
// =============================================================================
32
108
33
109
// / Emit post-processing for variadic element-wise ops.
@@ -1323,14 +1399,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
1323
1399
IndexExprScope allocScope (create.vec , shapeHelper->getScope ());
1324
1400
DimsExpr outputDims;
1325
1401
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims (), outputDims);
1326
- // Alloc memory with padding for SIMD.
1402
+ // Reuse the buffer from the input, or Alloc memory with padding for SIMD.
1327
1403
// For the moment, its ok to go here; if we truly have partial flattening of
1328
1404
// the simd code, then we only do it with static memref size that are
1329
1405
// multiples of VL * unrollVL, so there should be no padding anyway. This
1330
1406
// will change if we do partial flattening with non-multiple of VL *
1331
1407
// unrollVL.
1332
- Value alloc = create. mem . alignedAllocWithSimdPadding (
1333
- outputMemRefType, outputDims, VL, alignment );
1408
+ Value alloc = allocOrReuse (
1409
+ create. mem , op, operands, outputMemRefType, outputDims, alignment, VL );
1334
1410
// Create flat inputs in the last innerDinNum dims.
1335
1411
llvm::SmallVector<Value, 4 > flatOperands;
1336
1412
for (Value oper : operands) {
@@ -1975,8 +2051,9 @@ struct ONNXElementwiseUnaryOpLowering
1975
2051
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
1976
2052
1977
2053
// Insert an allocation for the result of this operation.
1978
- Value alloc = create.mem .alignedAlloc (
1979
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2054
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2055
+ shapeHelper.getOutputDims (), alignment);
2056
+ ;
1980
2057
1981
2058
// Only create krnl.iterate if one of the operands is not scalar tensor.
1982
2059
if (!isScalar) {
@@ -2156,8 +2233,9 @@ struct ONNXElementwiseBinaryOpLowering
2156
2233
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2157
2234
2158
2235
// Insert an allocation and deallocation for the result of this operation.
2159
- Value alloc = create.mem .alignedAlloc (
2160
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2236
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2237
+ shapeHelper.getOutputDims (), alignment);
2238
+ ;
2161
2239
2162
2240
// Only create krnl.iterate if one of the operands is not scalar tensor.
2163
2241
if (!isScalar) {
@@ -2331,8 +2409,9 @@ struct ONNXElementwiseVariadicOpLowering
2331
2409
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2332
2410
2333
2411
// Insert an allocation and deallocation for the result of this operation.
2334
- Value alloc = create.mem .alignedAlloc (
2335
- outputMemRefType, shapeHelper.getOutputDims (), alignment);
2412
+ Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2413
+ shapeHelper.getOutputDims (), alignment);
2414
+ ;
2336
2415
2337
2416
// Only create krnl.iterate if one of the operands is not scalar tensor.
2338
2417
if (!isScalar) {
0 commit comments