Skip to content

Commit ec5bcec

Browse files
authored
remove negative samples during RCNN box and mask heads training (#981)
* use new ops in apache/mxnet#16215 * sampler wrap around for last part * reduce mask head num samples * rm reshape fix bugs rm redundant comment * revert rpn_channel revert rpn_channel revert some change fix typo typo fix typo * fix docs fix fix fix fix fix fix fix docs fix docs docs docs * fix tutorial * fix log * fix learning rate
1 parent e1680e3 commit ec5bcec

File tree

17 files changed

+284
-160
lines changed

17 files changed

+284
-160
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ The following commands install the stable version of GluonCV and MXNet:
5353
```bash
5454
pip install gluoncv --upgrade
5555
pip install mxnet-mkl --upgrade
56-
# if cuda 9.2 is installed
57-
pip install mxnet-cu92mkl --upgrade
56+
# if cuda 10.1 is installed
57+
pip install mxnet-cu101mkl --upgrade
5858
```
5959

6060
**The latest stable version of GluonCV is 0.4 and depends on mxnet >= 1.4.0**
@@ -66,8 +66,8 @@ You may get access to latest features and bug fixes with the following commands
6666
```bash
6767
pip install gluoncv --pre --upgrade
6868
pip install mxnet-mkl --pre --upgrade
69-
# if cuda 9.2 is installed
70-
pip install mxnet-cu92mkl --pre --upgrade
69+
# if cuda 10.1 is installed
70+
pip install mxnet-cu101mkl --pre --upgrade
7171
```
7272

7373
There are multiple versions of MXNet pre-built package available. Please refer to [mxnet packages](https://gluon-crash-course.mxnet.io/mxnet_packages.html) if you need more details about MXNet versions.

docs/tutorials/detection/train_faster_rcnn_voc.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@
194194
with autograd.train_mode():
195195
# this time we need ground-truth to generate high quality roi proposals during training
196196
gt_box = mx.nd.zeros(shape=(1, 1, 4))
197-
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(x, gt_box)
197+
gt_label = mx.nd.zeros(shape=(1, 1, 1))
198+
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
199+
box_targets, box_masks, _ = net(x, gt_box, gt_label)
198200

199201
##############################################################################
200202
# In training mode, Faster-RCNN returns a lot of intermediate values, which we require to train in an end-to-end flavor,
@@ -272,11 +274,8 @@
272274
gt_label = label[:, :, 4:5]
273275
gt_box = label[:, :, :4]
274276
# network forward
275-
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(
276-
data.expand_dims(0), gt_box)
277-
# generate targets for rcnn
278-
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
279-
gt_label, gt_box)
277+
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
278+
box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)
280279

281280
print('data:', data.shape)
282281
# box and class labels
@@ -302,11 +301,8 @@
302301
gt_label = label[:, :, 4:5]
303302
gt_box = label[:, :, :4]
304303
# network forward
305-
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(
306-
data.expand_dims(0), gt_box)
307-
# generate targets for rcnn
308-
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
309-
gt_label, gt_box)
304+
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
305+
box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)
310306

311307
# losses of rpn
312308
rpn_score = rpn_score.squeeze(axis=-1)

docs/tutorials/instance/train_mask_rcnn_coco.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@
214214
with autograd.train_mode():
215215
# this time we need ground-truth to generate high quality roi proposals during training
216216
gt_box = mx.nd.zeros(shape=(1, 1, 4))
217-
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(x,
218-
gt_box)
217+
gt_label = mx.nd.zeros(shape=(1, 1, 1))
218+
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
219+
cls_targets, box_targets, box_masks, indices = net(x, gt_box, gt_label)
219220

220221
##########################################################
221222
# Training losses
@@ -260,14 +261,23 @@
260261
gt_label = label[:, :, 4:5]
261262
gt_box = label[:, :, :4]
262263
# network forward
263-
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = \
264-
net(data.expand_dims(0), gt_box)
265-
# generate targets for rcnn
266-
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
267-
gt_label, gt_box)
264+
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
265+
cls_targets, box_targets, box_masks, indices = \
266+
net(data.expand_dims(0), gt_box, gt_label)
267+
268268
# generate targets for mask head
269+
roi = mx.nd.concat(
270+
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
271+
.reshape((indices.shape[0], -1, 4))
272+
m_cls_targets = mx.nd.concat(
273+
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
274+
.reshape((indices.shape[0], -1))
275+
matches = mx.nd.concat(
276+
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
277+
.reshape((indices.shape[0], -1))
269278
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
270-
cls_targets)
279+
m_cls_targets)
280+
271281
print('data:', data.shape)
272282
# box and class labels
273283
print('box:', gt_box.shape)
@@ -299,14 +309,22 @@
299309
gt_label = label[:, :, 4:5]
300310
gt_box = label[:, :, :4]
301311
# network forward
302-
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = \
303-
net(data.expand_dims(0), gt_box)
304-
# generate targets for rcnn
305-
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
306-
gt_label, gt_box)
312+
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors, \
313+
cls_targets, box_targets, box_masks, indices = \
314+
net(data.expand_dims(0), gt_box, gt_label)
315+
307316
# generate targets for mask head
317+
roi = mx.nd.concat(
318+
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
319+
.reshape((indices.shape[0], -1, 4))
320+
m_cls_targets = mx.nd.concat(
321+
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
322+
.reshape((indices.shape[0], -1))
323+
matches = mx.nd.concat(
324+
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
325+
.reshape((indices.shape[0], -1))
308326
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
309-
cls_targets)
327+
m_cls_targets)
310328

311329
# losses of rpn
312330
rpn_score = rpn_score.squeeze(axis=-1)

gluoncv/model_zoo/faster_rcnn/faster_rcnn.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def __init__(self, features, top_features, classes, box_features=None,
198198
self._batch_size = per_device_batch_size
199199
self._num_sample = num_sample
200200
self._rpn_test_post_nms = rpn_test_post_nms
201-
self._target_generator = RCNNTargetGenerator(self.num_class)
201+
self._target_generator = RCNNTargetGenerator(self.num_class, int(num_sample * pos_ratio),
202+
self._batch_size)
202203
self._additional_output = additional_output
203204
with self.name_scope():
204205
self.rpn = RPN(
@@ -207,7 +208,7 @@ def __init__(self, features, top_features, classes, box_features=None,
207208
clip=clip, nms_thresh=rpn_nms_thresh, train_pre_nms=rpn_train_pre_nms,
208209
train_post_nms=rpn_train_post_nms, test_pre_nms=rpn_test_pre_nms,
209210
test_post_nms=rpn_test_post_nms, min_size=rpn_min_size,
210-
multi_level=self.num_stages > 1)
211+
multi_level=self.num_stages > 1, per_level_nms=False)
211212
self.sampler = RCNNTargetSampler(num_image=self._batch_size,
212213
num_proposal=rpn_train_post_nms, num_sample=num_sample,
213214
pos_iou_thresh=pos_iou_thresh, pos_ratio=pos_ratio,
@@ -252,7 +253,8 @@ def reset_class(self, classes, reuse_weights=None):
252253
253254
"""
254255
super(FasterRCNN, self).reset_class(classes, reuse_weights)
255-
self._target_generator = RCNNTargetGenerator(self.num_class)
256+
self._target_generator = RCNNTargetGenerator(self.num_class, self.sampler._max_pos,
257+
self._batch_size)
256258

257259
def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode='align',
258260
roi_canonical_scale=224.0, eps=1e-6):
@@ -292,16 +294,25 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
292294
# rpn_rois = F.take(rpn_rois, roi_level_sorted_args, axis=0)
293295
pooled_roi_feats = []
294296
for i, l in enumerate(range(self._min_stage, max_stage + 1)):
295-
# Pool features with all rois first, and then set invalid pooled features to zero,
296-
# at last ele-wise add together to aggregate all features.
297297
if roi_mode == 'pool':
298+
# Pool features with all rois first, and then set invalid pooled features to zero,
299+
# at last ele-wise add together to aggregate all features.
298300
pooled_feature = F.ROIPooling(features[i], rpn_rois, roi_size, 1. / strides[i])
301+
pooled_feature = F.where(roi_level == l, pooled_feature,
302+
F.zeros_like(pooled_feature))
299303
elif roi_mode == 'align':
300-
pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size,
301-
1. / strides[i], sample_ratio=2)
304+
if 'box_encode' in F.contrib.__dict__ and 'box_decode' in F.contrib.__dict__:
305+
# TODO(jerryzcn): clean this up for once mx 1.6 is released.
306+
masked_rpn_rois = F.where(roi_level == l, rpn_rois, F.ones_like(rpn_rois) * -1.)
307+
pooled_feature = F.contrib.ROIAlign(features[i], masked_rpn_rois, roi_size,
308+
1. / strides[i], sample_ratio=2)
309+
else:
310+
pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size,
311+
1. / strides[i], sample_ratio=2)
312+
pooled_feature = F.where(roi_level == l, pooled_feature,
313+
F.zeros_like(pooled_feature))
302314
else:
303315
raise ValueError("Invalid roi mode: {}".format(roi_mode))
304-
pooled_feature = F.where(roi_level == l, pooled_feature, F.zeros_like(pooled_feature))
305316
pooled_roi_feats.append(pooled_feature)
306317
# Ele-wise add to aggregate all pooled features
307318
pooled_roi_feats = F.ElementWiseSum(*pooled_roi_feats)
@@ -312,7 +323,7 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
312323
return pooled_roi_feats
313324

314325
# pylint: disable=arguments-differ
315-
def hybrid_forward(self, F, x, gt_box=None):
326+
def hybrid_forward(self, F, x, gt_box=None, gt_label=None):
316327
"""Forward Faster-RCNN network.
317328
318329
The behavior during training and inference is different.
@@ -322,7 +333,9 @@ def hybrid_forward(self, F, x, gt_box=None):
322333
x : mxnet.nd.NDArray or mxnet.symbol
323334
The network input tensor.
324335
gt_box : type, only required during training
325-
The ground-truth bbox tensor with shape (1, N, 4).
336+
The ground-truth bbox tensor with shape (B, N, 4).
337+
gt_label : type, only required during training
338+
The ground-truth label tensor with shape (B, 1, 4).
326339
327340
Returns
328341
-------
@@ -385,20 +398,29 @@ def _split(x, axis, num_outputs, squeeze_axis):
385398
else:
386399
box_feat = self.box_features(top_feat)
387400
cls_pred = self.class_predictor(box_feat)
388-
box_pred = self.box_predictor(box_feat)
389401
# cls_pred (B * N, C) -> (B, N, C)
390402
cls_pred = cls_pred.reshape((batch_size, num_roi, self.num_class + 1))
391-
# box_pred (B * N, C * 4) -> (B, N, C, 4)
392-
box_pred = box_pred.reshape((batch_size, num_roi, self.num_class, 4))
393403

394404
# no need to convert bounding boxes in training, just return
395405
if autograd.is_training():
406+
cls_targets, box_targets, box_masks, indices = \
407+
self._target_generator(rpn_box, samples, matches, gt_label, gt_box)
408+
box_feat = F.reshape(box_feat.expand_dims(0), (batch_size, -1, 0))
409+
box_pred = self.box_predictor(F.concat(
410+
*[F.take(F.slice_axis(box_feat, axis=0, begin=i, end=i + 1).squeeze(),
411+
F.slice_axis(indices, axis=0, begin=i, end=i + 1).squeeze())
412+
for i in range(batch_size)], dim=0))
413+
# box_pred (B * N, C * 4) -> (B, N, C, 4)
414+
box_pred = box_pred.reshape((batch_size, -1, self.num_class, 4))
396415
if self._additional_output:
397-
return (cls_pred, box_pred, rpn_box, samples, matches,
398-
raw_rpn_score, raw_rpn_box, anchors, top_feat)
399-
return (cls_pred, box_pred, rpn_box, samples, matches,
400-
raw_rpn_score, raw_rpn_box, anchors)
416+
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
417+
anchors, cls_targets, box_targets, box_masks, top_feat, indices)
418+
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
419+
anchors, cls_targets, box_targets, box_masks, indices)
401420

421+
box_pred = self.box_predictor(box_feat)
422+
# box_pred (B * N, C * 4) -> (B, N, C, 4)
423+
box_pred = box_pred.reshape((batch_size, num_roi, self.num_class, 4))
402424
# cls_ids (B, N, C), scores (B, N, C)
403425
cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1))
404426
# cls_ids, scores (B, N, C) -> (B, C, N) -> (B, C, N, 1)
@@ -419,7 +441,7 @@ def _split(x, axis, num_outputs, squeeze_axis):
419441
results = []
420442
for rpn_box, cls_id, score, box_pred in zip(rpn_boxes, cls_ids, scores, box_preds):
421443
# box_pred (C, N, 4) rpn_box (1, N, 4) -> bbox (C, N, 4)
422-
bbox = self.box_decoder(box_pred, self.box_to_center(rpn_box))
444+
bbox = self.box_decoder(box_pred, rpn_box)
423445
# res (C, N, 6)
424446
res = F.concat(*[cls_id, score, bbox], dim=-1)
425447
if self.force_nms:
@@ -683,7 +705,7 @@ def faster_rcnn_fpn_bn_resnet50_v1b_coco(pretrained=False, pretrained_base=True,
683705
top_features = None
684706
# 1 Conv 1 FC layer before RCNN cls and reg
685707
box_features = nn.HybridSequential()
686-
box_features.add(nn.Conv2D(256, 3, padding=1),
708+
box_features.add(nn.Conv2D(256, 3, padding=1, use_bias=False),
687709
SyncBatchNorm(**gluon_norm_kwargs),
688710
nn.Activation('relu'),
689711
nn.Dense(1024, weight_initializer=mx.init.Normal(0.01)),

gluoncv/model_zoo/faster_rcnn/rcnn_target.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
4545
4646
Parameters
4747
----------
48-
rois: (B, self._num_input, 4) encoded in (x1, y1, x2, y2).
49-
scores: (B, self._num_input, 1), value range [0, 1] with ignore value -1.
48+
rois: (B, self._num_proposal, 4) encoded in (x1, y1, x2, y2).
49+
scores: (B, self._num_proposal, 1), value range [0, 1] with ignore value -1.
5050
gt_boxes: (B, M, 4) encoded in (x1, y1, x2, y2), invalid box should have area of 0.
5151
5252
Returns
@@ -65,7 +65,7 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
6565
roi = F.squeeze(F.slice_axis(rois, axis=0, begin=i, end=i + 1), axis=0)
6666
score = F.squeeze(F.slice_axis(scores, axis=0, begin=i, end=i + 1), axis=0)
6767
gt_box = F.squeeze(F.slice_axis(gt_boxes, axis=0, begin=i, end=i + 1), axis=0)
68-
gt_score = F.ones_like(F.sum(gt_box, axis=-1, keepdims=True))
68+
gt_score = F.sign(F.sum(gt_box, axis=-1, keepdims=True) + 1)
6969

7070
# concat rpn roi with ground truth. mix gt with generated boxes.
7171
all_roi = F.concat(roi, gt_box, dim=0)
@@ -126,9 +126,13 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
126126
samples = F.concat(topk_samples, bottomk_samples, dim=0)
127127
matches = F.concat(topk_matches, bottomk_matches, dim=0)
128128

129-
new_rois.append(all_roi.take(indices))
130-
new_samples.append(samples)
131-
new_matches.append(matches)
129+
sampled_rois = all_roi.take(indices)
130+
x1, y1, x2, y2 = F.split(sampled_rois, axis=-1, num_outputs=4, squeeze_axis=True)
131+
rois_area = (x2 - x1) * (y2 - y1)
132+
ind = F.argsort(rois_area)
133+
new_rois.append(sampled_rois.take(ind))
134+
new_samples.append(samples.take(ind))
135+
new_matches.append(matches.take(ind))
132136
# stack all samples together
133137
new_rois = F.stack(*new_rois, axis=0)
134138
new_samples = F.stack(*new_samples, axis=0)
@@ -143,18 +147,24 @@ class RCNNTargetGenerator(gluon.HybridBlock):
143147
----------
144148
num_class : int
145149
Number of total number of positive classes.
150+
max_pos : int, default is 128
151+
Upper bound of Number of positive samples.
152+
per_device_batch_size : int, default is 1
153+
Per device batch size
146154
means : iterable of float, default is (0., 0., 0., 0.)
147155
Mean values to be subtracted from regression targets.
148156
stds : iterable of float, default is (.1, .1, .2, .2)
149157
Standard deviations to be divided from regression targets.
150158
151159
"""
152160

153-
def __init__(self, num_class, means=(0., 0., 0., 0.), stds=(.1, .1, .2, .2)):
161+
def __init__(self, num_class, max_pos=128, per_device_batch_size=1, means=(0., 0., 0., 0.),
162+
stds=(.1, .1, .2, .2)):
154163
super(RCNNTargetGenerator, self).__init__()
155164
self._cls_encoder = MultiClassEncoder()
156165
self._box_encoder = NormalizedPerClassBoxCenterEncoder(
157-
num_class=num_class, means=means, stds=stds)
166+
num_class=num_class, max_pos=max_pos, per_device_batch_size=per_device_batch_size,
167+
means=means, stds=stds)
158168

159169
# pylint: disable=arguments-differ, unused-argument
160170
def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box):
@@ -179,6 +189,7 @@ def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box):
179189
# cls_target (B, N)
180190
cls_target = self._cls_encoder(samples, matches, gt_label)
181191
# box_target, box_weight (C, B, N, 4)
182-
box_target, box_mask = self._box_encoder(
183-
samples, matches, roi, gt_label, gt_box)
184-
return cls_target, box_target, box_mask
192+
box_target, box_mask, indices = self._box_encoder(samples, matches, roi, gt_label,
193+
gt_box)
194+
195+
return cls_target, box_target, box_mask, indices

0 commit comments

Comments
 (0)