Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7803155

Browse files
committed
Add mkldnn_v1.0 int8 fc
1 parent 48bfcf9 commit 7803155

File tree

5 files changed

+31
-24
lines changed

5 files changed

+31
-24
lines changed

src/operator/nn/mkldnn/mkldnn_fully_connected.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,15 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
216216
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
217217
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);
218218

219-
std::unordered_map<int, mkldnn::memory> args = {
219+
mkldnn_args_map_t args = {
220220
{MKLDNN_ARG_SRC, *data_mem},
221221
{MKLDNN_ARG_WEIGHTS, *weight_mem},
222222
{MKLDNN_ARG_DST, *out_mem.second},
223223
};
224224
if (!full_param.default_param.no_bias) {
225225
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
226226
fwd->fwd_pd.bias_desc());
227-
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
227+
args[MKLDNN_ARG_BIAS] = *bias_mem;
228228
}
229229
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
230230
CommitOutput(out_data[fullc::kOut], out_mem);
@@ -298,7 +298,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
298298
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
299299
ipBwdData_pd.diff_src_desc(),
300300
req[fullc::kData]);
301-
std::unordered_map<int, mkldnn::memory> args = {
301+
mkldnn_args_map_t args = {
302302
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
303303
{MKLDNN_ARG_WEIGHTS, *weight_mem},
304304
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
@@ -317,7 +317,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
317317
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
318318
ipBwdWeights_pd.diff_weights_desc(),
319319
req[fullc::kWeight]);
320-
std::unordered_map<int, mkldnn::memory> args = {
320+
mkldnn_args_map_t args = {
321321
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
322322
{MKLDNN_ARG_SRC, *data_mem},
323323
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
@@ -328,7 +328,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
328328
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
329329
ipBwdWeights_pd.diff_bias_desc(),
330330
req[fullc::kBias]);
331-
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
331+
args[MKLDNN_ARG_DIFF_BIAS] = *in_grad_bias.second;
332332
}
333333
MKLDNNStream::Get()->RegisterPrimArgs(
334334
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);

src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
* \author Ciyong Chen
2525
*/
2626

27-
#if MXNET_USE_MKLDNN == 1
27+
#if MXNET_USE_MKLDNN == 100
2828
#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
2929
#include "../quantization_utils.h"
3030

@@ -89,33 +89,40 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
8989
auto &fwd = GetFCFwd(param, is_train, data, weight,
9090
param.no_bias ? nullptr : &quantized_bias, out_md);
9191

92-
auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
92+
auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
9393
const mkldnn::memory *weight_mem = nullptr;
9494

9595
if (weight.IsDefaultData()) {
9696
// We also need to modify the layout on the original weight array.
9797
// Don't switch below sequence because naive engine will executes
9898
// pushAsync synchronously.
99-
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
100-
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
99+
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
100+
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
101101
} else {
102102
weight_mem = weight.GetMKLDNNData();
103-
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
103+
CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
104104
}
105-
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(),
105+
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(),
106106
req[fullc::kOut]);
107-
const mkldnn::memory *bias_mem = nullptr;
108-
if (!param.no_bias)
109-
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());
110107

111-
fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
112-
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
108+
mkldnn_args_map_t args = {
109+
{MKLDNN_ARG_SRC, *data_mem},
110+
{MKLDNN_ARG_WEIGHTS, *weight_mem},
111+
{MKLDNN_ARG_DST, *out_mem.second},
112+
};
113+
114+
const mkldnn::memory *bias_mem = nullptr;
115+
if (!param.no_bias) {
116+
bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
117+
args[MKLDNN_ARG_BIAS] = *bias_mem;
118+
}
113119

120+
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), args);
114121
CommitOutput(out_data[fullc::kOut], out_mem);
115122
MKLDNNStream::Get()->Submit();
116123
}
117124

118125
} // namespace op
119126
} // namespace mxnet
120127

121-
#endif // MXNET_USE_MKLDNN == 1
128+
#endif // MXNET_USE_MKLDNN == 100

src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
2828
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
2929

30-
#if MXNET_USE_MKLDNN == 1
30+
#if MXNET_USE_MKLDNN == 100
3131

3232
#include <mxnet/ndarray.h>
3333
#include <vector>

src/operator/quantization/quantized_fully_connected.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include <vector>
2727
#include "quantization_utils.h"
2828
#include "../nn/fully_connected-inl.h"
29-
#if MXNET_USE_MKLDNN == 1
29+
#if MXNET_USE_MKLDNN == 100
3030
#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
3131
#include "mkldnn/mkldnn_quantized_ops-inl.h"
3232
#endif
@@ -94,7 +94,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
9494
CHECK_EQ(in_type->size(), num_inputs * 3);
9595
CHECK_EQ(out_type->size(), 3U);
9696

97-
#if MXNET_USE_MKLDNN == 1
97+
#if MXNET_USE_MKLDNN == 100
9898
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
9999
<< "QuantizedFullyConnected only supports int8/uint8 input, while "
100100
<< in_type->at(0) << " is given.";
@@ -124,7 +124,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
124124
CHECK_EQ(in_attrs->size(), num_inputs * 3);
125125
CHECK_EQ(out_attrs->size(), 3U);
126126

127-
#if MXNET_USE_MKLDNN == 1
127+
#if MXNET_USE_MKLDNN == 100
128128
return MKLDNNStorageType(attrs, dev_mask, true,
129129
dispatch_mode, in_attrs, out_attrs);
130130
#else
@@ -292,7 +292,7 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
292292
#endif
293293
}
294294

295-
#if MXNET_USE_MKLDNN == 1
295+
#if MXNET_USE_MKLDNN == 100
296296
void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
297297
const OpContext &ctx,
298298
const std::vector<NDArray> &in_data,
@@ -341,7 +341,7 @@ and max thresholds representing the threholds for quantizing the float32 output
341341
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
342342
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
343343
.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
344-
#if MXNET_USE_MKLDNN == 1
344+
#if MXNET_USE_MKLDNN == 100
345345
.set_attr<bool>("TIsMKLDNN", true)
346346
.set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
347347
#endif

tests/python/quantization/test_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
407407
def test_quantized_fc():
408408
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
409409
if is_test_for_native_cpu():
410-
hasMKL = False;
410+
hasMKL = False
411411
for key in os.environ.keys():
412412
if operator.eq(key, "BUILD_TAG"):
413413
if os.environ['BUILD_TAG'].find("MKL") != -1:

0 commit comments

Comments
 (0)