Skip to content

Commit 0b4657b

Browse files
[NNPA] Add a compiler option for saturation stickify inputs (#2879)
Signed-off-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent c40f7ec commit 0b4657b

28 files changed

+326
-225
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,9 @@ llvm::cl::opt<NNPAPlacementHeuristic> nnpaPlacementHeuristic{
9191
"Much/Significantly FasterOps with stick/unstick cost")),
9292
llvm::cl::init(QualifyingOps), llvm::cl::cat(OnnxMlirOptions)};
9393

94+
llvm::cl::opt<bool> nnpaEnableSaturation("nnpa-saturation",
95+
llvm::cl::desc("Enable saturating f32 values before stickify them."
96+
"Default is false."),
97+
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
98+
9499
} // namespace onnx_mlir

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ extern llvm::cl::opt<NNPAPlacementHeuristic> nnpaPlacementHeuristic;
6666
extern llvm::cl::opt<bool> profileZHighIR;
6767
extern llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile;
6868
extern llvm::cl::opt<std::string> nnpaSaveDevicePlacementFile;
69+
extern llvm::cl::opt<bool> nnpaEnableSaturation;
6970

7071
} // namespace onnx_mlir

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp"
1717
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
1818
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
19+
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
1920
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
2021
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
2122
#include "src/Dialect/ONNX/DialectBuilder.hpp"

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td

+133-128
Large diffs are not rendered by default.

src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
2828
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
2929
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
30+
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
3031
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
3132
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
3233
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"

src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td

+4-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef OP_BASE
1818
include "src/Dialect/ONNX/ONNX.td"
1919
include "src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td"
20+
include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td"
2021
include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
2122
#endif // OP_BASE
2223

@@ -67,9 +68,9 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
6768
// Calculate BatchNorm Op using $A and $B
6869
(ZHighUnstickOp
6970
(ZHighBatchNormOp
70-
(ZHighStickOp $x, (NHWCLayoutAttr)),
71-
(ZHighStickOp $A, (_1DLayoutAttr)),
72-
(ZHighStickOp $B, (_1DLayoutAttr))))
71+
(ZHighStickOp $x, (NHWCLayoutAttr), (GetDefaultSaturation)),
72+
(ZHighStickOp $A, (_1DLayoutAttr), (GetDefaultSaturation)),
73+
(ZHighStickOp $B, (_1DLayoutAttr), (GetDefaultSaturation))))
7374
]
7475
>;
7576

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td

+11-11
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
3838
// (ZHighStickOp %Y))
3939
//===----------------------------------------------------------------------===//
4040
def replaceZHighAddPattern1 : Pat<
41-
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)),
41+
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
4242
(ONNXAddOp $x, (ZHighUnstickOp $y)),
4343
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
4444
>;
4545

4646
def replaceZHighAddPattern2 : Pat<
47-
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))),
47+
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
4848
(ONNXAddOp (ZHighUnstickOp $x), $y),
4949
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
5050
>;
@@ -54,14 +54,14 @@ def replaceZHighAddPattern2 : Pat<
5454
// (ZHighStickOp %Y))
5555
//===----------------------------------------------------------------------===//
5656
def replaceZHighMulPattern1 : Pat<
57-
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)),
57+
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
5858
(ONNXMulOp $x, (ZHighUnstickOp $y)),
5959
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
6060
(addBenefit 1)
6161
>;
6262

6363
def replaceZHighMulPattern2 : Pat<
64-
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))),
64+
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
6565
(ONNXMulOp (ZHighUnstickOp $x), $y),
6666
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
6767
(addBenefit 0)
@@ -72,14 +72,14 @@ def replaceZHighMulPattern2 : Pat<
7272
// (ZHighStickOp %Y))
7373
//===----------------------------------------------------------------------===//
7474
def replaceZHighSubPattern1 : Pat<
75-
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)),
75+
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
7676
(ONNXSubOp $x, (ZHighUnstickOp $y)),
7777
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
7878
(addBenefit 1)
7979
>;
8080

8181
def replaceZHighSubPattern2 : Pat<
82-
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))),
82+
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
8383
(ONNXSubOp (ZHighUnstickOp $x), $y),
8484
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
8585
(addBenefit 0)
@@ -109,14 +109,14 @@ def replaceZHighSubPattern2 : Pat<
109109
// (ZHighStickOp %Y))
110110
//===----------------------------------------------------------------------===//
111111
def replaceZHighMinPattern1 : Pat<
112-
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)),
112+
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
113113
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
114114
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
115115
(addBenefit 1)
116116
>;
117117

118118
def replaceZHighMinPattern2 : Pat<
119-
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))),
119+
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
120120
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
121121
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
122122
(addBenefit 0)
@@ -127,14 +127,14 @@ def replaceZHighMinPattern2 : Pat<
127127
// (ZHighStickOp %Y))
128128
//===----------------------------------------------------------------------===//
129129
def replaceZHighMaxPattern1 : Pat<
130-
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)),
130+
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
131131
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
132132
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
133133
(addBenefit 1)
134134
>;
135135

136136
def replaceZHighMaxPattern2 : Pat<
137-
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))),
137+
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
138138
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
139139
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
140140
(addBenefit 0)
@@ -144,7 +144,7 @@ def replaceZHighMaxPattern2 : Pat<
144144
// ONNXReluOp %X = ZHighUnstickOp (ZHighReluOp (ZHighStickOp %X))
145145
//===----------------------------------------------------------------------===//
146146
def replaceZHighReluPattern : Pat<
147-
(ZHighUnstickOp (ZHighReluOp (ZHighStickOp:$s_x $x, $_))),
147+
(ZHighUnstickOp (ZHighReluOp (ZHighStickOp:$s_x $x, $_, $_))),
148148
(ONNXReluOp $x),
149149
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
150150
>;

src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,12 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern {
500500
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
501501
ConversionPatternRewriter &rewriter) const final {
502502
Location loc = op->getLoc();
503+
ZHighStickOp stickOp = cast<ZHighStickOp>(op);
503504

504505
ZHighStickOpAdaptor operandAdaptor(operands);
505506
Value input = operandAdaptor.getIn();
506-
StringAttr layout = cast<ZHighStickOp>(op).getLayoutAttr();
507+
StringAttr layout = stickOp.getLayoutAttr();
508+
IntegerAttr saturation = stickOp.getSaturationAttr();
507509

508510
IndexExprBuilderForKrnl createKrnlIE(rewriter, loc);
509511
ZHighStickOpShapeHelper shapeHelper(op, operands, &createKrnlIE);
@@ -521,7 +523,7 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern {
521523
layout = getNCHWLayoutAttr(rewriter);
522524

523525
// Else, emit a ZLow operation.
524-
rewriter.create<ZLowStickOp>(loc, input, alloc, layout);
526+
rewriter.create<ZLowStickOp>(loc, input, alloc, layout, saturation);
525527
rewriter.replaceOp(op, alloc);
526528
return success();
527529
}

src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ add_onnx_mlir_library(OMZHighOps
4747
OMONNXOps # Use ONNXShapeHelper
4848
OMLayoutHelper
4949
OMShapeHelperOpInterface
50+
OMNNPACompilerOptions
5051
MLIRIR
5152

5253
ACCEL_INCLUDE_DIRS PRIVATE

src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td

+22-4
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,26 @@ def ZHighStickOp:ZHigh_Op<"Stick", [Pure,
142142
let summary = "ZHigh Stick operation";
143143
let description = [{
144144
ZHigh operation to perform a Stick."
145+
145146
If `layout`=`NHWC`, input must be in `NCHW` and output will be in `NHWC`.
147+
148+
Optional `saturation` indicates whether the CPU tensor is saturated before stickification
149+
or not. If it is saturated, the dlfloat16 range would be used.
150+
Saturation if off if `saturation == 0` or it is not given. Otherwise, it is on.
146151
}];
147152
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, NoneType]>:$In,
148-
OptionalAttr<StrAttr>:$layout);
153+
OptionalAttr<StrAttr>:$layout,
154+
OptionalAttr<SI64Attr>:$saturation);
149155
let results = (outs AnyTypeOf<[ZTensor_1D, ZTensor_2D, ZTensor_3D, ZTensor_4D,
150156
ZTensor_2DS, ZTensor_3DS, ZTensor_4DS,
151157
ZTensor_NHWC, ZTensor_NCHW, ZTensor_HWCK,
152158
NoneType]>:$Out);
153159
let builders = [
154-
OpBuilder<(ins "::mlir::Value":$In, "::mlir::StringAttr":$layout)>
160+
OpBuilder<(ins "::mlir::Value":$In, "::mlir::StringAttr":$layout), [{
161+
build($_builder, $_state, In, layout, IntegerAttr());
162+
}]>,
163+
OpBuilder<(ins "::mlir::Value":$In, "::mlir::StringAttr":$layout,
164+
"::mlir::IntegerAttr":$saturation)>
155165
];
156166
let hasCanonicalizer = 1;
157167
let extraClassDefinition = [{
@@ -195,11 +205,19 @@ def ZHighF32ToDLF16Op:ZHigh_Op<"F32ToDLF16", [Pure,
195205
let summary = "ZHigh F32ToDLF16 operation";
196206
let description = [{
197207
ZHigh operation to convert a tensor of f32 to a tensor of dlfloat16.
208+
209+
Optional `saturation` indicates whether the CPU tensor is saturated before stickification
210+
or not. If it is saturated, the dlfloat16 range would be used.
211+
Saturation if off if `saturation == 0` or it is not given. Otherwise, it is on.
198212
}];
199-
let arguments = (ins TensorOf<[F32]>: $In);
213+
let arguments = (ins TensorOf<[F32]>: $In,
214+
OptionalAttr<SI64Attr>:$saturation);
200215
let results = (outs TensorOf<[F16]>:$Out);
201216
let builders = [
202-
OpBuilder<(ins "::mlir::Value":$In)>,
217+
OpBuilder<(ins "::mlir::Value":$In), [{
218+
build($_builder, $_state, In, IntegerAttr());
219+
}]>,
220+
OpBuilder<(ins "::mlir::Value":$In, "::mlir::IntegerAttr":$saturation)>
203221
];
204222
let hasCanonicalizer = 1;
205223
let extraClassDefinition = [{

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/ZHighDLF16ToF32.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def GetTypeInDLF16: NativeCodeCall<
4040

4141
// zhigh.DLF16ToF32 (zhigh.F32ToDLF16(%X)) = %X
4242
def ConversionRemovalPattern : Pat<
43-
(ZHighDLF16ToF32Op (ZHighF32ToDLF16Op $arg)),
43+
(ZHighDLF16ToF32Op (ZHighF32ToDLF16Op $arg, $saturation)),
4444
(replaceWithValue $arg)
4545
>;
4646

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/F32ToDLF16.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ namespace zhigh {
2828
// Custom builders
2929
//===----------------------------------------------------------------------===//
3030

31-
void ZHighF32ToDLF16Op::build(
32-
OpBuilder &builder, OperationState &state, Value input) {
31+
void ZHighF32ToDLF16Op::build(OpBuilder &builder, OperationState &state,
32+
Value input, IntegerAttr saturation) {
3333
Type elementType = builder.getF16Type();
3434
Type resType = UnrankedTensorType::get(elementType);
3535

3636
if (auto inType = dyn_cast<RankedTensorType>(input.getType()))
3737
resType = RankedTensorType::get(inType.getShape(), elementType);
3838

39-
build(builder, state, resType, input);
39+
build(builder, state, resType, input, saturation);
4040
}
4141

4242
//===----------------------------------------------------------------------===//

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/ZHighF32ToDLF16.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td"
3636

3737
// zhigh.F32ToDLF16 (zhigh.DLF16ToF32(%X)) = %X
3838
def ConversionRemovalPattern : Pat<
39-
(ZHighF32ToDLF16Op (ZHighDLF16ToF32Op $arg)),
39+
(ZHighF32ToDLF16Op (ZHighDLF16ToF32Op $arg), $saturation),
4040
(replaceWithValue $arg)
4141
>;
4242

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
14+
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
1415
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
1516
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
1617

@@ -472,5 +473,14 @@ bool hasNNPAUse(Value v) {
472473
});
473474
}
474475

476+
/// Get default saturation setting.
477+
IntegerAttr getDefaultSaturation(PatternRewriter &rewriter) {
478+
Type si64Ty = rewriter.getIntegerType(64, true);
479+
if (nnpaEnableSaturation)
480+
return rewriter.getIntegerAttr(si64Ty, -1);
481+
else
482+
return IntegerAttr();
483+
}
484+
475485
} // namespace zhigh
476486
} // namespace onnx_mlir

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,8 @@ mlir::IntegerAttr getAxisNHWC(mlir::IntegerAttr axisNCHWAttr);
8484
/// Check if the value has NNPA users (or is consumed by an NNPA op).
8585
bool hasNNPAUse(mlir::Value v);
8686

87+
/// Get saturation settings.
88+
mlir::IntegerAttr getDefaultSaturation(mlir::PatternRewriter &rewriter);
89+
8790
} // namespace zhigh
8891
} // namespace onnx_mlir

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td

+6
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,10 @@ def GetAxisNHWC : NativeCodeCall<
195195
"::onnx_mlir::zhigh::getAxisNHWC($0)"
196196
>;
197197

198+
def NoneIntegerAttr: NativeCodeCall<"IntegerAttr()">;
199+
200+
def GetDefaultSaturation : NativeCodeCall<
201+
"::onnx_mlir::zhigh::getDefaultSaturation($_builder)"
202+
>;
203+
198204
#endif // OP_HELPER

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ namespace zhigh {
2828
// Custom builders
2929
//===----------------------------------------------------------------------===//
3030

31-
void ZHighStickOp::build(
32-
OpBuilder &builder, OperationState &state, Value input, StringAttr layout) {
31+
void ZHighStickOp::build(OpBuilder &builder, OperationState &state, Value input,
32+
StringAttr layout, IntegerAttr saturation) {
3333
Type resType = builder.getNoneType();
3434
Type resElementType = builder.getF16Type();
3535
if (!mlir::isa<NoneType>(input.getType())) {
@@ -63,7 +63,7 @@ void ZHighStickOp::build(
6363
resType = UnrankedTensorType::get(resElementType);
6464
}
6565
}
66-
build(builder, state, resType, input, layout);
66+
build(builder, state, resType, input, layout, saturation);
6767
}
6868

6969
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)