@@ -28,82 +28,6 @@ using namespace mlir;
28
28
29
29
namespace onnx_mlir {
30
30
31
- // Check the input, x, can be reused as the output buffer
32
- bool isBufferReusable (Value x, MemRefType outputType) {
33
- if (!x.hasOneUse ())
34
- return false ;
35
-
36
- Type xType = x.getType ();
37
- auto inputType = dyn_cast<ShapedType>(xType);
38
- if (!inputType)
39
- return false ;
40
- // Currently, only static shape could be reused.
41
- // ToFix: use DimAnalysis to handle dynamic shape.
42
- if (!hasStaticShape (inputType))
43
- return false ;
44
- if (!hasStaticShape (outputType))
45
- return false ;
46
-
47
- // Currently reuse requires that the shape has to be the same.
48
- // ToFix: If the shape is not the same, memref.cast can be used.
49
- if (getRank (inputType) != getRank (outputType))
50
- return false ;
51
- for (int64_t i = 0 ; i < getRank (inputType); i++) {
52
- if (inputType.getShape ()[i] != outputType.getShape ()[i])
53
- return false ;
54
- }
55
-
56
- // ToFix: The simd padding is not checked
57
- // We did not record whether the memref is padded or not.
58
- // The padding added to the memref the as an attribute, or not needed.
59
- return true ;
60
- }
61
-
62
- // Traverse the operands to find the candidate for buffer reuse.
63
- // Return -1, if no candidate is found.
64
- int whichBufferToReuse (ValueRange values, MemRefType outputType) {
65
- for (size_t i = 0 ; i < values.size (); i++) {
66
- if (isBufferReusable (values[i], outputType))
67
- return i;
68
- }
69
- return -1 ;
70
- }
71
-
72
- // Allocate memref (as before) if no input buffer can be reused.
73
- // Default VL=0 is used for non SIMD allocation
74
- Value allocOrReuse (MemRefBuilder &create, Operation *op,
75
- ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
76
- int64_t alignment, int64_t VL = 0 );
77
-
78
- Value allocOrReuse (MemRefBuilder &create, Operation *op,
79
- ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
80
- int64_t alignment, int64_t VL) {
81
-
82
- int indexToReuse = -1 ;
83
- // By default, enableKrnlBufferReuse is false. Simply allocate a memref.
84
- if (enableKrnlBufferReuse) {
85
- // Be aware to use the op->getOperands() to check the number of uses.
86
- // After buffer reuse, the number of uses of the transformed Value,
87
- // generatedOperands, will increase.
88
- indexToReuse = whichBufferToReuse (op->getOperands (), outputMemRefType);
89
- }
90
-
91
- if (indexToReuse != -1 ) {
92
- int size = getSizeInBytes (outputMemRefType);
93
- LLVM_DEBUG ({
94
- llvm::dbgs () << " malloc_size " << size << " \n " ;
95
- op->dump ();
96
- });
97
- return generatedOperands[indexToReuse];
98
- } else {
99
- if (VL == 0 )
100
- return create.alignedAlloc (outputMemRefType, dims, alignment);
101
- else
102
- return create.alignedAllocWithSimdPadding (
103
- outputMemRefType, dims, VL, alignment);
104
- }
105
- }
106
-
107
31
// =============================================================================
108
32
109
33
// / Emit post-processing for variadic element-wise ops.
@@ -1399,14 +1323,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
1399
1323
IndexExprScope allocScope (create.vec , shapeHelper->getScope ());
1400
1324
DimsExpr outputDims;
1401
1325
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims (), outputDims);
1402
- // Reuse the buffer from the input, or Alloc memory with padding for SIMD.
1326
+ // Alloc memory with padding for SIMD.
1403
1327
// For the moment, its ok to go here; if we truly have partial flattening of
1404
1328
// the simd code, then we only do it with static memref size that are
1405
1329
// multiples of VL * unrollVL, so there should be no padding anyway. This
1406
1330
// will change if we do partial flattening with non-multiple of VL *
1407
1331
// unrollVL.
1408
- Value alloc = allocOrReuse (
1409
- create. mem , op, operands, outputMemRefType, outputDims, alignment, VL );
1332
+ Value alloc = create. mem . alignedAllocWithSimdPadding (
1333
+ outputMemRefType, outputDims, VL, alignment );
1410
1334
// Create flat inputs in the last innerDinNum dims.
1411
1335
llvm::SmallVector<Value, 4 > flatOperands;
1412
1336
for (Value oper : operands) {
@@ -2051,9 +1975,8 @@ struct ONNXElementwiseUnaryOpLowering
2051
1975
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2052
1976
2053
1977
// Insert an allocation for the result of this operation.
2054
- Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2055
- shapeHelper.getOutputDims (), alignment);
2056
- ;
1978
+ Value alloc = create.mem .alignedAlloc (
1979
+ outputMemRefType, shapeHelper.getOutputDims (), alignment);
2057
1980
2058
1981
// Only create krnl.iterate if one of the operands is not scalar tensor.
2059
1982
if (!isScalar) {
@@ -2233,9 +2156,8 @@ struct ONNXElementwiseBinaryOpLowering
2233
2156
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2234
2157
2235
2158
// Insert an allocation and deallocation for the result of this operation.
2236
- Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2237
- shapeHelper.getOutputDims (), alignment);
2238
- ;
2159
+ Value alloc = create.mem .alignedAlloc (
2160
+ outputMemRefType, shapeHelper.getOutputDims (), alignment);
2239
2161
2240
2162
// Only create krnl.iterate if one of the operands is not scalar tensor.
2241
2163
if (!isScalar) {
@@ -2409,9 +2331,8 @@ struct ONNXElementwiseVariadicOpLowering
2409
2331
outputMemRefType = opFusionHelper.getOutputType (outputMemRefType);
2410
2332
2411
2333
// Insert an allocation and deallocation for the result of this operation.
2412
- Value alloc = allocOrReuse (create.mem , op, operands, outputMemRefType,
2413
- shapeHelper.getOutputDims (), alignment);
2414
- ;
2334
+ Value alloc = create.mem .alignedAlloc (
2335
+ outputMemRefType, shapeHelper.getOutputDims (), alignment);
2415
2336
2416
2337
// Only create krnl.iterate if one of the operands is not scalar tensor.
2417
2338
if (!isScalar) {
0 commit comments