Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5fb2916

Browse files
ZhennanQinpengzhao-intel
authored andcommitted
[MKLDNN] Use MKLDNNRun (#16772)
* Use MKLDNNRun * Fix lint * Run CI
1 parent 32a9baa commit 5fb2916

27 files changed

+150
-185
lines changed

src/ndarray/ndarray.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
684684
ptr_->Reorder2Default();
685685

686686
const mkldnn::memory *this_mem = GetMKLDNNData();
687-
MKLDNNCopy(mem, this_mem);
687+
MKLDNNMemoryCopy(mem, this_mem);
688688
}
689689

690690
mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) {

src/operator/leaky_relu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
9595
CHECK_EQ(inputs.size(), expected);
9696
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
9797
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
98-
MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]);
98+
MKLDNNRun(MKLDNNLeakyReluForward, attrs, ctx, inputs[0], req[0], outputs[0]);
9999
MKLDNN_OPCHECK_RUN(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, outputs);
100100
return;
101101
}
@@ -111,7 +111,7 @@ void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
111111
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
112112
std::vector<NDArray> in_data{inputs[0], inputs[1]};
113113
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
114-
MKLDNNLeakyReluBackward(attrs, ctx, in_data, req[0], outputs[0]);
114+
MKLDNNRun(MKLDNNLeakyReluBackward, attrs, ctx, in_data, req, outputs);
115115
MKLDNN_OPCHECK_RUN(LeakyReLUGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
116116
return;
117117
}

src/operator/nn/activation.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
102102
CHECK_EQ(outputs.size(), 1U);
103103
if (SupportMKLDNNAct(param, inputs[0])) {
104104
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
105-
MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
105+
MKLDNNRun(MKLDNNActivationForward, attrs, ctx, inputs[0], req[0], outputs[0]);
106106
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
107107
return;
108108
}
@@ -118,10 +118,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
118118
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
119119
if (SupportMKLDNNAct(param, inputs[0])) {
120120
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
121-
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
122-
const bool relu = param.act_type == activation::kReLU;
123-
MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0],
124-
outputs[0]);
121+
MKLDNNRun(MKLDNNActivationBackward, attrs, ctx, inputs, req, outputs);
125122
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
126123
return;
127124
}

src/operator/nn/batch_norm.cc

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,11 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
394394
const std::vector<NDArray> &outputs) {
395395
CHECK_EQ(inputs.size(), 5U);
396396
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
397-
398397
if (SupportMKLDNNBN(inputs[0], param)) {
399-
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
400-
std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
401-
402-
if (inputs[0].dtype() == mshadow::kFloat32) {
403398
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
404-
MKLDNNBatchNormForward<float>(ctx, param, in_data, req, outputs, aux_states);
399+
MKLDNNRun(MKLDNNBatchNormForward<float>, attrs, ctx, inputs, req, outputs);
405400
MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
406401
return;
407-
}
408402
}
409403
FallBackCompute(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
410404
}
@@ -414,33 +408,12 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
414408
const std::vector<NDArray> &inputs,
415409
const std::vector<OpReqType> &req,
416410
const std::vector<NDArray> &outputs) {
417-
CHECK_EQ(inputs.size(), 8U);
418411
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
419-
420-
mxnet::TShape shape = inputs[0].shape();
421-
422412
if (SupportMKLDNNBN(inputs[0], param)) {
423-
std::vector<NDArray> out_grad(1);
424-
std::vector<NDArray> out_data(3);
425-
std::vector<NDArray> in_data(3);
426-
std::vector<NDArray> aux_states(2);
427-
out_grad[0] = inputs[0];
428-
out_data[batchnorm::kMean] = inputs[1];
429-
out_data[batchnorm::kVar] = inputs[2];
430-
in_data[batchnorm::kData] = inputs[3];
431-
in_data[batchnorm::kGamma] = inputs[4];
432-
in_data[batchnorm::kBeta] = inputs[5];
433-
aux_states[batchnorm::kMovingMean] = inputs[6];
434-
aux_states[batchnorm::kMovingVar] = inputs[7];
435-
const std::vector<NDArray> &in_grad = outputs;
436-
437-
if (inputs[0].dtype() == mshadow::kFloat32) {
438413
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
439-
MKLDNNBatchNormBackward<float>(ctx, param, out_grad, in_data,
440-
out_data, req, in_grad, aux_states);
414+
MKLDNNRun(MKLDNNBatchNormBackward<float>, attrs, ctx, inputs, req, outputs);
441415
MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
442416
return;
443-
}
444417
}
445418
FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
446419
}

src/operator/nn/concat.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
270270
#if MXNET_USE_MKLDNN == 1
271271
} else if (SupportMKLDNNConcat(inputs)) {
272272
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
273-
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
273+
MKLDNNRun(MKLDNNConcatForward, attrs, op_ctx, inputs, req, outputs);
274274
MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
275275
} else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
276276
FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
@@ -288,7 +288,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
288288
const std::vector<NDArray>& outputs) {
289289
if (SupportMKLDNNConcat(inputs)) {
290290
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
291-
MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
291+
MKLDNNRun(MKLDNNConcatBackward, attrs, ctx, inputs, req, outputs);
292292
MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
293293
return;
294294
}

src/operator/nn/fully_connected.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
102102
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
103103
if (SupportMKLDNNFC(inputs[0])) {
104104
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
105-
MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
105+
MKLDNNRun(MKLDNNFCForward, attrs, ctx, inputs, req, outputs);
106106
MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
107107
outputs);
108108
} else {
@@ -152,7 +152,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
152152
bool mkldnn_fc_backward_enable = false;
153153
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
154154
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
155-
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
155+
MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
156156
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
157157
outputs);
158158
return;

src/operator/nn/lrn.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
110110
const std::vector<NDArray> &inputs,
111111
const std::vector<OpReqType> &req,
112112
const std::vector<NDArray> &outputs) {
113-
const LRNParam &param = nnvm::get<LRNParam>(attrs.parsed);
114113
if (SupportMKLDNN(inputs[0])) {
115114
// We only need to test one output array.
116115
MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
117-
MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
116+
MKLDNNRun(MKLDNNLRNForward, attrs, ctx, inputs[0], req[0], outputs[0]);
118117
MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
119118
// Copy outputs[1] from opcheck reference as backward check needs it.
120119
MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
@@ -128,14 +127,9 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs,
128127
const std::vector<NDArray> &inputs,
129128
const std::vector<OpReqType> &req,
130129
const std::vector<NDArray> &outputs) {
131-
const LRNParam &param = nnvm::get<LRNParam>(attrs.parsed);
132-
const NDArray &out_grad = inputs[0];
133-
const NDArray &in_data = inputs[1];
134-
const NDArray &in_grad = outputs[0];
135-
136130
if (SupportMKLDNN(inputs[0])) {
137131
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
138-
MKLDNNLRNBackward(ctx, param, out_grad, in_data, req[0], in_grad);
132+
MKLDNNRun(MKLDNNLRNBackward, attrs, ctx, inputs, req, outputs);
139133
MKLDNN_OPCHECK_RUN(LRNGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
140134
return;
141135
}

src/operator/nn/mkldnn/mkldnn_act.cc

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,8 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
151151
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
152152
MKLDNNActParam param_;
153153
param_.alg = GetMKLDNNActAlgo(param);
154-
155-
NDArray in_buffer = in_data;
154+
const NDArray& in_buffer = in_data;
156155
MKLDNNStream *stream = MKLDNNStream::Get();
157-
158-
if (in_data.IsView() && in_data.IsMKLDNNData())
159-
in_buffer = in_data.Reorder2Default();
160-
161156
auto input_mem = in_buffer.GetMKLDNNData();
162157
MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem);
163158
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
@@ -235,22 +230,18 @@ static inline MKLDNNActBackward &GetActBackward(const MKLDNNActParam &param,
235230

236231
// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
237232
// function, since the computation only involes non-zeros.
238-
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
239-
const NDArray &out_grad, const NDArray &in_data,
240-
const OpReqType &req, const NDArray &in_grad) {
241-
if (req == kNullOp) {
233+
void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
234+
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
235+
const std::vector<NDArray> &outputs) {
236+
if (req[0] == kNullOp) {
242237
return;
243238
}
244-
245-
NDArray out_buffer = out_grad;
246-
if (out_grad.IsView() && out_grad.IsMKLDNNData())
247-
out_buffer = out_grad.Reorder2Default();
248-
249-
NDArray in_buffer = in_data;
250-
if (in_data.IsView() && in_data.IsMKLDNNData())
251-
in_buffer = in_data.Reorder2Default();
252-
253239
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
240+
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
241+
const bool relu = param.act_type == activation::kReLU;
242+
const NDArray &out_buffer = inputs[0];
243+
const NDArray &in_buffer = relu ? inputs[1] : inputs[2];
244+
const NDArray &in_grad = outputs[0];
254245
MKLDNNActParam param_;
255246
param_.alg = GetMKLDNNActAlgo(param);
256247
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
@@ -264,7 +255,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
264255
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
265256
MKLDNNStream *stream = MKLDNNStream::Get();
266257
mkldnn_output_t diff_src_memory =
267-
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req);
258+
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
268259
mkldnn_args_map_t args = {
269260
{ MKLDNN_ARG_SRC, *input_mem },
270261
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
@@ -278,19 +269,16 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
278269
void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
279270
const OpContext &ctx,
280271
const std::vector<NDArray>& inputs,
281-
const OpReqType &req,
282-
const NDArray &output) {
283-
if (req == kNullOp) {
272+
const std::vector<OpReqType> &req,
273+
const std::vector<NDArray> &outputs) {
274+
if (req[0] == kNullOp) {
284275
return;
285276
}
286-
CHECK_GE(inputs.size(), 2U);
287-
NDArray out_buffer = inputs[0];
288-
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
289-
out_buffer = inputs[0].Reorder2Default();
290-
291-
NDArray in_buffer = inputs[1];
292-
if (inputs[1].IsView() && inputs[1].IsMKLDNNData())
293-
in_buffer = inputs[1].Reorder2Default();
277+
CHECK_EQ(inputs.size(), 2U);
278+
CHECK_EQ(outputs.size(), 1U);
279+
const NDArray& out_buffer = inputs[0];
280+
const NDArray& in_buffer = inputs[1];
281+
const NDArray &output = outputs[0];
294282

295283
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
296284
MKLDNNActParam param_;
@@ -308,7 +296,7 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
308296
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
309297
MKLDNNStream *stream = MKLDNNStream::Get();
310298
mkldnn_output_t diff_src_memory =
311-
CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req);
299+
CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
312300
mkldnn_args_map_t args = {
313301
{ MKLDNN_ARG_SRC, *input_mem },
314302
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ enum OutDataOp {
446446
};
447447

448448
typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
449-
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);
449+
void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);
450450

451451
/*
452452
* Here we want to get MKLDNN memory whose desc is exactly the same as
@@ -688,6 +688,19 @@ void MKLDNNRun(mxnet::FComputeEx fn,
688688
const std::vector<mxnet::OpReqType> &req,
689689
const std::vector<mxnet::NDArray> &outputs_);
690690

691+
using FComputeExUnary = std::function<void (const nnvm::NodeAttrs& attrs,
692+
const OpContext& ctx,
693+
const NDArray& input,
694+
const OpReqType& req,
695+
const NDArray& output)>;
696+
697+
void MKLDNNRun(FComputeExUnary fn,
698+
const nnvm::NodeAttrs &attrs,
699+
const mxnet::OpContext &ctx,
700+
const mxnet::NDArray &inputs_,
701+
const mxnet::OpReqType &req,
702+
const mxnet::NDArray &outputs_);
703+
691704
} // namespace mxnet
692705
#endif
693706
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_

src/operator/nn/mkldnn/mkldnn_base.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) {
8585
}
8686
}
8787

88-
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
88+
void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
8989
MKLDNNStream *stream = MKLDNNStream::Get();
9090
mkldnn::memory::desc from_desc = mem.get_desc();
9191
mkldnn::memory::desc this_desc = this_mem->get_desc();
@@ -227,7 +227,7 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
227227
auto mem = arr.GetMKLDNNData(res.second->get_desc());
228228
if (mem == nullptr) {
229229
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
230-
MKLDNNCopy(*res_memory, tmp_memory);
230+
MKLDNNMemoryCopy(*res_memory, tmp_memory);
231231
res_memory = tmp_memory;
232232
mem = arr.GetMKLDNNData();
233233
}
@@ -606,6 +606,21 @@ void MKLDNNRun(mxnet::FComputeEx fn,
606606
}
607607
}
608608

609+
void MKLDNNRun(FComputeExUnary fn,
610+
const nnvm::NodeAttrs &attrs,
611+
const mxnet::OpContext &ctx,
612+
const mxnet::NDArray &input,
613+
const mxnet::OpReqType &req,
614+
const mxnet::NDArray &output) {
615+
auto mkldnn_input = input;
616+
if (input.IsView() && input.IsMKLDNNData()) {
617+
mkldnn_input = input.Reorder2Default();
618+
fn(attrs, ctx, mkldnn_input, req, output);
619+
} else {
620+
fn(attrs, ctx, input, req, output);
621+
}
622+
}
623+
609624
} // namespace mxnet
610625

611626
#endif

0 commit comments

Comments
 (0)