From 8b862eff01418dda377128869e4532f2b494308a Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 25 Oct 2019 10:10:48 +0800 Subject: [PATCH 1/3] Use MKLDNNRun --- src/ndarray/ndarray.cc | 2 +- src/operator/leaky_relu.cc | 4 +- src/operator/nn/activation.cc | 7 +-- src/operator/nn/batch_norm.cc | 31 +---------- src/operator/nn/concat.cc | 4 +- src/operator/nn/fully_connected.cc | 4 +- src/operator/nn/lrn.cc | 10 +--- src/operator/nn/mkldnn/mkldnn_act.cc | 52 +++++++------------ src/operator/nn/mkldnn/mkldnn_base-inl.h | 15 +++++- src/operator/nn/mkldnn/mkldnn_base.cc | 19 ++++++- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 50 ++++++++++-------- src/operator/nn/mkldnn/mkldnn_copy.cc | 13 ++--- .../nn/mkldnn/mkldnn_fully_connected.cc | 6 --- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 25 +++++---- src/operator/nn/mkldnn/mkldnn_ops-inl.h | 20 +++---- src/operator/nn/mkldnn/mkldnn_reshape.cc | 2 +- src/operator/nn/mkldnn/mkldnn_slice-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_slice.cc | 3 +- src/operator/nn/mkldnn/mkldnn_softmax.cc | 8 +-- src/operator/nn/mkldnn/mkldnn_sum.cc | 25 ++++----- src/operator/nn/softmax.cc | 2 +- .../mkldnn/mkldnn_quantized_flatten.cc | 2 +- src/operator/softmax_output.cc | 2 +- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 4 +- .../tensor/elemwise_binary_op_basic.cc | 6 +-- src/operator/tensor/elemwise_sum.cc | 2 +- .../tensor/elemwise_unary_op_basic.cc | 8 ++- src/operator/tensor/matrix_op.cc | 8 +-- 28 files changed, 150 insertions(+), 186 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6dc6bafa7288..95fc09774dda 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -685,7 +685,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { ptr_->Reorder2Default(); const mkldnn::memory *this_mem = GetMKLDNNData(); - MKLDNNCopy(mem, this_mem); + MKLDNNMemoryCopy(mem, this_mem); } mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) { diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index 4d1c5ca10a30..49ba95d306f4 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -95,7 +95,7 @@ static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), expected); if (SupportMKLDNNLeakyRelu(param, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNLeakyReluForward, attrs, ctx, inputs[0], req[0], outputs[0]); MKLDNN_OPCHECK_RUN(LeakyReLUCompute, attrs, ctx, inputs, req, outputs); return; } @@ -111,7 +111,7 @@ void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs, if (SupportMKLDNNLeakyRelu(param, inputs[0])) { std::vector in_data{inputs[0], inputs[1]}; MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNLeakyReluBackward(attrs, ctx, in_data, req[0], outputs[0]); + MKLDNNRun(MKLDNNLeakyReluBackward, attrs, ctx, in_data, req, outputs); MKLDNN_OPCHECK_RUN(LeakyReLUGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 5abb6670c9b0..ce5fb3e45322 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -102,7 +102,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); if (SupportMKLDNNAct(param, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNActivationForward, attrs, ctx, inputs[0], req[0], outputs[0]); MKLDNN_OPCHECK_RUN(ActivationCompute, attrs, ctx, inputs, req, outputs); return; } @@ -118,10 +118,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type)); if (SupportMKLDNNAct(param, inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - // XXX: for y = relu(x), y is passed as "in_data" to Backward() - const bool relu = param.act_type == activation::kReLU; - MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0], - outputs[0]); + MKLDNNRun(MKLDNNActivationBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ActivationGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 3214e3b9b9ac..04e45d4acfed 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -394,17 +394,11 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - if (SupportMKLDNNBN(inputs[0], param)) { - std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); - std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); - - if (inputs[0].dtype() == mshadow::kFloat32) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNBatchNormForward(ctx, param, in_data, req, outputs, aux_states); + MKLDNNRun(MKLDNNBatchNormForward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(BatchNormCompute, attrs, ctx, inputs, req, outputs); return; - } } FallBackCompute(BatchNormCompute, attrs, ctx, inputs, req, outputs); } @@ -414,33 +408,12 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - CHECK_EQ(inputs.size(), 8U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - - mxnet::TShape shape = inputs[0].shape(); - if (SupportMKLDNNBN(inputs[0], param)) { - std::vector out_grad(1); - std::vector out_data(3); - std::vector in_data(3); - std::vector aux_states(2); - out_grad[0] = inputs[0]; - out_data[batchnorm::kMean] = inputs[1]; - out_data[batchnorm::kVar] = inputs[2]; - in_data[batchnorm::kData] = inputs[3]; - in_data[batchnorm::kGamma] = inputs[4]; - in_data[batchnorm::kBeta] = inputs[5]; - aux_states[batchnorm::kMovingMean] = inputs[6]; - aux_states[batchnorm::kMovingVar] = inputs[7]; - const std::vector &in_grad = outputs; - - if (inputs[0].dtype() == mshadow::kFloat32) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNBatchNormBackward(ctx, param, out_grad, in_data, - out_data, req, in_grad, aux_states); + MKLDNNRun(MKLDNNBatchNormBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); return; - } } FallBackCompute(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 4d90810915a2..1eeef7db5cb5 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -270,7 +270,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 } else if (SupportMKLDNNConcat(inputs)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNConcatForward, attrs, op_ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConcatCompute, attrs, op_ctx, inputs, req, outputs); } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { FallBackCompute(ConcatCompute, attrs, op_ctx, inputs, req, outputs); @@ -288,7 +288,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { if (SupportMKLDNNConcat(inputs)) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNConcatBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConcatGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 5d722581257f..1632486e0a82 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -102,7 +102,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs, common::ContainsOnlyStorage(outputs, kDefaultStorage)) { if (SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNFCForward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNFCForward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedCompute, attrs, ctx, inputs, req, outputs); } else { @@ -152,7 +152,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, bool mkldnn_fc_backward_enable = false; if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, outputs); return; diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 3a3ca59f2be1..41337352df63 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -110,11 +110,10 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - const LRNParam ¶m = nnvm::get(attrs.parsed); if (SupportMKLDNN(inputs[0])) { // We only need to test one output array. MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs); - MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNLRNForward, attrs, ctx, inputs[0], req[0], outputs[0]); MKLDNN_OPCHECK_RUN(LRNCompute, attrs, ctx, inputs, req, outputs); // Copy outputs[1] from opcheck reference as backward check needs it. MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector{1}); @@ -128,14 +127,9 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - const LRNParam ¶m = nnvm::get(attrs.parsed); - const NDArray &out_grad = inputs[0]; - const NDArray &in_data = inputs[1]; - const NDArray &in_grad = outputs[0]; - if (SupportMKLDNN(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNLRNBackward(ctx, param, out_grad, in_data, req[0], in_grad); + MKLDNNRun(MKLDNNLRNBackward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(LRNGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 6ea7ac4a30b3..7cf94790ed0d 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -151,13 +151,8 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const ActivationParam& param = nnvm::get(attrs.parsed); MKLDNNActParam param_; param_.alg = GetMKLDNNActAlgo(param); - - NDArray in_buffer = in_data; + const NDArray& in_buffer = in_data; MKLDNNStream *stream = MKLDNNStream::Get(); - - if (in_data.IsView() && in_data.IsMKLDNNData()) - in_buffer = in_data.Reorder2Default(); - auto input_mem = in_buffer.GetMKLDNNData(); MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem); auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer); @@ -235,22 +230,18 @@ static inline MKLDNNActBackward &GetActBackward(const MKLDNNActParam ¶m, // For backward relu activation, it's okay to pass "out_data" as "in_data" to this // function, since the computation only involes non-zeros. -void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &out_grad, const NDArray &in_data, - const OpReqType &req, const NDArray &in_grad) { - if (req == kNullOp) { +void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + if (req[0] == kNullOp) { return; } - - NDArray out_buffer = out_grad; - if (out_grad.IsView() && out_grad.IsMKLDNNData()) - out_buffer = out_grad.Reorder2Default(); - - NDArray in_buffer = in_data; - if (in_data.IsView() && in_data.IsMKLDNNData()) - in_buffer = in_data.Reorder2Default(); - const ActivationParam& param = nnvm::get(attrs.parsed); + // XXX: for y = relu(x), y is passed as "in_data" to Backward() + const bool relu = param.act_type == activation::kReLU; + const NDArray &out_buffer = inputs[0]; + const NDArray &in_buffer = relu ? inputs[1] : inputs[2]; + const NDArray &in_grad = outputs[0]; MKLDNNActParam param_; param_.alg = GetMKLDNNActAlgo(param); TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); @@ -264,7 +255,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req); + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]); mkldnn_args_map_t args = { { MKLDNN_ARG_SRC, *input_mem }, { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, @@ -278,19 +269,16 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, - const OpReqType &req, - const NDArray &output) { - if (req == kNullOp) { + const std::vector &req, + const std::vector &outputs) { + if (req[0] == kNullOp) { return; } - CHECK_GE(inputs.size(), 2U); - NDArray out_buffer = inputs[0]; - if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) - out_buffer = inputs[0].Reorder2Default(); - - NDArray in_buffer = inputs[1]; - if (inputs[1].IsView() && inputs[1].IsMKLDNNData()) - in_buffer = inputs[1].Reorder2Default(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const NDArray& out_buffer = inputs[0]; + const NDArray& in_buffer = inputs[1]; + const NDArray &output = outputs[0]; const LeakyReLUParam& param = nnvm::get(attrs.parsed); MKLDNNActParam param_; @@ -308,7 +296,7 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req); + CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req[0]); mkldnn_args_map_t args = { { MKLDNN_ARG_SRC, *input_mem }, { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 0f371d174e40..b536f51d4a81 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -442,7 +442,7 @@ enum OutDataOp { }; typedef std::pair mkldnn_output_t; -void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); +void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); /* * Here we want to get MKLDNN memory whose desc is exactly the same as @@ -684,6 +684,19 @@ void MKLDNNRun(mxnet::FComputeEx fn, const std::vector &req, const std::vector &outputs_); +using FComputeExUnary = std::function; + +void MKLDNNRun(FComputeExUnary fn, + const nnvm::NodeAttrs &attrs, + const mxnet::OpContext &ctx, + const mxnet::NDArray &inputs_, + const mxnet::OpReqType &req, + const mxnet::NDArray &outputs_); + } // namespace mxnet #endif #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 1b147c69ba62..8ee9e48b6f11 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -85,7 +85,7 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) { } } -void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) { +void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) { MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn::memory::desc from_desc = mem.get_desc(); mkldnn::memory::desc this_desc = this_mem->get_desc(); @@ -227,7 +227,7 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { auto mem = arr.GetMKLDNNData(res.second->get_desc()); if (mem == nullptr) { auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd); - MKLDNNCopy(*res_memory, tmp_memory); + MKLDNNMemoryCopy(*res_memory, tmp_memory); res_memory = tmp_memory; mem = arr.GetMKLDNNData(); } @@ -606,6 +606,21 @@ void MKLDNNRun(mxnet::FComputeEx fn, } } +void MKLDNNRun(FComputeExUnary fn, + const nnvm::NodeAttrs &attrs, + const mxnet::OpContext &ctx, + const mxnet::NDArray &input, + const mxnet::OpReqType &req, + const mxnet::NDArray &output) { + auto mkldnn_input = input; + if (input.IsView() && input.IsMKLDNNData()) { + mkldnn_input = input.Reorder2Default(); + fn(attrs, ctx, mkldnn_input, req, output); + } else { + fn(attrs, ctx, input, req, output); + } +} + } // namespace mxnet #endif diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 26637c7c0b65..23e327389f46 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -147,11 +147,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, } template -void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { +void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + const BatchNormParam ¶m = nnvm::get(attrs.parsed); + const std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + const std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, aux_states, @@ -159,7 +160,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = out_data[batchnorm::kOut]; + const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -201,8 +202,8 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, net_args[MKLDNN_ARG_DST] = *out_mem; if (!ctx.is_train || param.use_global_stats) { - DType* omean = out_data[batchnorm::kMean].data().dptr(); - DType* ovar = out_data[batchnorm::kVar].data().dptr(); + DType* omean = outputs[batchnorm::kMean].data().dptr(); + DType* ovar = outputs[batchnorm::kVar].data().dptr(); DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr(); DType* invar = aux_states[batchnorm::kMovingVar].data().dptr(); // to align with origin implmentation: batch_norm.cc: L164 @@ -215,8 +216,8 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } else { // training - const NDArray &outMean = out_data[batchnorm::kMean]; - const NDArray &outVar = out_data[batchnorm::kVar]; + const NDArray &outMean = outputs[batchnorm::kMean]; + const NDArray &outVar = outputs[batchnorm::kVar]; net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); @@ -278,18 +279,25 @@ static MKLDNNBNBackward &GetBNBackward( } template -void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { +void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 8U); + const BatchNormParam ¶m = nnvm::get(attrs.parsed); + std::vector out_grad(1); + std::vector out_data(3); + std::vector in_data(3); + std::vector aux_states(2); + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + const std::vector &in_grad = outputs; TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - CHECK_EQ(out_grad.size(), 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); mkldnn::normalization_flags flags = _GetFlags(in_data, aux_states, param, diff --git a/src/operator/nn/mkldnn/mkldnn_copy.cc b/src/operator/nn/mkldnn/mkldnn_copy.cc index cf8daa4b45df..a67847f9c882 100644 --- a/src/operator/nn/mkldnn/mkldnn_copy.cc +++ b/src/operator/nn/mkldnn/mkldnn_copy.cc @@ -33,22 +33,17 @@ namespace op { void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data) { + if (req == kNullOp || req == kWriteInplace) return; TmpMemMgr::Get()->Init(ctx.requested[0]); - - // If the input data is a view of an MKLDNN array, we should create a new - // NDArray with reordered data. - NDArray data = in_data; - if (data.IsMKLDNNData() && data.IsView()) - data = data.Reorder2Default(); - auto in_mem = data.GetMKLDNNData(); + auto in_mem = in_data.GetMKLDNNData(); if (req == kAddTo) { TmpMemMgr::Get()->Init(ctx.requested[0]); // We should try and force the input memory has the same format // as the input output. If not, we'll have to reorder memory. auto out_mem = out_data.GetMKLDNNData(); - in_mem = data.GetMKLDNNData(out_mem ->get_desc()); + in_mem = in_data.GetMKLDNNData(out_mem ->get_desc()); if (in_mem == nullptr) - in_mem = data.GetMKLDNNDataReorder(out_mem->get_desc()); + in_mem = in_data.GetMKLDNNDataReorder(out_mem->get_desc()); MKLDNNSum(*out_mem, *in_mem, *out_mem); } else { const_cast(out_data).CopyFrom(*in_mem); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 1403cd114201..8c401a879f15 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -166,12 +166,6 @@ void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, mkldnn::memory::desc *out_md) { const mxnet::TShape ishape = in_data->shape(); const mxnet::TShape oshape = out_data.shape(); - - // If the input data is a view of an MKLDNN array, we should create a new - // NDArray with reordered data. - if (in_data->IsMKLDNNData() && in_data->IsView()) - *in_data = in_data->Reorder2Default(); - if (ishape.ndim() != 2) { if (!param.flatten) { *in_data = in_data->MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1), diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index ca7095fd3f02..5ebc0e31822d 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -171,9 +171,10 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, return it->second; } -void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, +void MKLDNNLRNForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType req, const NDArray &out_data) { + const LRNParam ¶m = nnvm::get(attrs.parsed); auto in_buffer = in_data; if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = in_buffer.Reorder2Default(); @@ -244,22 +245,21 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data, return it->second; } -void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, - const NDArray &out_grad, - const NDArray &in_data, - const OpReqType req, - const NDArray &in_grad) { - if (req == kNullOp) { +void MKLDNNLRNBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + if (req[0] == kNullOp) { return; } + const LRNParam ¶m = nnvm::get(attrs.parsed); + const NDArray &out_grad = inputs[0]; + const NDArray &in_data = inputs[1]; + const NDArray &in_grad = outputs[0]; // TODO(alex): (MXNET-846) figure out why in_grad output incorrect when in_data is nchw8c - auto in_buffer = in_data; - if (in_buffer.IsMKLDNNData()) { - in_buffer = in_data.Reorder2Default(); - } + const auto in_buffer = in_data.Reorder2Default(); MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad); mkldnn_output_t diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req); + CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]); bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem); } @@ -267,4 +267,3 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__ - diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 71f3eafa8ee9..1ce36303689d 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -76,16 +76,18 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext & void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); -void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &out_grad, const NDArray &in_data, - const OpReqType &req, const NDArray &in_grad); +void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); -void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, const OpReqType &req, - const NDArray &output); +void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); /* For softmax */ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -99,9 +101,9 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &c const std::vector &out_data); /* For sum */ -void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &inputs, const OpReqType &req, - const NDArray &out_data); +void MKLDNNSumForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs); /* For copy */ void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 0fc9f20703af..944e7d310d1e 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -128,7 +128,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to // naive implement. if (input.shape().ndim() > 4 || !SupportMKLDNNQuantize(input.dtype())) { - if (req != kWriteInplace) { + if (req != kWriteInplace && req != kNullOp) { FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, {input}, {req}, {output}); } return; diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h b/src/operator/nn/mkldnn/mkldnn_slice-inl.h index e6258c8c3f43..0bb432da9f7f 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h @@ -57,7 +57,7 @@ typedef ParamOpSign MKLDNNSliceSignature; MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, const NDArray &in_data, const NDArray &out_data); -void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, +void MKLDNNSlice(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const NDArray &in, OpReqType req, const NDArray &out); } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 575554a25c88..26d4f096bef1 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -90,8 +90,9 @@ MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, return it->second; } -void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, +void MKLDNNSlice(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const NDArray &in, OpReqType req, const NDArray &out) { + const SliceParam& param = nnvm::get(attrs.parsed); MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out); auto in_mem = in.GetMKLDNNData(); auto out_md = out.GetMKLDNNData()->get_desc(); diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index 5b43cb0b0864..a3558847fba1 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -76,12 +76,7 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, CHECK_NE(req, kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, in_data.shape().ndim()); - NDArray data = in_data; - if (in_data.IsView() && in_data.IsMKLDNNData()) { - data = in_data.Reorder2Default(); - } - - auto data_mem = data.GetMKLDNNData(); + auto data_mem = in_data.GetMKLDNNData(); auto pd = GetSoftmaxFwdPd(ctx.is_train, axis, *data_mem); auto out_mem = CreateMKLDNNMem(out_data, pd.dst_desc(), req); MKLDNNStream *stream = MKLDNNStream::Get(); @@ -94,4 +89,3 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet #endif - diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index 5027bcbaabb1..747dde69ce13 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -46,8 +46,8 @@ void MKLDNNSum(const mkldnn::memory &arr1, if (input_pds[0] != output_pd) { auto tmp_memory1 = TmpMemMgr::Get()->Alloc(output_pd); auto tmp_memory2 = TmpMemMgr::Get()->Alloc(output_pd); - mxnet::MKLDNNCopy(arr1, tmp_memory1); - mxnet::MKLDNNCopy(arr2, tmp_memory2); + MKLDNNMemoryCopy(arr1, tmp_memory1); + MKLDNNMemoryCopy(arr2, tmp_memory2); input_pds[0] = tmp_memory1->get_desc(); input_pds[1] = tmp_memory2->get_desc(); in_mem1 = tmp_memory1; @@ -98,37 +98,30 @@ static MKLDNNSumFwd &GetSumForward( } void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &inputs, const OpReqType &req, - const NDArray &out_data) { + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { TmpMemMgr::Get()->Init(ctx.requested[0]); const int num_inputs = inputs.size(); + const NDArray &out_data = outputs[0]; std::vector data_md; std::vector data_mem; std::vector scales(num_inputs, 1); - std::vector in_bufs(num_inputs); data_md.reserve(num_inputs); data_mem.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { - const mkldnn::memory *in_mem; - if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) { - in_bufs[i] = inputs[i].Reorder2Default(); - in_mem = in_bufs[i].GetMKLDNNData(); - } else { - in_bufs[i] = inputs[i]; - in_mem = inputs[i].GetMKLDNNData(); - } + const mkldnn::memory *in_mem = inputs[i].GetMKLDNNData(); mkldnn::memory::desc tmp_md = in_mem->get_desc(); data_md.push_back(tmp_md); data_mem.push_back(in_mem); } - MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md); + MKLDNNSumFwd &fwd = GetSumForward(scales, inputs, data_md); mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), - req, - &in_bufs[0]); + req[0], + &inputs[0]); mkldnn_args_map_t net_args; net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); for (int i = 0; i < num_inputs; ++i) { diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index ce19dda23d87..57edab7037d5 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -45,7 +45,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const SoftmaxParam& param = nnvm::get(attrs.parsed); if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]); auto fn = SoftmaxCompute; MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); return; diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc index c059f9868ea0..11a960e3b9e0 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc @@ -42,7 +42,7 @@ static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs, const Op const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]); outputs[1].data().dptr()[0] = inputs[1].data().dptr()[0]; outputs[2].data().dptr()[0] = inputs[2].data().dptr()[0]; } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index ec2670974f49..0bf6e2a014a6 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -144,7 +144,7 @@ void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs, const SoftmaxOutputParam ¶m = nnvm::get(attrs.parsed); if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs); + MKLDNNRun(MKLDNNSoftmaxOutputForward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(SoftmaxOutputCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index f5bbd5044446..f3a7d2c4e914 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -278,7 +278,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, CpuEngine::Get()->get_engine(), out_mkl_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(tmp_mem); - mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get()); + MKLDNNMemoryCopy(*in_mkl_mem, tmp_mem.get()); output = NDArray(tmp_mem); } } @@ -416,7 +416,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, mkldnn_mem_ptr new_out_mem(new mkldnn::memory(data_md, CpuEngine::Get()->get_engine(), output_mem->get_data_handle())); MKLDNNStream::Get()->RegisterMem(new_out_mem); - mxnet::MKLDNNCopy(*tmp_out_mem, new_out_mem.get()); + MKLDNNMemoryCopy(*tmp_out_mem, new_out_mem.get()); output = NDArray(new_out_mem); } } diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 50772bc075d4..98cf7f067527 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -45,7 +45,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); #if MXNET_USE_MKLDNN == 1 if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) { - MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]); + MKLDNNRun(MKLDNNSumForward, attrs, ctx, inputs, req, outputs); return; } else if (inputs[0].storage_type() == kDefaultStorage && inputs[1].storage_type() == kDefaultStorage) { @@ -123,8 +123,8 @@ static void _backward_ElemwiseAddEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 2U); #if MXNET_USE_MKLDNN == 1 if (inputs[0].IsMKLDNNData()) { - MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); - MKLDNNCopy(attrs, ctx, inputs[0], req[1], outputs[1]); + MKLDNNRun(MKLDNNCopy, attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNCopy, attrs, ctx, inputs[0], req[1], outputs[1]); return; } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { FallBackCompute( diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index d1b86d161e89..b07c9590e8f5 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -125,7 +125,7 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, mxnet::ndarray::ElementwiseSum(s, rsc, inputs, &out_nd); #if MXNET_USE_MKLDNN == 1 } else if (IsMKLDNNData(inputs)) { - MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]); + MKLDNNRun(MKLDNNSumForward, attrs, ctx, inputs, req, outputs); } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { FallBackCompute(ElementWiseSumCompute, attrs, ctx, inputs, req, outputs); #endif diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 56674409601c..71fbde75637c 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -208,13 +208,11 @@ static void CopyEx(const nnvm::NodeAttrs& attrs, const auto in_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); if (inputs[0].IsMKLDNNData()) { - MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNCopy, attrs, ctx, inputs[0], req[0], outputs[0]); return; } else if (in_stype == kDefaultStorage && out_stype == kDefaultStorage) { - // This happens if inputs are supposed to be in MKLDNN format - // but MKLDNN doesn't support the data type or the shape. We're - // forced to convert it to the default format. - FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); + if (req[0] != kNullOp && req[0] != kWriteInplace) + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); return; } #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index eee5ea67f6e1..8aec0dd99da2 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -118,7 +118,7 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, // If inputs are supposed to be in MKLDNN format and // MKLDNN support the data type or the shape. Then convert // it to the output format and shape - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]); } inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, @@ -211,7 +211,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, // If inputs are supposed to be in MKLDNN format and // MKLDNN support the data type or the shape. Then convert // it to the output format and shape - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]); } static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, @@ -370,7 +370,7 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs, // If inputs are supposed to be in MKLDNN format and // MKLDNN support the data type or the shape. Then convert // it to the output format and shape - MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]); } inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs, @@ -430,7 +430,7 @@ void SliceExCPU(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 } else if (in_stype == kDefaultStorage) { if (SupportMKLDNN(inputs[0])) { - MKLDNNSlice(param, ctx, inputs[0], req[0], outputs[0]); + MKLDNNRun(MKLDNNSlice, attrs, ctx, inputs[0], req[0], outputs[0]); } else { FallBackCompute(SliceOpForward, attrs, ctx, inputs, req, outputs); } From 927fc55be0fb71f942ec25df57f6092806d90424 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 12 Nov 2019 15:36:33 +0800 Subject: [PATCH 2/3] Fix lint --- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 5ebc0e31822d..6f7a1d917734 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -27,6 +27,7 @@ #if MXNET_USE_MKLDNN == 1 #include +#include #include #include "../lrn-inl.h" #include "./mkldnn_base-inl.h" From 7d4528f6f7b44c179654b51e68a7809b7fca668c Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 26 Nov 2019 13:55:30 +0800 Subject: [PATCH 3/3] Run CI