Skip to content

Commit 57dc78d

Browse files
committed
[MKL-DNN] Integrate Conv3d and Pool3d/1d (apache#17884)
* Integrate MKl-DNN conv3d and pool3d/1d * fix UT & address comments * clean code * rebase against latest master
1 parent 6d8b679 commit 57dc78d

15 files changed

+492
-279
lines changed

src/operator/nn/mkldnn/mkldnn_act.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
4848
}
4949

5050
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
51-
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
51+
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
5252
if ((input.shape().ndim() < 1) ||
53-
(input.shape().ndim() > 4) ||
53+
(input.shape().ndim() > 5) ||
5454
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
5555
return false;
5656
return SupportMKLDNNAct(param);
@@ -63,9 +63,9 @@ bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param) {
6363
}
6464

6565
bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input) {
66-
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
66+
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
6767
if ((input.shape().ndim() < 1) ||
68-
(input.shape().ndim() > 4) ||
68+
(input.shape().ndim() > 5) ||
6969
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
7070
return false;
7171
return SupportMKLDNNLeakyRelu(param);

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
153153
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
154154
return false;
155155
}
156-
157156
return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
158-
(ndim == 1 || ndim == 2 || ndim == 4);
157+
(ndim == 1 || ndim == 2 || ndim == 4);
159158
}
160159

161160
static inline bool SupportMKLDNNRnn(const NDArray &input) {
@@ -332,20 +331,32 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
332331
if (num_groups == 1) {
333332
return GetMemDesc(arr, dtype);
334333
} else {
335-
auto ndim = arr.shape().ndim();
336-
CHECK((ndim == 3) || (ndim == 4))
337-
<< "MKL-DNN weight currectly supports 3d and 4d layout";
334+
const auto ndim = arr.shape().ndim();
335+
CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
336+
<< "MKL-DNN weight currently supports 3d or 4d or 5d layout";
338337
auto tz = mkldnn::memory::dims{0};
339-
const int N = 0, H = 2, W = 3, C = 1;
340-
if (ndim == 3) {
341-
tz = mkldnn::memory::dims{
342-
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
343-
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
344-
} else {
345-
tz = mkldnn::memory::dims{
346-
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
347-
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
348-
static_cast<int>(arr.shape()[W])};
338+
int N = 0, C = 1, H = 2, W = 3;
339+
int D = -1;
340+
if (ndim == 5) {
341+
D = 2;
342+
H = 3;
343+
W = 4;
344+
}
345+
switch (ndim) {
346+
case 3:
347+
tz = mkldnn::memory::dims{
348+
num_groups, arr.shape()[N] / num_groups,
349+
arr.shape()[C], arr.shape()[H]};
350+
break;
351+
case 4:
352+
tz = mkldnn::memory::dims{
353+
num_groups, arr.shape()[N] / num_groups,
354+
arr.shape()[C], arr.shape()[H], arr.shape()[W]};
355+
break;
356+
case 5:
357+
tz = mkldnn::memory::dims{
358+
num_groups, arr.shape()[N] / num_groups,
359+
arr.shape()[C], arr.shape()[D], arr.shape()[H], arr.shape()[W]};
349360
}
350361
return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
351362
}

src/operator/nn/mkldnn/mkldnn_base.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -240,31 +240,44 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
240240
auto tz = mkldnn::memory::dims{0};
241241
auto format_tag = mkldnn::memory::format_tag::undef;
242242
auto engine = CpuEngine::Get()->get_engine();
243-
const int O = 0, I = 1, H = 2, W = 3;
244-
if (arr.shape().ndim() == 2) {
245-
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I])};
243+
const int ndim = arr.shape().ndim();
244+
int O = 0, I = 1, H = 2, W = 3;
245+
int D = -1;
246+
if (ndim == 5) {
247+
D = 2;
248+
H = 3;
249+
W = 4;
250+
}
251+
if (ndim == 2) {
252+
tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]};
246253
format_tag = mkldnn::memory::format_tag::oi;
247-
} else if (arr.shape().ndim() == 3) {
254+
} else if (ndim == 3) {
248255
tz = num_groups > 1
249-
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
250-
static_cast<int>(arr.shape()[I]),
251-
static_cast<int>(arr.shape()[H])}
252-
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
253-
static_cast<int>(arr.shape()[I]),
254-
static_cast<int>(arr.shape()[H])};
256+
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
257+
arr.shape()[I], arr.shape()[H]}
258+
: mkldnn::memory::dims{arr.shape()[O],
259+
arr.shape()[I], arr.shape()[H]};
255260
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw
256261
: mkldnn::memory::format_tag::oiw;
257-
} else if (arr.shape().ndim() == 4) {
262+
} else if (ndim == 4) {
258263
tz = num_groups > 1
259-
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
260-
static_cast<int>(arr.shape()[I]),
261-
static_cast<int>(arr.shape()[H]),
262-
static_cast<int>(arr.shape()[W])}
264+
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
265+
arr.shape()[I], arr.shape()[H],
266+
arr.shape()[W]}
263267
: mkldnn::memory::dims{
264-
static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I]),
265-
static_cast<int>(arr.shape()[H]), static_cast<int>(arr.shape()[W])};
268+
arr.shape()[O], arr.shape()[I], arr.shape()[H], arr.shape()[W]};
266269
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw
267270
: mkldnn::memory::format_tag::oihw;
271+
} else if (ndim == 5) {
272+
tz = num_groups > 1
273+
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
274+
arr.shape()[I], arr.shape()[D],
275+
arr.shape()[H], arr.shape()[W]}
276+
: mkldnn::memory::dims{
277+
arr.shape()[O], arr.shape()[I], arr.shape()[D],
278+
arr.shape()[H], arr.shape()[W]};
279+
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goidhw
280+
: mkldnn::memory::format_tag::oidhw;
268281
} else {
269282
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
270283
}

src/operator/nn/mkldnn/mkldnn_convolution.cc

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam);
3737

3838
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
3939
if ((params.kernel.ndim() != 1) &&
40-
(params.kernel.ndim() != 2))
40+
(params.kernel.ndim() != 2) &&
41+
(params.kernel.ndim() != 3))
4142
return false;
4243
return SupportMKLDNNQuantize(input.dtype()) &&
4344
((input.shape().ndim() == 3) ||
44-
(input.shape().ndim() == 4));
45+
(input.shape().ndim() == 4) ||
46+
(input.shape().ndim() == 5));
4547
}
4648

4749
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
@@ -77,9 +79,19 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
7779
strides[1] = param.conv_param.stride[1];
7880
padding[0] = param.conv_param.pad[0];
7981
padding[1] = param.conv_param.pad[1];
82+
} else if (param.conv_param.kernel.ndim() == 3) {
83+
CHECK_GE(param.conv_param.stride.ndim(), 3);
84+
CHECK_GE(param.conv_param.pad.ndim(), 3);
85+
CHECK_GE(param.conv_param.dilate.ndim(), 3);
86+
strides[0] = param.conv_param.stride[0];
87+
strides[1] = param.conv_param.stride[1];
88+
strides[2] = param.conv_param.stride[2];
89+
padding[0] = param.conv_param.pad[0];
90+
padding[1] = param.conv_param.pad[1];
91+
padding[2] = param.conv_param.pad[2];
8092
} else {
8193
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size "
82-
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2.";
94+
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2 or 3.";
8395
}
8496
mkldnn::primitive_attr attr;
8597
mkldnn::post_ops ops;
@@ -141,9 +153,13 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
141153
} else if (param.conv_param.dilate.ndim() == 2) {
142154
dilates[0] = param.conv_param.dilate[0] - 1;
143155
dilates[1] = param.conv_param.dilate[1] - 1;
156+
} else if (param.conv_param.dilate.ndim() == 3) {
157+
dilates[0] = param.conv_param.dilate[0] - 1;
158+
dilates[1] = param.conv_param.dilate[1] - 1;
159+
dilates[2] = param.conv_param.dilate[2] - 1;
144160
} else {
145161
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.conv_param.dilate.ndim()
146-
<< ", supporting only 1 or 2.";
162+
<< ", supporting only 1 or 2 or 3.";
147163
}
148164
if (bias_md_ptr == nullptr) {
149165
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
@@ -181,9 +197,19 @@ static std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
181197
strides[1] = param.stride[1];
182198
padding[0] = param.pad[0];
183199
padding[1] = param.pad[1];
200+
} else if (param.kernel.ndim() == 3) {
201+
CHECK_GE(param.stride.ndim(), 3);
202+
CHECK_GE(param.pad.ndim(), 3);
203+
CHECK_GE(param.dilate.ndim(), 3);
204+
strides[0] = param.stride[0];
205+
strides[1] = param.stride[1];
206+
strides[2] = param.stride[2];
207+
padding[0] = param.pad[0];
208+
padding[1] = param.pad[1];
209+
padding[2] = param.pad[2];
184210
} else {
185211
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
186-
<< ", supporting only 1 or 2.";
212+
<< ", supporting only 1 or 2 or 3.";
187213
}
188214

189215
auto GetConvBwdDataPd = [&data, &weight, &output,
@@ -216,9 +242,13 @@ static std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
216242
} else if (param.dilate.ndim() == 2) {
217243
dilates[0] = param.dilate[0] - 1;
218244
dilates[1] = param.dilate[1] - 1;
245+
} else if (param.dilate.ndim() == 3) {
246+
dilates[0] = param.dilate[0] - 1;
247+
dilates[1] = param.dilate[1] - 1;
248+
dilates[2] = param.dilate[2] - 1;
219249
} else {
220250
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
221-
<< param.dilate.ndim() << ", supporting only 1 or 2.";
251+
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
222252
}
223253
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md,
224254
weight_md, out_md, strides, dilates, padding,
@@ -250,9 +280,19 @@ static std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
250280
strides[1] = param.stride[1];
251281
padding[0] = param.pad[0];
252282
padding[1] = param.pad[1];
283+
} else if (param.kernel.ndim() == 3) {
284+
CHECK_GE(param.stride.ndim(), 3);
285+
CHECK_GE(param.pad.ndim(), 3);
286+
CHECK_GE(param.dilate.ndim(), 3);
287+
strides[0] = param.stride[0];
288+
strides[1] = param.stride[1];
289+
strides[2] = param.stride[2];
290+
padding[0] = param.pad[0];
291+
padding[1] = param.pad[1];
292+
padding[2] = param.pad[2];
253293
} else {
254294
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
255-
<< ", supporting only 1 or 2.";
295+
<< ", supporting only 1 or 2 or 3.";
256296
}
257297

258298
auto GetConvBwdWeightsPd = [&data, &weight, &output,
@@ -291,9 +331,13 @@ static std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
291331
} else if (param.dilate.ndim() == 2) {
292332
dilates[0] = param.dilate[0] - 1;
293333
dilates[1] = param.dilate[1] - 1;
334+
} else if (param.dilate.ndim() == 3) {
335+
dilates[0] = param.dilate[0] - 1;
336+
dilates[1] = param.dilate[1] - 1;
337+
dilates[2] = param.dilate[2] - 1;
294338
} else {
295339
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
296-
<< param.dilate.ndim() << ", supporting only 1 or 2.";
340+
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
297341
}
298342
if (bias == nullptr) {
299343
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,

src/operator/nn/mkldnn/mkldnn_pooling-inl.h

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,15 @@ class MKLDNNPoolingFwd {
3838
public:
3939
MKLDNNPoolingFwd(const mxnet::NDArray &input,
4040
const mxnet::NDArray &output,
41-
const int kernel_h, const int kernel_w,
42-
const int stride_h, const int stride_w,
43-
const int padding_t, const int padding_b,
44-
const int padding_l, const int padding_r,
41+
const mkldnn::memory::dims &kernel,
42+
const mkldnn::memory::dims &strides,
43+
const mkldnn::memory::dims &pad_l,
44+
const mkldnn::memory::dims &pad_r,
4545
const mkldnn::algorithm alg_kind,
4646
const bool with_workspace, const bool is_train):
4747
with_workspace_(with_workspace),
4848
fwd_(nullptr) {
49-
Init(input, output,
50-
kernel_h, kernel_w, stride_h, stride_w,
51-
padding_t, padding_b, padding_l, padding_r,
49+
Init(input, output, kernel, strides, pad_l, pad_r,
5250
is_train, alg_kind);
5351
}
5452

@@ -67,10 +65,10 @@ class MKLDNNPoolingFwd {
6765
private:
6866
void Init(const mxnet::NDArray &input,
6967
const mxnet::NDArray &output,
70-
const int kernel_h, const int kernel_w,
71-
const int stride_h, const int stride_w,
72-
const int padding_t, const int padding_b,
73-
const int padding_l, const int padding_r,
68+
const mkldnn::memory::dims &kernel,
69+
const mkldnn::memory::dims &strides,
70+
const mkldnn::memory::dims &pad_l,
71+
const mkldnn::memory::dims &pad_r,
7472
const bool is_train, const mkldnn::algorithm alg_kind);
7573
};
7674

@@ -98,31 +96,46 @@ inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
9896
}
9997

10098
inline bool SupportMKLDNNPooling(const PoolingParam &param) {
101-
return param.kernel.ndim() == 2 &&
99+
return (param.kernel.ndim() == 1 || param.kernel.ndim() == 2 ||
100+
param.kernel.ndim() == 3) &&
102101
(param.pool_type == pool_enum::kMaxPooling ||
103102
param.pool_type == pool_enum::kAvgPooling) &&
104-
(!param.layout.has_value() || param.layout.value() == mshadow::kNCHW);
103+
(!param.layout.has_value() ||
104+
(param.layout.value() == mshadow::kNCW || param.layout.value() == mshadow::kNCHW ||
105+
param.layout.value() == mshadow::kNCDHW));
105106
}
106107

107108
inline bool SupportMKLDNNPooling(const PoolingParam &param,
108-
const mxnet::TShape &dshape) {
109-
bool ret = SupportMKLDNNPooling(param);
110-
if (!ret)
109+
const NDArray &input) {
110+
const auto dshape = input.shape();
111+
const auto ndim = dshape.ndim();
112+
const auto dtype = input.dtype();
113+
114+
if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 || ndim == 5) &&
115+
(dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
116+
return false;
117+
118+
if (!SupportMKLDNNPooling(param))
111119
return false;
112120

113121
if (param.pooling_convention == pool_enum::kValid) {
114122
return true;
115123
} else {
116124
if (param.pool_type == pool_enum::kAvgPooling) {
117-
CHECK_EQ(dshape.ndim(), 4);
118125
// mkldnn works differently when padding is asymmetric, so let's skip this case.
119-
if (param.pad[0] == GetPaddingSizeFull(dshape[2], param.pad[0], param.pad[0], param.kernel[0],
120-
param.stride[0]) &&
121-
param.pad[1] == GetPaddingSizeFull(dshape[3], param.pad[1], param.pad[1], param.kernel[1],
122-
param.stride[1])) {
123-
return true;
126+
bool is_symmetric = true;
127+
switch (ndim) {
128+
case 5:
129+
is_symmetric = is_symmetric && (param.pad[2] == GetPaddingSizeFull(dshape[4],
130+
param.pad[2], param.pad[2], param.kernel[2], param.stride[2]));
131+
case 4:
132+
is_symmetric = is_symmetric && (param.pad[1] == GetPaddingSizeFull(dshape[3],
133+
param.pad[1], param.pad[1], param.kernel[1], param.stride[1]));
134+
case 3:
135+
is_symmetric = is_symmetric && (param.pad[0] == GetPaddingSizeFull(dshape[2],
136+
param.pad[0], param.pad[0], param.kernel[0], param.stride[0]));
124137
}
125-
return false;
138+
return is_symmetric;
126139
}
127140
return param.pool_type == pool_enum::kMaxPooling;
128141
}

0 commit comments

Comments
 (0)