Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 23093e6

Browse files
rongzha1pengzhao-intel
authored andcommitted
[mkldnn-v1.0] Add MKL-DNN FC (#16221)
* add mkldnn fc; pass lint; pass mnist training * add TODO info for future debug
1 parent a559760 commit 23093e6

File tree

5 files changed

+77
-110
lines changed

5 files changed

+77
-110
lines changed

src/imperative/imperative_utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state,
541541
// copying A to B may not happen, and will corrupt A's memory.
542542
InvalidateOutputs(outputs, req);
543543
}
544+
// add for mkldnn OP + no mkldnn OP
545+
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
546+
if (!is_mkldnn.get(attrs.op, false)) {
547+
std::vector<NDArray> inputs_fallback;
548+
CreateDefaultInputs(inputs, &inputs_fallback);
549+
fcompute_ex(state, opctx, inputs_fallback, req, outputs);
550+
} else {
551+
#endif
552+
fcompute_ex(state, opctx, inputs, req, outputs);
553+
#if MXNET_USE_MKLDNN == 100
554+
}
544555
#endif
545-
fcompute_ex(state, opctx, inputs, req, outputs);
546556
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
547557
&& rctx.get_stream<gpu>() && !rctx.is_bulk) {
548558
rctx.get_stream<gpu>()->Wait();

src/operator/nn/fully_connected.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
9797
valid_bias = inputs[2].storage_type() == kDefaultStorage ||
9898
inputs[2].storage_type() == kRowSparseStorage;
9999
}
100-
#if MXNET_USE_MKLDNN == 1
100+
#if MXNET_USE_MKLDNN == 100
101101
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
102102
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
103103
if (SupportMKLDNNFC(inputs[0])) {
@@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
141141
#endif
142142
}
143143

144-
#if MXNET_USE_MKLDNN == 1
144+
#if MXNET_USE_MKLDNN == 100
145145
void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
146146
const OpContext &ctx,
147147
const std::vector<NDArray> &inputs,
@@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
199199
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
200200
dispatch_mode, DispatchMode::kFComputeEx);
201201
}
202-
#if MXNET_USE_MKLDNN == 1
202+
#if MXNET_USE_MKLDNN == 100
203203
if (!MKLDNNEnvSet())
204204
*dispatch_mode = DispatchMode::kFComputeFallback;
205205
#endif
@@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
233233
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
234234
dispatch_mode, DispatchMode::kFCompute);
235235
}
236-
#if MXNET_USE_MKLDNN == 1
236+
#if MXNET_USE_MKLDNN == 100
237237
if (!MKLDNNEnvSet())
238238
*dispatch_mode = DispatchMode::kFComputeFallback;
239239
#endif
@@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
295295
[](const NodeAttrs& attrs) {
296296
return std::vector<std::string>{"output"};
297297
})
298-
#if MXNET_USE_MKLDNN == 1
298+
#if MXNET_USE_MKLDNN == 100
299299
.set_attr<bool>("TIsMKLDNN", true)
300300
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
301301
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
326326
})
327327
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
328328
.set_attr_parser(ParamParser<FullyConnectedParam>)
329-
#if MXNET_USE_MKLDNN == 1
329+
#if MXNET_USE_MKLDNN == 100
330330
.set_attr<bool>("TIsMKLDNN", true)
331331
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
332332
#endif

src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
2828
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
2929

30-
#if MXNET_USE_MKLDNN == 1
30+
#if MXNET_USE_MKLDNN == 100
3131

3232
#include <vector>
3333
#include <string>
@@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
5050
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
5151
.describe("Whether to enable float32 output");
5252
DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
53-
.describe("Whether there's a post elemwise after FullyConnected operator");
53+
.describe("Whether there's a post with_eltwise after FullyConnected operator");
5454
DMLC_DECLARE_FIELD(min_calib_range)
5555
.set_default(dmlc::optional<float>())
5656
.describe("The minimum scalar value in the form of float32 obtained "
@@ -85,21 +85,16 @@ class MKLDNNFullyConnectedForward {
8585
const NDArray &data, const NDArray &weight,
8686
const NDArray *bias,
8787
const mkldnn::memory::desc &out_md)
88-
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}
89-
90-
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
91-
const mkldnn::memory *bias, const mkldnn::memory &output);
88+
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {
89+
fwd_ = std::make_shared<mkldnn::inner_product_forward>(fwd_pd);
90+
}
9291

9392
const mkldnn::inner_product_forward &GetFwd() const {
9493
return *fwd_;
9594
}
9695

9796
private:
9897
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
99-
std::shared_ptr<mkldnn::memory> data_;
100-
std::shared_ptr<mkldnn::memory> weight_;
101-
std::shared_ptr<mkldnn::memory> bias_;
102-
std::shared_ptr<mkldnn::memory> out_;
10398
};
10499

105100
typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

src/operator/nn/mkldnn/mkldnn_fully_connected.cc

Lines changed: 45 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
* \author Da Zheng, Ciyong Chen
2525
*/
2626

27-
#if MXNET_USE_MKLDNN == 1
27+
#if MXNET_USE_MKLDNN == 100
2828
#include "mkldnn_fully_connected-inl.h"
2929

3030
namespace mxnet {
@@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
6767
}
6868

6969
attr.set_output_scales(mask, scales);
70-
attr.set_int_output_round_mode(round_nearest);
7170
}
7271
}
7372

@@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
130129
}
131130
}
132131

133-
void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
134-
const mkldnn::memory &weight,
135-
const mkldnn::memory *bias,
136-
const mkldnn::memory &output) {
137-
if (this->data_ == nullptr)
138-
this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
139-
fwd_pd.src_primitive_desc(), data.get_data_handle()));
140-
else
141-
this->data_->set_data_handle(data.get_data_handle());
142-
143-
if (this->weight_ == nullptr)
144-
this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
145-
fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
146-
else
147-
this->weight_->set_data_handle(weight.get_data_handle());
148-
149-
if (this->out_ == nullptr)
150-
this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
151-
fwd_pd.dst_primitive_desc(), output.get_data_handle()));
152-
else
153-
this->out_->set_data_handle(output.get_data_handle());
154-
155-
if (bias != nullptr) {
156-
if (this->bias_ == nullptr)
157-
this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
158-
fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
159-
else
160-
this->bias_->set_data_handle(bias->get_data_handle());
161-
162-
if (this->fwd_ == nullptr)
163-
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
164-
new mkldnn::inner_product_forward(
165-
fwd_pd, mkldnn::primitive::at(*this->data_),
166-
mkldnn::primitive::at(*this->weight_),
167-
mkldnn::primitive::at(*this->bias_), *this->out_));
168-
} else {
169-
if (this->fwd_ == nullptr) {
170-
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
171-
new mkldnn::inner_product_forward(
172-
fwd_pd, mkldnn::primitive::at(*this->data_),
173-
mkldnn::primitive::at(*this->weight_), *this->out_));
174-
}
175-
}
176-
}
177-
178132
MKLDNNFullyConnectedForward &GetFCFwd(
179133
const FullyConnectedParam &param, const bool is_train,
180134
const NDArray &data, const NDArray &weight,
@@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam &param,
223177
mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
224178
static_cast<int>(oshape[ishape.ndim()-1])};
225179
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
226-
mkldnn::memory::format::any);
180+
mkldnn::memory::format_tag::any);
227181
} else {
228182
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
229183
mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
230184
static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
231185
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
232-
mkldnn::memory::format::any);
186+
mkldnn::memory::format_tag::any);
233187
}
234188
}
235189
}
@@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
244198
NDArray weight = in_data[fullc::kWeight];
245199
NDArray data = in_data[fullc::kData];
246200

247-
auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
201+
auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
248202
const mkldnn::memory *weight_mem;
249203
if (ctx.is_train) {
250204
if (weight.IsMKLDNNData()) {
251205
weight.Reorder2DefaultAsync();
252206
}
253-
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
207+
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
254208
} else {
255-
if (weight.IsDefaultData()) {
256-
// We also need to modify the layout on the original weight array.
257-
// Don't switch below sequence because naive engine will executes
258-
// pushAsync synchronously.
259-
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
260-
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
261-
} else {
262-
weight_mem = weight.GetMKLDNNData();
263-
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
209+
weight_mem = weight.GetMKLDNNData();
210+
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
211+
// TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
212+
// weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
213+
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
264214
}
265215
}
266216
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
267-
fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
217+
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);
218+
219+
std::unordered_map<int, mkldnn::memory> args = {
220+
{MKLDNN_ARG_SRC, *data_mem},
221+
{MKLDNN_ARG_WEIGHTS, *weight_mem},
222+
{MKLDNN_ARG_DST, *out_mem.second},
223+
};
268224
if (!full_param.default_param.no_bias) {
269225
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
270-
fwd->fwd_pd.bias_primitive_desc());
271-
fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
272-
} else {
273-
fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
226+
fwd->fwd_pd.bias_desc());
227+
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
274228
}
275-
MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
229+
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
276230
CommitOutput(out_data[fullc::kOut], out_mem);
277231
MKLDNNStream::Get()->Submit();
278232
}
@@ -339,37 +293,45 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
339293
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
340294
data, weight, out_grad, fwd_pd);
341295
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
342-
ipBwdData_pd.diff_dst_primitive_desc());
343-
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
296+
ipBwdData_pd.diff_dst_desc());
297+
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
344298
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
345-
ipBwdData_pd.diff_src_primitive_desc(),
299+
ipBwdData_pd.diff_src_desc(),
346300
req[fullc::kData]);
347-
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
348-
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
301+
std::unordered_map<int, mkldnn::memory> args = {
302+
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
303+
{MKLDNN_ARG_WEIGHTS, *weight_mem},
304+
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
305+
};
306+
307+
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
349308
CommitOutput(in_grad[fullc::kData], in_grad_mem);
350309
}
351310
if (req[fullc::kWeight]) {
352311
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
353312
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
354313
out_grad, fwd_pd);
355314
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
356-
ipBwdWeights_pd.diff_dst_primitive_desc());
357-
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
315+
ipBwdWeights_pd.diff_dst_desc());
316+
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
358317
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
359-
ipBwdWeights_pd.diff_weights_primitive_desc(),
318+
ipBwdWeights_pd.diff_weights_desc(),
360319
req[fullc::kWeight]);
320+
std::unordered_map<int, mkldnn::memory> args = {
321+
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
322+
{MKLDNN_ARG_SRC, *data_mem},
323+
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
324+
};
325+
361326
mkldnn_output_t in_grad_bias;
362-
if (param.no_bias) {
363-
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
364-
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
365-
} else {
327+
if (!param.no_bias) {
366328
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
367-
ipBwdWeights_pd.diff_bias_primitive_desc(),
329+
ipBwdWeights_pd.diff_bias_desc(),
368330
req[fullc::kBias]);
369-
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
370-
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
371-
*in_grad_bias.second));
331+
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
372332
}
333+
MKLDNNStream::Get()->RegisterPrimArgs(
334+
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
373335
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
374336
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
375337
}
@@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
378340

379341
} // namespace op
380342
} // namespace mxnet
381-
#endif // MXNET_USE_MKLDNN == 1
343+
#endif // MXNET_USE_MKLDNN == 100

src/operator/nn/mkldnn/mkldnn_ops-inl.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,6 @@ namespace mxnet {
4444
namespace op {
4545

4646
#if MXNET_USE_MKLDNN == 1
47-
/* For fully connected. */
48-
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
49-
const std::vector<NDArray> &in_data,
50-
const std::vector<OpReqType> &req,
51-
const std::vector<NDArray> &out_data);
52-
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
53-
const std::vector<NDArray> &inputs,
54-
const std::vector<OpReqType> &req,
55-
const std::vector<NDArray> &outputs);
56-
5747
/* For deconvolution */
5848
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
5949
const std::vector<NDArray> &in_data,
@@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
10494
#endif
10595

10696
#if MXNET_USE_MKLDNN == 100
97+
/* For fully connected. */
98+
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
99+
const std::vector<NDArray> &in_data,
100+
const std::vector<OpReqType> &req,
101+
const std::vector<NDArray> &out_data);
102+
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
103+
const std::vector<NDArray> &inputs,
104+
const std::vector<OpReqType> &req,
105+
const std::vector<NDArray> &outputs);
106+
107107
/* For convolution. */
108108
void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
109109
const std::vector<NDArray> &in_data,

0 commit comments

Comments
 (0)