Skip to content

Commit 0107e3a

Browse files
wuxun-zhangUbuntu
authored andcommitted
Add omp parallel optimization for _contrib_BilinearReisze2D (apache#15584)
* Add omp parallel optimization for bilinear_resize op * retrigger CI * retrigger CI * trigger CI
1 parent db1212c commit 0107e3a

File tree

3 files changed

+101
-63
lines changed

3 files changed

+101
-63
lines changed

src/operator/contrib/bilinear_resize.cc

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
* \author Hang Zhang
2424
*/
2525
#include "bilinear_resize-inl.h"
26-
// #include "elemwise_op_common.h"
2726
#include "../elemwise_op_common.h"
2827

2928
namespace mxnet {
@@ -44,56 +43,66 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
4443
int inputHeight = itensor.size(2);
4544
int inputWidth = itensor.size(3);
4645

46+
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
47+
4748
DType *idata = itensor.dptr_;
4849
DType *odata = otensor.dptr_;
4950
channels = nbatch * channels;
51+
const int input_elems_per_channel = inputWidth * inputHeight;
52+
const int output_elems_per_channel = outputWidth * outputHeight;
53+
5054
// special case: just copy
5155
if (inputHeight == outputHeight && inputWidth == outputWidth) {
52-
for (int h2 = 0; h2 < outputHeight; ++h2) {
56+
#pragma omp parallel for num_threads(nthreads)
57+
for (int index = 0; index < output_elems_per_channel; index++) {
58+
const int h2 = index / outputWidth;
5359
const int h1 = h2;
54-
for (int w2 = 0; w2 < outputWidth; ++w2) {
55-
const int w1 = w2;
56-
const DType* pos1 = &idata[h1 * inputWidth + w1];
57-
DType* pos2 = &odata[h2 * outputWidth + w2];
58-
for (int c = 0; c < channels; ++c) {
59-
pos2[0] = pos1[0];
60-
pos1 += inputWidth * inputHeight;
61-
pos2 += outputWidth * outputHeight;
62-
}
60+
const int w2 = index % outputWidth;
61+
const int w1 = w2;
62+
const DType* pos1 = &idata[h1 * inputWidth + w1];
63+
DType* pos2 = &odata[index];
64+
for (int c = 0; c < channels; ++c) {
65+
*pos2 = *pos1;
66+
pos1 += input_elems_per_channel;
67+
pos2 += output_elems_per_channel;
6368
}
6469
}
6570
return;
6671
}
72+
6773
const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 1)/
6874
(outputHeight - 1) : 0.f;
6975
const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1) /
7076
(outputWidth - 1) : 0.f;
71-
for (int h2 = 0; h2 < outputHeight; ++h2) {
77+
#pragma omp parallel for num_threads(nthreads)
78+
for (int index = 0; index < output_elems_per_channel; index++) {
79+
const int h2 = index / outputWidth;
80+
const int w2 = index % outputWidth;
81+
7282
const float h1r = rheight * h2;
7383
const int h1 = h1r;
7484
const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
7585
const DType h1lambda = h1r - h1;
7686
const DType h0lambda = (DType)1. - h1lambda;
77-
for (int w2 = 0; w2 < outputWidth; ++w2) {
78-
const float w1r = rwidth * w2;
79-
const int w1 = w1r;
80-
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
81-
const DType w1lambda = w1r - w1;
82-
const DType w0lambda = (DType)1. - w1lambda;
83-
const DType* pos1 = &idata[h1 * inputWidth + w1];
84-
DType* pos2 = &odata[h2 * outputWidth + w2];
85-
for (int c = 0; c < channels; ++c) {
86-
pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p])
87-
+ h1lambda * (w0lambda * pos1[h1p * inputWidth]
88-
+ w1lambda * pos1[h1p * inputWidth + w1p]);
89-
pos1 += inputWidth * inputHeight;
90-
pos2 += outputWidth * outputHeight;
91-
}
87+
88+
const float w1r = rwidth * w2;
89+
const int w1 = w1r;
90+
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
91+
const DType w1lambda = w1r - w1;
92+
const DType w0lambda = (DType)1. - w1lambda;
93+
const DType* pos1 = &idata[h1 * inputWidth + w1];
94+
DType* pos2 = &odata[index];
95+
96+
for (int c = 0; c < channels; ++c) {
97+
*pos2 = h0lambda * (w0lambda * (*pos1) + w1lambda * *(pos1 + w1p))
98+
+ h1lambda * (w0lambda * *(pos1 + h1p * inputWidth)
99+
+ w1lambda * *(pos1 + h1p * inputWidth + w1p));
100+
pos1 += input_elems_per_channel;
101+
pos2 += output_elems_per_channel;
92102
}
93103
}
94104
}
95105

96-
97106
template<typename xpu, typename DType, typename AccReal>
98107
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
99108
const std::vector<TBlob> &input,
@@ -109,23 +118,28 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
109118
int inputHeight = gradInput.size(2);
110119
int inputWidth = gradInput.size(3);
111120

121+
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
122+
112123
DType *dataInput = gradInput.dptr_;
113124
DType *dataOutput = gradOutput.dptr_;
114125
channels = nbatch * channels;
126+
const int input_elems_per_channel = inputWidth * inputHeight;
127+
const int output_elems_per_channel = outputWidth * outputHeight;
115128

116129
// special case: same-size matching grids
117130
if (inputHeight == outputHeight && inputWidth == outputWidth) {
118-
for (int h2 = 0; h2 < outputHeight; ++h2) {
131+
#pragma omp parallel for num_threads(nthreads)
132+
for (int index = 0; index < output_elems_per_channel; index++) {
133+
const int h2 = index / outputWidth;
119134
const int h1 = h2;
120-
for (int w2 = 0; w2 < outputWidth; ++w2) {
121-
const int w1 = w2;
122-
DType* pos1 = &dataInput[h1 * inputWidth + w1];
123-
const DType* pos2 = &dataOutput[h2 * outputWidth + w2];
124-
for (int c = 0; c < channels; ++c) {
125-
pos1[0] += pos2[0];
126-
pos1 += inputWidth * inputHeight;
127-
pos2 += outputWidth * outputHeight;
128-
}
135+
const int w2 = index % outputWidth;
136+
const int w1 = w2;
137+
DType* pos1 = &dataInput[h1 * inputWidth + w1];
138+
const DType* pos2 = &dataOutput[index];
139+
for (int c = 0; c < channels; ++c) {
140+
*pos1 += *pos2;
141+
pos1 += input_elems_per_channel;
142+
pos2 += output_elems_per_channel;
129143
}
130144
}
131145
return;
@@ -134,28 +148,36 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
134148
(outputHeight - 1) : 0.f;
135149
const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1)/
136150
(outputWidth - 1) : 0.f;
137-
for (int h2 = 0; h2 < outputHeight; ++h2) {
151+
152+
#pragma omp parallel for num_threads(nthreads)
153+
for (int index = 0; index < output_elems_per_channel; index++) {
154+
const int h2 = index / outputWidth;
155+
const int w2 = index % outputWidth;
156+
138157
const float h1r = rheight * h2;
139158
const int h1 = h1r;
140159
const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
141160
const DType h1lambda = h1r - h1;
142161
const DType h0lambda = (DType)1. - h1lambda;
143-
for (int w2 = 0; w2 < outputWidth; ++w2) {
144-
const float w1r = rwidth * w2;
145-
const int w1 = w1r;
146-
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
147-
const DType w1lambda = w1r - w1;
148-
const DType w0lambda = (DType)1. - w1lambda;
149-
DType* posInput = &dataInput[h1 * inputWidth + w1];
150-
const DType* posOutput = &dataOutput[h2 * outputWidth + w2];
151-
for (int c = 0; c < channels; ++c) {
152-
posInput[0] += h0lambda * w0lambda * posOutput[0];
153-
posInput[w1p] += h0lambda * w1lambda * posOutput[0];
154-
posInput[h1p * inputWidth] += h1lambda * w0lambda * posOutput[0];
155-
posInput[h1p * inputWidth + w1p] += h1lambda * w1lambda * posOutput[0];
156-
posInput += inputWidth * inputHeight;
157-
posOutput += outputWidth * outputHeight;
162+
163+
const float w1r = rwidth * w2;
164+
const int w1 = w1r;
165+
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
166+
const DType w1lambda = w1r - w1;
167+
const DType w0lambda = (DType)1. - w1lambda;
168+
169+
DType* posInput = &dataInput[h1 * inputWidth + w1];
170+
const DType* posOutput = &dataOutput[index];
171+
for (int c = 0; c < channels; ++c) {
172+
#pragma omp critical
173+
{
174+
*posInput += h0lambda * w0lambda * (*posOutput);
175+
*(posInput + w1p) += h0lambda * w1lambda * (*posOutput);
176+
*(posInput + h1p * inputWidth) += h1lambda * w0lambda * (*posOutput);
177+
*(posInput + h1p * inputWidth + w1p) += h1lambda * w1lambda * (*posOutput);
158178
}
179+
posInput += input_elems_per_channel;
180+
posOutput += output_elems_per_channel;
159181
}
160182
}
161183

@@ -165,19 +187,19 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
165187
int inputWidthLike = gradInputLike.size(3);
166188
DType *dataInputLike = gradInputLike.dptr_;
167189
int channelsLike = nbatch * gradInputLike.size(1);
168-
for (int h_like = 0; h_like < inputHeightLike; ++h_like) {
169-
for (int w_like = 0; w_like < inputWidthLike; ++w_like) {
170-
DType *posInput = &dataInputLike[h_like * inputWidthLike + w_like];
171-
for (int c = 0; c < channelsLike; ++c) {
172-
posInput[0] = 0;
173-
posInput += inputWidthLike * inputHeightLike;
174-
}
190+
191+
const int inputLike_elems_per_channel = inputHeightLike * inputWidthLike;
192+
#pragma omp parallel for num_threads(nthreads)
193+
for (int index = 0; index < inputLike_elems_per_channel; index++) {
194+
DType *posInput = &dataInputLike[index];
195+
for (int c = 0; c < channelsLike; ++c) {
196+
*posInput = 0;
197+
posInput += inputLike_elems_per_channel;
175198
}
176199
}
177200
}
178201
}
179202

180-
181203
DMLC_REGISTER_PARAMETER(BilinearSampleParam);
182204

183205
NNVM_REGISTER_OP(_contrib_BilinearResize2D)

tests/python/gpu/test_operator_gpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,22 @@ def test_flatten_slice_after_conv():
11431143
check_consistency(slice_sym, ctx_list)
11441144

11451145

1146+
@with_seed()
1147+
def test_bilinear_resize_op():
1148+
ctx_list = [{'ctx': mx.cpu(0), 'data': (2, 2, 20, 20), 'type_dict': {'data': np.float32}},
1149+
{'ctx': mx.gpu(0), 'data': (2, 2, 20, 20), 'type_dict': {'data': np.float32}}]
1150+
1151+
data = mx.sym.Variable('data')
1152+
sym = mx.sym.contrib.BilinearResize2D(data, height=10, width=5)
1153+
check_consistency(sym, ctx_list)
1154+
1155+
sym = mx.sym.contrib.BilinearResize2D(data, None, scale_height=2, scale_width=0.5, mode='odd_scale')
1156+
check_consistency(sym, ctx_list)
1157+
1158+
sym = mx.sym.contrib.BilinearResize2D(data, None, scale_height=0.5, scale_width=2, mode='to_even_up')
1159+
check_consistency(sym, ctx_list)
1160+
1161+
11461162
@with_seed()
11471163
def test_global_pooling():
11481164
def test_1d_pooling(pool_type, p_value=2):

tests/python/unittest/test_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7663,7 +7663,7 @@ def py_bilinear_resize(x, outputHeight, outputWidth):
76637663
w1r = 1.0 * w2 * rwidth
76647664
w1 = int(np.floor(w1r))
76657665
w1lambda = w1r - w1
7666-
w1p = 1 if w1 < (inputHeight - 1) else 0
7666+
w1p = 1 if w1 < (inputWidth - 1) else 0
76677667
for b in range(batch):
76687668
for c in range(channel):
76697669
y[b][c][h2][w2] = (1-h1lambda)*((1-w1lambda)*x[b][c][h1][w1] + \

0 commit comments

Comments
 (0)