Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d1797cf

Browse files
Kh4Lapeforest
authored andcommitted
Move MRCNNMaskTarget op to contrib (#16486)
Signed-off-by: Serge Panev <[email protected]>
1 parent c4580ae commit d1797cf

File tree

4 files changed

+108
-107
lines changed

4 files changed

+108
-107
lines changed

src/operator/contrib/mrcnn_target-inl.h renamed to src/operator/contrib/mrcnn_mask_target-inl.h

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
/*!
2121
* Copyright (c) 2019 by Contributors
22-
* \file mrcnn_target-inl.h
22+
* \file mrcnn_mask_target-inl.h
2323
* \brief Mask-RCNN target generator
2424
* \author Serge Panev
2525
*/
2626

2727

28-
#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
29-
#define MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
28+
#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
29+
#define MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
3030

3131
#include <mxnet/operator.h>
3232
#include <vector>
@@ -42,13 +42,13 @@ namespace mrcnn_index {
4242
enum ROIAlignOpOutputs {kMask, kMaskClasses};
4343
} // namespace mrcnn_index
4444

45-
struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
45+
struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
4646
int num_rois;
4747
int num_classes;
4848
int mask_size;
4949
int sample_ratio;
5050

51-
DMLC_DECLARE_PARAMETER(MRCNNTargetParam) {
51+
DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
5252
DMLC_DECLARE_FIELD(num_rois)
5353
.describe("Number of sampled RoIs.");
5454
DMLC_DECLARE_FIELD(num_classes)
@@ -60,11 +60,11 @@ struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
6060
}
6161
};
6262

63-
inline bool MRCNNTargetShape(const NodeAttrs& attrs,
64-
std::vector<mxnet::TShape>* in_shape,
65-
std::vector<mxnet::TShape>* out_shape) {
63+
inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
64+
std::vector<mxnet::TShape>* in_shape,
65+
std::vector<mxnet::TShape>* out_shape) {
6666
using namespace mshadow;
67-
const MRCNNTargetParam& param = nnvm::get<MRCNNTargetParam>(attrs.parsed);
67+
const MRCNNMaskTargetParam& param = nnvm::get<MRCNNMaskTargetParam>(attrs.parsed);
6868

6969
CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]";
7070

@@ -98,9 +98,9 @@ inline bool MRCNNTargetShape(const NodeAttrs& attrs,
9898
return true;
9999
}
100100

101-
inline bool MRCNNTargetType(const NodeAttrs& attrs,
102-
std::vector<int>* in_type,
103-
std::vector<int>* out_type) {
101+
inline bool MRCNNMaskTargetType(const NodeAttrs& attrs,
102+
std::vector<int>* in_type,
103+
std::vector<int>* out_type) {
104104
CHECK_EQ(in_type->size(), 4);
105105
int dtype = (*in_type)[1];
106106
CHECK_NE(dtype, -1) << "Input must have specified type";
@@ -112,21 +112,21 @@ inline bool MRCNNTargetType(const NodeAttrs& attrs,
112112
}
113113

114114
template<typename xpu>
115-
void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
116-
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
115+
void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
116+
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
117117

118118
template<typename xpu>
119-
void MRCNNTargetCompute(const nnvm::NodeAttrs& attrs,
120-
const OpContext &ctx,
121-
const std::vector<TBlob> &inputs,
122-
const std::vector<OpReqType> &req,
123-
const std::vector<TBlob> &outputs) {
119+
void MRCNNMaskTargetCompute(const nnvm::NodeAttrs& attrs,
120+
const OpContext &ctx,
121+
const std::vector<TBlob> &inputs,
122+
const std::vector<OpReqType> &req,
123+
const std::vector<TBlob> &outputs) {
124124
auto s = ctx.get_stream<xpu>();
125-
const auto& p = dmlc::get<MRCNNTargetParam>(attrs.parsed);
126-
MRCNNTargetRun<xpu>(p, inputs, outputs, s);
125+
const auto& p = dmlc::get<MRCNNMaskTargetParam>(attrs.parsed);
126+
MRCNNMaskTargetRun<xpu>(p, inputs, outputs, s);
127127
}
128128

129129
} // namespace op
130130
} // namespace mxnet
131131

132-
#endif // MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
132+
#endif // MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_

src/operator/contrib/mrcnn_target.cu renamed to src/operator/contrib/mrcnn_mask_target.cu

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
/*!
2121
* Copyright (c) 2019 by Contributors
22-
* \file mrcnn_target.cu
22+
* \file mrcnn_mask_target.cu
2323
* \brief Mask-RCNN target generator
2424
* \author Serge Panev
2525
*/
2626

27-
#include "./mrcnn_target-inl.h"
27+
#include "./mrcnn_mask_target-inl.h"
2828

2929
namespace mxnet {
3030
namespace op {
@@ -183,21 +183,21 @@ __device__ void RoIAlignForward(
183183

184184

185185
template<typename DType>
186-
__global__ void MRCNNTargetKernel(const DType *rois,
187-
const DType *gt_masks,
188-
const DType *matches,
189-
const DType *cls_targets,
190-
DType* sampled_masks,
191-
DType* mask_cls,
192-
const int total_out_el,
193-
int batch_size,
194-
int num_classes,
195-
int num_rois,
196-
int num_gtmasks,
197-
int gt_height,
198-
int gt_width,
199-
int mask_size,
200-
int sample_ratio) {
186+
__global__ void MRCNNMaskTargetKernel(const DType *rois,
187+
const DType *gt_masks,
188+
const DType *matches,
189+
const DType *cls_targets,
190+
DType* sampled_masks,
191+
DType* mask_cls,
192+
const int total_out_el,
193+
int batch_size,
194+
int num_classes,
195+
int num_rois,
196+
int num_gtmasks,
197+
int gt_height,
198+
int gt_width,
199+
int mask_size,
200+
int sample_ratio) {
201201
// computing sampled_masks
202202
RoIAlignForward(gt_masks, rois, matches, total_out_el,
203203
num_classes, gt_height, gt_width, mask_size, mask_size,
@@ -221,8 +221,8 @@ __global__ void MRCNNTargetKernel(const DType *rois,
221221
}
222222

223223
template<>
224-
void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
225-
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
224+
void MRCNNMaskTargetRun<gpu>(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
225+
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
226226
const int block_dim_size = kMaxThreadsPerBlock;
227227
using namespace mxnet_op;
228228
using mshadow::Tensor;
@@ -248,31 +248,31 @@ void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob>
248248
dim3 dimGrid = dim3(CUDA_GET_BLOCKS(num_el));
249249
dim3 dimBlock = dim3(block_dim_size);
250250

251-
MRCNNTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
251+
MRCNNMaskTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
252252
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
253253
out_masks.dptr_, out_mask_cls.dptr_,
254254
num_el, batch_size, param.num_classes, param.num_rois,
255255
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
256-
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel);
256+
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
257257
});
258258
}
259259

260-
DMLC_REGISTER_PARAMETER(MRCNNTargetParam);
260+
DMLC_REGISTER_PARAMETER(MRCNNMaskTargetParam);
261261

262-
NNVM_REGISTER_OP(mrcnn_target)
262+
NNVM_REGISTER_OP(_contrib_mrcnn_mask_target)
263263
.describe("Generate mask targets for Mask-RCNN.")
264264
.set_num_inputs(4)
265265
.set_num_outputs(2)
266-
.set_attr_parser(ParamParser<MRCNNTargetParam>)
267-
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNTargetShape)
268-
.set_attr<nnvm::FInferType>("FInferType", MRCNNTargetType)
269-
.set_attr<FCompute>("FCompute<gpu>", MRCNNTargetCompute<gpu>)
266+
.set_attr_parser(ParamParser<MRCNNMaskTargetParam>)
267+
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNMaskTargetShape)
268+
.set_attr<nnvm::FInferType>("FInferType", MRCNNMaskTargetType)
269+
.set_attr<FCompute>("FCompute<gpu>", MRCNNMaskTargetCompute<gpu>)
270270
.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 3D array")
271271
.add_argument("gt_masks", "NDArray-or-Symbol", "Input masks of full image size, a 4D array")
272272
.add_argument("matches", "NDArray-or-Symbol", "Index to a gt_mask, a 2D array")
273273
.add_argument("cls_targets", "NDArray-or-Symbol",
274274
"Value [0, num_class), excluding background class, a 2D array")
275-
.add_arguments(MRCNNTargetParam::__FIELDS__());
275+
.add_arguments(MRCNNMaskTargetParam::__FIELDS__());
276276

277277
} // namespace op
278278
} // namespace mxnet

tests/python/unittest/test_contrib_operator.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import itertools
2424
from numpy.testing import assert_allclose, assert_array_equal
2525
from mxnet.test_utils import *
26+
from common import with_seed
2627
import unittest
2728

2829
def test_box_nms_op():
@@ -351,6 +352,63 @@ def test_box_decode_op():
351352
assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \
352353
[0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5)
353354

355+
@with_seed()
356+
def test_op_mrcnn_mask_target():
357+
if default_context().device_type != 'gpu':
358+
return
359+
360+
num_rois = 2
361+
num_classes = 4
362+
mask_size = 3
363+
ctx = mx.gpu(0)
364+
# (B, N, 4)
365+
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
366+
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
367+
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)
368+
369+
# (B, N)
370+
matches = mx.nd.array([[2, 0]], ctx=ctx)
371+
# (B, N)
372+
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)
373+
374+
mask_targets, mask_cls = mx.nd.contrib.mrcnn_mask_target(rois, gt_masks, matches, cls_targets,
375+
num_rois=num_rois,
376+
num_classes=num_classes,
377+
mask_size=mask_size)
378+
379+
# Ground truth outputs were generated with GluonCV's target generator
380+
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
381+
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
382+
[2204.0667 , 2204.4 , 2204.7334 ],
383+
[2214.7334 , 2215.0667 , 2215.4 ]],
384+
[[2193.4 , 2193.7332 , 2194.0667 ],
385+
[2204.0667 , 2204.4 , 2204.7334 ],
386+
[2214.7334 , 2215.0667 , 2215.4 ]],
387+
[[2193.4 , 2193.7332 , 2194.0667 ],
388+
[2204.0667 , 2204.4 , 2204.7334 ],
389+
[2214.7334 , 2215.0667 , 2215.4 ]],
390+
[[2193.4 , 2193.7332 , 2194.0667 ],
391+
[2204.0667 , 2204.4 , 2204.7334 ],
392+
[2214.7334 , 2215.0667 , 2215.4 ]]],
393+
[[[ 185. , 185.33334, 185.66667],
394+
[ 195.66667, 196.00002, 196.33334],
395+
[ 206.33333, 206.66666, 207. ]],
396+
[[ 185. , 185.33334, 185.66667],
397+
[ 195.66667, 196.00002, 196.33334],
398+
[ 206.33333, 206.66666, 207. ]],
399+
[[ 185. , 185.33334, 185.66667],
400+
[ 195.66667, 196.00002, 196.33334],
401+
[ 206.33333, 206.66666, 207. ]],
402+
[[ 185. , 185.33334, 185.66667],
403+
[ 195.66667, 196.00002, 196.33334],
404+
[ 206.33333, 206.66666, 207. ]]]]])
405+
406+
gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
407+
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))
408+
409+
assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
410+
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())
411+
354412
if __name__ == '__main__':
355413
import nose
356414
nose.runmodule()

tests/python/unittest/test_operator.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8639,63 +8639,6 @@ def test_rroi_align_value(sampling_ratio=-1):
86398639
test_rroi_align_value()
86408640
test_rroi_align_value(sampling_ratio=2)
86418641

8642-
@with_seed()
8643-
def test_op_mrcnn_target():
8644-
if default_context().device_type != 'gpu':
8645-
return
8646-
8647-
num_rois = 2
8648-
num_classes = 4
8649-
mask_size = 3
8650-
ctx = mx.gpu(0)
8651-
# (B, N, 4)
8652-
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
8653-
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
8654-
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)
8655-
8656-
# (B, N)
8657-
matches = mx.nd.array([[2, 0]], ctx=ctx)
8658-
# (B, N)
8659-
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)
8660-
8661-
mask_targets, mask_cls = mx.nd.mrcnn_target(rois, gt_masks, matches, cls_targets,
8662-
num_rois=num_rois,
8663-
num_classes=num_classes,
8664-
mask_size=mask_size)
8665-
8666-
# Ground truth outputs were generated with GluonCV's target generator
8667-
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
8668-
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
8669-
[2204.0667 , 2204.4 , 2204.7334 ],
8670-
[2214.7334 , 2215.0667 , 2215.4 ]],
8671-
[[2193.4 , 2193.7332 , 2194.0667 ],
8672-
[2204.0667 , 2204.4 , 2204.7334 ],
8673-
[2214.7334 , 2215.0667 , 2215.4 ]],
8674-
[[2193.4 , 2193.7332 , 2194.0667 ],
8675-
[2204.0667 , 2204.4 , 2204.7334 ],
8676-
[2214.7334 , 2215.0667 , 2215.4 ]],
8677-
[[2193.4 , 2193.7332 , 2194.0667 ],
8678-
[2204.0667 , 2204.4 , 2204.7334 ],
8679-
[2214.7334 , 2215.0667 , 2215.4 ]]],
8680-
[[[ 185. , 185.33334, 185.66667],
8681-
[ 195.66667, 196.00002, 196.33334],
8682-
[ 206.33333, 206.66666, 207. ]],
8683-
[[ 185. , 185.33334, 185.66667],
8684-
[ 195.66667, 196.00002, 196.33334],
8685-
[ 206.33333, 206.66666, 207. ]],
8686-
[[ 185. , 185.33334, 185.66667],
8687-
[ 195.66667, 196.00002, 196.33334],
8688-
[ 206.33333, 206.66666, 207. ]],
8689-
[[ 185. , 185.33334, 185.66667],
8690-
[ 195.66667, 196.00002, 196.33334],
8691-
[ 206.33333, 206.66666, 207. ]]]]])
8692-
8693-
gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
8694-
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))
8695-
8696-
assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
8697-
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())
8698-
86998642
@with_seed()
87008643
def test_diag():
87018644

0 commit comments

Comments
 (0)