Skip to content

Commit 35f66a0

Browse files
authored
Support onnx.Softplus op on NNPA. (#2792)
* Support onnx.Softplus op on NNPA. --------- Signed-off-by: Yasushi Negishi <[email protected]>
1 parent b5ffe74 commit 35f66a0

File tree

11 files changed

+296
-9
lines changed

11 files changed

+296
-9
lines changed

docs/SupportedONNXOps-NNPA.md

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ NNPA has hardware limitations in dimension index size and tensor size, which are
3939
| **Relu** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |
4040
| **Sigmoid** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |
4141
| **Softmax** |6 - * |- `axis` must be the last dimension, i.e. `rank - 1` or -1. | |
42+
| **Softplus** |6 - * |The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. | |
4243
| **Sub** |6 - * |- Shape of input tensors should be the same since broadcasting is not supported.<br>- Input tensors must have static dimensions. | |
4344
| **Sum** |6 - * |- All inputs must have the same static shape (Broadcasting not supported.)<br>- Single input not supported. | |
4445
| **Tanh** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,18 @@ bool isSuitableForZDNN<ONNXReduceMeanV13Op>(
637637
return true;
638638
}
639639

640+
/// Check legality for ONNXSoftplus.
641+
template <>
642+
bool isSuitableForZDNN<ONNXSoftplusOp>(
643+
ONNXSoftplusOp op, const DimAnalysis *dimAnalysis) {
644+
// Check NNPA level.
645+
if (!isCompatibleWithNNPALevel(NNPA_Z16))
646+
return false;
647+
if (!isValidElementTypeAndRank(op.getX()))
648+
return false;
649+
return true;
650+
}
651+
640652
/// Check legality for ONNXLSTM.
641653
/// TODO: current ONNX-to-zhigh conversion does not support bi-direction
642654
template <>

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ void ZHighStickOp::getCanonicalizationPatterns(
137137
results.insert<StickUnstickSameLayoutRemovalPattern>(context);
138138
results.insert<StickUnstickDiffLayoutRemovalPattern>(context);
139139
results.insert<ReplaceONNXLeakyReluPattern>(context);
140+
results.insert<ReplaceONNXSoftplusPattern>(context);
140141
results.insert<ReplaceONNXReciprocalSqrtPattern>(context);
141142
results.insert<ReshapeTransposeReshape2DTo3DSPattern>(context);
142143
results.insert<ReshapeTransposeReshape3DSTo2DPattern>(context);

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td

+36
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,42 @@ def ReplaceONNXLeakyReluPattern: Pat<
8383
(SameLayout $X, $stickout)]
8484
>;
8585

86+
// The pattern
87+
// zhigh.Stick (onnx.Softplus (zhigh.Unstick (%X)))
88+
// can be replaced by
89+
// %minusOne = zhigh.Stick(GetConstantOfType<"-1.0">, %X)
90+
// %minusX = zhigh.Mul(%X, %minusOne)
91+
// zhigh.Add (
92+
// zhigh.Relu(%X),
93+
// zhigh.log(zhigh.Sub(zhigh.Exp(zhigh.Min(%X,%minusX)), %minusOne)))
94+
// References:
95+
// http://www.beam2d.net/blog/2014/03/02/softplus/ (Japanese)
96+
// https://www-beam2d-net.translate.goog/blog/2014/03/02/softplus/?_x_tr_sch=http&_x_tr_sl=ja&_x_tr_tl=en&_x_tr_hl=ja&_x_tr_pto=wapp (Translated English)
97+
// c.f.
98+
// -|x| is replaced by min(x, -x), since NNPA does not have the abs(x) function.
99+
// Constraints:
100+
// - %X should have static shape
101+
//
102+
def ReplaceONNXSoftplusPattern: Pattern<
103+
(ZHighStickOp:$stickout (ONNXSoftplusOp:$out (ZHighUnstickOp $X)), $layout),
104+
[
105+
// Get stickified constant of minus one with input shape
106+
(ZHighStickOp:$minusOne (GetConstantOfType<"-1.0"> $out), $layout),
107+
// Get minus X with input shape
108+
(ZHighMulOp:$minusX $X, $minusOne, (returnType $X)),
109+
110+
// Get Softplus
111+
(ZHighAddOp
112+
(ZHighReluOp $X, (returnType $X)),
113+
(ZHighLogOp (ZHighSubOp (ZHighExpOp (ZHighMinOp $X, $minusX,
114+
(returnType $X)),
115+
(returnType $X)),
116+
$minusOne, (returnType $X)),
117+
(returnType $X))),
118+
],
119+
[(IsStaticShapeTensor $X), (SameLayout $X, $stickout)]
120+
>;
121+
86122
// Calulation of `1/sqrt(X)` or reciprocal square root is often found in
87123
// deep learning models, but zDNN does not support it. Thus, we rewrite it into
88124
// zDNN-supported operations.

test/accelerators/NNPA/backend/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,13 @@ set(NNPA_TEST_LIST
332332
test_softmax_example_cpu,zdnn_softmax_ext,NO_DYNAMIC_SHAPE_TEST
333333
# test_softmax_large_number_cpu # accuracy error
334334

335+
# ==OP== Softplus
336+
# ==MIN== 1
337+
# ==LIM== The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation.
338+
# Softplus op in following test cases doesn't run on NNPA because single Softplus op is included. Softplus is tested not by backend tests but by the TestSoftplus numerical test
339+
# test_softplus_cpu,zdnn_log
340+
# test_softplus_example_cpu,zdnn_log
341+
335342
# ==OP== Sub
336343
# ==MIN== 6
337344
# ==LIM== - Shape of input tensors should be the same since broadcasting is not supported.<br>- Input tensors must have static dimensions.

test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir

+43-9
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,50 @@ func.func @replace_leakyrelu_2(%arg0 : tensor<1x104x128x104xf16, #zhigh.layout<{
137137

138138
// -----
139139

140-
// Do not replace onnx.LeakyRelu if alpha < 0
141-
func.func @donot_replace_leakyrelu(%arg0 : tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>> {
142-
%0 = "zhigh.Unstick"(%arg0) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf32>
143-
%1 = "onnx.LeakyRelu"(%0) {alpha = -1.000000e-01 : f32} : (tensor<1x104x128x104xf32>) -> tensor<1x104x128x104xf32>
144-
%2 = "zhigh.Stick"(%1) {layout = "NHWC"} : (tensor<1x104x128x104xf32>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
140+
// Replace onnx.Softplus
141+
func.func @replace_softplus_1(%arg0 : tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>> {
142+
%0 = "zhigh.Unstick"(%arg0) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf32>
143+
%1 = "onnx.Softplus"(%0) {alpha = 1.000000e-01 : f32} : (tensor<1x104x104x128xf32>) -> tensor<1x104x104x128xf32>
144+
%2 = "zhigh.Stick"(%1) {layout = "NHWC"} : (tensor<1x104x104x128xf32>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
145145
return %2 : tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
146-
// CHECK-LABEL: donot_replace_leakyrelu
147-
// CHECK: zhigh.Unstick
148-
// CHECK: onnx.LeakyRelu
149-
// CHECK: zhigh.Stick
146+
// mlir2FileCheck.py
147+
// CHECK-LABEL: func.func @replace_softplus_1
148+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>> {
149+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<-1.000000e+00> : tensor<1x104x104x128xf32>
150+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "NHWC"} : (tensor<1x104x104x128xf32>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
151+
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
152+
// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Relu"([[PARAM_0_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
153+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Min"([[PARAM_0_]], [[VAR_2_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
154+
// CHECK: [[VAR_5_:%.+]] = "zhigh.Exp"([[VAR_4_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
155+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Sub"([[VAR_5_]], [[VAR_1_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
156+
// CHECK: [[VAR_7_:%.+]] = "zhigh.Log"([[VAR_6_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
157+
// CHECK: [[VAR_8_:%.+]] = "zhigh.Add"([[VAR_3_]], [[VAR_7_]]) : (tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
158+
// CHECK: return [[VAR_8_]] : tensor<1x104x104x128xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
159+
// CHECK: }
160+
}
161+
162+
// -----
163+
164+
// Replace onnx.Softplus
165+
func.func @replace_softplus_2(%arg0 : tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>> {
166+
%0 = "zhigh.Unstick"(%arg0) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x104x128xf32>
167+
%1 = "onnx.Softplus"(%0) {alpha = 1.000000e-01 : f32} : (tensor<1x104x104x128xf32>) -> tensor<1x104x104x128xf32>
168+
%2 = "zhigh.Stick"(%1) {layout = "NHWC"} : (tensor<1x104x104x128xf32>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
169+
return %2 : tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
170+
// mlir2FileCheck.py
171+
// CHECK-LABEL: func.func @replace_softplus_2
172+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>> {
173+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<-1.000000e+00> : tensor<1x104x104x128xf32>
174+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "NHWC"} : (tensor<1x104x104x128xf32>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
175+
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
176+
// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Relu"([[PARAM_0_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
177+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Min"([[PARAM_0_]], [[VAR_2_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
178+
// CHECK: [[VAR_5_:%.+]] = "zhigh.Exp"([[VAR_4_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
179+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Sub"([[VAR_5_]], [[VAR_1_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
180+
// CHECK: [[VAR_7_:%.+]] = "zhigh.Log"([[VAR_6_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
181+
// CHECK: [[VAR_8_:%.+]] = "zhigh.Add"([[VAR_3_]], [[VAR_7_]]) : (tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
182+
// CHECK: return [[VAR_8_]] : tensor<1x104x128x104xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
183+
// CHECK: }
150184
}
151185

152186
// -----

test/modellib/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_onnx_mlir_library(ModelLib
1212
ModelLib.cpp
1313
RNNModel.cpp
1414
ScanModel.cpp
15+
SoftplusModel.cpp
1516

1617
EXCLUDE_FROM_OM_LIBS
1718

test/modellib/ModelLib.hpp

+19
Original file line numberDiff line numberDiff line change
@@ -510,5 +510,24 @@ class Elementwise2DLibBuilder : public ModelLibBuilder {
510510
const int inputNum;
511511
};
512512

513+
class SoftplusLibBuilder : public ModelLibBuilder {
514+
public:
515+
SoftplusLibBuilder(
516+
const std::string &modelName, const int N);
517+
bool build() final;
518+
bool prepareInputs() final;
519+
bool prepareInputs(float dataRangeLB, float dataRangeUB);
520+
bool prepareInputsFromEnv(const std::string envDataRange);
521+
bool verifyOutputs() final;
522+
523+
private:
524+
// Data that defines model.
525+
const int N;
526+
// Derived data that defines model.
527+
llvm::SmallVector<int64_t, 2> xShape, yShape;
528+
// model definition in std::string
529+
std::string moduleIR;
530+
};
531+
513532
} // namespace test
514533
} // namespace onnx_mlir

test/modellib/SoftplusModel.cpp

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//==========-- SoftplusModel.cpp - Building Softplus Models for tests -=====//
6+
//
7+
// Copyright 2022,2023 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file contains a function that builds a model consisting of onnx.Add,
12+
// onnx.Softplus and onnx.Sub ops, and compiles it to check if the second
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "mlir/IR/BuiltinOps.h"
17+
18+
#include "include/OnnxMlirRuntime.h"
19+
#include "src/Compiler/CompilerUtils.hpp"
20+
#include "src/Dialect/ONNX/ONNXOps.hpp"
21+
#include "src/Runtime/OMTensorHelper.hpp"
22+
#include "test/modellib/ModelLib.hpp"
23+
24+
using namespace mlir;
25+
26+
namespace onnx_mlir {
27+
namespace test {
28+
29+
// =============================================================================
30+
// Model consisting of onnx.Add, onnx.Softplus and onnx.Sub ops
31+
32+
SoftplusLibBuilder::SoftplusLibBuilder(
33+
const std::string &modelName, const int N)
34+
: ModelLibBuilder(modelName), N(N) {}
35+
36+
bool SoftplusLibBuilder::build() {
37+
llvm::SmallVector<int64_t, 1> xShape = {N};
38+
llvm::SmallVector<int64_t, 1> yShape = {N};
39+
auto xType = RankedTensorType::get(xShape, builder.getF32Type());
40+
auto yType = RankedTensorType::get(yShape, builder.getF32Type());
41+
42+
llvm::SmallVector<Type, 1> inputsType{xType};
43+
llvm::SmallVector<Type, 1> outputsType{yType};
44+
45+
func::FuncOp funcOp = createEmptyTestFunction(inputsType, outputsType);
46+
Block &entryBlock = funcOp.getBody().front();
47+
auto xVal = entryBlock.getArgument(0);
48+
49+
auto addOp = builder.create<ONNXAddOp>(loc,
50+
/*Y=*/yType, /*X=*/xVal, /*X=*/xVal);
51+
auto softPlusOp = builder.create<ONNXSoftplusOp>(loc,
52+
/*Y=*/yType, /*X=*/addOp);
53+
auto subOp = builder.create<ONNXSubOp>(loc,
54+
/*Y=*/yType, /*X=*/softPlusOp, /*X=*/xVal);
55+
56+
llvm::SmallVector<Value, 1> results = {subOp.getResult()};
57+
builder.create<func::ReturnOp>(loc, results);
58+
module.push_back(funcOp);
59+
60+
createEntryPoint(funcOp);
61+
return true;
62+
}
63+
64+
bool SoftplusLibBuilder::prepareInputs(float dataRangeLB, float dataRangeUB) {
65+
constexpr int num = 1;
66+
OMTensor* list[num];
67+
list[0] = omTensorCreateWithRandomData<float>({N}, dataRangeLB, dataRangeUB);
68+
inputs = omTensorListCreate(list, num);
69+
return inputs && list[0];
70+
}
71+
72+
bool SoftplusLibBuilder::prepareInputs() {
73+
return SoftplusLibBuilder::prepareInputs(
74+
-omDefaultRangeBound, omDefaultRangeBound);
75+
}
76+
77+
bool SoftplusLibBuilder::prepareInputsFromEnv(const std::string envDataRange) {
78+
std::vector<float> range = ModelLibBuilder::getDataRangeFromEnv(envDataRange);
79+
return range.size() == 2 ? prepareInputs(range[0], range[1])
80+
: prepareInputs();
81+
}
82+
83+
bool SoftplusLibBuilder::verifyOutputs() {
84+
// Get inputs and outputs.
85+
if (!inputs || !outputs)
86+
return false;
87+
OMTensor *x = omTensorListGetOmtByIndex(inputs, 0);
88+
OMTensor *res = omTensorListGetOmtByIndex(outputs, 0);
89+
OMTensor *ref = omTensorCreateWithShape<float>({N});
90+
if (!x || !res || !ref)
91+
return false;
92+
for (int64_t i = 0; i < N; ++i) {
93+
float val1 = omTensorGetElem<float>(x, {i}) * 2;
94+
float val2 = log(exp(val1) + 1.0);
95+
float val3 = val2 - omTensorGetElem<float>(x, {i});
96+
omTensorGetElem<float>(ref, {i}) = val3;
97+
}
98+
bool ok = areCloseFloat(res, ref);
99+
omTensorDestroy(ref);
100+
return ok;
101+
}
102+
103+
} // namespace test
104+
} // namespace onnx_mlir

test/numerical/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ add_numerical_test(TestScan
9292
TestElementwise.cpp
9393
LINK_LIBS PRIVATE ${TEST_LINK_LIBS}
9494
)
95+
96+
add_numerical_test(TestSoftplus
97+
TestSoftplus.cpp
98+
LINK_LIBS PRIVATE ${TEST_LINK_LIBS}
99+
)

test/numerical/TestSoftplus.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//====-- TestSoftplus.cpp - test GEMM code -======================================//
6+
//
7+
// Copyright 2022 The IBM Research Authors.
8+
//
9+
// =============================================================================
10+
//
11+
// This file contains the code to test Softplus code.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
// Common.hpp needs to be included first to correctly surpress the rapidcheck.h
16+
// warnings.
17+
#include "Common.hpp"
18+
19+
#include "src/Runtime/OMTensorHelper.hpp"
20+
21+
static const llvm::StringRef SHARED_LIB_BASE("./TestSoftplus_main_graph");
22+
23+
using namespace mlir;
24+
25+
namespace onnx_mlir {
26+
namespace test {
27+
28+
static bool isOMSoftplusTheSameAsNaiveImplFor(const int N) {
29+
static int testNum = 0;
30+
printf("attempt %d with N %d\n", ++testNum, N);
31+
32+
SoftplusLibBuilder softplus( SHARED_LIB_BASE.str(), N);
33+
return softplus.build() && softplus.compileAndLoad() &&
34+
softplus.checkInstructionFromEnv("TEST_INSTRUCTION") &&
35+
softplus.prepareInputsFromEnv("TEST_DATARANGE") && softplus.run() &&
36+
softplus.verifyOutputs();
37+
}
38+
39+
} // namespace test
40+
} // namespace onnx_mlir
41+
42+
int main(int argc, char *argv[]) {
43+
using namespace onnx_mlir;
44+
using namespace onnx_mlir::test;
45+
46+
llvm::FileRemover remover(
47+
onnx_mlir::getTargetFilename(SHARED_LIB_BASE.str(), onnx_mlir::EmitLib));
48+
49+
ModelLibBuilder::setRandomNumberGeneratorSeed("TEST_SEED");
50+
removeUnrelatedOptions({&OnnxMlirCommonOptions, &OnnxMlirOptions});
51+
llvm::cl::ParseCommandLineOptions(
52+
argc, argv, "TestSoftplus\n", nullptr, "TEST_ARGS");
53+
initCompilerConfig();
54+
std::string target = getCompilerOption(OptionKind::TargetAccel);
55+
std::cout << "Target options: \"" << target << "\"\n";
56+
if (true) {
57+
printf("RapidCheck test case generation.\n");
58+
bool success = rc::check("Softplus implementation correctness", [&]() {
59+
const int maxRange = 50;
60+
const int N = *rc::gen::inRange(1, maxRange);
61+
RC_ASSERT(isOMSoftplusTheSameAsNaiveImplFor(N));
62+
});
63+
if (!success)
64+
return 1;
65+
}
66+
return 0;
67+
}

0 commit comments

Comments
 (0)