Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8fd5ab5

Browse files
committed
image crop gpu
1 parent 32bb374 commit 8fd5ab5

File tree

6 files changed

+76
-87
lines changed

6 files changed

+76
-87
lines changed

src/operator/contrib/bilinear_resize-inl.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ static unsigned getNumThreads(int nElem, const bool smaller) {
6262

6363
// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
6464
template<typename xpu, typename Dtype, typename Acctype>
65-
__global__ void caffe_gpu_interp2_kernel(const int n,
65+
__global__ void
66+
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
67+
caffe_gpu_interp2_kernel(const int n,
6668
const Acctype rheight, const Acctype rwidth,
6769
const Tensor<xpu, 3, Dtype> data1,
6870
Tensor<xpu, 3, Dtype> data2,
@@ -111,7 +113,9 @@ __global__ void caffe_gpu_interp2_kernel(const int n,
111113

112114
// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
113115
template<typename xpu, typename Dtype, typename Acctype>
114-
__global__ void caffe_gpu_interp2_kernel(const int n,
116+
__global__ void
117+
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
118+
caffe_gpu_interp2_kernel(const int n,
115119
const Acctype rheight, const Acctype rwidth,
116120
const Tensor<xpu, 4, Dtype> data1,
117121
Tensor<xpu, 4, Dtype> data2,

src/operator/image/crop-inl.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs,
9494
return true;
9595
}
9696

97+
template<typename xpu>
9798
inline void CropImpl(int x,
9899
int y,
99100
int width,
@@ -106,7 +107,7 @@ inline void CropImpl(int x,
106107
const TBlob& data = inputs[0];
107108
const TBlob& out = outputs[0];
108109
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
109-
Stream<cpu>* s = ctx.get_stream<cpu>();
110+
Stream<xpu>* s = ctx.get_stream<xpu>();
110111
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
111112
if (ndim == 3) {
112113
begin[0] = y;
@@ -118,14 +119,18 @@ inline void CropImpl(int x,
118119
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
119120
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
120121
size_t num_threads = out.shape_.FlatTo2D()[0];
121-
mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
122+
if (std::is_same<xpu, gpu>::value) {
123+
num_threads *= out.shape_.get<ndim>()[ndim - 1];
124+
}
125+
mxnet_op::Kernel<slice_forward<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
122126
out.dptr<DType>(), data.dptr<DType>(),
123127
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
124128
})
125129
})
126130
})
127131
}
128132

133+
template<typename xpu>
129134
inline void CropBackwardImpl(int x,
130135
int y,
131136
int width,
@@ -138,7 +143,7 @@ inline void CropBackwardImpl(int x,
138143
if (req[0] == kNullOp) return;
139144
const TBlob& output_grad = inputs[0];
140145
const TBlob& input_grad = outputs[0];
141-
Stream<cpu>* s = ctx.get_stream<cpu>();
146+
Stream<xpu>* s = ctx.get_stream<xpu>();
142147
if (req[0] == kWriteTo) {
143148
Fill(s, input_grad, req[0], 0);
144149
} else if (req[0] == kWriteInplace) {
@@ -156,32 +161,37 @@ inline void CropBackwardImpl(int x,
156161
MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
157162
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
158163
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
159-
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
164+
if (std::is_same<xpu, gpu>::value) {
165+
num_threads *= output_grad.shape_.get<ndim>()[ndim - 1];
166+
}
167+
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
160168
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
161169
input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), begin, step);
162170
})
163171
})
164172
})
165173
}
166174

175+
template<typename xpu>
167176
inline void CropOpForward(const nnvm::NodeAttrs &attrs,
168177
const OpContext &ctx,
169178
const std::vector<TBlob> &inputs,
170179
const std::vector<OpReqType> &req,
171180
const std::vector<TBlob> &outputs) {
172181
CHECK_EQ(outputs.size(), 1U);
173182
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
174-
CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
183+
CropImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
175184
}
176185

186+
template<typename xpu>
177187
inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
178188
const OpContext &ctx,
179189
const std::vector<TBlob> &inputs,
180190
const std::vector<OpReqType> &req,
181191
const std::vector<TBlob> &outputs) {
182192
CHECK_EQ(outputs.size(), 1U);
183193
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
184-
CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
194+
CropBackwardImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
185195
}
186196
} // namespace image
187197
} // namespace op

src/operator/image/crop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ to the given size.
6969
.set_attr_parser(ParamParser<CropParam>)
7070
.set_attr<mxnet::FInferShape>("FInferShape", CropShape)
7171
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
72-
.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
72+
.set_attr<FCompute>("FCompute<cpu>", CropOpForward<cpu>)
7373
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" })
7474
.add_argument("data", "NDArray-or-Symbol", "The input.")
7575
.add_arguments(CropParam::__FIELDS__());
@@ -79,7 +79,7 @@ NNVM_REGISTER_OP(_backward_image_crop)
7979
.set_num_inputs(1)
8080
.set_num_outputs(1)
8181
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
82-
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);
82+
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward<cpu>);
8383

8484
} // namespace image
8585
} // namespace op

src/operator/image/crop.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "crop-inl.h"
21+
22+
namespace mxnet {
23+
namespace op {
24+
namespace image {
25+
26+
NNVM_REGISTER_OP(_image_crop)
27+
.set_attr<FCompute>("FCompute<gpu>", CropOpForward<gpu>);
28+
29+
NNVM_REGISTER_OP(_backward_image_crop)
30+
.set_attr<FCompute>("FCompute<gpu>", CropOpBackward<gpu>);
31+
32+
} // namespace image
33+
} // namespace op
34+
} // namespace mxnet

tests/python/gpu/test_gluon_transforms.py

Lines changed: 11 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -28,79 +28,22 @@
2828
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
2929
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
3030
from common import assertRaises, setup_module, with_seed, teardown
31-
31+
from test_gluon_data_vision import test_to_tensor, test_normalize, test_crop_resize
3232

3333
set_default_context(mx.gpu(0))
3434

3535
@with_seed()
36-
def test_normalize():
37-
# 3D Input
38-
data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
39-
out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
40-
data_expected_3d = data_in_3d.asnumpy()
41-
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
42-
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
43-
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
44-
assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())
45-
46-
# 4D Input
47-
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
48-
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
49-
data_expected_4d = data_in_4d.asnumpy()
50-
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
51-
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
52-
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
53-
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
54-
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
55-
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
56-
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
57-
58-
# Default normalize values i.e., mean=0, std=1
59-
data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
60-
out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
61-
data_expected_3d_def = data_in_3d_def.asnumpy()
62-
assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
36+
def test_normalize_gpu():
37+
test_normalize()
6338

64-
# Invalid Input - Neither 3D or 4D input
65-
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
66-
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
67-
assertRaises(MXNetError, normalize_transformer, invalid_data_in)
68-
69-
# Invalid Input - Channel neither 1 or 3
70-
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
71-
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
72-
assertRaises(MXNetError, normalize_transformer, invalid_data_in)
7339

7440
@with_seed()
75-
def test_to_tensor():
76-
# 3D Input
77-
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
78-
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
79-
assert_almost_equal(out_nd.asnumpy(), np.transpose(
80-
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
81-
82-
# 4D Input
83-
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
84-
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
85-
assert_almost_equal(out_nd.asnumpy(), np.transpose(
86-
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
41+
def test_to_tensor_gpu():
42+
test_to_tensor()
8743

88-
# Invalid Input
89-
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
90-
transformer = transforms.ToTensor()
91-
assertRaises(MXNetError, transformer, invalid_data_in)
92-
93-
# Bounds (0->0, 255->1)
94-
data_in = np.zeros((10, 20, 3)).astype(dtype=np.uint8)
95-
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
96-
assert same(out_nd.asnumpy(), np.transpose(np.zeros(data_in.shape, dtype=np.float32), (2, 0, 1)))
97-
98-
data_in = np.full((10, 20, 3), 255).astype(dtype=np.uint8)
99-
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
100-
assert same(out_nd.asnumpy(), np.transpose(np.ones(data_in.shape, dtype=np.float32), (2, 0, 1)))
10144

10245
@with_seed()
103-
def test_resize():
46+
def test_resize_gpu():
10447
# Test with normal case 3D input float type
10548
data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
10649
out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
@@ -155,3 +98,8 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
15598
data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
15699
out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
157100
assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)
101+
102+
103+
@with_seed()
104+
def test_crop_resize_gpu():
105+
test_crop_resize()

tests/python/unittest/test_gluon_data_vision.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,18 @@ def _test_crop_resize_with_diff_type(dtype):
146146
assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all())
147147
# test normal case with resize
148148
data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype)
149-
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in)
150-
data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2)
149+
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_in)
150+
data_expected = transforms.Resize(size=25, interpolation=1)(nd.slice(data_in, (0, 0, 0), (50, 100, 3)))
151151
assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
152152
# test 4D input with resize
153153
data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype)
154-
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in)
154+
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_bath_in)
155155
for i in range(len(out_batch_nd)):
156-
assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(),
157-
out_batch_nd[i].asnumpy())
156+
actual = transforms.Resize(size=25, interpolation=1)(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3))).asnumpy()
157+
expected = out_batch_nd[i].asnumpy()
158+
assert_almost_equal(expected, actual)
158159
# test with resize height and width should be greater than 0
159-
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2)
160+
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 1)
160161
assertRaises(MXNetError, transformer, data_in)
161162
# test height and width should be greater than 0
162163
transformer = transforms.CropResize(0, 0, -100, -50)
@@ -188,14 +189,6 @@ def test_crop_backward(test_nd_arr, TestCase):
188189
data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
189190
for test_case in test_list:
190191
test_crop_backward(data_in, test_case)
191-
192-
193-
194-
# check numeric gradient of nd.image.crop
195-
# in_data = np.arange(36).reshape(3, 4, 3)
196-
# data = mx.sym.Variable('data')
197-
# image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2)
198-
# check_numeric_gradient(image_crop_sym, [in_data])
199192

200193

201194
@with_seed()

0 commit comments

Comments
 (0)