@@ -290,24 +290,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
290
290
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias ], GetMemDesc (out_grad));
291
291
292
292
CHECK_NE (req[fullc::kWeight ], kWriteInplace ) << " cannot write weight inplace" ;
293
- if (req[fullc::kData ]) {
294
- mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData (
295
- data, weight, out_grad, fwd_pd);
296
- auto out_grad_mem = out_grad.GetMKLDNNDataReorder (
297
- ipBwdData_pd.diff_dst_desc ());
298
- auto weight_mem = weight.GetMKLDNNDataReorder (ipBwdData_pd.weights_desc ());
299
- auto in_grad_mem = CreateMKLDNNMem (in_grad[fullc::kData ],
300
- ipBwdData_pd.diff_src_desc (),
301
- req[fullc::kData ]);
302
- mkldnn_args_map_t args = {
303
- {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
304
- {MKLDNN_ARG_WEIGHTS, *weight_mem},
305
- {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second }
306
- };
307
-
308
- MKLDNNStream::Get ()->RegisterPrimArgs (mkldnn::inner_product_backward_data (ipBwdData_pd), args);
309
- CommitOutput (in_grad[fullc::kData ], in_grad_mem);
310
- }
311
293
if (req[fullc::kWeight ]) {
312
294
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
313
295
= GetFCBwdWeights (data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias ],
@@ -336,6 +318,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
336
318
CommitOutput (in_grad[fullc::kWeight ], in_grad_weight);
337
319
CommitOutput (in_grad[fullc::kBias ], in_grad_bias);
338
320
}
321
+ if (req[fullc::kData ]) {
322
+ mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData (
323
+ data, weight, out_grad, fwd_pd);
324
+ auto out_grad_mem = out_grad.GetMKLDNNDataReorder (
325
+ ipBwdData_pd.diff_dst_desc ());
326
+ auto weight_mem = weight.GetMKLDNNDataReorder (ipBwdData_pd.weights_desc ());
327
+ auto in_grad_mem = CreateMKLDNNMem (in_grad[fullc::kData ],
328
+ ipBwdData_pd.diff_src_desc (),
329
+ req[fullc::kData ]);
330
+ mkldnn_args_map_t args = {
331
+ {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
332
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
333
+ {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second }
334
+ };
335
+
336
+ MKLDNNStream::Get ()->RegisterPrimArgs (mkldnn::inner_product_backward_data (ipBwdData_pd), args);
337
+ CommitOutput (in_grad[fullc::kData ], in_grad_mem);
338
+ }
339
339
MKLDNNStream::Get ()->Submit ();
340
340
}
341
341
0 commit comments