Skip to content

Commit 436967b

Browse files
rongzha1pengzhao-intel
authored andcommitted
Mkldnn fullyConnect bwd bug fix (apache#16890)
* fix mkldnn fc bwd bug due to data inplace * enable mkldnn fc bwd
1 parent 9b49cfe commit 436967b

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

src/operator/nn/fully_connected.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
147147
const std::vector<NDArray> &inputs,
148148
const std::vector<OpReqType> &req,
149149
const std::vector<NDArray> &outputs) {
150-
// TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp
151-
// Will be fixed when we decide to enable the backward of FC.
152-
bool mkldnn_fc_backward_enable = false;
153-
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
150+
if (SupportMKLDNNFC(inputs[0])) {
154151
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
155152
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
156153
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
@@ -232,12 +229,10 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
232229
uint32_t out_expected = param.no_bias ? 2 : 3;
233230
CHECK_EQ(in_attrs->size(), 3U);
234231
CHECK_EQ(out_attrs->size(), out_expected);
235-
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
236-
// It seems there is a bug.
237232
bool dispatched = false;
238233
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
239234
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
240-
dispatch_mode, DispatchMode::kFCompute);
235+
dispatch_mode, DispatchMode::kFComputeEx);
241236
}
242237
if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
243238
dispatched = dispatch_fallback(out_attrs, dispatch_mode);

src/operator/nn/mkldnn/mkldnn_fully_connected.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -290,24 +290,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
290290
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));
291291

292292
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
293-
if (req[fullc::kData]) {
294-
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
295-
data, weight, out_grad, fwd_pd);
296-
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
297-
ipBwdData_pd.diff_dst_desc());
298-
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
299-
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
300-
ipBwdData_pd.diff_src_desc(),
301-
req[fullc::kData]);
302-
mkldnn_args_map_t args = {
303-
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
304-
{MKLDNN_ARG_WEIGHTS, *weight_mem},
305-
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
306-
};
307-
308-
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
309-
CommitOutput(in_grad[fullc::kData], in_grad_mem);
310-
}
311293
if (req[fullc::kWeight]) {
312294
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
313295
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
@@ -336,6 +318,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
336318
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
337319
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
338320
}
321+
if (req[fullc::kData]) {
322+
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
323+
data, weight, out_grad, fwd_pd);
324+
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
325+
ipBwdData_pd.diff_dst_desc());
326+
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
327+
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
328+
ipBwdData_pd.diff_src_desc(),
329+
req[fullc::kData]);
330+
mkldnn_args_map_t args = {
331+
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
332+
{MKLDNN_ARG_WEIGHTS, *weight_mem},
333+
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
334+
};
335+
336+
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
337+
CommitOutput(in_grad[fullc::kData], in_grad_mem);
338+
}
339339
MKLDNNStream::Get()->Submit();
340340
}
341341

0 commit comments

Comments
 (0)