Skip to content

Commit bd070ea

Browse files
Decompose Hardswish into simpler ONNX ops (#3107)
* Decompose and lower Hardswish Signed-off-by: Kumarappan <[email protected]> * Providing the decomposition as compile time option with krnl dialect lowering as default Signed-off-by: Kumarappan <[email protected]> --------- Signed-off-by: Kumarappan <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent 25b67e6 commit bd070ea

File tree

10 files changed

+248
-7
lines changed

10 files changed

+248
-7
lines changed

docs/SupportedONNXOps-cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 22. Limitatio
9292
| **HammingWindow** |none | | | |
9393
| **HannWindow** |none | | | |
9494
| **HardSigmoid** |6 - * | | |
95-
| **HardSwish** |none | | | |
95+
| **HardSwish** |14 - * | | | |
9696
| **Hardmax** |6 - * | | |
9797
| **Identity** |16 - * |Sequence identity not supported. Does not support int4 and uint4. | |
9898
| **If** |16 - * |Sequence and Optional outputs are not supported. Does not support int4 and uint4. | |

src/Compiler/CompilerOptions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ bool enableKrnlBufferReuse; // common for both
4848
bool enableSafeCodeGen; // common for both
4949
bool disableMemRefPrefetch; // common for both
5050
uint64_t compilationNumThreads; // common for both
51+
std::vector<std::string> decomposeOpsInONNX; // common for both
5152
EmissionTargetType emissionTarget; // onnx-mlir only
5253
bool invokeOnnxVersionConverter; // onnx-mlir only
5354
bool preserveLocations; // onnx-mlir only
@@ -264,6 +265,15 @@ static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
264265
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
265266
llvm::cl::cat(OnnxMlirCommonOptions));
266267

268+
static llvm::cl::list<std::string, std::vector<std::string>>
269+
decomposeOpsInONNXOpt("decompose-op-in-onnx",
270+
llvm::cl::desc("Specify ONNX operations to decompose.\n"
271+
"Supported Ops - HardSwish"),
272+
llvm::cl::value_desc("ONNX operation to decompose"),
273+
llvm::cl::location(decomposeOpsInONNX),
274+
llvm::cl::cat(OnnxMlirCommonOptions), llvm::cl::CommaSeparated,
275+
llvm::cl::ZeroOrMore);
276+
267277
static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
268278
llvm::cl::desc("Disable recomposition of ONNX operations."),
269279
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),

src/Compiler/CompilerOptions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ extern bool enableKrnlBufferReuse; // common for both
9494
extern bool enableSafeCodeGen; // common for both
9595
extern bool disableMemRefPrefetch; // common for both
9696
extern uint64_t compilationNumThreads; // common for both
97+
extern std::vector<std::string> decomposeOpsInONNX; // common for both
9798
extern EmissionTargetType emissionTarget; // onnx-mlir only
9899
extern bool invokeOnnxVersionConverter; // onnx-mlir only
99100
extern bool preserveLocations; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,51 @@ Value emitScalarOpFor<ONNXHardSigmoidOp>(ConversionPatternRewriter &rewriter,
646646
return clipHighest;
647647
}
648648

649+
//===----------------------------------------------------------------------===//
650+
// Scalar unary ops for lowering ONNXHardSwishOp
651+
//===----------------------------------------------------------------------===//
652+
template <>
653+
struct ScalarOp<ONNXHardSwishOp> {
654+
using FOp = CustomScalarOp;
655+
using IOp = NotSuportedScalarOp;
656+
};
657+
658+
template <>
659+
GenOpMix getGenOpMix<ONNXHardSwishOp>(Type t, Operation *op) {
660+
return {{GenericOps::ArithmeticGop, 3}, {GenericOps::MulGop, 2}};
661+
}
662+
663+
template <>
664+
Value emitScalarOpFor<ONNXHardSwishOp>(ConversionPatternRewriter &rewriter,
665+
Location loc, Operation *op, Type elementType,
666+
ArrayRef<Value> scalarOperands) {
667+
// HardSwish(x) = x * max(0, min(1, (x / 6) + 0.5))
668+
CheckIfCustomScalarOpIsSupported<ONNXHardSwishOp>(elementType);
669+
Value operand = scalarOperands[0];
670+
671+
// Define constants: alpha = 1/6, beta = 0.5
672+
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
673+
Value zero = create.math.constant(elementType, 0);
674+
Value one = create.math.constant(elementType, 1);
675+
Value alpha = create.math.constant(elementType, 1.0 / 6.0);
676+
Value beta = create.math.constant(elementType, 0.5);
677+
678+
// Compute (x / 6) + 0.5
679+
Value scaledX = create.math.mul(operand, alpha);
680+
Value shiftedX = create.math.add(scaledX, beta);
681+
682+
// Apply min(1, shiftedX)
683+
Value minOp = create.math.min(shiftedX, one);
684+
685+
// Apply max(0, minOp)
686+
Value maxOp = create.math.max(minOp, zero);
687+
688+
// Compute final HardSwish: x * max(0, min(1, (x / 6) + 0.5))
689+
Value result = create.math.mul(operand, maxOp);
690+
691+
return result;
692+
}
693+
649694
//===----------------------------------------------------------------------===//
650695
// Scalar unary ops for lowering ONNXEluOp
651696
//===----------------------------------------------------------------------===//
@@ -1714,12 +1759,12 @@ bool OpFusionHelper::checkFusibleOp(Operation *useOp, Operation *defOp,
17141759
mlir::ONNXEluOp, mlir::ONNXErfOp, mlir::ONNXAcosOp, mlir::ONNXAcoshOp,
17151760
mlir::ONNXAsinOp, mlir::ONNXAsinhOp, mlir::ONNXAtanhOp, mlir::ONNXExpOp,
17161761
mlir::ONNXFloorOp, mlir::ONNXGeluOp, mlir::ONNXHardSigmoidOp,
1717-
mlir::ONNXIsInfOp, mlir::ONNXIsNaNOp, mlir::ONNXLeakyReluOp,
1718-
mlir::ONNXLogOp, mlir::ONNXNegOp, mlir::ONNXNotOp, mlir::ONNXReciprocalOp,
1719-
mlir::ONNXReluOp, mlir::ONNXRoundOp, mlir::ONNXSeluOp,
1720-
mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp, mlir::ONNXSinhOp,
1721-
mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp, mlir::ONNXSqrtOp,
1722-
mlir::ONNXTanOp, mlir::ONNXTanhOp,
1762+
mlir::ONNXHardSwishOp, mlir::ONNXIsInfOp, mlir::ONNXIsNaNOp,
1763+
mlir::ONNXLeakyReluOp, mlir::ONNXLogOp, mlir::ONNXNegOp, mlir::ONNXNotOp,
1764+
mlir::ONNXReciprocalOp, mlir::ONNXReluOp, mlir::ONNXRoundOp,
1765+
mlir::ONNXSeluOp, mlir::ONNXSigmoidOp, mlir::ONNXSignOp, mlir::ONNXSinOp,
1766+
mlir::ONNXSinhOp, mlir::ONNXSoftplusOp, mlir::ONNXSoftsignOp,
1767+
mlir::ONNXSqrtOp, mlir::ONNXTanOp, mlir::ONNXTanhOp,
17231768
// Binary Op
17241769
mlir::ONNXEqualOp, mlir::ONNXGreaterOp, mlir::ONNXGreaterOrEqualOp,
17251770
mlir::ONNXLessOp, mlir::ONNXLessOrEqualOp, mlir::ONNXModOp,
@@ -2674,6 +2719,7 @@ void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
26742719
ONNXElementwiseBinaryOpLowering<mlir::ONNXGreaterOp>,
26752720
ONNXElementwiseBinaryOpLowering<mlir::ONNXGreaterOrEqualOp>,
26762721
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
2722+
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSwishOp>,
26772723
ONNXElementwiseUnaryOpLowering<mlir::ONNXIsInfOp>,
26782724
ONNXElementwiseUnaryOpLowering<mlir::ONNXIsNaNOp>,
26792725
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ Value OnnxBuilder::constantInt64(const ArrayRef<int64_t> intVals) const {
116116
return constant(denseAttr);
117117
}
118118

119+
Value OnnxBuilder::constantFloat32(const ArrayRef<float> floatVals) const {
120+
auto shape = RankedTensorType::get(
121+
{static_cast<int64_t>(floatVals.size())}, b().getF32Type());
122+
DenseElementsAttr denseAttr = DenseElementsAttr::get(shape, floatVals);
123+
return constant(denseAttr);
124+
}
125+
119126
Value OnnxBuilder::conv(Type Y, Value X, Value W, Value B, StringRef autoPad,
120127
ArrayRef<int64_t> dilations, int64_t group, ArrayRef<int64_t> kernelShape,
121128
ArrayRef<int64_t> pads, ArrayRef<int64_t> strides) const {

src/Dialect/ONNX/DialectBuilder.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct OnnxBuilder : DialectBuilder {
7070
// ONNXConstantOp
7171
mlir::Value constant(mlir::Attribute denseAttr) const;
7272
mlir::Value constantInt64(const mlir::ArrayRef<int64_t> intVals) const;
73+
mlir::Value constantFloat32(const mlir::ArrayRef<float> floatVals) const;
7374

7475
// ONNXConvOp
7576
mlir::Value conv(mlir::Type Y, mlir::Value X, mlir::Value W, mlir::Value B,

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/IR/PatternMatch.h"
2727
#include "mlir/Pass/Pass.h"
2828
#include "mlir/Transforms/DialectConversion.h"
29+
#include "src/Compiler/CompilerOptions.hpp"
2930
#include "llvm/Support/Debug.h"
3031

3132
#include "src/Dialect/ONNX/DialectBuilder.hpp"
@@ -1288,6 +1289,63 @@ class ReplaceCastLikeByCastPattern : public OpRewritePattern<ONNXCastLikeOp> {
12881289
}
12891290
};
12901291

1292+
// =============================================================================
1293+
// Decompose Hardswish to simpler ONNX ops
1294+
// =============================================================================
1295+
// DecomposeHardSwishPattern replaces ONNXHardSwishOp with its equivalent
1296+
// mathematical decomposition using basic ONNX operations:
1297+
//
1298+
// HardSwish(x) = x * max(0, min(1, (x / 6) + 0.5))
1299+
//
1300+
// This pass:
1301+
// - Multiplies input by `1/6`
1302+
// - Adds `0.5` to the scaled input
1303+
// - Clamps the result between `0` and `1` using Min and Max ops
1304+
// - Multiplies the clamped value with the original input
1305+
1306+
struct DecomposeHardSwishPattern : public OpRewritePattern<ONNXHardSwishOp> {
1307+
using OpRewritePattern<ONNXHardSwishOp>::OpRewritePattern;
1308+
1309+
LogicalResult matchAndRewrite(
1310+
ONNXHardSwishOp hardswishOp, PatternRewriter &rewriter) const final {
1311+
1312+
// Get location and element type
1313+
Location loc = hardswishOp.getLoc();
1314+
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
1315+
rewriter, loc);
1316+
1317+
Value alphaConst = create.onnx.constantFloat32(1.0f / 6.0f);
1318+
Value betaConst = create.onnx.constantFloat32(0.5f);
1319+
Value minConst = create.onnx.constantFloat32(1.0f);
1320+
Value maxConst = create.onnx.constantFloat32(0.0f);
1321+
1322+
// Multiply input by alpha
1323+
auto scaledInput =
1324+
rewriter.create<ONNXMulOp>(loc, hardswishOp.getOperand().getType(),
1325+
hardswishOp.getOperand(), alphaConst);
1326+
1327+
// Add beta to (input * alpha)
1328+
auto shiftedInput = rewriter.create<ONNXAddOp>(
1329+
loc, scaledInput.getType(), scaledInput, betaConst);
1330+
1331+
// Compute min(1.0, shiftedInput)
1332+
auto minOp = rewriter.create<ONNXMinOp>(
1333+
loc, shiftedInput.getType(), ValueRange({shiftedInput, minConst}));
1334+
1335+
// Compute max(0, min(1, shiftedInput))
1336+
auto maxOp = rewriter.create<ONNXMaxOp>(
1337+
loc, minOp.getType(), ValueRange({minOp, maxConst}));
1338+
1339+
// Compute final HardSwish: input * max(0, min(1, add(mul(x, alpha), beta)))
1340+
auto hardswishResult = rewriter.create<ONNXMulOp>(loc,
1341+
hardswishOp.getOperand().getType(), hardswishOp.getOperand(), maxOp);
1342+
1343+
// Replace the original HardSwishOp with the new computation
1344+
rewriter.replaceOp(hardswishOp, hardswishResult.getResult());
1345+
return success();
1346+
}
1347+
};
1348+
12911349
struct DecomposeONNXToONNXPass
12921350
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
12931351
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeONNXToONNXPass)
@@ -1364,6 +1422,13 @@ void DecomposeONNXToONNXPass::runOnOperation() {
13641422
target.addIllegalOp<ONNXUpsampleOp>();
13651423
target.addIllegalOp<ONNXUpsampleV7Op>();
13661424

1425+
if (!onnx_mlir::decomposeOpsInONNX.empty()) {
1426+
for (const auto &op : onnx_mlir::decomposeOpsInONNX) {
1427+
if (op == "HardSwish") {
1428+
target.addIllegalOp<ONNXHardSwishOp>();
1429+
}
1430+
}
1431+
}
13671432
target.addDynamicallyLegalOp<ONNXEinsumOp>([](ONNXEinsumOp op) {
13681433
return !onnx_mlir::DecomposeEinsumPattern::isDecomposable(op);
13691434
});
@@ -1439,6 +1504,14 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
14391504
patterns.insert<SoftmaxCrossEntropyPattern>(context);
14401505
patterns.insert<SumToAddPattern>(context);
14411506

1507+
if (!onnx_mlir::decomposeOpsInONNX.empty()) {
1508+
for (const auto &op : onnx_mlir::decomposeOpsInONNX) {
1509+
if (op == "HardSwish") {
1510+
patterns.insert<DecomposeHardSwishPattern>(context);
1511+
}
1512+
}
1513+
}
1514+
14421515
// TODO: consider whether to include SoftmaxPattern here
14431516
}
14441517

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,40 @@ func.func private @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
12471247

12481248
// -----
12491249

1250+
func.func private @test_hardswish(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
1251+
%0 = "onnx.HardSwish"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
1252+
"func.return"(%0) : (tensor<*xf32>) -> ()
1253+
1254+
// mlir2FileCheck.py
1255+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
1256+
// CHECK-LABEL: func.func private @test_hardswish
1257+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
1258+
// CHECK-DAG: [[CST_HALF_:%.+]] = arith.constant 5.000000e-01 : f32
1259+
// CHECK-DAG: [[CST_ONE_SIXTH_:%.+]] = arith.constant 0.166666672 : f32
1260+
// CHECK-DAG: [[CST_ONE_:%.+]] = arith.constant 1.000000e+00 : f32
1261+
// CHECK-DAG: [[CST_ZERO_:%.+]] = arith.constant 0.000000e+00 : f32
1262+
// CHECK-DAG: [[CST_IDX0_:%.+]] = arith.constant 0 : index
1263+
// CHECK: [[VAR_DIM_:%.+]] = memref.dim [[PARAM_0_]], [[CST_IDX0_]] : memref<?x10xf32>
1264+
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_DIM_]]) {{.*}}: memref<?x10xf32>
1265+
// CHECK-DAG: [[LOOPS_:%.+]]:2 = krnl.define_loops 2
1266+
// CHECK-DAG: [[DIM_:%.+]] = memref.dim [[PARAM_0_]], [[CST_IDX0_]] : memref<?x10xf32>
1267+
// CHECK: krnl.iterate([[LOOPS_]]#0, [[LOOPS_]]#1) with ([[LOOPS_]]#0 -> [[I0_:%.+]] = 0 to [[MAP_0_]]([[DIM_]]), [[LOOPS_]]#1 -> [[I1_:%.+]] = 0 to 10){
1268+
// CHECK: [[IVS_:%.+]]:2 = krnl.get_induction_var_value([[LOOPS_]]#0, [[LOOPS_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
1269+
// CHECK: [[LOAD_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[IVS_]]#0, [[IVS_]]#1] : memref<?x10xf32>
1270+
// CHECK: [[SCALE_:%.+]] = arith.mulf [[LOAD_]], [[CST_ONE_SIXTH_]] : f32
1271+
// CHECK: [[SHIFTED_:%.+]] = arith.addf [[SCALE_]], [[CST_HALF_]] : f32
1272+
// CHECK: [[CLAMPED1_:%.+]] = arith.minnumf [[SHIFTED_]], [[CST_ONE_]] : f32
1273+
// CHECK: [[CLAMPED2_:%.+]] = arith.maxnumf [[CLAMPED1_]], [[CST_ZERO_]] : f32
1274+
// CHECK: [[MUL_FINAL_:%.+]] = arith.mulf [[LOAD_]], [[CLAMPED2_]] : f32
1275+
// CHECK: krnl.store [[MUL_FINAL_]], [[RES_]]{{.}}[[IVS_]]#0, [[IVS_]]#1] : memref<?x10xf32>
1276+
// CHECK: }
1277+
// CHECK: return [[RES_]] : memref<?x10xf32>
1278+
// CHECK: }
1279+
1280+
}
1281+
1282+
// -----
1283+
12501284
func.func private @test_reciprocal(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
12511285
%0 = "onnx.Reciprocal"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
12521286
"func.return"(%0) : (tensor<*xf32>) -> ()

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,6 +1860,55 @@ func.func private @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
18601860
// -----
18611861

18621862

1863+
func.func private @test_hardswish(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
1864+
%0 = "onnx.HardSwish"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
1865+
"func.return"(%0) : (tensor<*xf32>) -> ()
1866+
1867+
// mlir2FileCheck.py
1868+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 128)>
1869+
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)>
1870+
// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)>
1871+
// CHECK-LABEL: func.func private @test_hardswish
1872+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
1873+
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<5.000000e-01> : vector<32xf32>
1874+
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.166666672> : vector<32xf32>
1875+
// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32>
1876+
// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
1877+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
1878+
// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
1879+
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
1880+
// CHECK: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<?xi8>
1881+
// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}{{.}}[[VAR_dim_]]{{.}} : memref<?xi8> to memref<?x10xf32>
1882+
// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
1883+
// CHECK-NOT: separator of consecutive DAGs
1884+
// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_2_]]{{.}}
1885+
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
1886+
// CHECK: affine.store [[VAR_1_]], [[RES_1_]][0] : memref<1xindex>
1887+
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<?x10xf32>, memref<1xindex>) -> memref<?xf32>
1888+
// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}}
1889+
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
1890+
// CHECK: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex>
1891+
// CHECK: [[VAR_reshape_5_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref<?x10xf32>, memref<1xindex>) -> memref<?xf32>
1892+
// CHECK: krnl.iterate() with (){
1893+
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
1894+
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
1895+
// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_3, [[VAR_2_]]{{.}}){
1896+
// CHECK: [[IV_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
1897+
// CHECK: [[VLOAD_:%.+]] = vector.load [[VAR_reshape_]]{{\[}}[[IV_]]] : memref<?xf32>, vector<32xf32>
1898+
// CHECK: [[MUL_1_:%.+]] = arith.mulf [[VLOAD_]], [[VAR_cst_0_]] : vector<32xf32>
1899+
// CHECK: [[ADD_:%.+]] = arith.addf [[MUL_1_]], [[VAR_cst_]] : vector<32xf32>
1900+
// CHECK: [[MIN_:%.+]] = arith.minnumf [[ADD_]], [[VAR_cst_1_]] : vector<32xf32>
1901+
// CHECK: [[MAX_:%.+]] = arith.maxnumf [[MIN_]], [[VAR_cst_2_]] : vector<32xf32>
1902+
// CHECK: [[MUL_2_:%.+]] = arith.mulf [[VLOAD_]], [[MAX_]] : vector<32xf32>
1903+
// CHECK: vector.store [[MUL_2_]], [[VAR_reshape_5_]]{{\[}}[[IV_]]] : memref<?xf32>, vector<32xf32>
1904+
// CHECK: }
1905+
// CHECK: }
1906+
// CHECK: return [[VAR_view_]] : memref<?x10xf32>
1907+
// CHECK: }
1908+
}
1909+
1910+
// -----
1911+
18631912
func.func private @test_reciprocal(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
18641913
%0 = "onnx.Reciprocal"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
18651914
"func.return"(%0) : (tensor<*xf32>) -> ()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: onnx-mlir-opt --decompose-onnx --decompose-op-in-onnx HardSwish %s | FileCheck %s
2+
func.func @test_hardswish(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
3+
%0 = "onnx.HardSwish"(%arg0) {onnx_node_name = "/hardswish/HardSwish"} :
4+
(tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
5+
onnx.Return %0 : tensor<?x?x?xf32>
6+
7+
// CHECK-LABEL: func @test_hardswish
8+
// CHECK-NOT: "onnx.HardSwish"
9+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
10+
// CHECK-NEXT: %[[C1:.*]] = onnx.Constant dense<0.166666672> : tensor<1xf32>
11+
// CHECK-NEXT: %[[C2:.*]] = onnx.Constant dense<5.000000e-01> : tensor<1xf32>
12+
// CHECK-NEXT: %[[C3:.*]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32>
13+
// CHECK-NEXT: %[[C4:.*]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32>
14+
// CHECK-NEXT: %[[MUL1:.*]] = "onnx.Mul"(%[[ARG0]], %[[C1]]) : (tensor<?x?x?xf32>, tensor<1xf32>) -> tensor<?x?x?xf32>
15+
// CHECK-NEXT: %[[ADD:.*]] = "onnx.Add"(%[[MUL1]], %[[C2]]) : (tensor<?x?x?xf32>, tensor<1xf32>) -> tensor<?x?x?xf32>
16+
// CHECK-NEXT: %[[MIN:.*]] = "onnx.Min"(%[[ADD]], %[[C3]]) : (tensor<?x?x?xf32>, tensor<1xf32>) -> tensor<?x?x?xf32>
17+
// CHECK-NEXT: %[[MAX:.*]] = "onnx.Max"(%[[MIN]], %[[C4]]) : (tensor<?x?x?xf32>, tensor<1xf32>) -> tensor<?x?x?xf32>
18+
// CHECK-NEXT: %[[MUL2:.*]] = "onnx.Mul"(%[[ARG0]], %[[MAX]]) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
19+
// CHECK-NEXT: onnx.Return %[[MUL2]] : tensor<?x?x?xf32>
20+
}

0 commit comments

Comments
 (0)