Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a37a76c

Browse files
ZhennanQinpengzhao-intel
authored andcommitted
Float64 fallback for mkldnn subgraph and rnn op (#15853)
* Fix mkldnn subgraph with float64 * Fix ci Change-Id: I0bb4e8d7a0a534aa661601887cc633cb9c4fcadf * Fix test Change-Id: I96c529abe7adb6def90a22f03b3432263ef12fda * Update dmlc-core Change-Id: I472fb7bbffc16ed8c36494ab49838b08c59b2f12 * pin to offical dmlc Change-Id: I5a27dc83b892bf8fcb34bb089449d1d3b6e9beed * Fix GPU CI Change-Id: I285947e01bdb0651c2c7830ed4eb76931a09b754 * Fix GPU CI Change-Id: I6f23b51d6bda44f6ae18766ebe390118740bb9c7
1 parent 5f9a680 commit a37a76c

14 files changed

+266
-62
lines changed

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@
4747
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_
4848

4949
#if MXNET_USE_MKLDNN == 1
50+
#include <algorithm>
5051
#include <iterator>
52+
#include <memory>
5153
#include <string>
5254
#include <unordered_map>
53-
#include <vector>
5455
#include <utility>
55-
#include <algorithm>
56-
#include <memory>
56+
#include <vector>
5757
#include "mkldnn.hpp"
58+
#include "mxnet/graph_attr_types.h"
5859
#include "mxnet/ndarray.h"
59-
#include "mxnet/resource.h"
6060
#include "mxnet/op_attr_types.h"
61+
#include "mxnet/resource.h"
6162
using namespace mkldnn;
6263
namespace mxnet {
6364

@@ -132,6 +133,11 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
132133
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
133134
}
134135

136+
static inline bool SupportMKLDNNRNN(const NDArray &input) {
137+
int ndim = input.shape().ndim();
138+
return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
139+
}
140+
135141
static inline bool SupportMKLDNNQuantize(int dtype) {
136142
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
137143
dtype == mshadow::kUint8;
@@ -569,7 +575,8 @@ class MKLDNNMemory {
569575
}
570576
};
571577

572-
void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
578+
template <typename Compute, typename AttrState>
579+
void FallBackCompute(Compute fn, const AttrState &attrs,
573580
const OpContext &ctx,
574581
const std::vector<NDArray> &inputs,
575582
const std::vector<OpReqType> &req,

src/operator/nn/mkldnn/mkldnn_base.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p
420420
return mkldnn::memory::primitive_desc(data_md, pd.get_engine());
421421
}
422422

423-
void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
423+
template <typename Compute, typename AttrState>
424+
void FallBackCompute(Compute fn, const AttrState &attrs_states,
424425
const OpContext &ctx,
425426
const std::vector<NDArray> &inputs,
426427
const std::vector<OpReqType> &req,
@@ -461,7 +462,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
461462
out_blobs[i] = output.data();
462463
}
463464

464-
fn(attrs, ctx, in_blobs, req, out_blobs);
465+
fn(attrs_states, ctx, in_blobs, req, out_blobs);
465466
for (size_t i = 0; i < out_blobs.size(); i++) {
466467
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
467468
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
@@ -518,6 +519,24 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2,
518519
return success.load();
519520
}
520521

522+
template void FallBackCompute(void (*)(nnvm::NodeAttrs const &, OpContext const &,
523+
std::vector<TBlob, std::allocator<TBlob> > const &,
524+
std::vector<OpReqType, std::allocator<OpReqType> > const &,
525+
std::vector<TBlob, std::allocator<TBlob> > const &),
526+
nnvm::NodeAttrs const &, OpContext const &,
527+
std::vector<NDArray, std::allocator<NDArray> > const &,
528+
std::vector<OpReqType, std::allocator<OpReqType> > const &,
529+
std::vector<NDArray, std::allocator<NDArray> > const &);
530+
531+
template void FallBackCompute(void (*)(OpStatePtr const &, OpContext const &,
532+
std::vector<TBlob, std::allocator<TBlob> > const &,
533+
std::vector<OpReqType, std::allocator<OpReqType> > const &,
534+
std::vector<TBlob, std::allocator<TBlob> > const &),
535+
OpStatePtr const &, OpContext const &,
536+
std::vector<NDArray, std::allocator<NDArray> > const &,
537+
std::vector<OpReqType, std::allocator<OpReqType> > const &,
538+
std::vector<NDArray, std::allocator<NDArray> > const &);
539+
521540
void OpCheck::Init(const std::vector<mxnet::NDArray> &inputs_,
522541
const std::vector<mxnet::NDArray> &outputs_) {
523542
auto ctx = inputs_[0].ctx();

src/operator/rnn.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,20 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
633633
});
634634
});
635635
}
636+
637+
static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx,
638+
const std::vector<NDArray>& inputs,
639+
const std::vector<OpReqType>& req,
640+
const std::vector<NDArray>& outputs) {
641+
if (SupportMKLDNNRNN(inputs[0])) {
642+
RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs);
643+
return;
644+
}
645+
int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
646+
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0);
647+
FallBackCompute(RNNStatefulCompute<cpu>, state_ptr, ctx, inputs, req, outputs);
648+
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn);
649+
}
636650
#endif
637651

638652
NNVM_REGISTER_OP(RNN)
@@ -719,7 +733,7 @@ The definition of GRU here is slightly different from paper but compatible with
719733
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
720734
#if MXNET_USE_MKLDNN == 1
721735
.set_attr<bool>("TIsMKLDNN", true)
722-
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
736+
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
723737
#endif
724738
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
725739
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx)

src/operator/subgraph/build_subgraph.cc

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
*/
2525
#include <nnvm/graph.h>
2626
#include <nnvm/pass.h>
27-
#include <mxnet/op_attr_types.h>
2827
#include <unordered_set>
2928
#include <stack>
3029
#include <queue>
@@ -105,6 +104,28 @@ void ResetNodeLabels(const nnvm::Graph& g,
105104
subgraph_nodes->clear();
106105
}
107106

107+
/*
108+
* \brief Prepare NodeAttr for node. NodeAttr will be used in SubgraphSelectorV2.
109+
*/
110+
static const std::shared_ptr<NodeAttr> PrepareNodeAttr(const nnvm::Graph& g,
111+
const BiDirectedNode& node) {
112+
const auto& indexed_graph = g.indexed_graph();
113+
if (g.HasAttr("dtype") && g.HasAttr("shape") && g.HasAttr("dispatch_mode")) {
114+
const auto& vdtype = g.GetAttr<nnvm::DTypeVector>("dtype");
115+
const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
116+
const auto& dispatch_modes = g.GetAttr<mxnet::DispatchModeVector>("dispatch_mode");
117+
auto ret = std::make_shared<NodeAttr>();
118+
ret->dispatch_mode = dispatch_modes[indexed_graph.node_id(node.node)];
119+
for (const auto& e : node.node->inputs) {
120+
ret->ishape.emplace_back(vshape[indexed_graph.entry_id(e)]);
121+
ret->itype.emplace_back(vdtype[indexed_graph.entry_id(e)]);
122+
}
123+
return ret;
124+
} else {
125+
return nullptr;
126+
}
127+
}
128+
108129
/*!
109130
* \brief This function traverses the nodes in a computation graph from a starting
110131
* node following the input edges and output edges, and marks all nodes that
@@ -153,7 +174,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector
153174
CHECK_LT(nid, simple_nodes.size());
154175
const bool select_input =
155176
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
156-
subgraph_selector->SelectInput(*cur_node, *snode);
177+
subgraph_selector->SelectInput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
157178
if (select_input) {
158179
// e.node is a subgraph node
159180
snode->label = label;
@@ -170,7 +191,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector
170191
CHECK_LT(nid, simple_nodes.size());
171192
const bool select_output =
172193
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
173-
subgraph_selector->SelectOutput(*cur_node, *snode);
194+
subgraph_selector->SelectOutput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
174195
if (select_output) {
175196
// it->first is a subgraph node
176197
snode->label = label;
@@ -325,14 +346,16 @@ void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector
325346
std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors,
326347
const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) {
327348
const auto& indexed_graph = g->indexed_graph();
349+
328350
auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
329351
return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
330352
};
331-
if (simple_nodes[snid]->label == -1 && subgraph_selector->Select(*node)) {
353+
if ((simple_nodes[snid]->label == -1) &&
354+
subgraph_selector->Select(*node, PrepareNodeAttr(*g, *node))) {
332355
// pre-select nodes that can be grouped in a subgraph
333356
std::vector<BiDirectedNode*> preselected_nodes;
334357
PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes,
335-
&preselected_nodes);
358+
&preselected_nodes);
336359

337360
// filter out unqualified pre-selected nodes
338361
std::vector<BiDirectedNode*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);

src/operator/subgraph/mkldnn/mkldnn_conv_property.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
2929
#include "../../tensor/matrix_op-inl.h"
3030
#include "../common.h"
31-
#include "../subgraph_property.h"
31+
#include "mkldnn_subgraph_base-inl.h"
3232

3333
namespace mxnet {
3434
namespace op {
@@ -61,15 +61,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
6161
disable_conv_sum_(dis_conv_sum),
6262
quantize_(quantize) {}
6363

64-
bool Select(const nnvm::Node &n) override {
64+
bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
6565
if (n.op() && n.op()->name == "Convolution") {
6666
const auto &param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
67-
if (param.kernel.ndim() == 2) {
67+
if (param.kernel.ndim() == 2 && SupportMKLDNNAttr(node_attr)) {
6868
status_ = disable_all_ ? kSuccess : kStart;
6969
matched_list_.clear();
7070
matched_list_.push_back(&n);
7171
return true;
72-
}
72+
}
7373
}
7474
return false;
7575
}
@@ -161,7 +161,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
161161
CHECK_GE(matched_list_.size(), 1);
162162
auto new_selector = SgMKLDNNConvSelector(disable_all_, disable_conv_bn_, disable_conv_act_,
163163
disable_conv_sum_, quantize_);
164-
new_selector.Select(*matched_list_[0]);
164+
new_selector.Select(*matched_list_[0], nullptr);
165165
*this = new_selector;
166166
}
167167
};

src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "../../nn/fully_connected-inl.h"
3434
#include "../../quantization/requantize-inl.h"
3535
#include "../common.h"
36-
#include "../subgraph_property.h"
36+
#include "mkldnn_subgraph_base-inl.h"
3737

3838
namespace mxnet {
3939
namespace op {

src/operator/subgraph/mkldnn/mkldnn_fc_property.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#include <vector>
3333
#include "../common.h"
3434
#include "../../tensor/matrix_op-inl.h"
35-
#include "../subgraph_property.h"
35+
#include "mkldnn_subgraph_base-inl.h"
3636
#include "mkldnn_fc-inl.h"
3737

3838
namespace mxnet {
@@ -58,8 +58,8 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
5858
disable_fc_eltwise_(dis_fc_eltwise),
5959
quantized_(quantized) {}
6060

61-
bool Select(const nnvm::Node &n) override {
62-
if (n.op() == Op::Get("FullyConnected")) {
61+
bool Select(const nnvm::Node &n, const std::shared_ptr<NodeAttr>& node_attr) override {
62+
if (n.op() == Op::Get("FullyConnected") && SupportMKLDNNAttr(node_attr)) {
6363
status_ = disable_fc_eltwise_ ? kSuccess : kStart;
6464
matched_list_.clear();
6565
matched_list_.push_back(&n);
@@ -150,7 +150,7 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
150150
void Reset() override {
151151
CHECK_GE(matched_list_.size(), 1);
152152
auto new_selector = SgMKLDNNFCSelector(disable_fc_eltwise_, quantized_);
153-
new_selector.Select(*matched_list_[0]);
153+
new_selector.Select(*matched_list_[0], nullptr);
154154
*this = new_selector;
155155
}
156156
};

src/operator/subgraph/mkldnn/mkldnn_post_quantize_align_scale_property.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include <string>
2525
#include <vector>
2626
#include "../common.h"
27-
#include "../subgraph_property.h"
27+
#include "mkldnn_subgraph_base-inl.h"
2828

2929
namespace mxnet {
3030
namespace op {

src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_POST_QUANTIZE_PROPERTY_H_
2121
#if MXNET_USE_MKLDNN == 1
2222

23+
#include <set>
2324
#include <string>
2425
#include <vector>
25-
#include <set>
26-
#include "../common.h"
27-
#include "../subgraph_property.h"
2826
#include "../../nn/mkldnn/mkldnn_convolution-inl.h"
29-
#include "mkldnn_conv-inl.h"
3027
#include "../../quantization/requantize-inl.h"
28+
#include "../common.h"
29+
#include "mkldnn_conv-inl.h"
30+
#include "mkldnn_subgraph_base-inl.h"
3131

3232
namespace mxnet {
3333
namespace op {
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_
20+
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_
21+
#if MXNET_USE_MKLDNN == 1
22+
23+
#include "../subgraph_property.h"
24+
25+
namespace mxnet {
26+
namespace op {
27+
28+
static inline bool SupportMKLDNNAttr(const std::shared_ptr<NodeAttr>& node_attr) {
29+
if (node_attr) {
30+
int ndim = node_attr->ishape[0].ndim();
31+
return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
32+
(node_attr->itype[0] == mshadow::kFloat32) && (ndim == 1 || ndim == 2 || ndim == 4);
33+
} else {
34+
return true;
35+
}
36+
}
37+
38+
} // namespace op
39+
} // namespace mxnet
40+
41+
#endif // MXNET_USE_MKLDNN == 1
42+
#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_SUBGRAPH_BASE_INL_H_

0 commit comments

Comments
 (0)