Skip to content

Commit 168d50d

Browse files
updated tests
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 3a62034 commit 168d50d

6 files changed

+512
-475
lines changed

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -71,48 +71,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
7171
DimsExpr outputAF;
7272
outputAF.emplace_back(zero);
7373

74-
#if 1
75-
// Allocate output buffers.
76-
MemRefType inputBufferType =
77-
MemRefType::get({totVL}, inputType.getElementType());
78-
Value inputBuffer = create.mem.alignedAlloc(inputBufferType);
79-
MemRefType outputBufferType = MemRefType::get({totVL}, quantizedElementType);
80-
VectorType outputVectorType = VectorType::get({totVL}, quantizedElementType);
81-
Value outputBuffer = create.mem.alignedAlloc(outputBufferType);
82-
83-
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
84-
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
85-
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
86-
MultiDialectBuilder<KrnlBuilder, MathBuilder, VectorBuilder> create(kb);
87-
Value x = inputVals[0];
88-
// Scale
89-
Value scaleX = create.math.div(x, scale);
90-
// Round
91-
Value roundX = create.math.round(scaleX);
92-
// Adjust
93-
Value adjustX;
94-
if (hasZeroPoint)
95-
adjustX = create.math.add(roundX, zeroPoint);
96-
else
97-
adjustX = roundX;
98-
// Saturate: use max into a min.
99-
Value saturateX = create.math.clip(adjustX, qMin, qMax);
100-
if (VL == 1)
101-
return create.math.cast(quantizedElementType, saturateX);
102-
// Has VL values; first save all VL into buffer.
103-
create.vec.storeIE(saturateX, inputBuffer, {zero});
104-
// Now process each value in turn
105-
for (int64_t v = 0; v < VL; ++v) {
106-
IndexExpr vv = LitIE(v);
107-
Value scalarSaturateX = create.krnl.loadIE(inputBuffer, {vv});
108-
Value scalarRes =
109-
create.math.cast(quantizedElementType, scalarSaturateX);
110-
create.krnl.storeIE(scalarRes, outputBuffer, {vv});
111-
}
112-
// Reload the output buffer as one vector.
113-
return create.vec.loadIE(outputVectorType, outputBuffer, {zero});
114-
}});
115-
#else
11674
// faster than original loop on z16, takes 124us for 64k vals
11775
// Allocate output buffers.
11876
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
@@ -146,20 +104,19 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
146104
// compiler's attempt to generate SIMD conversion code. This might not hold
147105
// with all data types, but is definitely noticeable with uint8.
148106
//
149-
// Todo: we might save the vector to a buffer on the fly (avoiding a second
150-
// loop as below), and then reload each value as scalar and then saved them as
151-
// scalar (thus avoiding the insert/extract SIMD operations that also do not
152-
// perform well). The problem is that the current SIMD scheme expect a return
153-
// value, either SIMD in SIMD mode or scalar in scalar mode. Thus that
154-
// alternative scheme is not easy to pull off here.
107+
// Investigate further: we might save the vector to a buffer on the fly
108+
// (avoiding a second loop as below), and then reload each value as scalar and
109+
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
110+
// that also do not perform well). We can have a SIMD buffer in memory for the
111+
// non-quantized and quantized simd values, but then we also need to privatize
112+
// it, which is also not easy in this scheme. So ignore this for now.
155113
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
156114
[&](KrnlBuilder &kb, ValueRange loopInd) {
157115
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
158116
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
159117
Value res = create.math.cast(quantizedElementType, buffVal);
160118
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
161119
});
162-
#endif
163120

164121
if (totVL > 1)
165122
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,

test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
3131
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
3232
// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref<?x2xf32>
3333
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){
34-
// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
35-
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref<?x2xf32>
34+
// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
35+
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref<?x2xf32>
3636
// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref<f32>
37-
// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
38-
// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref<f32>
37+
// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
38+
// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref<f32>
3939
// CHECK: }
4040
// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref<f32>
4141
// CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref<f32>
4242
// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2
4343
// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref<?x2xf32>
4444
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){
45-
// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
46-
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref<?x2xf32>
45+
// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
46+
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref<?x2xf32>
4747
// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref<f32>
48-
// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32
49-
// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref<f32>
48+
// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32
49+
// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref<f32>
5050
// CHECK: }
5151
// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref<f32>
5252
// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref<f32>
@@ -87,33 +87,40 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
8787
// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
8888
// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex>
8989
// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref<?x2xui8>, memref<1xindex>) -> memref<?xui8>
90+
// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref<?xf32>
9091
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
9192
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
92-
// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
93-
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref<?xf32>
93+
// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
94+
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
9495
// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32
95-
// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32
96-
// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32
97-
// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32
98-
// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32
96+
// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32
97+
// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32
98+
// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32
99+
// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32
99100
// CHECK-NOT: separator of consecutive DAGs
100-
// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32
101-
// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32
102-
// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32
103-
// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32
104-
// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32
105-
// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32
106-
// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32
101+
// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32
102+
// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32
103+
// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32
104+
// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32
105+
// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32
106+
// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32
107+
// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32
107108
// CHECK-NOT: separator of consecutive DAGs
108-
// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32
109-
// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32
110-
// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32
111-
// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32
112-
// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32
113-
// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32
114-
// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i8
115-
// CHECK: [[VAR_52_:%.+]] = builtin.unrealized_conversion_cast [[VAR_51_]] : i8 to ui8
116-
// CHECK: krnl.store [[VAR_52_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref<?xui8>
109+
// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32
110+
// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32
111+
// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32
112+
// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32
113+
// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32
114+
// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32
115+
// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
116+
// CHECK: }
117+
// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1
118+
// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
119+
// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index
120+
// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xf32>
121+
// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8
122+
// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8
123+
// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xui8>
117124
// CHECK: }
118125
// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref<?x2xui8>, memref<f32>, memref<ui8>
119126
// CHECK: }

0 commit comments

Comments
 (0)