Skip to content

Commit 6186aaf

Browse files
Jerryzcnaaronmarkham
authored andcommitted
New ops for RCNN + old ops improvements for RCNN (apache#16215)
* box encode and box decode seems to work now bug fix use template to get rid of if statement * roi align ignore batchid < 0 * amp multicast support casting to narrowest type cpplint cpplint * add unittest * address comments * fix amp_multicast
1 parent 02ed060 commit 6186aaf

File tree

7 files changed

+368
-4
lines changed

7 files changed

+368
-4
lines changed

src/operator/contrib/bounding_box-inl.h

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,284 @@ void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs,
787787
});
788788
}
789789

790+
791+
inline bool BoxEncodeShape(const nnvm::NodeAttrs& attrs,
792+
mxnet::ShapeVector *in_attrs,
793+
mxnet::ShapeVector *out_attrs) {
794+
CHECK_EQ(in_attrs->size(), 6U);
795+
CHECK_EQ(out_attrs->size(), 2U);
796+
mxnet::TShape& sshape = (*in_attrs)[0];
797+
mxnet::TShape& mshape = (*in_attrs)[1];
798+
mxnet::TShape& ashape = (*in_attrs)[2];
799+
mxnet::TShape& rshape = (*in_attrs)[3];
800+
801+
CHECK_EQ(sshape.ndim(), 2)
802+
<< "samples shape must have dim == 2, "
803+
<< sshape.ndim() << " provided";
804+
805+
CHECK_GE(mshape.ndim(), 2)
806+
<< "matches shape must have dim == 2, "
807+
<< mshape.ndim() << " provided";
808+
809+
CHECK_GE(ashape.ndim(), 3)
810+
<< "matches shape must have dim == 3, "
811+
<< ashape.ndim() << " provided";
812+
int ldim = ashape[ashape.ndim() - 1];
813+
CHECK_EQ(ldim, 4)
814+
<< "last dimension of anchors must be 4, "
815+
<< ldim << " provided";
816+
817+
CHECK_GE(rshape.ndim(), 3)
818+
<< "refs shape must have dim == 3, "
819+
<< ashape.ndim() << " provided";
820+
ldim = rshape[rshape.ndim() - 1];
821+
CHECK_EQ(ldim, 4)
822+
<< "last dimension of anchors must be 4, "
823+
<< ldim << " provided";
824+
825+
// asign input shape
826+
SHAPE_ASSIGN_CHECK(*in_attrs, 4, mshadow::Shape1(4));
827+
SHAPE_ASSIGN_CHECK(*in_attrs, 5, mshadow::Shape1(4));
828+
829+
// assign output shape
830+
mxnet::TShape oshape = ashape;
831+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
832+
SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
833+
return shape_is_known(oshape);
834+
}
835+
836+
struct box_encode {
837+
template<typename DType>
838+
MSHADOW_XINLINE static void Map(index_t i, DType *out_targets, DType *out_masks,
839+
const DType *samples, const DType *matches,
840+
const DType *anchors, const DType *refs,
841+
const DType *means, const DType *stds,
842+
const int m, const int n) {
843+
index_t j = i / n;
844+
index_t match = matches[i];
845+
// xmin: 0, ymin:1, xmax: 2, ymax: 3
846+
// x:0, y:1, w:2, h:3
847+
index_t ref_index = (j * m + match) * 4;
848+
DType ref_xmin = refs[ref_index + 0];
849+
DType ref_ymin = refs[ref_index + 1];
850+
DType ref_width = refs[ref_index + 2] - ref_xmin;
851+
DType ref_height = refs[ref_index + 3] - ref_ymin;
852+
DType ref_x = ref_xmin + ref_width * 0.5;
853+
DType ref_y = ref_ymin + ref_height * 0.5;
854+
index_t a_index = i * 4;
855+
DType a_xmin = anchors[a_index + 0];
856+
DType a_ymin = anchors[a_index + 1];
857+
DType a_width = anchors[a_index + 2] - a_xmin;
858+
DType a_height = anchors[a_index + 3] - a_ymin;
859+
DType a_x = a_xmin + a_width * 0.5;
860+
DType a_y = a_ymin + a_height * 0.5;
861+
DType valid = samples[i] > 0.5 ? 1.0 : 0.0;
862+
out_masks[a_index + 0] = valid;
863+
out_masks[a_index + 1] = valid;
864+
out_masks[a_index + 2] = valid;
865+
out_masks[a_index + 3] = valid;
866+
out_targets[a_index + 0] = valid > static_cast<DType>(0.5) ?
867+
((ref_x - a_x) / a_width - static_cast<DType>(means[0])) /
868+
static_cast<DType>(stds[0]) : static_cast<DType>(0.0);
869+
out_targets[a_index + 1] = valid > static_cast<DType>(0.5) ?
870+
((ref_y - a_y) / a_height - static_cast<DType>(means[1])) /
871+
static_cast<DType>(stds[1]) : static_cast<DType>(0.0);
872+
out_targets[a_index + 2] = valid > static_cast<DType>(0.5) ?
873+
(log(ref_width / a_width) - static_cast<DType>(means[2])) /
874+
static_cast<DType>(stds[2]) : static_cast<DType>(0.0);
875+
out_targets[a_index + 3] = valid > static_cast<DType>(0.5) ?
876+
(log(ref_height / a_height) - static_cast<DType>(means[3])) /
877+
static_cast<DType>(stds[3]) : static_cast<DType>(0.0);
878+
}
879+
};
880+
881+
template<typename xpu>
882+
void BoxEncodeForward(const nnvm::NodeAttrs& attrs,
883+
const OpContext& ctx,
884+
const std::vector<TBlob>& inputs,
885+
const std::vector<OpReqType>& req,
886+
const std::vector<TBlob>& outputs) {
887+
using namespace mshadow;
888+
using namespace mshadow::expr;
889+
using namespace mxnet_op;
890+
CHECK_EQ(inputs.size(), 6U);
891+
CHECK_EQ(outputs.size(), 2U);
892+
Stream<xpu> *s = ctx.get_stream<xpu>();
893+
// samples, matches, anchors, refs, means, stds
894+
mxnet::TShape anchor_shape = inputs[2].shape_;
895+
int loop_size = anchor_shape.ProdShape(0, 2);
896+
int b = anchor_shape[0];
897+
int n = anchor_shape[1];
898+
int m = inputs[3].shape_[1];
899+
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
900+
Tensor<xpu, 2, DType> samples = inputs[0]
901+
.get_with_shape<xpu, 2, DType>(Shape2(b, n), s);
902+
Tensor<xpu, 2, DType> matches = inputs[1]
903+
.get_with_shape<xpu, 2, DType>(Shape2(b, n), s);
904+
Tensor<xpu, 3, DType> anchors = inputs[2]
905+
.get_with_shape<xpu, 3, DType>(Shape3(b, n, 4), s);
906+
Tensor<xpu, 3, DType> refs = inputs[3]
907+
.get_with_shape<xpu, 3, DType>(Shape3(b, m, 4), s);
908+
Tensor<xpu, 1, DType> means = inputs[4]
909+
.get_with_shape<xpu, 1, DType>(Shape1(4), s);
910+
Tensor<xpu, 1, DType> stds = inputs[5]
911+
.get_with_shape<xpu, 1, DType>(Shape1(4), s);
912+
Tensor<xpu, 3, DType> out_targets = outputs[0]
913+
.get_with_shape<xpu, 3, DType>(Shape3(b, n, 4), s);
914+
Tensor<xpu, 3, DType> out_masks = outputs[1]
915+
.get_with_shape<xpu, 3, DType>(Shape3(b, n, 4), s);
916+
917+
Kernel<box_encode, xpu>::Launch(s, loop_size, out_targets.dptr_,
918+
out_masks.dptr_, samples.dptr_, matches.dptr_, anchors.dptr_,
919+
refs.dptr_, means.dptr_, stds.dptr_, m, n);
920+
});
921+
}
922+
923+
struct BoxDecodeParam : public dmlc::Parameter<BoxDecodeParam> {
924+
float std0;
925+
float std1;
926+
float std2;
927+
float std3;
928+
float clip;
929+
int format;
930+
DMLC_DECLARE_PARAMETER(BoxDecodeParam) {
931+
DMLC_DECLARE_FIELD(std0).set_default(1.0)
932+
.describe("value to be divided from the 1st encoded values");
933+
DMLC_DECLARE_FIELD(std1).set_default(1.0)
934+
.describe("value to be divided from the 2nd encoded values");
935+
DMLC_DECLARE_FIELD(std2).set_default(1.0)
936+
.describe("value to be divided from the 3rd encoded values");
937+
DMLC_DECLARE_FIELD(std3).set_default(1.0)
938+
.describe("value to be divided from the 4th encoded values");
939+
DMLC_DECLARE_FIELD(clip).set_default(-1.0)
940+
.describe("If larger than 0, bounding box target will be clipped to this value.");
941+
DMLC_DECLARE_FIELD(format).set_default(box_common_enum::kCenter)
942+
.add_enum("corner", box_common_enum::kCorner)
943+
.add_enum("center", box_common_enum::kCenter)
944+
.describe("The box encoding type. \n"
945+
" \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax],"
946+
" \"center\" means boxes are encodes as [x, y, width, height].");
947+
}
948+
}; // BoxDecodeParam
949+
950+
inline bool BoxDecodeShape(const nnvm::NodeAttrs& attrs,
951+
mxnet::ShapeVector *in_attrs,
952+
mxnet::ShapeVector *out_attrs) {
953+
CHECK_EQ(in_attrs->size(), 2U);
954+
CHECK_EQ(out_attrs->size(), 1U);
955+
mxnet::TShape& dshape = (*in_attrs)[0];
956+
mxnet::TShape& ashape = (*in_attrs)[1];
957+
958+
CHECK_EQ(dshape.ndim(), 3)
959+
<< "data shape must have dim == 3, "
960+
<< dshape.ndim() << " provided";
961+
int ldim = dshape[dshape.ndim() - 1];
962+
CHECK_EQ(ldim, 4)
963+
<< "last dimension of data must be 4, "
964+
<< ldim << " provided";
965+
966+
CHECK_GE(ashape.ndim(), 3)
967+
<< "anchors shape must have dim == 3, "
968+
<< ashape.ndim() << " provided";
969+
ldim = ashape[ashape.ndim() - 1];
970+
CHECK_EQ(ldim, 4)
971+
<< "last dimension of anchors must be 4, "
972+
<< ldim << " provided";
973+
974+
// assign output shape
975+
mxnet::TShape oshape = dshape;
976+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
977+
return shape_is_known(oshape);
978+
}
979+
980+
template<int anchor_encode, bool has_clip>
981+
struct box_decode {
982+
template<typename DType>
983+
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *x,
984+
const DType *anchors, const DType std0,
985+
const DType std1, const DType std2,
986+
const DType std3, const DType clip,
987+
const int n) {
988+
index_t index = i * 4;
989+
index_t a_index = (i % n) * 4;
990+
DType a_x = anchors[a_index + 0];
991+
DType a_y = anchors[a_index + 1];
992+
DType a_width = anchors[a_index + 2];
993+
DType a_height = anchors[a_index + 3];
994+
if (box_common_enum::kCorner == anchor_encode) {
995+
// a_x = xmin, a_y = ymin, a_width = xmax, a_height = ymax
996+
a_width = a_width - a_x;
997+
a_height = a_height - a_y;
998+
a_x = a_x + a_width * 0.5;
999+
a_y = a_y + a_height * 0.5;
1000+
}
1001+
DType ox = x[index + 0] * std0 * a_width + a_x;
1002+
DType oy = x[index + 1] * std1 * a_height + a_y;
1003+
DType dw = x[index + 2] * std2;
1004+
DType dh = x[index + 3] * std3;
1005+
if (has_clip) {
1006+
dw = dw < clip ? dw : clip;
1007+
dh = dh < clip ? dh : clip;
1008+
}
1009+
dw = exp(dw);
1010+
dh = exp(dh);
1011+
DType ow = dw * a_width * 0.5;
1012+
DType oh = dh * a_height * 0.5;
1013+
out[index + 0] = ox - ow;
1014+
out[index + 1] = oy - oh;
1015+
out[index + 2] = ox + ow;
1016+
out[index + 3] = oy + oh;
1017+
}
1018+
};
1019+
1020+
template<typename xpu>
1021+
void BoxDecodeForward(const nnvm::NodeAttrs& attrs,
1022+
const OpContext& ctx,
1023+
const std::vector<TBlob>& inputs,
1024+
const std::vector<OpReqType>& req,
1025+
const std::vector<TBlob>& outputs) {
1026+
using namespace mshadow;
1027+
using namespace mshadow::expr;
1028+
using namespace mxnet_op;
1029+
CHECK_EQ(inputs.size(), 2U);
1030+
CHECK_EQ(outputs.size(), 1U);
1031+
Stream<xpu> *s = ctx.get_stream<xpu>();
1032+
mxnet::TShape x_shape = inputs[0].shape_;
1033+
int b = x_shape[0];
1034+
int n = x_shape[1];
1035+
int loop_size = b * n;
1036+
const BoxDecodeParam& param = nnvm::get<BoxDecodeParam>(attrs.parsed);
1037+
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
1038+
Tensor<xpu, 3, DType> data = inputs[0]
1039+
.get_with_shape<xpu, 3, DType>(Shape3(b, n, 4), s);
1040+
Tensor<xpu, 3, DType> anchors = inputs[1]
1041+
.get_with_shape<xpu, 3, DType>(Shape3(1, n, 4), s);
1042+
Tensor<xpu, 3, DType> out = outputs[0]
1043+
.get_with_shape<xpu, 3, DType>(Shape3(b, n, 4), s);
1044+
if (box_common_enum::kCorner == param.format && param.clip > 0.0) {
1045+
Kernel<box_decode<box_common_enum::kCorner, true>, xpu>::Launch(s, loop_size,
1046+
out.dptr_, data.dptr_, anchors.dptr_, static_cast<DType>(param.std0),
1047+
static_cast<DType>(param.std1), static_cast<DType>(param.std2),
1048+
static_cast<DType>(param.std3), static_cast<DType>(param.clip), n);
1049+
} else if (box_common_enum::kCenter == param.format && param.clip > 0.0) {
1050+
Kernel<box_decode<box_common_enum::kCenter, true>, xpu>::Launch(s, loop_size,
1051+
out.dptr_, data.dptr_, anchors.dptr_, static_cast<DType>(param.std0),
1052+
static_cast<DType>(param.std1), static_cast<DType>(param.std2),
1053+
static_cast<DType>(param.std3), static_cast<DType>(param.clip), n);
1054+
} else if (box_common_enum::kCorner == param.format && param.clip <= 0.0) {
1055+
Kernel<box_decode<box_common_enum::kCorner, false>, xpu>::Launch(s, loop_size,
1056+
out.dptr_, data.dptr_, anchors.dptr_, static_cast<DType>(param.std0),
1057+
static_cast<DType>(param.std1), static_cast<DType>(param.std2),
1058+
static_cast<DType>(param.std3), static_cast<DType>(param.clip), n);
1059+
} else {
1060+
Kernel<box_decode<box_common_enum::kCenter, false>, xpu>::Launch(s, loop_size,
1061+
out.dptr_, data.dptr_, anchors.dptr_, static_cast<DType>(param.std0),
1062+
static_cast<DType>(param.std1), static_cast<DType>(param.std2),
1063+
static_cast<DType>(param.std3), static_cast<DType>(param.clip), n);
1064+
}
1065+
});
1066+
}
1067+
7901068
} // namespace op
7911069
} // namespace mxnet
7921070

src/operator/contrib/bounding_box.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace op {
3232
DMLC_REGISTER_PARAMETER(BoxNMSParam);
3333
DMLC_REGISTER_PARAMETER(BoxOverlapParam);
3434
DMLC_REGISTER_PARAMETER(BipartiteMatchingParam);
35+
DMLC_REGISTER_PARAMETER(BoxDecodeParam);
3536

3637
NNVM_REGISTER_OP(_contrib_box_nms)
3738
.add_alias("_contrib_box_non_maximum_suppression")
@@ -201,5 +202,47 @@ NNVM_REGISTER_OP(_backward_contrib_bipartite_matching)
201202
.set_attr<FCompute>("FCompute<cpu>", BipartiteMatchingBackward<cpu>)
202203
.add_arguments(BipartiteMatchingParam::__FIELDS__());
203204

205+
NNVM_REGISTER_OP(_contrib_box_encode)
206+
.describe(R"doc(Encode bounding boxes training target with normalized center offsets.
207+
Input bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.) array
208+
)doc" ADD_FILELINE)
209+
.set_num_inputs(6)
210+
.set_num_outputs(2)
211+
.set_attr<nnvm::FListInputNames>("FListInputNames",
212+
[](const NodeAttrs& attrs) {
213+
return std::vector<std::string>{"samples", "matches", "anchors", "refs", "means", "stds"};
214+
})
215+
.set_attr<mxnet::FInferShape>("FInferShape", BoxEncodeShape)
216+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<6, 2>)
217+
.set_attr<FCompute>("FCompute<cpu>", BoxEncodeForward<cpu>)
218+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
219+
.add_argument("samples", "NDArray-or-Symbol", "(B, N) value +1 (positive), -1 (negative), "
220+
"0 (ignore)")
221+
.add_argument("matches", "NDArray-or-Symbol", "(B, N) value range [0, M)")
222+
.add_argument("anchors", "NDArray-or-Symbol", "(B, N, 4) encoded in corner")
223+
.add_argument("refs", "NDArray-or-Symbol", "(B, M, 4) encoded in corner")
224+
.add_argument("means", "NDArray-or-Symbol", "(4,) Mean value to be subtracted from encoded values")
225+
.add_argument("stds", "NDArray-or-Symbol", "(4,) Std value to be divided from encoded values");
226+
227+
NNVM_REGISTER_OP(_contrib_box_decode)
228+
.describe(R"doc(Decode bounding boxes training target with normalized center offsets.
229+
Input bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`
230+
or center type: `x, y, width, height.) array
231+
)doc" ADD_FILELINE)
232+
.set_num_inputs(2)
233+
.set_num_outputs(1)
234+
.set_attr_parser(ParamParser<BoxDecodeParam>)
235+
.set_attr<nnvm::FListInputNames>("FListInputNames",
236+
[](const NodeAttrs& attrs) {
237+
return std::vector<std::string>{"data", "anchors"};
238+
})
239+
.set_attr<mxnet::FInferShape>("FInferShape", BoxDecodeShape)
240+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
241+
.set_attr<FCompute>("FCompute<cpu>", BoxDecodeForward<cpu>)
242+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
243+
.add_argument("data", "NDArray-or-Symbol", "(B, N, 4) predicted bbox offset")
244+
.add_argument("anchors", "NDArray-or-Symbol", "(1, N, 4) encoded in corner or center")
245+
.add_arguments(BoxDecodeParam::__FIELDS__());
246+
204247
} // namespace op
205248
} // namespace mxnet

src/operator/contrib/bounding_box.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,12 @@ NNVM_REGISTER_OP(_contrib_bipartite_matching)
4747

4848
NNVM_REGISTER_OP(_backward_contrib_bipartite_matching)
4949
.set_attr<FCompute>("FCompute<gpu>", BipartiteMatchingBackward<gpu>);
50+
51+
NNVM_REGISTER_OP(_contrib_box_encode)
52+
.set_attr<FCompute>("FCompute<gpu>", BoxEncodeForward<gpu>);
53+
54+
NNVM_REGISTER_OP(_contrib_box_decode)
55+
.set_attr<FCompute>("FCompute<gpu>", BoxDecodeForward<gpu>);
56+
5057
} // namespace op
5158
} // namespace mxnet

src/operator/contrib/roi_align.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
167167
int roi_batch_ind = 0;
168168
if (roi_cols == 5) {
169169
roi_batch_ind = offset_bottom_rois[0];
170+
if (roi_batch_ind < 0) {
171+
top_data[n] = 0;
172+
continue;
173+
}
170174
offset_bottom_rois++;
171175
}
172176

@@ -340,6 +344,7 @@ void ROIAlignBackward(
340344
int roi_batch_ind = 0;
341345
if (rois_cols == 5) {
342346
roi_batch_ind = offset_bottom_rois[0];
347+
if (roi_batch_ind < 0) continue;
343348
offset_bottom_rois++;
344349
}
345350

@@ -520,7 +525,8 @@ NNVM_REGISTER_OP(_contrib_ROIAlign)
520525
.describe(R"code(
521526
This operator takes a 4D feature map as an input array and region proposals as `rois`,
522527
then align the feature map over sub-regions of input and produces a fixed-sized output array.
523-
This operator is typically used in Faster R-CNN & Mask R-CNN networks.
528+
This operator is typically used in Faster R-CNN & Mask R-CNN networks. If roi batchid is less
529+
than 0, it will be ignored, and the corresponding output will be set to 0.
524530
525531
Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning
526532
the extracted features with the input. RoIAlign computes the value of each sampling point
@@ -594,7 +600,8 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
594600
return MakeGradNode("_backward_ROIAlign", n, heads, n->attrs.dict);
595601
})
596602
.add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps")
597-
.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array")
603+
.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array, "
604+
"if batchid is less than 0, it will be ignored.")
598605
.add_arguments(ROIAlignParam::__FIELDS__());
599606

600607

0 commit comments

Comments
 (0)