Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8ad8b41

Browse files
rongzha1TaoLv
authored andcommitted
add mkldnn lrn (#16223)
1 parent 0b8805a commit 8ad8b41

File tree

2 files changed

+76
-128
lines changed

2 files changed

+76
-128
lines changed

src/operator/nn/lrn.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
#include "./lrn-inl.h"
2828
#include "../operator_common.h"
29-
#if MXNET_USE_MKLDNN == 1
29+
#if MXNET_USE_MKLDNN == 100
3030
#include "./mkldnn/mkldnn_lrn-inl.h"
3131
#include "./mkldnn/mkldnn_base-inl.h"
3232
#endif
@@ -82,7 +82,7 @@ struct LRNGrad {
8282
}
8383
};
8484

85-
#if MXNET_USE_MKLDNN == 1
85+
#if MXNET_USE_MKLDNN == 100
8686
bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs,
8787
const int dev_mask,
8888
DispatchMode* dispatch_mode,
@@ -169,7 +169,7 @@ number of kernels in the layer.
169169
.set_attr_parser(ParamParser<LRNParam>)
170170
.set_attr<mxnet::FInferShape>("FInferShape", LRNShape)
171171
.set_attr<nnvm::FInferType>("FInferType", LRNType)
172-
#if MXNET_USE_MKLDNN == 1
172+
#if MXNET_USE_MKLDNN == 100
173173
.set_attr<FInferStorageType>("FInferStorageType", LRNForwardInferStorageType)
174174
#endif
175175
.set_attr<nnvm::FListInputNames>("FListInputNames",
@@ -181,7 +181,7 @@ number of kernels in the layer.
181181
return std::vector<std::string>{"output", "tmp_norm"};
182182
})
183183
.set_attr<FCompute>("FCompute<cpu>", LRNCompute<cpu>)
184-
#if MXNET_USE_MKLDNN == 1
184+
#if MXNET_USE_MKLDNN == 100
185185
.set_attr<bool>("TIsMKLDNN", true)
186186
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNComputeExCPU)
187187
#endif
@@ -192,11 +192,11 @@ number of kernels in the layer.
192192
NNVM_REGISTER_OP(_backward_LRN)
193193
.set_num_outputs(1)
194194
.set_attr_parser(ParamParser<LRNParam>)
195-
#if MXNET_USE_MKLDNN == 1
195+
#if MXNET_USE_MKLDNN == 100
196196
.set_attr<FInferStorageType>("FInferStorageType", LRNBackwardInferStorageType)
197197
#endif
198198
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
199-
#if MXNET_USE_MKLDNN == 1
199+
#if MXNET_USE_MKLDNN == 100
200200
.set_attr<bool>("TIsMKLDNN", true)
201201
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
202202
// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode

src/operator/nn/mkldnn/mkldnn_lrn-inl.h

Lines changed: 70 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
2626
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
2727

28-
#if MXNET_USE_MKLDNN == 1
28+
#if MXNET_USE_MKLDNN == 100
2929
#include <utility>
3030
#include <mkldnn.hpp>
3131
#include "../lrn-inl.h"
@@ -34,27 +34,27 @@
3434
namespace mxnet {
3535
namespace op {
3636

37-
inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
37+
inline mkldnn::algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
3838
// TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward
3939
// Need to confirm with MKLDNN team and fix later
40-
return algorithm::lrn_across_channels;
40+
return mkldnn::algorithm::lrn_across_channels;
4141
}
4242

4343
inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
44-
const LRNParam &param, const bool is_train, const memory::desc &src_md) {
44+
const LRNParam &param, const bool is_train, const mkldnn::memory::desc &src_md) {
4545
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
46-
const algorithm alg = GetMKLDNNLRNAlgo(param);
46+
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
4747
const float alpha = param.alpha;
4848
const float beta = param.beta;
4949
const int nsize = param.nsize;
5050
const float k = param.knorm;
51-
auto kind = prop_kind::forward_training;
51+
auto kind = mkldnn::prop_kind::forward_training;
5252
if (is_train) {
53-
kind = prop_kind::forward_training;
53+
kind = mkldnn::prop_kind::forward_training;
5454
} else {
55-
kind = prop_kind::forward_scoring;
55+
kind = mkldnn::prop_kind::forward_scoring;
5656
}
57-
lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
57+
mkldnn::lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
5858
return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
5959
}
6060

@@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
6363
const mkldnn::memory::desc &diff_md,
6464
const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
6565
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
66-
const algorithm alg = GetMKLDNNLRNAlgo(param);
66+
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
6767
const float alpha = param.alpha;
6868
const float beta = param.beta;
6969
const int nsize = param.nsize;
7070
const float k = param.knorm;
7171

72-
lrn_backward::desc lrnBwd_desc(alg, data_in_md,
72+
mkldnn::lrn_backward::desc lrnBwd_desc(alg, data_in_md,
7373
diff_md, nsize, alpha, beta, k);
7474
return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc,
7575
engine, lrnFwd_desc);
@@ -83,33 +83,24 @@ class MKLDNNLRNFwd {
8383
public:
8484
MKLDNNLRNFwd(const LRNParam& param,
8585
bool is_train,
86-
const NDArray &in_data):
87-
is_train(is_train) {
86+
const NDArray &in_data) {
8887
_Init(param, is_train, in_data);
8988
}
9089

9190
~MKLDNNLRNFwd() {}
9291

93-
void SetNewMem(const NDArray &data,
94-
const NDArray &output,
95-
const OpReqType req);
96-
97-
void SetNewMem(const NDArray &in_data,
98-
const mkldnn::memory *out_mem);
99-
100-
void Execute(const NDArray &out_data);
92+
void Execute(const OpContext &ctx,
93+
const NDArray &in_data,
94+
const OpReqType req,
95+
const NDArray &out_data);
10196

10297
mkldnn::lrn_forward &GetFwd();
103-
10498
const mkldnn::memory *GetWs();
99+
mkldnn::lrn_forward::primitive_desc &GetFwdPd();
105100

106101
private:
107102
std::shared_ptr<mkldnn::lrn_forward> fwd;
108-
std::shared_ptr<mkldnn::memory> in_mem;
109-
std::shared_ptr<mkldnn::memory> out_mem;
110-
std::shared_ptr<mkldnn::memory> ws_mem;
111-
mkldnn_output_t output_mem_t;
112-
bool is_train;
103+
mkldnn::lrn_forward::primitive_desc fwd_pd;
113104

114105
private:
115106
void _Init(const LRNParam &param, bool is_train, const NDArray &in_data);
@@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam &param,
119110
bool is_train,
120111
const NDArray &in_data) {
121112
mkldnn::memory::desc in_data_md =
122-
in_data.GetMKLDNNData()->get_primitive_desc().desc();
123-
mkldnn::lrn_forward::primitive_desc fwd_pd =
113+
in_data.GetMKLDNNData()->get_desc();
114+
this->fwd_pd =
124115
GetLRNFwdDesc(param, is_train, in_data_md);
125116

126-
this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
127-
->get_primitive_desc()));
128-
this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
129-
if (is_train) {
130-
// If it's training, we have to create a workspace memory. Otherwise,
131-
// MKLDNN will have segmentation fault.
132-
ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
133-
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
134-
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
135-
*this->ws_mem, *this->out_mem));
136-
} else {
137-
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
138-
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)),
139-
*(this->out_mem)));
140-
}
141-
}
142-
143-
void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
144-
const NDArray &out_data,
145-
const OpReqType req) {
146-
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
147-
output_mem_t = CreateMKLDNNMem(out_data, this->out_mem->get_primitive_desc(), req);
148-
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
149-
this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
117+
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward(this->fwd_pd));
150118
}
151119

152-
void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
153-
const mkldnn::memory *out_mem) {
154-
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
155-
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
156-
this->out_mem->set_data_handle(out_mem->get_data_handle());
157-
}
158-
159-
void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
160-
MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
120+
void MKLDNNLRNFwd::Execute(const OpContext &ctx,
121+
const NDArray &in_data,
122+
const OpReqType req,
123+
const NDArray &out_data) {
124+
auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req);
125+
126+
mkldnn_args_map_t args = {
127+
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
128+
{ MKLDNN_ARG_DST, *output_mem_t.second },
129+
};
130+
std::shared_ptr<mkldnn::memory> workspace;
131+
if (ctx.is_train) {
132+
auto engine = CpuEngine::Get()->get_engine();
133+
workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
134+
args[MKLDNN_ARG_WORKSPACE] = *(workspace);
135+
}
136+
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd), args);
161137
CommitOutput(out_data, output_mem_t);
162138
MKLDNNStream::Get()->Submit();
163139
}
164140

165141
mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
142+
mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd() { return this->fwd_pd; }
166143

167-
const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
168144
// End of LRN Class and its functions
169145

170146
static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
@@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
180156
OpHash> lrn_fwds;
181157
#endif
182158
auto kind_ =
183-
ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
159+
ctx.is_train ? mkldnn::prop_kind::forward_training
160+
: mkldnn::prop_kind::forward_scoring;
184161

185162
MKLDNNLRNSignature key(param);
186-
key.AddSign(kind_);
163+
key.AddSign(static_cast<int>(kind_));
187164
key.AddSign(in_data);
188165

189166
auto it = lrn_fwds.find(key);
@@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam &param,
201178
if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
202179
in_buffer = in_buffer.Reorder2Default();
203180
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
204-
fwd.SetNewMem(in_buffer, out_data, req);
205-
fwd.Execute(out_data);
181+
fwd.Execute(ctx, in_buffer, req, out_data);
206182
}
207183

208184
// LRN Backward Class
209185
class MKLDNNLRNBwd {
210186
std::shared_ptr<mkldnn::lrn_backward> bwd;
211-
std::shared_ptr<mkldnn::memory> in_data_mem;
212-
std::shared_ptr<mkldnn::memory> diff_dst_mem;
213-
std::shared_ptr<mkldnn::memory> ws_mem;
214-
std::shared_ptr<mkldnn::memory> diff_src_mem;
215187

216188
public:
217189
const mkldnn::lrn_forward::primitive_desc fwd_pd;
@@ -222,40 +194,26 @@ class MKLDNNLRNBwd {
222194
MKLDNNLRNBwd(const LRNParam &param, const mkldnn::memory::desc in_data_md,
223195
const mkldnn::memory::desc diff_md)
224196
: fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
225-
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}
226-
227-
void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
228-
const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
229-
if (bwd == nullptr) {
230-
this->in_data_mem.reset(
231-
new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
232-
in_data.GetMKLDNNData()->get_data_handle()));
233-
this->diff_dst_mem.reset(
234-
new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
235-
out_grad.GetMKLDNNData()->get_data_handle()));
236-
this->ws_mem.reset(
237-
new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
238-
ws->get_data_handle()));
239-
this->diff_src_mem.reset(
240-
new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
241-
diff_src_mem->get_data_handle()));
242-
this->bwd.reset(new mkldnn::lrn_backward(
243-
this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
244-
mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
245-
*this->diff_src_mem));
246-
} else {
247-
this->in_data_mem->set_data_handle(
248-
in_data.GetMKLDNNData()->get_data_handle());
249-
this->diff_dst_mem->set_data_handle(
250-
out_grad.GetMKLDNNData()->get_data_handle());
251-
this->ws_mem->set_data_handle(ws->get_data_handle());
252-
this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
253-
}
254-
}
255-
256-
void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
257-
MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
258-
CommitOutput(in_grad, diff_src_mem_);
197+
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {
198+
bwd = std::make_shared<mkldnn::lrn_backward>(bwd_pd);
199+
}
200+
201+
const mkldnn::lrn_backward &GetBwd() const { return *bwd; }
202+
203+
void Execute(const NDArray &out_grad,
204+
const NDArray &in_data,
205+
const NDArray &in_grad,
206+
const mkldnn_output_t &diff_src_mem) {
207+
auto engine = CpuEngine::Get()->get_engine();
208+
auto workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
209+
mkldnn_args_map_t args = {
210+
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData() },
211+
{ MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()},
212+
{ MKLDNN_ARG_WORKSPACE, *workspace },
213+
{ MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }
214+
};
215+
MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args);
216+
CommitOutput(in_grad, diff_src_mem);
259217
MKLDNNStream::Get()->Submit();
260218
}
261219
}; // End of LRN Class
@@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
277235
auto it = lrn_bwds.find(key);
278236
if (it == lrn_bwds.end()) {
279237
const mkldnn::memory::desc in_data_md =
280-
in_data.GetMKLDNNData()->get_primitive_desc().desc();
238+
in_data.GetMKLDNNData()->get_desc();
281239
const mkldnn::memory::desc diff_md =
282-
out_grad.GetMKLDNNData()->get_primitive_desc().desc();
240+
out_grad.GetMKLDNNData()->get_desc();
283241
MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
284242
it = AddToCache(&lrn_bwds, key, bwd);
285243
}
@@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
300258
in_buffer = in_data.Reorder2Default();
301259
}
302260
MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
303-
// Repeat FW for getting workspace
304-
// TODO(Patric): To keep the function stateless, we can't pass workspace
305-
// from LRN forward to backward. We have to re-compute
306-
// LRN forward to get the workspace.
307-
// Will refine this code later.
308-
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
309-
std::shared_ptr<const mkldnn::memory> dst_temp(
310-
new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
311-
fwd.SetNewMem(in_buffer, dst_temp.get());
312-
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
313-
314261
mkldnn_output_t diff_src_mem =
315-
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
316-
bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
317-
bwd.Execute(in_grad, diff_src_mem);
262+
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req);
263+
264+
bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem);
318265
}
319266
} // namespace op
320267
} // namespace mxnet
321-
#endif // MXNET_USE_MKLDNN == 1
268+
#endif // MXNET_USE_MKLDNN == 100
322269
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__
270+

0 commit comments

Comments
 (0)