24
24
* \author Da Zheng, Ciyong Chen
25
25
*/
26
26
27
- #if MXNET_USE_MKLDNN == 1
27
+ #if MXNET_USE_MKLDNN == 100
28
28
#include " mkldnn_fully_connected-inl.h"
29
29
30
30
namespace mxnet {
@@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
67
67
}
68
68
69
69
attr.set_output_scales (mask, scales);
70
- attr.set_int_output_round_mode (round_nearest);
71
70
}
72
71
}
73
72
@@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
130
129
}
131
130
}
132
131
133
- void MKLDNNFullyConnectedForward::SetNewMem (const mkldnn::memory &data,
134
- const mkldnn::memory &weight,
135
- const mkldnn::memory *bias,
136
- const mkldnn::memory &output) {
137
- if (this ->data_ == nullptr )
138
- this ->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory (
139
- fwd_pd.src_primitive_desc (), data.get_data_handle ()));
140
- else
141
- this ->data_ ->set_data_handle (data.get_data_handle ());
142
-
143
- if (this ->weight_ == nullptr )
144
- this ->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory (
145
- fwd_pd.weights_primitive_desc (), weight.get_data_handle ()));
146
- else
147
- this ->weight_ ->set_data_handle (weight.get_data_handle ());
148
-
149
- if (this ->out_ == nullptr )
150
- this ->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory (
151
- fwd_pd.dst_primitive_desc (), output.get_data_handle ()));
152
- else
153
- this ->out_ ->set_data_handle (output.get_data_handle ());
154
-
155
- if (bias != nullptr ) {
156
- if (this ->bias_ == nullptr )
157
- this ->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory (
158
- fwd_pd.bias_primitive_desc (), bias->get_data_handle ()));
159
- else
160
- this ->bias_ ->set_data_handle (bias->get_data_handle ());
161
-
162
- if (this ->fwd_ == nullptr )
163
- this ->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
164
- new mkldnn::inner_product_forward (
165
- fwd_pd, mkldnn::primitive::at (*this ->data_ ),
166
- mkldnn::primitive::at (*this ->weight_ ),
167
- mkldnn::primitive::at (*this ->bias_ ), *this ->out_ ));
168
- } else {
169
- if (this ->fwd_ == nullptr ) {
170
- this ->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
171
- new mkldnn::inner_product_forward (
172
- fwd_pd, mkldnn::primitive::at (*this ->data_ ),
173
- mkldnn::primitive::at (*this ->weight_ ), *this ->out_ ));
174
- }
175
- }
176
- }
177
-
178
132
MKLDNNFullyConnectedForward &GetFCFwd (
179
133
const FullyConnectedParam ¶m, const bool is_train,
180
134
const NDArray &data, const NDArray &weight,
@@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam ¶m,
223
177
mkldnn::memory::dims out_dims{static_cast <int >(oshape.ProdShape (0 , oshape.ndim ()-1 )),
224
178
static_cast <int >(oshape[ishape.ndim ()-1 ])};
225
179
*out_md = mkldnn::memory::desc (out_dims, get_mkldnn_type (out_data.dtype ()),
226
- mkldnn::memory::format ::any);
180
+ mkldnn::memory::format_tag ::any);
227
181
} else {
228
182
*in_data = in_data->MKLDNNDataReshape (Shape2 (ishape[0 ], ishape.ProdShape (1 , ishape.ndim ())));
229
183
mkldnn::memory::dims out_dims{static_cast <int >(oshape[0 ]),
230
184
static_cast <int >(oshape.ProdShape (1 , oshape.ndim ()))};
231
185
*out_md = mkldnn::memory::desc (out_dims, get_mkldnn_type (out_data.dtype ()),
232
- mkldnn::memory::format ::any);
186
+ mkldnn::memory::format_tag ::any);
233
187
}
234
188
}
235
189
}
@@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
244
198
NDArray weight = in_data[fullc::kWeight ];
245
199
NDArray data = in_data[fullc::kData ];
246
200
247
- auto data_mem = data.GetMKLDNNDataReorder (fwd->fwd_pd .src_primitive_desc ());
201
+ auto data_mem = data.GetMKLDNNDataReorder (fwd->fwd_pd .src_desc ());
248
202
const mkldnn::memory *weight_mem;
249
203
if (ctx.is_train ) {
250
204
if (weight.IsMKLDNNData ()) {
251
205
weight.Reorder2DefaultAsync ();
252
206
}
253
- weight_mem = GetWeights (weight, fwd->fwd_pd .weights_primitive_desc (), 1 );
207
+ weight_mem = GetWeights (weight, fwd->fwd_pd .weights_desc (), 1 );
254
208
} else {
255
- if (weight.IsDefaultData ()) {
256
- // We also need to modify the layout on the original weight array.
257
- // Don't switch below sequence because naive engine will executes
258
- // pushAsync synchronously.
259
- weight.MKLDNNDataReorderAsync (fwd->fwd_pd .weights_primitive_desc ());
260
- weight_mem = GetWeights (weight, fwd->fwd_pd .weights_primitive_desc (), 1 );
261
- } else {
262
- weight_mem = weight.GetMKLDNNData ();
263
- CHECK (weight_mem->get_primitive_desc () == fwd->fwd_pd .weights_primitive_desc ());
209
+ weight_mem = weight.GetMKLDNNData ();
210
+ if (weight_mem->get_desc () != fwd->fwd_pd .weights_desc ()) {
211
+ // TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
212
+ // weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
213
+ weight_mem = GetWeights (weight, fwd->fwd_pd .weights_desc (), 1 );
264
214
}
265
215
}
266
216
auto out_mem = CreateMKLDNNMem (out_data[fullc::kOut ],
267
- fwd->fwd_pd .dst_primitive_desc (), req[fullc::kOut ], &data);
217
+ fwd->fwd_pd .dst_desc (), req[fullc::kOut ], &data);
218
+
219
+ std::unordered_map<int , mkldnn::memory> args = {
220
+ {MKLDNN_ARG_SRC, *data_mem},
221
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
222
+ {MKLDNN_ARG_DST, *out_mem.second },
223
+ };
268
224
if (!full_param.default_param .no_bias ) {
269
225
auto bias_mem = in_data[fullc::kBias ].GetMKLDNNDataReorder (
270
- fwd->fwd_pd .bias_primitive_desc ());
271
- fwd->SetNewMem (*data_mem, *weight_mem, bias_mem, *out_mem.second );
272
- } else {
273
- fwd->SetNewMem (*data_mem, *weight_mem, nullptr , *out_mem.second );
226
+ fwd->fwd_pd .bias_desc ());
227
+ args.insert ({ MKLDNN_ARG_BIAS, *bias_mem});
274
228
}
275
- MKLDNNStream::Get ()->RegisterPrim (fwd->GetFwd ());
229
+ MKLDNNStream::Get ()->RegisterPrimArgs (fwd->GetFwd (), args );
276
230
CommitOutput (out_data[fullc::kOut ], out_mem);
277
231
MKLDNNStream::Get ()->Submit ();
278
232
}
@@ -339,37 +293,45 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
339
293
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData (
340
294
data, weight, out_grad, fwd_pd);
341
295
auto out_grad_mem = out_grad.GetMKLDNNDataReorder (
342
- ipBwdData_pd.diff_dst_primitive_desc ());
343
- auto weight_mem = weight.GetMKLDNNDataReorder (ipBwdData_pd.weights_primitive_desc ());
296
+ ipBwdData_pd.diff_dst_desc ());
297
+ auto weight_mem = weight.GetMKLDNNDataReorder (ipBwdData_pd.weights_desc ());
344
298
auto in_grad_mem = CreateMKLDNNMem (in_grad[fullc::kData ],
345
- ipBwdData_pd.diff_src_primitive_desc (),
299
+ ipBwdData_pd.diff_src_desc (),
346
300
req[fullc::kData ]);
347
- MKLDNNStream::Get ()->RegisterPrim (mkldnn::inner_product_backward_data (
348
- ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second ));
301
+ std::unordered_map<int , mkldnn::memory> args = {
302
+ {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
303
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
304
+ {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second }
305
+ };
306
+
307
+ MKLDNNStream::Get ()->RegisterPrimArgs (mkldnn::inner_product_backward_data (ipBwdData_pd), args);
349
308
CommitOutput (in_grad[fullc::kData ], in_grad_mem);
350
309
}
351
310
if (req[fullc::kWeight ]) {
352
311
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
353
312
= GetFCBwdWeights (data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias ],
354
313
out_grad, fwd_pd);
355
314
auto out_grad_mem = out_grad.GetMKLDNNDataReorder (
356
- ipBwdWeights_pd.diff_dst_primitive_desc ());
357
- auto data_mem = data.GetMKLDNNDataReorder (ipBwdWeights_pd.src_primitive_desc ());
315
+ ipBwdWeights_pd.diff_dst_desc ());
316
+ auto data_mem = data.GetMKLDNNDataReorder (ipBwdWeights_pd.src_desc ());
358
317
auto in_grad_weight = CreateMKLDNNWeightGrad (in_grad[fullc::kWeight ],
359
- ipBwdWeights_pd.diff_weights_primitive_desc (),
318
+ ipBwdWeights_pd.diff_weights_desc (),
360
319
req[fullc::kWeight ]);
320
+ std::unordered_map<int , mkldnn::memory> args = {
321
+ {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
322
+ {MKLDNN_ARG_SRC, *data_mem},
323
+ {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second },
324
+ };
325
+
361
326
mkldnn_output_t in_grad_bias;
362
- if (param.no_bias ) {
363
- MKLDNNStream::Get ()->RegisterPrim (mkldnn::inner_product_backward_weights (
364
- ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second ));
365
- } else {
327
+ if (!param.no_bias ) {
366
328
in_grad_bias = CreateMKLDNNMem (in_grad[fullc::kBias ],
367
- ipBwdWeights_pd.diff_bias_primitive_desc (),
329
+ ipBwdWeights_pd.diff_bias_desc (),
368
330
req[fullc::kBias ]);
369
- MKLDNNStream::Get ()->RegisterPrim (mkldnn::inner_product_backward_weights (
370
- ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second ,
371
- *in_grad_bias.second ));
331
+ args.insert ({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second });
372
332
}
333
+ MKLDNNStream::Get ()->RegisterPrimArgs (
334
+ mkldnn::inner_product_backward_weights (ipBwdWeights_pd), args);
373
335
CommitOutput (in_grad[fullc::kWeight ], in_grad_weight);
374
336
CommitOutput (in_grad[fullc::kBias ], in_grad_bias);
375
337
}
@@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
378
340
379
341
} // namespace op
380
342
} // namespace mxnet
381
- #endif // MXNET_USE_MKLDNN == 1
343
+ #endif // MXNET_USE_MKLDNN == 100
0 commit comments