25
25
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
26
26
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
27
27
28
- #if MXNET_USE_MKLDNN == 1
28
+ #if MXNET_USE_MKLDNN == 100
29
29
#include < utility>
30
30
#include < mkldnn.hpp>
31
31
#include " ../lrn-inl.h"
34
34
namespace mxnet {
35
35
namespace op {
36
36
37
- inline algorithm GetMKLDNNLRNAlgo (const LRNParam ¶m) {
37
+ inline mkldnn:: algorithm GetMKLDNNLRNAlgo (const LRNParam ¶m) {
38
38
// TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward
39
39
// Need to confirm with MKLDNN team and fix later
40
- return algorithm::lrn_across_channels;
40
+ return mkldnn:: algorithm::lrn_across_channels;
41
41
}
42
42
43
43
inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc (
44
- const LRNParam ¶m, const bool is_train, const memory::desc &src_md) {
44
+ const LRNParam ¶m, const bool is_train, const mkldnn:: memory::desc &src_md) {
45
45
mkldnn::engine &engine = CpuEngine::Get ()->get_engine ();
46
- const algorithm alg = GetMKLDNNLRNAlgo (param);
46
+ const mkldnn:: algorithm alg = GetMKLDNNLRNAlgo (param);
47
47
const float alpha = param.alpha ;
48
48
const float beta = param.beta ;
49
49
const int nsize = param.nsize ;
50
50
const float k = param.knorm ;
51
- auto kind = prop_kind::forward_training;
51
+ auto kind = mkldnn:: prop_kind::forward_training;
52
52
if (is_train) {
53
- kind = prop_kind::forward_training;
53
+ kind = mkldnn:: prop_kind::forward_training;
54
54
} else {
55
- kind = prop_kind::forward_scoring;
55
+ kind = mkldnn:: prop_kind::forward_scoring;
56
56
}
57
- lrn_forward::desc fwd_desc (kind, alg, src_md, nsize, alpha, beta, k);
57
+ mkldnn:: lrn_forward::desc fwd_desc (kind, alg, src_md, nsize, alpha, beta, k);
58
58
return mkldnn::lrn_forward::primitive_desc (fwd_desc, engine);
59
59
}
60
60
@@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
63
63
const mkldnn::memory::desc &diff_md,
64
64
const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
65
65
mkldnn::engine &engine = CpuEngine::Get ()->get_engine ();
66
- const algorithm alg = GetMKLDNNLRNAlgo (param);
66
+ const mkldnn:: algorithm alg = GetMKLDNNLRNAlgo (param);
67
67
const float alpha = param.alpha ;
68
68
const float beta = param.beta ;
69
69
const int nsize = param.nsize ;
70
70
const float k = param.knorm ;
71
71
72
- lrn_backward::desc lrnBwd_desc (alg, data_in_md,
72
+ mkldnn:: lrn_backward::desc lrnBwd_desc (alg, data_in_md,
73
73
diff_md, nsize, alpha, beta, k);
74
74
return mkldnn::lrn_backward::primitive_desc (lrnBwd_desc,
75
75
engine, lrnFwd_desc);
@@ -83,33 +83,24 @@ class MKLDNNLRNFwd {
83
83
public:
84
84
MKLDNNLRNFwd (const LRNParam& param,
85
85
bool is_train,
86
- const NDArray &in_data):
87
- is_train (is_train) {
86
+ const NDArray &in_data) {
88
87
_Init (param, is_train, in_data);
89
88
}
90
89
91
90
~MKLDNNLRNFwd () {}
92
91
93
- void SetNewMem (const NDArray &data,
94
- const NDArray &output,
95
- const OpReqType req);
96
-
97
- void SetNewMem (const NDArray &in_data,
98
- const mkldnn::memory *out_mem);
99
-
100
- void Execute (const NDArray &out_data);
92
+ void Execute (const OpContext &ctx,
93
+ const NDArray &in_data,
94
+ const OpReqType req,
95
+ const NDArray &out_data);
101
96
102
97
mkldnn::lrn_forward &GetFwd ();
103
-
104
98
const mkldnn::memory *GetWs ();
99
+ mkldnn::lrn_forward::primitive_desc &GetFwdPd ();
105
100
106
101
private:
107
102
std::shared_ptr<mkldnn::lrn_forward> fwd;
108
- std::shared_ptr<mkldnn::memory> in_mem;
109
- std::shared_ptr<mkldnn::memory> out_mem;
110
- std::shared_ptr<mkldnn::memory> ws_mem;
111
- mkldnn_output_t output_mem_t ;
112
- bool is_train;
103
+ mkldnn::lrn_forward::primitive_desc fwd_pd;
113
104
114
105
private:
115
106
void _Init (const LRNParam ¶m, bool is_train, const NDArray &in_data);
@@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam ¶m,
119
110
bool is_train,
120
111
const NDArray &in_data) {
121
112
mkldnn::memory::desc in_data_md =
122
- in_data.GetMKLDNNData ()->get_primitive_desc (). desc ();
123
- mkldnn::lrn_forward::primitive_desc fwd_pd =
113
+ in_data.GetMKLDNNData ()->get_desc ();
114
+ this -> fwd_pd =
124
115
GetLRNFwdDesc (param, is_train, in_data_md);
125
116
126
- this ->in_mem .reset (new mkldnn::memory (in_data.GetMKLDNNData ()
127
- ->get_primitive_desc ()));
128
- this ->out_mem .reset (new mkldnn::memory (fwd_pd.dst_primitive_desc ()));
129
- if (is_train) {
130
- // If it's training, we have to create a workspace memory. Otherwise,
131
- // MKLDNN will have segmentation fault.
132
- ws_mem.reset (new mkldnn::memory (fwd_pd.workspace_primitive_desc ()));
133
- this ->fwd = std::shared_ptr<mkldnn::lrn_forward>(
134
- new mkldnn::lrn_forward (fwd_pd, mkldnn::primitive::at (*this ->in_mem ),
135
- *this ->ws_mem , *this ->out_mem ));
136
- } else {
137
- this ->fwd = std::shared_ptr<mkldnn::lrn_forward>(
138
- new mkldnn::lrn_forward (fwd_pd, mkldnn::primitive::at (*(this ->in_mem )),
139
- *(this ->out_mem )));
140
- }
141
- }
142
-
143
- void MKLDNNLRNFwd::SetNewMem (const NDArray &in_data,
144
- const NDArray &out_data,
145
- const OpReqType req) {
146
- const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData ();
147
- output_mem_t = CreateMKLDNNMem (out_data, this ->out_mem ->get_primitive_desc (), req);
148
- this ->in_mem ->set_data_handle (in_data_mem->get_data_handle ());
149
- this ->out_mem ->set_data_handle (output_mem_t .second ->get_data_handle ());
117
+ this ->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward (this ->fwd_pd ));
150
118
}
151
119
152
- void MKLDNNLRNFwd::SetNewMem (const NDArray &in_data,
153
- const mkldnn::memory *out_mem) {
154
- const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData ();
155
- this ->in_mem ->set_data_handle (in_data_mem->get_data_handle ());
156
- this ->out_mem ->set_data_handle (out_mem->get_data_handle ());
157
- }
158
-
159
- void MKLDNNLRNFwd::Execute (const NDArray &out_data) {
160
- MKLDNNStream::Get ()->RegisterPrim (*(this ->fwd ));
120
+ void MKLDNNLRNFwd::Execute (const OpContext &ctx,
121
+ const NDArray &in_data,
122
+ const OpReqType req,
123
+ const NDArray &out_data) {
124
+ auto output_mem_t = CreateMKLDNNMem (out_data, (this ->fwd_pd ).dst_desc (), req);
125
+
126
+ mkldnn_args_map_t args = {
127
+ { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData ()},
128
+ { MKLDNN_ARG_DST, *output_mem_t .second },
129
+ };
130
+ std::shared_ptr<mkldnn::memory> workspace;
131
+ if (ctx.is_train ) {
132
+ auto engine = CpuEngine::Get ()->get_engine ();
133
+ workspace = std::make_shared<mkldnn::memory>((this ->fwd_pd ).workspace_desc (), engine);
134
+ args[MKLDNN_ARG_WORKSPACE] = *(workspace);
135
+ }
136
+ MKLDNNStream::Get ()->RegisterPrimArgs (*(this ->fwd ), args);
161
137
CommitOutput (out_data, output_mem_t );
162
138
MKLDNNStream::Get ()->Submit ();
163
139
}
164
140
165
141
mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd () { return *this ->fwd ; }
142
+ mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd () { return this ->fwd_pd ; }
166
143
167
- const mkldnn::memory *MKLDNNLRNFwd::GetWs () { return this ->ws_mem .get (); }
168
144
// End of LRN Class and its functions
169
145
170
146
static MKLDNNLRNFwd &GetLRNFwd (const LRNParam& param,
@@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
180
156
OpHash> lrn_fwds;
181
157
#endif
182
158
auto kind_ =
183
- ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
159
+ ctx.is_train ? mkldnn::prop_kind::forward_training
160
+ : mkldnn::prop_kind::forward_scoring;
184
161
185
162
MKLDNNLRNSignature key (param);
186
- key.AddSign (kind_);
163
+ key.AddSign (static_cast < int >( kind_) );
187
164
key.AddSign (in_data);
188
165
189
166
auto it = lrn_fwds.find (key);
@@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m,
201
178
if (in_buffer.IsView () && in_buffer.IsMKLDNNData ())
202
179
in_buffer = in_buffer.Reorder2Default ();
203
180
MKLDNNLRNFwd fwd = GetLRNFwd (param, ctx, in_buffer);
204
- fwd.SetNewMem (in_buffer, out_data, req);
205
- fwd.Execute (out_data);
181
+ fwd.Execute (ctx, in_buffer, req, out_data);
206
182
}
207
183
208
184
// LRN Backward Class
209
185
class MKLDNNLRNBwd {
210
186
std::shared_ptr<mkldnn::lrn_backward> bwd;
211
- std::shared_ptr<mkldnn::memory> in_data_mem;
212
- std::shared_ptr<mkldnn::memory> diff_dst_mem;
213
- std::shared_ptr<mkldnn::memory> ws_mem;
214
- std::shared_ptr<mkldnn::memory> diff_src_mem;
215
187
216
188
public:
217
189
const mkldnn::lrn_forward::primitive_desc fwd_pd;
@@ -222,40 +194,26 @@ class MKLDNNLRNBwd {
222
194
MKLDNNLRNBwd (const LRNParam ¶m, const mkldnn::memory::desc in_data_md,
223
195
const mkldnn::memory::desc diff_md)
224
196
: fwd_pd(GetLRNFwdDesc(param, true , in_data_md)),
225
- bwd_pd (GetLRNBwdDesc(param, in_data_md, diff_md, this ->fwd_pd)) {}
226
-
227
- void SetNewMem (const NDArray &in_data, const NDArray &out_grad,
228
- const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
229
- if (bwd == nullptr ) {
230
- this ->in_data_mem .reset (
231
- new mkldnn::memory (this ->fwd_pd .src_primitive_desc (),
232
- in_data.GetMKLDNNData ()->get_data_handle ()));
233
- this ->diff_dst_mem .reset (
234
- new mkldnn::memory (this ->fwd_pd .dst_primitive_desc (),
235
- out_grad.GetMKLDNNData ()->get_data_handle ()));
236
- this ->ws_mem .reset (
237
- new mkldnn::memory (this ->fwd_pd .workspace_primitive_desc (),
238
- ws->get_data_handle ()));
239
- this ->diff_src_mem .reset (
240
- new mkldnn::memory (this ->bwd_pd .diff_src_primitive_desc (),
241
- diff_src_mem->get_data_handle ()));
242
- this ->bwd .reset (new mkldnn::lrn_backward (
243
- this ->bwd_pd , mkldnn::primitive::at (*this ->in_data_mem ),
244
- mkldnn::primitive::at (*this ->diff_dst_mem ), *this ->ws_mem ,
245
- *this ->diff_src_mem ));
246
- } else {
247
- this ->in_data_mem ->set_data_handle (
248
- in_data.GetMKLDNNData ()->get_data_handle ());
249
- this ->diff_dst_mem ->set_data_handle (
250
- out_grad.GetMKLDNNData ()->get_data_handle ());
251
- this ->ws_mem ->set_data_handle (ws->get_data_handle ());
252
- this ->diff_src_mem ->set_data_handle (diff_src_mem->get_data_handle ());
253
- }
254
- }
255
-
256
- void Execute (const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
257
- MKLDNNStream::Get ()->RegisterPrim (*(this ->bwd ));
258
- CommitOutput (in_grad, diff_src_mem_);
197
+ bwd_pd (GetLRNBwdDesc(param, in_data_md, diff_md, this ->fwd_pd)) {
198
+ bwd = std::make_shared<mkldnn::lrn_backward>(bwd_pd);
199
+ }
200
+
201
+ const mkldnn::lrn_backward &GetBwd () const { return *bwd; }
202
+
203
+ void Execute (const NDArray &out_grad,
204
+ const NDArray &in_data,
205
+ const NDArray &in_grad,
206
+ const mkldnn_output_t &diff_src_mem) {
207
+ auto engine = CpuEngine::Get ()->get_engine ();
208
+ auto workspace = std::make_shared<mkldnn::memory>((this ->fwd_pd ).workspace_desc (), engine);
209
+ mkldnn_args_map_t args = {
210
+ { MKLDNN_ARG_SRC, *in_data.GetMKLDNNData () },
211
+ { MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData ()},
212
+ { MKLDNN_ARG_WORKSPACE, *workspace },
213
+ { MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }
214
+ };
215
+ MKLDNNStream::Get ()->RegisterPrimArgs (*(this ->bwd ), args);
216
+ CommitOutput (in_grad, diff_src_mem);
259
217
MKLDNNStream::Get ()->Submit ();
260
218
}
261
219
}; // End of LRN Class
@@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam ¶m, const NDArray &in_data,
277
235
auto it = lrn_bwds.find (key);
278
236
if (it == lrn_bwds.end ()) {
279
237
const mkldnn::memory::desc in_data_md =
280
- in_data.GetMKLDNNData ()->get_primitive_desc (). desc ();
238
+ in_data.GetMKLDNNData ()->get_desc ();
281
239
const mkldnn::memory::desc diff_md =
282
- out_grad.GetMKLDNNData ()->get_primitive_desc (). desc ();
240
+ out_grad.GetMKLDNNData ()->get_desc ();
283
241
MKLDNNLRNBwd bwd (param, in_data_md, diff_md);
284
242
it = AddToCache (&lrn_bwds, key, bwd);
285
243
}
@@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m,
300
258
in_buffer = in_data.Reorder2Default ();
301
259
}
302
260
MKLDNNLRNBwd &bwd = GetLRNBwd (param, in_buffer, in_grad, out_grad);
303
- // Repeat FW for getting workspace
304
- // TODO(Patric): To keep the function stateless, we can't pass workspace
305
- // from LRN forward to backward. We have to re-compute
306
- // LRN forward to get the workspace.
307
- // Will refine this code later.
308
- MKLDNNLRNFwd fwd = GetLRNFwd (param, ctx, in_buffer);
309
- std::shared_ptr<const mkldnn::memory> dst_temp (
310
- new mkldnn::memory (bwd.fwd_pd .dst_primitive_desc ()));
311
- fwd.SetNewMem (in_buffer, dst_temp.get ());
312
- MKLDNNStream::Get ()->RegisterPrim (fwd.GetFwd ());
313
-
314
261
mkldnn_output_t diff_src_mem =
315
- CreateMKLDNNMem (in_grad, bwd.bwd_pd .diff_src_primitive_desc (), req);
316
- bwd. SetNewMem (in_buffer, out_grad, fwd. GetWs (), diff_src_mem. second );
317
- bwd.Execute (in_grad, diff_src_mem);
262
+ CreateMKLDNNMem (in_grad, bwd.bwd_pd .diff_src_desc (), req);
263
+
264
+ bwd.Execute (out_grad, in_buffer, in_grad, diff_src_mem);
318
265
}
319
266
} // namespace op
320
267
} // namespace mxnet
321
- #endif // MXNET_USE_MKLDNN == 1
268
+ #endif // MXNET_USE_MKLDNN == 100
322
269
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__
270
+
0 commit comments