Skip to content

Commit e491713

Browse files
yhcao6hellock
authored andcommitted
add reduction_override to BoundedIoULoss (open-mmlab#850)
1 parent 4a0d7ad commit e491713

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

mmdet/models/losses/iou_loss.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,25 @@ def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
111111
self.reduction = reduction
112112
self.loss_weight = loss_weight
113113

114-
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
114+
def forward(self,
115+
pred,
116+
target,
117+
weight=None,
118+
avg_factor=None,
119+
reduction_override=None,
120+
**kwargs):
115121
if weight is not None and not torch.any(weight > 0):
116122
return (pred * weight).sum() # 0
123+
assert reduction_override in (None, 'none', 'mean', 'sum')
124+
reduction = (
125+
reduction_override if reduction_override else self.reduction)
117126
loss = self.loss_weight * bounded_iou_loss(
118127
pred,
119128
target,
120129
weight,
121130
beta=self.beta,
122131
eps=self.eps,
123-
reduction=self.reduction,
132+
reduction=reduction,
124133
avg_factor=avg_factor,
125134
**kwargs)
126135
return loss

0 commit comments

Comments
 (0)