Skip to content

Commit cb68807

Browse files
committed
remove expand loop in bbox head to speed up
1 parent d5a6b5d commit cb68807

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

mmdet/core/bbox/bbox_target.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes,
5757
bbox_weights[:num_pos, :] = 1
5858
if num_neg > 0:
5959
label_weights[-num_neg:] = 1.0
60-
if reg_classes > 1:
61-
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
62-
labels, reg_classes)
6360

6461
return labels, label_weights, bbox_targets, bbox_weights
6562

mmdet/models/bbox_heads/bbox_head.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
54
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
65
weighted_cross_entropy, weighted_smoothl1, accuracy)
6+
77
from ..registry import HEADS
88

99

@@ -94,10 +94,16 @@ def loss(self,
9494
cls_score, labels, label_weights, reduce=reduce)
9595
losses['acc'] = accuracy(cls_score, labels)
9696
if bbox_pred is not None:
97+
pos_mask = labels > 0
98+
if self.reg_class_agnostic:
99+
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask]
100+
else:
101+
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
102+
4)[pos_mask, labels[pos_mask]]
97103
losses['loss_reg'] = weighted_smoothl1(
98-
bbox_pred,
99-
bbox_targets,
100-
bbox_weights,
104+
pos_bbox_pred,
105+
bbox_targets[pos_mask],
106+
bbox_weights[pos_mask],
101107
avg_factor=bbox_targets.size(0))
102108
return losses
103109

0 commit comments

Comments
 (0)