Skip to content

Commit 1b41989

Browse files
RockeyCossJunjun2016xvjiarui
authored
[Feature] Add focal loss (open-mmlab#1024)
* [Feature] add focal loss * fix the bug of 'non' reduction type * refine the implementation * add class_weight and ignore_index; support different alpha values for different classes * fixed some bugs * fix bugs * add comments * modify test * Update mmseg/models/losses/focal_loss.py Co-authored-by: Junjun2016 <[email protected]> * update test_focal_loss.py * modified the implementation * Update mmseg/models/losses/focal_loss.py Co-authored-by: Jerry Jiarui XU <[email protected]> * update focal_loss.py Co-authored-by: Junjun2016 <[email protected]> Co-authored-by: Jerry Jiarui XU <[email protected]>
1 parent f0e6201 commit 1b41989

File tree

3 files changed

+546
-1
lines changed

3 files changed

+546
-1
lines changed

mmseg/models/losses/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
44
cross_entropy, mask_cross_entropy)
55
from .dice_loss import DiceLoss
6+
from .focal_loss import FocalLoss
67
from .lovasz_loss import LovaszLoss
78
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
89

910
__all__ = [
1011
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
1112
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
12-
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
13+
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
14+
'FocalLoss'
1315
]

mmseg/models/losses/focal_loss.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
# Modified from https://github.com/open-mmlab/mmdetection
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
7+
8+
from ..builder import LOSSES
9+
from .utils import weight_reduce_loss
10+
11+
12+
# This method is used when cuda is not available
13+
def py_sigmoid_focal_loss(pred,
14+
target,
15+
one_hot_target=None,
16+
weight=None,
17+
gamma=2.0,
18+
alpha=0.5,
19+
class_weight=None,
20+
valid_mask=None,
21+
reduction='mean',
22+
avg_factor=None):
23+
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
24+
25+
Args:
26+
pred (torch.Tensor): The prediction with shape (N, C), C is the
27+
number of classes
28+
target (torch.Tensor): The learning label of the prediction with
29+
shape (N, C)
30+
one_hot_target (None): Placeholder. It should be None.
31+
weight (torch.Tensor, optional): Sample-wise loss weight.
32+
gamma (float, optional): The gamma for calculating the modulating
33+
factor. Defaults to 2.0.
34+
alpha (float | list[float], optional): A balanced form for Focal Loss.
35+
Defaults to 0.5.
36+
class_weight (list[float], optional): Weight of each class.
37+
Defaults to None.
38+
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
39+
samples and uses 0 to mark the ignored samples. Default: None.
40+
reduction (str, optional): The method used to reduce the loss into
41+
a scalar. Defaults to 'mean'.
42+
avg_factor (int, optional): Average factor that is used to average
43+
the loss. Defaults to None.
44+
"""
45+
if isinstance(alpha, list):
46+
alpha = pred.new_tensor(alpha)
47+
pred_sigmoid = pred.sigmoid()
48+
target = target.type_as(pred)
49+
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
50+
focal_weight = (alpha * target + (1 - alpha) *
51+
(1 - target)) * one_minus_pt.pow(gamma)
52+
53+
loss = F.binary_cross_entropy_with_logits(
54+
pred, target, reduction='none') * focal_weight
55+
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
56+
if weight is not None:
57+
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
58+
# For most cases, weight is of shape (N, ),
59+
# which means it does not have the second axis num_class
60+
weight = weight.view(-1, 1)
61+
assert weight.dim() == loss.dim()
62+
final_weight = final_weight * weight
63+
if class_weight is not None:
64+
final_weight = final_weight * pred.new_tensor(class_weight)
65+
if valid_mask is not None:
66+
final_weight = final_weight * valid_mask
67+
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
68+
return loss
69+
70+
71+
def sigmoid_focal_loss(pred,
72+
target,
73+
one_hot_target,
74+
weight=None,
75+
gamma=2.0,
76+
alpha=0.5,
77+
class_weight=None,
78+
valid_mask=None,
79+
reduction='mean',
80+
avg_factor=None):
81+
r"""A warpper of cuda version `Focal Loss
82+
<https://arxiv.org/abs/1708.02002>`_.
83+
Args:
84+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
85+
of classes.
86+
target (torch.Tensor): The learning label of the prediction. It's shape
87+
should be (N, )
88+
one_hot_target (torch.Tensor): The learning label with shape (N, C)
89+
weight (torch.Tensor, optional): Sample-wise loss weight.
90+
gamma (float, optional): The gamma for calculating the modulating
91+
factor. Defaults to 2.0.
92+
alpha (float | list[float], optional): A balanced form for Focal Loss.
93+
Defaults to 0.5.
94+
class_weight (list[float], optional): Weight of each class.
95+
Defaults to None.
96+
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
97+
samples and uses 0 to mark the ignored samples. Default: None.
98+
reduction (str, optional): The method used to reduce the loss into
99+
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
100+
avg_factor (int, optional): Average factor that is used to average
101+
the loss. Defaults to None.
102+
"""
103+
# Function.apply does not accept keyword arguments, so the decorator
104+
# "weighted_loss" is not applicable
105+
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
106+
if isinstance(alpha, list):
107+
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
108+
# a list is given, we set the input alpha as 0.5. This means setting
109+
# equal weight for foreground class and background class. By
110+
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
111+
# undone. The alpha of type list is used to regulate the loss in the
112+
# post-processing process.
113+
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
114+
gamma, 0.5, None, 'none') * 2
115+
alpha = pred.new_tensor(alpha)
116+
final_weight = final_weight * (
117+
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
118+
else:
119+
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
120+
gamma, alpha, None, 'none')
121+
if weight is not None:
122+
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
123+
# For most cases, weight is of shape (N, ),
124+
# which means it does not have the second axis num_class
125+
weight = weight.view(-1, 1)
126+
assert weight.dim() == loss.dim()
127+
final_weight = final_weight * weight
128+
if class_weight is not None:
129+
final_weight = final_weight * pred.new_tensor(class_weight)
130+
if valid_mask is not None:
131+
final_weight = final_weight * valid_mask
132+
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
133+
return loss
134+
135+
136+
@LOSSES.register_module()
137+
class FocalLoss(nn.Module):
138+
139+
def __init__(self,
140+
use_sigmoid=True,
141+
gamma=2.0,
142+
alpha=0.5,
143+
reduction='mean',
144+
class_weight=None,
145+
loss_weight=1.0,
146+
loss_name='loss_focal'):
147+
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
148+
Args:
149+
use_sigmoid (bool, optional): Whether to the prediction is
150+
used for sigmoid or softmax. Defaults to True.
151+
gamma (float, optional): The gamma for calculating the modulating
152+
factor. Defaults to 2.0.
153+
alpha (float | list[float], optional): A balanced form for Focal
154+
Loss. Defaults to 0.5. When a list is provided, the length
155+
of the list should be equal to the number of classes.
156+
Please be careful that this parameter is not the
157+
class-wise weight but the weight of a binary classification
158+
problem. This binary classification problem regards the
159+
pixels which belong to one class as the foreground
160+
and the other pixels as the background, each element in
161+
the list is the weight of the corresponding foreground class.
162+
The value of alpha or each element of alpha should be a float
163+
in the interval [0, 1]. If you want to specify the class-wise
164+
weight, please use `class_weight` parameter.
165+
reduction (str, optional): The method used to reduce the loss into
166+
a scalar. Defaults to 'mean'. Options are "none", "mean" and
167+
"sum".
168+
class_weight (list[float], optional): Weight of each class.
169+
Defaults to None.
170+
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
171+
loss_name (str, optional): Name of the loss item. If you want this
172+
loss item to be included into the backward graph, `loss_` must
173+
be the prefix of the name. Defaults to 'loss_focal'.
174+
"""
175+
super(FocalLoss, self).__init__()
176+
assert use_sigmoid is True, \
177+
'AssertionError: Only sigmoid focal loss supported now.'
178+
assert reduction in ('none', 'mean', 'sum'), \
179+
"AssertionError: reduction should be 'none', 'mean' or " \
180+
"'sum'"
181+
assert isinstance(alpha, (float, list)), \
182+
'AssertionError: alpha should be of type float'
183+
assert isinstance(gamma, float), \
184+
'AssertionError: gamma should be of type float'
185+
assert isinstance(loss_weight, float), \
186+
'AssertionError: loss_weight should be of type float'
187+
assert isinstance(loss_name, str), \
188+
'AssertionError: loss_name should be of type str'
189+
assert isinstance(class_weight, list) or class_weight is None, \
190+
'AssertionError: class_weight must be None or of type list'
191+
self.use_sigmoid = use_sigmoid
192+
self.gamma = gamma
193+
self.alpha = alpha
194+
self.reduction = reduction
195+
self.class_weight = class_weight
196+
self.loss_weight = loss_weight
197+
self._loss_name = loss_name
198+
199+
def forward(self,
200+
pred,
201+
target,
202+
weight=None,
203+
avg_factor=None,
204+
reduction_override=None,
205+
ignore_index=255,
206+
**kwargs):
207+
"""Forward function.
208+
209+
Args:
210+
pred (torch.Tensor): The prediction with shape
211+
(N, C) where C = number of classes, or
212+
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
213+
case of K-dimensional loss.
214+
target (torch.Tensor): The ground truth. If containing class
215+
indices, shape (N) where each value is 0≤targets[i]≤C−1,
216+
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
217+
K-dimensional loss. If containing class probabilities,
218+
same shape as the input.
219+
weight (torch.Tensor, optional): The weight of loss for each
220+
prediction. Defaults to None.
221+
avg_factor (int, optional): Average factor that is used to
222+
average the loss. Defaults to None.
223+
reduction_override (str, optional): The reduction method used
224+
to override the original reduction method of the loss.
225+
Options are "none", "mean" and "sum".
226+
ignore_index (int, optional): The label index to be ignored.
227+
Default: 255
228+
Returns:
229+
torch.Tensor: The calculated loss
230+
"""
231+
assert isinstance(ignore_index, int), \
232+
'ignore_index must be of type int'
233+
assert reduction_override in (None, 'none', 'mean', 'sum'), \
234+
"AssertionError: reduction should be 'none', 'mean' or " \
235+
"'sum'"
236+
assert pred.shape == target.shape or \
237+
(pred.size(0) == target.size(0) and
238+
pred.shape[2:] == target.shape[1:]), \
239+
"The shape of pred doesn't match the shape of target"
240+
241+
original_shape = pred.shape
242+
243+
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
244+
pred = pred.transpose(0, 1)
245+
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
246+
pred = pred.reshape(pred.size(0), -1)
247+
# [C, N] -> [N, C]
248+
pred = pred.transpose(0, 1).contiguous()
249+
250+
if original_shape == target.shape:
251+
# target with shape [B, C, d_1, d_2, ...]
252+
# transform it's shape into [N, C]
253+
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
254+
target = target.transpose(0, 1)
255+
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
256+
target = target.reshape(target.size(0), -1)
257+
# [C, N] -> [N, C]
258+
target = target.transpose(0, 1).contiguous()
259+
else:
260+
# target with shape [B, d_1, d_2, ...]
261+
# transform it's shape into [N, ]
262+
target = target.view(-1).contiguous()
263+
valid_mask = (target != ignore_index).view(-1, 1)
264+
# avoid raising error when using F.one_hot()
265+
target = torch.where(target == ignore_index, target.new_tensor(0),
266+
target)
267+
268+
reduction = (
269+
reduction_override if reduction_override else self.reduction)
270+
if self.use_sigmoid:
271+
num_classes = pred.size(1)
272+
if torch.cuda.is_available() and pred.is_cuda:
273+
if target.dim() == 1:
274+
one_hot_target = F.one_hot(target, num_classes=num_classes)
275+
else:
276+
one_hot_target = target
277+
target = target.argmax(dim=1)
278+
valid_mask = (target != ignore_index).view(-1, 1)
279+
calculate_loss_func = sigmoid_focal_loss
280+
else:
281+
one_hot_target = None
282+
if target.dim() == 1:
283+
target = F.one_hot(target, num_classes=num_classes)
284+
else:
285+
valid_mask = (target.argmax(dim=1) != ignore_index).view(
286+
-1, 1)
287+
calculate_loss_func = py_sigmoid_focal_loss
288+
289+
loss_cls = self.loss_weight * calculate_loss_func(
290+
pred,
291+
target,
292+
one_hot_target,
293+
weight,
294+
gamma=self.gamma,
295+
alpha=self.alpha,
296+
class_weight=self.class_weight,
297+
valid_mask=valid_mask,
298+
reduction=reduction,
299+
avg_factor=avg_factor)
300+
301+
if reduction == 'none':
302+
# [N, C] -> [C, N]
303+
loss_cls = loss_cls.transpose(0, 1)
304+
# [C, N] -> [C, B, d1, d2, ...]
305+
# original_shape: [B, C, d1, d2, ...]
306+
loss_cls = loss_cls.reshape(original_shape[1],
307+
original_shape[0],
308+
*original_shape[2:])
309+
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
310+
loss_cls = loss_cls.transpose(0, 1).contiguous()
311+
else:
312+
raise NotImplementedError
313+
return loss_cls
314+
315+
@property
316+
def loss_name(self):
317+
"""Loss Name.
318+
319+
This function must be implemented and will return the name of this
320+
loss function. This name will be used to combine different loss items
321+
by simple sum operation. In addition, if you want this loss item to be
322+
included into the backward graph, `loss_` must be the prefix of the
323+
name.
324+
Returns:
325+
str: The name of this loss item.
326+
"""
327+
return self._loss_name

0 commit comments

Comments
 (0)