Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0b8805a

Browse files
rongzha1TaoLv
authored andcommitted
[mkldnn-v1.0] Add MKL-DNN BN (#16199)
* add mkldnn bn * add static_cast to transform data type * change mkldnn_args_map_t * retrigger CI
1 parent f930baa commit 0b8805a

File tree

2 files changed

+81
-165
lines changed

2 files changed

+81
-165
lines changed

src/operator/nn/batch_norm.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include <nnvm/op_attr_types.h>
2929
#include "../elemwise_op_common.h"
3030
#include "../operator_common.h"
31-
#if MXNET_USE_MKLDNN == 1
31+
#if MXNET_USE_MKLDNN == 100
3232
#include "./mkldnn/mkldnn_batch_norm-inl.h"
3333
#endif
3434

@@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
379379
return true;
380380
}
381381

382-
#if MXNET_USE_MKLDNN == 1
382+
#if MXNET_USE_MKLDNN == 100
383383
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
384384
mxnet::TShape shape = input.shape();
385385
return SupportMKLDNN(input) && shape.ndim() == 4
@@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
454454
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
455455

456456
bool dispatched = false;
457-
#if MXNET_USE_MKLDNN == 1
457+
#if MXNET_USE_MKLDNN == 100
458458
if (!dispatched) {
459459
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
460460
in_attrs, out_attrs);
@@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0.
592592
.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
593593
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
594594
.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
595-
#if MXNET_USE_MKLDNN == 1
595+
#if MXNET_USE_MKLDNN == 100
596596
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
597597
#endif
598598
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
599-
#if MXNET_USE_MKLDNN == 1
599+
#if MXNET_USE_MKLDNN == 100
600600
.set_attr<bool>("TIsMKLDNN", true)
601601
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
602602
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
623623
.set_num_outputs(3)
624624
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
625625
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
626-
#if MXNET_USE_MKLDNN == 1
626+
#if MXNET_USE_MKLDNN == 100
627627
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
628628
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
629629
})
630630
#endif
631631
.set_attr_parser(ParamParser<BatchNormParam>)
632-
#if MXNET_USE_MKLDNN == 1
632+
#if MXNET_USE_MKLDNN == 100
633633
.set_attr<bool>("TIsMKLDNN", true)
634634
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
635635
#endif

0 commit comments

Comments
 (0)