Skip to content

Commit 70383d4

Browse files
authored
Merge pull request open-mmlab#327 from hellock/bboxes-ignore
Allow gt_bboxes_ignore for RPN and single-stage detectors
2 parents 85c30cc + f1d06cd commit 70383d4

File tree

8 files changed

+64
-18
lines changed

8 files changed

+64
-18
lines changed

mmdet/core/anchor/anchor_target.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def anchor_target(anchor_list,
1111
target_means,
1212
target_stds,
1313
cfg,
14+
gt_bboxes_ignore_list=None,
1415
gt_labels_list=None,
1516
label_channels=1,
1617
sampling=True,
@@ -41,6 +42,8 @@ def anchor_target(anchor_list,
4142
valid_flag_list[i] = torch.cat(valid_flag_list[i])
4243

4344
# compute targets for each image
45+
if gt_bboxes_ignore_list is None:
46+
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
4447
if gt_labels_list is None:
4548
gt_labels_list = [None for _ in range(num_imgs)]
4649
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
@@ -49,6 +52,7 @@ def anchor_target(anchor_list,
4952
anchor_list,
5053
valid_flag_list,
5154
gt_bboxes_list,
55+
gt_bboxes_ignore_list,
5256
gt_labels_list,
5357
img_metas,
5458
target_means=target_means,
@@ -90,6 +94,7 @@ def images_to_levels(target, num_level_anchors):
9094
def anchor_target_single(flat_anchors,
9195
valid_flags,
9296
gt_bboxes,
97+
gt_bboxes_ignore,
9398
gt_labels,
9499
img_meta,
95100
target_means,
@@ -108,11 +113,11 @@ def anchor_target_single(flat_anchors,
108113

109114
if sampling:
110115
assign_result, sampling_result = assign_and_sample(
111-
anchors, gt_bboxes, None, None, cfg)
116+
anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
112117
else:
113118
bbox_assigner = build_assigner(cfg.assigner)
114-
assign_result = bbox_assigner.assign(anchors, gt_bboxes, None,
115-
gt_labels)
119+
assign_result = bbox_assigner.assign(anchors, gt_bboxes,
120+
gt_bboxes_ignore, gt_labels)
116121
bbox_sampler = PseudoSampler()
117122
sampling_result = bbox_sampler.sample(assign_result, anchors,
118123
gt_bboxes)

mmdet/models/anchor_heads/anchor_head.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,14 @@ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
169169
avg_factor=num_total_samples)
170170
return loss_cls, loss_reg
171171

172-
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
173-
cfg):
172+
def loss(self,
173+
cls_scores,
174+
bbox_preds,
175+
gt_bboxes,
176+
gt_labels,
177+
img_metas,
178+
cfg,
179+
gt_bboxes_ignore=None):
174180
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
175181
assert len(featmap_sizes) == len(self.anchor_generators)
176182

@@ -186,6 +192,7 @@ def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
186192
self.target_means,
187193
self.target_stds,
188194
cfg,
195+
gt_bboxes_ignore_list=gt_bboxes_ignore,
189196
gt_labels_list=gt_labels,
190197
label_channels=label_channels,
191198
sampling=sampling)

mmdet/models/anchor_heads/rpn_head.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,21 @@ def forward_single(self, x):
3434
rpn_bbox_pred = self.rpn_reg(x)
3535
return rpn_cls_score, rpn_bbox_pred
3636

37-
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
38-
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
39-
None, img_metas, cfg)
37+
def loss(self,
38+
cls_scores,
39+
bbox_preds,
40+
gt_bboxes,
41+
img_metas,
42+
cfg,
43+
gt_bboxes_ignore=None):
44+
losses = super(RPNHead, self).loss(
45+
cls_scores,
46+
bbox_preds,
47+
gt_bboxes,
48+
None,
49+
img_metas,
50+
cfg,
51+
gt_bboxes_ignore=gt_bboxes_ignore)
4052
return dict(
4153
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
4254

mmdet/models/anchor_heads/ssd_head.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,14 @@ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
130130
avg_factor=num_total_samples)
131131
return loss_cls[None], loss_reg
132132

133-
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
134-
cfg):
133+
def loss(self,
134+
cls_scores,
135+
bbox_preds,
136+
gt_bboxes,
137+
gt_labels,
138+
img_metas,
139+
cfg,
140+
gt_bboxes_ignore=None):
135141
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
136142
assert len(featmap_sizes) == len(self.anchor_generators)
137143

@@ -145,6 +151,7 @@ def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
145151
self.target_means,
146152
self.target_stds,
147153
cfg,
154+
gt_bboxes_ignore_list=gt_bboxes_ignore,
148155
gt_labels_list=gt_labels,
149156
label_channels=1,
150157
sampling=False,

mmdet/models/detectors/cascade_rcnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def forward_train(self,
109109
img,
110110
img_meta,
111111
gt_bboxes,
112-
gt_bboxes_ignore,
113112
gt_labels,
113+
gt_bboxes_ignore=None,
114114
gt_masks=None,
115115
proposals=None):
116116
x = self.extract_feat(img)
@@ -121,7 +121,8 @@ def forward_train(self,
121121
rpn_outs = self.rpn_head(x)
122122
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
123123
self.train_cfg.rpn)
124-
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
124+
rpn_losses = self.rpn_head.loss(
125+
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
125126
losses.update(rpn_losses)
126127

127128
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)

mmdet/models/detectors/rpn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,20 @@ def extract_feat(self, img):
3838
x = self.neck(x)
3939
return x
4040

41-
def forward_train(self, img, img_meta, gt_bboxes=None):
41+
def forward_train(self,
42+
img,
43+
img_meta,
44+
gt_bboxes=None,
45+
gt_bboxes_ignore=None):
4246
if self.train_cfg.rpn.get('debug', False):
4347
self.rpn_head.debug_imgs = tensor2imgs(img)
4448

4549
x = self.extract_feat(img)
4650
rpn_outs = self.rpn_head(x)
4751

4852
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
49-
losses = self.rpn_head.loss(*rpn_loss_inputs)
53+
losses = self.rpn_head.loss(
54+
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
5055
return losses
5156

5257
def simple_test(self, img, img_meta, rescale=False):

mmdet/models/detectors/single_stage.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,17 @@ def extract_feat(self, img):
4242
x = self.neck(x)
4343
return x
4444

45-
def forward_train(self, img, img_metas, gt_bboxes, gt_labels):
45+
def forward_train(self,
46+
img,
47+
img_metas,
48+
gt_bboxes,
49+
gt_labels,
50+
gt_bboxes_ignore=None):
4651
x = self.extract_feat(img)
4752
outs = self.bbox_head(x)
4853
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
49-
losses = self.bbox_head.loss(*loss_inputs)
54+
losses = self.bbox_head.loss(
55+
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
5056
return losses
5157

5258
def simple_test(self, img, img_meta, rescale=False):

mmdet/models/detectors/two_stage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def forward_train(self,
8181
img,
8282
img_meta,
8383
gt_bboxes,
84-
gt_bboxes_ignore,
8584
gt_labels,
85+
gt_bboxes_ignore=None,
8686
gt_masks=None,
8787
proposals=None):
8888
x = self.extract_feat(img)
@@ -94,7 +94,8 @@ def forward_train(self,
9494
rpn_outs = self.rpn_head(x)
9595
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
9696
self.train_cfg.rpn)
97-
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
97+
rpn_losses = self.rpn_head.loss(
98+
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
9899
losses.update(rpn_losses)
99100

100101
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
@@ -108,6 +109,8 @@ def forward_train(self,
108109
bbox_sampler = build_sampler(
109110
self.train_cfg.rcnn.sampler, context=self)
110111
num_imgs = img.size(0)
112+
if gt_bboxes_ignore is None:
113+
gt_bboxes_ignore = [None for _ in range(num_imgs)]
111114
sampling_results = []
112115
for i in range(num_imgs):
113116
assign_result = bbox_assigner.assign(

0 commit comments

Comments
 (0)