@@ -532,16 +532,15 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
532
532
return GetMKLDNNExact (mem, new_desc);
533
533
}
534
534
535
- mkldnn::memory::desc desc1 = mem->get_desc ();
536
- mkldnn::memory::desc desc2 = new_desc;
535
+ mkldnn::memory::desc old_desc = mem->get_desc ();
537
536
// Now we need to determine if we should reorder the memory.
538
537
// If both use the default formats, we think we don't need to reorder.
539
- if ((!mxnet::IsMKLDNN (desc1 )) && (!mxnet::IsMKLDNN (desc2 ))) {
538
+ if ((!mxnet::IsMKLDNN (old_desc )) && (!mxnet::IsMKLDNN (new_desc ))) {
540
539
mkldnn_mem_ptr ret (new mkldnn::memory (new_desc,
541
540
CpuEngine::Get ()->get_engine (), mem->get_data_handle ()));
542
541
stream->RegisterMem (ret);
543
542
return ret.get ();
544
- } else if (same_shape (desc1, desc2 )) {
543
+ } else if (same_shape (old_desc, new_desc )) {
545
544
// If they have the same shape, we can reorder data directly.
546
545
mkldnn::memory *ret = TmpMemMgr::Get ()->Alloc (new_desc);
547
546
std::unordered_map<int , mkldnn::memory> args ({{MKLDNN_ARG_FROM, *mem }, {MKLDNN_ARG_TO, *ret}});
@@ -551,9 +550,9 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
551
550
// If they have different shapes, we need to reshape the array first.
552
551
// Since this method will only be used inside an operator, we can call
553
552
// MKLDNNDataReshape to reshape an array.
554
- mxnet::TShape required_shape (desc2 .data .ndims , -1 );
555
- for (int i = 0 ; i < desc2 .data .ndims ; i++)
556
- required_shape[i] = desc2 .data .dims [i];
553
+ mxnet::TShape required_shape (new_desc .data .ndims , -1 );
554
+ for (int i = 0 ; i < new_desc .data .ndims ; i++)
555
+ required_shape[i] = new_desc .data .dims [i];
557
556
NDArray reshaped = MKLDNNDataReshape (required_shape);
558
557
const mkldnn::memory *ret = reshaped.GetMKLDNNData ();
559
558
if (ret->get_desc () == new_desc) {
@@ -684,7 +683,9 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
684
683
685
684
mkldnn::memory *NDArray::CreateMKLDNNData (const mkldnn::memory::desc &desc) {
686
685
if (desc.get_size () != shape ().Size () * GetTypeSize (dtype_)) {
687
- LOG (FATAL) << " The size of NDArray doesn't match the requested MKLDNN memory desc " ;
686
+ LOG (FATAL) << " The size of NDArray doesn't match the requested MKLDNN memory desc. "
687
+ << " MKLDNN memory requests for " << desc.get_size () << " bytes, but got "
688
+ << shape ().Size () * GetTypeSize (dtype_) << " bytes from NDArray" ;
688
689
return nullptr ;
689
690
}
690
691
bool isDefaultFormat = IsDefaultFormat (desc);
0 commit comments