Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b6972bb

Browse files
ElaineBaopengzhao-intel
authored andcommitted
add int8 bn mkldnn implementation and test (#15664)
* add int8 bn mkldnn implementation and test * fix lint * fix ci * enable int8 bn test only in mkldnn backend * disable int8 bn forward test with gpu backend * update int8 bn with reference to comments * fix lint * disable int8 bn gluon forward test with gpu backend * disable uint8 bn forward test with mkldnn backend * restore support mkldnn bn condition * rm duplicate code
1 parent 79d8d86 commit b6972bb

File tree

7 files changed

+446
-28
lines changed

7 files changed

+446
-28
lines changed

cpp-package/scripts/OpWrapperGenerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class Arg:
9393
'int (non-negative)': 'uint32_t',\
9494
'long (non-negative)': 'uint64_t',\
9595
'int or None':'dmlc::optional<int>',\
96+
'float or None':'dmlc::optional<float>',\
9697
'long':'int64_t',\
9798
'double':'double',\
9899
'double or None':'dmlc::optional<double>',\

src/operator/nn/batch_norm-inl.h

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,19 @@ enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data
5353
enum BatchNormOpResource {kTempSpace};
5454
enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states
5555

56-
/*! \brief Default channel axis if none specified int he params */
56+
/*! \brief Default channel axis if none specified in the params */
5757
constexpr int DEFAULT_AXIS = 1;
5858
} // namespace batchnorm
5959

6060
/*! \brief Parameters for BatchNorm operator */
61+
namespace quantized_batchnorm {
62+
enum QuantizedBatchNormOpInputs {kData, kGamma, kBeta, kInMovingMean,
63+
kInMovingVar, kDataMin, kDataMax};
64+
enum QuantizedBatchNormOutputs {kOut, kOutMin, kOutMax};
65+
enum QuantizedBatchNormOpAuxiliary {kMovingMean, kMovingVar};
66+
} // quantized_batchnorm
67+
68+
/*! \brief Parameters for BatchNoram operator */
6169
struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
6270
double eps;
6371
float momentum;
@@ -66,6 +74,10 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
6674
bool output_mean_var;
6775
int axis;
6876
bool cudnn_off;
77+
78+
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
79+
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
80+
6981
DMLC_DECLARE_PARAMETER(BatchNormParam) {
7082
DMLC_DECLARE_FIELD(eps).set_default(1e-3f)
7183
.describe("Epsilon to prevent div 0. "
@@ -81,19 +93,37 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
8193
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
8294
.describe("Output the mean and inverse std ");
8395
DMLC_DECLARE_FIELD(axis).set_default(mxnet::op::batchnorm::DEFAULT_AXIS)
84-
.describe("Specify which shape axis the channel is specified");
96+
.describe("Specify which shape axis the channel is specified");
8597
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
86-
.describe("Do not select CUDNN operator, if available");
98+
.describe("Do not select CUDNN operator, if available");
99+
DMLC_DECLARE_FIELD(min_calib_range)
100+
.set_default(dmlc::optional<float>())
101+
.describe("The minimum scalar value in the form of float32 obtained "
102+
"through calibration. If present, it will be used to by "
103+
"quantized batch norm op to calculate primitive scale."
104+
"Note: this calib_range is to calib bn output.");
105+
DMLC_DECLARE_FIELD(max_calib_range)
106+
.set_default(dmlc::optional<float>())
107+
.describe("The maximum scalar value in the form of float32 obtained "
108+
"through calibration. If present, it will be used to by "
109+
"quantized batch norm op to calculate primitive scale."
110+
"Note: this calib_range is to calib bn output.");
87111
}
88112

89-
bool operator==(const BatchNormParam& other) const {
90-
return this->eps == other.eps &&
91-
this->momentum == other.momentum &&
92-
this->fix_gamma == other.fix_gamma &&
93-
this->use_global_stats == other.use_global_stats &&
94-
this->output_mean_var == other.output_mean_var &&
95-
this->axis == other.axis &&
96-
this->cudnn_off == other.cudnn_off;
113+
bool operator==(const BatchNormParam &other) const {
114+
bool flag = this->eps == other.eps && this->momentum == other.momentum &&
115+
this->fix_gamma == other.fix_gamma &&
116+
this->use_global_stats == other.use_global_stats &&
117+
this->output_mean_var == other.output_mean_var && this->axis == other.axis &&
118+
this->cudnn_off == other.cudnn_off &&
119+
this->min_calib_range.has_value() == other.min_calib_range.has_value() &&
120+
this->max_calib_range.has_value() == other.max_calib_range.has_value();
121+
if (this->min_calib_range.has_value() && other.min_calib_range.has_value() &&
122+
this->max_calib_range.has_value() && other.max_calib_range.has_value()) {
123+
flag = flag && this->min_calib_range.value() == other.min_calib_range.value() &&
124+
this->max_calib_range.value() == other.max_calib_range.value();
125+
}
126+
return flag;
97127
}
98128
};
99129

src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ class MKLDNNBNForward {
132132
return *var_m;
133133
}
134134

135-
void SetDataHandle(const NDArray &data, const NDArray &mean,
136-
const NDArray &var, const mkldnn::memory &out) {
135+
void SetDataHandle(const NDArray &data, const mkldnn::memory *mean,
136+
const mkldnn::memory *var, const mkldnn::memory *out) {
137137
auto _data = data.GetMKLDNNData();
138138
if (data_m) {
139139
data_m->set_data_handle(_data->get_data_handle());
@@ -142,24 +142,22 @@ class MKLDNNBNForward {
142142
_data->get_data_handle()));
143143
}
144144
if (out_m) {
145-
out_m->set_data_handle(out.get_data_handle());
145+
out_m->set_data_handle(out->get_data_handle());
146146
} else {
147-
out_m.reset(new mkldnn::memory(out.get_primitive_desc(),
148-
out.get_data_handle()));
147+
out_m.reset(new mkldnn::memory(out->get_primitive_desc(),
148+
out->get_data_handle()));
149149
}
150-
auto mean_ptr = mean.data().dptr_;
151150
if (mean_m) {
152-
mean_m->set_data_handle(mean_ptr);
151+
mean_m->set_data_handle(mean->get_data_handle());
153152
} else {
154-
mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(),
155-
mean_ptr));
153+
mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(),
154+
mean->get_data_handle()));
156155
}
157-
auto var_ptr = var.data().dptr_;
158156
if (var_m) {
159-
var_m->set_data_handle(var_ptr);
157+
var_m->set_data_handle(var->get_data_handle());
160158
} else {
161-
var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
162-
var_ptr));
159+
var_m.reset(new mkldnn::memory(var->get_primitive_desc(),
160+
var->get_data_handle()));
163161
}
164162

165163
if (fwd == nullptr) {
@@ -175,6 +173,11 @@ class MKLDNNBNForward {
175173
}
176174
}
177175

176+
void SetDataHandle(const NDArray &data, const NDArray &mean,
177+
const NDArray &var, const mkldnn::memory &out) {
178+
SetDataHandle(data, mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
179+
}
180+
178181
const mkldnn::batch_normalization_forward &GetFwd() const {
179182
return *fwd;
180183
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file mkldnn_quantized_batch_norm.cc
22+
* \brief
23+
* \author Yixin Bao
24+
*/
25+
26+
#if MXNET_USE_MKLDNN == 1
27+
#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h"
28+
#include "../quantization_utils.h"
29+
30+
namespace mxnet {
31+
namespace op {
32+
33+
static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
34+
const std::vector<NDArray> &in_data,
35+
const std::vector<OpReqType> &req,
36+
const std::vector<NDArray> &outputs) {
37+
CHECK_EQ(in_data.size(), 7U);
38+
CHECK_EQ(outputs.size(), 3U);
39+
40+
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
41+
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
42+
const NDArray &data = in_data[quantized_batchnorm::kData];
43+
const size_t channelAxis = static_cast<size_t>(
44+
param.axis < 0 ? static_cast<int>(data.shape().ndim()) + param.axis : param.axis);
45+
const int channel_count = data.shape()[channelAxis];
46+
const float min_data = in_data[quantized_batchnorm::kDataMin].data().dptr<float>()[0];
47+
const float max_data = in_data[quantized_batchnorm::kDataMax].data().dptr<float>()[0];
48+
const float max_abs_data = std::max(std::abs(min_data), std::abs(max_data));
49+
50+
float *min_output_ptr = outputs[quantized_batchnorm::kOutMin].data().dptr<float>();
51+
float *max_output_ptr = outputs[quantized_batchnorm::kOutMax].data().dptr<float>();
52+
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
53+
*max_output_ptr = param.max_calib_range.value();
54+
*min_output_ptr = param.min_calib_range.value();
55+
} else {
56+
LOG(FATAL) << "min_calib_range or max_calib_range is not available. Quantized BN currently "
57+
"don't support calib_mode=None";
58+
}
59+
const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr));
60+
61+
unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift;
62+
auto &fwd = GetBNForward<float>(param, ctx, data, flags);
63+
const mkldnn::memory &weight_mem = fwd.GetWeight();
64+
CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2);
65+
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());
66+
67+
float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr<float>();
68+
float *beta_ptr = in_data[quantized_batchnorm::kBeta].data().dptr<float>();
69+
70+
const NDArray &moving_mean = in_data[quantized_batchnorm::kInMovingMean];
71+
const NDArray &moving_var = in_data[quantized_batchnorm::kInMovingVar];
72+
float *moving_mean_ptr = moving_mean.data().dptr<float>();
73+
float *moving_var_ptr = moving_var.data().dptr<float>();
74+
75+
// rescale gamma and beta, to make mean=0 and var=1
76+
auto rescaled_mean_mem =
77+
TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_primitive_desc());
78+
auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_primitive_desc());
79+
float *rescaled_mean_ptr = reinterpret_cast<float *>(rescaled_mean_mem->get_data_handle());
80+
float *rescaled_var_ptr = reinterpret_cast<float *>(rescaled_var_mem->get_data_handle());
81+
82+
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
83+
for (int channel = 0; channel < channel_count; ++channel) {
84+
float invstd = 1.0 / std::sqrt(moving_var_ptr[channel] + param.eps);
85+
weight_buf[channel] = gamma_ptr[channel] * invstd * max_abs_data / max_abs_output;
86+
weight_buf[channel_count + channel] =
87+
(beta_ptr[channel] - moving_mean_ptr[channel] * gamma_ptr[channel] * invstd) * kInt8Range /
88+
max_abs_output;
89+
rescaled_mean_ptr[channel] = 0.0f;
90+
rescaled_var_ptr[channel] = 1.0f;
91+
}
92+
93+
auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut],
94+
fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data);
95+
fwd.SetDataHandle(data, rescaled_mean_mem, rescaled_var_mem, out_mem.second);
96+
97+
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
98+
MKLDNNStream::Get()->Submit();
99+
}
100+
101+
inline static bool QuantizedBatchNormStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask,
102+
DispatchMode *dispatch_mode,
103+
std::vector<int> *in_attrs,
104+
std::vector<int> *out_attrs) {
105+
bool dispatched = false;
106+
if (!dispatched) {
107+
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
108+
}
109+
return dispatched;
110+
}
111+
112+
NNVM_REGISTER_OP(_contrib_quantized_batch_norm)
113+
.set_attr<FInferStorageType>("FInferStorageType", QuantizedBatchNormStorageType)
114+
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedBatchNormForward)
115+
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
116+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
117+
})
118+
.set_attr<bool>("TIsMKLDNN", true);
119+
120+
} // namespace op
121+
} // namespace mxnet
122+
123+
#endif // MXNET_USE_MKLDNN == 1

src/operator/quantization/quantize_graph_pass.cc

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ using nnvm::NodePtr;
3737
using nnvm::NodeEntry;
3838
using nnvm::Graph;
3939

40+
inline size_t GetNumOutputs(NodePtr node) {
41+
// Get NumOutputs, check if current node has NumVisibleOutputs function, if yes, return
42+
// num_visible_outputs
43+
size_t num_outputs = node->num_outputs();
44+
static const auto& num_visible_outputs_attr =
45+
nnvm::Op::GetAttr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs");
46+
auto num_visible_output_func = num_visible_outputs_attr.get(node->op(), nullptr);
47+
if (num_visible_output_func != nullptr) {
48+
num_outputs = num_visible_output_func(node->attrs);
49+
}
50+
return num_outputs;
51+
}
52+
4053
NodePtr CreateNode(std::string op_name, std::string node_name) {
4154
NodePtr node = Node::Create();
4255
node->attrs.name = node_name;
@@ -223,7 +236,7 @@ Graph QuantizeGraph(Graph &&src) {
223236
// calculate min/max index from mirror node) based on assumption that
224237
// there is only 1min and 1max output from mirror node (which is
225238
// currently true)
226-
size_t num_outputs = mirror_node->num_outputs() - 2;
239+
size_t num_outputs = GetNumOutputs(mirror_node) - 2;
227240
min_index = num_outputs + 2 * e.index;
228241
max_index = num_outputs + 2 * e.index + 1;
229242
} else {
@@ -276,7 +289,7 @@ Graph QuantizeGraph(Graph &&src) {
276289
// calculate min/max index from mirror node) based on assumption that
277290
// there is only 1 min and 1 max output from mirror node (which is
278291
// currently true)
279-
size_t num_outputs = mirror_node->num_outputs() - 2;
292+
size_t num_outputs = GetNumOutputs(mirror_node) - 2;
280293
uint32_t min_index = num_outputs + 2 * e.index;
281294
uint32_t max_index = num_outputs + 2 * e.index + 1;
282295
NodePtr dequantize_node = CreateNode("_contrib_dequantize",
@@ -309,7 +322,7 @@ Graph QuantizeGraph(Graph &&src) {
309322
// calculate min/max index from mirror node) based on assumption that
310323
// there is only 1 min and 1 max output from mirror node (which is
311324
// currently true)
312-
size_t num_outputs = e.node->num_outputs();
325+
size_t num_outputs = GetNumOutputs(e.node);
313326
uint32_t min_index = num_outputs + 2 * e.index;
314327
uint32_t max_index = num_outputs + 2 * e.index + 1;
315328

@@ -403,6 +416,29 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
403416
<< "` has negative input, consider use `auto` or `int8` as out_type";
404417
}
405418
}
419+
} else if (node->op() == Op::Get("_contrib_quantized_batch_norm")) {
420+
auto quantized_op_idx = node->inputs[0].index;
421+
const std::string prefix = "quantized_";
422+
std::string out_data_name = node->attrs.name.substr(prefix.size());
423+
if (node->op()) {
424+
auto list_output_names_func = flist_outputs.get(node->op(), nullptr);
425+
// We want to get the pre-calculated min_range and max_range from the calibration table for
426+
// out_data. Here we create the output data name same as its constructed in
427+
// GraphExecutor::ExecuteMonCallback.
428+
if (list_output_names_func != nullptr) {
429+
std::vector<std::string> names = list_output_names_func(node->attrs);
430+
out_data_name += "_" + names[quantized_op_idx];
431+
} else {
432+
out_data_name += "_" + std::to_string(quantized_op_idx);
433+
}
434+
}
435+
436+
const auto calib_table_iter = calib_table.find(out_data_name);
437+
if (calib_table_iter != calib_table.end()) {
438+
node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first);
439+
node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second);
440+
node->op()->attr_parser(&(node->attrs));
441+
}
406442
}
407443
});
408444
return g;

0 commit comments

Comments
 (0)