Skip to content

Commit acff839

Browse files
authored
[Feature] Support Tversky Loss (open-mmlab#1986)
1 parent c5259a0 commit acff839

File tree

3 files changed

+215
-1
lines changed

3 files changed

+215
-1
lines changed

mmseg/models/losses/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from .dice_loss import DiceLoss
66
from .focal_loss import FocalLoss
77
from .lovasz_loss import LovaszLoss
8+
from .tversky_loss import TverskyLoss
89
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
910

1011
__all__ = [
1112
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
1213
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
1314
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
14-
'FocalLoss'
15+
'FocalLoss', 'TverskyLoss'
1516
]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Modified from
3+
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333
4+
(Apache-2.0 License)"""
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
from ..builder import LOSSES
10+
from .utils import get_class_weight, weighted_loss
11+
12+
13+
@weighted_loss
14+
def tversky_loss(pred,
15+
target,
16+
valid_mask,
17+
alpha=0.3,
18+
beta=0.7,
19+
smooth=1,
20+
class_weight=None,
21+
ignore_index=255):
22+
assert pred.shape[0] == target.shape[0]
23+
total_loss = 0
24+
num_classes = pred.shape[1]
25+
for i in range(num_classes):
26+
if i != ignore_index:
27+
tversky_loss = binary_tversky_loss(
28+
pred[:, i],
29+
target[..., i],
30+
valid_mask=valid_mask,
31+
alpha=alpha,
32+
beta=beta,
33+
smooth=smooth)
34+
if class_weight is not None:
35+
tversky_loss *= class_weight[i]
36+
total_loss += tversky_loss
37+
return total_loss / num_classes
38+
39+
40+
@weighted_loss
41+
def binary_tversky_loss(pred,
42+
target,
43+
valid_mask,
44+
alpha=0.3,
45+
beta=0.7,
46+
smooth=1):
47+
assert pred.shape[0] == target.shape[0]
48+
pred = pred.reshape(pred.shape[0], -1)
49+
target = target.reshape(target.shape[0], -1)
50+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
51+
52+
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1)
53+
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1)
54+
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1)
55+
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
56+
57+
return 1 - tversky
58+
59+
60+
@LOSSES.register_module()
61+
class TverskyLoss(nn.Module):
62+
"""TverskyLoss. This loss is proposed in `Tversky loss function for image
63+
segmentation using 3D fully convolutional deep networks.
64+
65+
<https://arxiv.org/abs/1706.05721>`_.
66+
Args:
67+
smooth (float): A float number to smooth loss, and avoid NaN error.
68+
Default: 1.
69+
class_weight (list[float] | str, optional): Weight of each class. If in
70+
str format, read them from a file. Defaults to None.
71+
loss_weight (float, optional): Weight of the loss. Default to 1.0.
72+
ignore_index (int | None): The label index to be ignored. Default: 255.
73+
alpha(float, in [0, 1]):
74+
The coefficient of false positives. Default: 0.3.
75+
beta (float, in [0, 1]):
76+
The coefficient of false negatives. Default: 0.7.
77+
Note: alpha + beta = 1.
78+
loss_name (str, optional): Name of the loss item. If you want this loss
79+
item to be included into the backward graph, `loss_` must be the
80+
prefix of the name. Defaults to 'loss_tversky'.
81+
"""
82+
83+
def __init__(self,
84+
smooth=1,
85+
class_weight=None,
86+
loss_weight=1.0,
87+
ignore_index=255,
88+
alpha=0.3,
89+
beta=0.7,
90+
loss_name='loss_tversky'):
91+
super(TverskyLoss, self).__init__()
92+
self.smooth = smooth
93+
self.class_weight = get_class_weight(class_weight)
94+
self.loss_weight = loss_weight
95+
self.ignore_index = ignore_index
96+
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!'
97+
self.alpha = alpha
98+
self.beta = beta
99+
self._loss_name = loss_name
100+
101+
def forward(self, pred, target, **kwargs):
102+
if self.class_weight is not None:
103+
class_weight = pred.new_tensor(self.class_weight)
104+
else:
105+
class_weight = None
106+
107+
pred = F.softmax(pred, dim=1)
108+
num_classes = pred.shape[1]
109+
one_hot_target = F.one_hot(
110+
torch.clamp(target.long(), 0, num_classes - 1),
111+
num_classes=num_classes)
112+
valid_mask = (target != self.ignore_index).long()
113+
114+
loss = self.loss_weight * tversky_loss(
115+
pred,
116+
one_hot_target,
117+
valid_mask=valid_mask,
118+
alpha=self.alpha,
119+
beta=self.beta,
120+
smooth=self.smooth,
121+
class_weight=class_weight,
122+
ignore_index=self.ignore_index)
123+
return loss
124+
125+
@property
126+
def loss_name(self):
127+
"""Loss Name.
128+
129+
This function must be implemented and will return the name of this
130+
loss function. This name will be used to combine different loss items
131+
by simple sum operation. In addition, if you want this loss item to be
132+
included into the backward graph, `loss_` must be the prefix of the
133+
name.
134+
Returns:
135+
str: The name of this loss item.
136+
"""
137+
return self._loss_name
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
5+
6+
def test_tversky_lose():
7+
from mmseg.models import build_loss
8+
9+
# test alpha + beta != 1
10+
with pytest.raises(AssertionError):
11+
loss_cfg = dict(
12+
type='TverskyLoss',
13+
class_weight=[1.0, 2.0, 3.0],
14+
loss_weight=1.0,
15+
alpha=0.4,
16+
beta=0.7,
17+
loss_name='loss_tversky')
18+
tversky_loss = build_loss(loss_cfg)
19+
logits = torch.rand(8, 3, 4, 4)
20+
labels = (torch.rand(8, 4, 4) * 3).long()
21+
tversky_loss(logits, labels, ignore_index=1)
22+
23+
# test tversky loss
24+
loss_cfg = dict(
25+
type='TverskyLoss',
26+
class_weight=[1.0, 2.0, 3.0],
27+
loss_weight=1.0,
28+
ignore_index=1,
29+
loss_name='loss_tversky')
30+
tversky_loss = build_loss(loss_cfg)
31+
logits = torch.rand(8, 3, 4, 4)
32+
labels = (torch.rand(8, 4, 4) * 3).long()
33+
tversky_loss(logits, labels)
34+
35+
# test loss with class weights from file
36+
import os
37+
import tempfile
38+
39+
import mmcv
40+
import numpy as np
41+
tmp_file = tempfile.NamedTemporaryFile()
42+
43+
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
44+
loss_cfg = dict(
45+
type='TverskyLoss',
46+
class_weight=f'{tmp_file.name}.pkl',
47+
loss_weight=1.0,
48+
ignore_index=1,
49+
loss_name='loss_tversky')
50+
tversky_loss = build_loss(loss_cfg)
51+
tversky_loss(logits, labels)
52+
53+
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
54+
loss_cfg = dict(
55+
type='TverskyLoss',
56+
class_weight=f'{tmp_file.name}.pkl',
57+
loss_weight=1.0,
58+
ignore_index=1,
59+
loss_name='loss_tversky')
60+
tversky_loss = build_loss(loss_cfg)
61+
tversky_loss(logits, labels)
62+
tmp_file.close()
63+
os.remove(f'{tmp_file.name}.pkl')
64+
os.remove(f'{tmp_file.name}.npy')
65+
66+
# test tversky loss has name `loss_tversky`
67+
loss_cfg = dict(
68+
type='TverskyLoss',
69+
smooth=2,
70+
loss_weight=1.0,
71+
ignore_index=1,
72+
alpha=0.3,
73+
beta=0.7,
74+
loss_name='loss_tversky')
75+
tversky_loss = build_loss(loss_cfg)
76+
assert tversky_loss.loss_name == 'loss_tversky'

0 commit comments

Comments
 (0)