28
28
#include < nnvm/op_attr_types.h>
29
29
#include " ../elemwise_op_common.h"
30
30
#include " ../operator_common.h"
31
- #if MXNET_USE_MKLDNN == 1
31
+ #if MXNET_USE_MKLDNN == 100
32
32
#include " ./mkldnn/mkldnn_batch_norm-inl.h"
33
33
#endif
34
34
@@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
379
379
return true ;
380
380
}
381
381
382
- #if MXNET_USE_MKLDNN == 1
382
+ #if MXNET_USE_MKLDNN == 100
383
383
static inline bool SupportMKLDNNBN (const NDArray &input, const BatchNormParam ¶m) {
384
384
mxnet::TShape shape = input.shape ();
385
385
return SupportMKLDNN (input) && shape.ndim () == 4
@@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
454
454
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed );
455
455
456
456
bool dispatched = false ;
457
- #if MXNET_USE_MKLDNN == 1
457
+ #if MXNET_USE_MKLDNN == 100
458
458
if (!dispatched) {
459
459
dispatched = MKLDNNStorageType (attrs, dev_mask, true , dispatch_mode,
460
460
in_attrs, out_attrs);
@@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0.
592
592
.set_attr<nnvm::FInferType>(" FInferType" , BatchNormType)
593
593
.set_attr<FInferStorageType>(" FInferStorageType" , BatchNormStorageType)
594
594
.set_attr<FCompute>(" FCompute<cpu>" , BatchNormCompute<cpu>)
595
- #if MXNET_USE_MKLDNN == 1
595
+ #if MXNET_USE_MKLDNN == 100
596
596
.set_attr<FComputeEx>(" FComputeEx<cpu>" , BatchNormComputeExCPU)
597
597
#endif
598
598
.set_attr<nnvm::FGradient>(" FGradient" , BatchNormGrad)
599
- #if MXNET_USE_MKLDNN == 1
599
+ #if MXNET_USE_MKLDNN == 100
600
600
.set_attr<bool >(" TIsMKLDNN" , true )
601
601
.set_attr<FResourceRequest>(" FResourceRequest" , [](const NodeAttrs& n) {
602
602
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace };
@@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
623
623
.set_num_outputs(3 )
624
624
.set_attr<nnvm::TIsBackward>(" TIsBackward" , true )
625
625
.set_attr<FInferStorageType>(" FInferStorageType" , BatchNormStorageType)
626
- #if MXNET_USE_MKLDNN == 1
626
+ #if MXNET_USE_MKLDNN == 100
627
627
.set_attr<FResourceRequest>(" FResourceRequest" , [](const NodeAttrs& n) {
628
628
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace };
629
629
})
630
630
#endif
631
631
.set_attr_parser(ParamParser<BatchNormParam>)
632
- #if MXNET_USE_MKLDNN == 1
632
+ #if MXNET_USE_MKLDNN == 100
633
633
.set_attr<bool >(" TIsMKLDNN" , true )
634
634
.set_attr<FComputeEx>(" FComputeEx<cpu>" , BatchNormGradComputeExCPU)
635
635
#endif
0 commit comments