Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit dc9d27d

Browse files
committed
add dice_loss
1 parent 89cfd0a commit dc9d27d

File tree

5 files changed

+180
-4
lines changed

5 files changed

+180
-4
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
ProstateX-0002/
12
# Created by .ignore support plugin (hsz.mobi)
23
### Python template
34
# Byte-compiled / optimized / DLL files
4-
ProstateX-0002/
5-
DiceLoss/
65
__pycache__/
76
*.py[cod]
87
*$py.class

DiceLoss/__init__.py

Whitespace-only changes.

DiceLoss/dice_loss.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
7+
def make_one_hot(input, num_classes=None):
8+
"""Convert class index tensor to one hot encoding tensor.
9+
10+
Args:
11+
input: A tensor of shape [N, 1, *]
12+
num_classes: An int of number of class
13+
Shapes:
14+
predict: A tensor of shape [N, *] without sigmoid activation function applied
15+
target: A tensor of shape same with predict
16+
Returns:
17+
A tensor of shape [N, num_classes, *]
18+
"""
19+
if num_classes is None:
20+
num_classes = input.max() + 1
21+
shape = np.array(input.shape)
22+
shape[1] = num_classes
23+
shape = tuple(shape)
24+
result = torch.zeros(shape)
25+
result = result.scatter_(1, input.cpu(), 1)
26+
27+
return result
28+
29+
30+
class BinaryDiceLoss(nn.Module):
31+
"""Dice loss of binary class
32+
Args:
33+
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
34+
reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
35+
Shapes:
36+
output: A tensor of shape [N, *] without sigmoid activation function applied
37+
target: A tensor of shape same with output
38+
Returns:
39+
Loss tensor according to arg reduction
40+
Raise:
41+
Exception if unexpected reduction
42+
"""
43+
44+
def __init__(self, ignore_index=None, reduction='mean'):
45+
super(BinaryDiceLoss, self).__init__()
46+
self.smooth = 1
47+
self.ignore_index = ignore_index
48+
self.reduction = reduction
49+
50+
def forward(self, output, target):
51+
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
52+
output = torch.sigmoid(output)
53+
54+
if self.ignore_index is not None:
55+
validmask = (target != self.ignore_index).float()
56+
output = output.mul(validmask) # can not use inplace for bp
57+
target = target.float().mul(validmask)
58+
59+
output = output.contiguous().view(output.shape[0], -1)
60+
target = target.contiguous().view(target.shape[0], -1).float()
61+
62+
num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
63+
den = torch.sum(output.pow(2) + target.pow(2), dim=1) + self.smooth
64+
65+
loss = 1 - (num / den)
66+
67+
if self.reduction == 'mean':
68+
return loss.mean()
69+
elif self.reduction == 'sum':
70+
return loss.sum()
71+
elif self.reduction == 'none':
72+
return loss
73+
else:
74+
raise Exception('Unexpected reduction {}'.format(self.reduction))
75+
76+
77+
class DiceLoss(nn.Module):
78+
"""Dice loss, need one hot encode input
79+
Args:
80+
weight: An array of shape [num_classes,]
81+
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
82+
output: A tensor of shape [N, C, *]
83+
target: A tensor of same shape with output
84+
other args pass to BinaryDiceLoss
85+
Return:
86+
same as BinaryDiceLoss
87+
"""
88+
89+
def __init__(self, weight=None, ignore_index=[], **kwargs):
90+
super(DiceLoss, self).__init__()
91+
self.kwargs = kwargs
92+
self.weight = weight
93+
if isinstance(ignore_index, int):
94+
self.ignore_index = [ignore_index]
95+
elif ignore_index is None:
96+
self.ignore_index = []
97+
self.ignore_index = ignore_index
98+
99+
def forward(self, output, target):
100+
assert output.shape == target.shape, 'output & target shape do not match'
101+
dice = BinaryDiceLoss(**self.kwargs)
102+
total_loss = 0
103+
output = F.softmax(output, dim=1)
104+
for i in range(target.shape[1]):
105+
if i not in self.ignore_index:
106+
dice_loss = dice(output[:, i], target[:, i])
107+
if self.weight is not None:
108+
assert self.weight.shape[0] == target.shape[1], \
109+
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
110+
dice_loss *= self.weights[i]
111+
total_loss += (dice_loss)
112+
loss = total_loss / (target.size(1) - len(self.ignore_index))
113+
return loss
114+
115+
116+
class BCE_DiceLoss(nn.Module):
117+
def __init__(self, alpha=0.5, ignore_index=None, reduction='mean'):
118+
"""
119+
combination of Binary Cross Entropy and Binary Dice Loss
120+
Args:
121+
@param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
122+
@param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
123+
@param alpha: weight between BCE('Binary Cross Entropy') and binary dice
124+
Shapes:
125+
output: A tensor of shape [N, *] without sigmoid activation function applied
126+
target: A tensor of shape same with output
127+
"""
128+
super(BCE_DiceLoss, self).__init__()
129+
assert reduction in ['none', 'mean', 'sum']
130+
assert (alpha >= 0 and alpha <= 1), '`alpha` should in [0,1]'
131+
self.alpha = alpha
132+
self.ignore_index = ignore_index
133+
self.reduction = reduction
134+
self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction)
135+
self.bce = nn.BCEWithLogitsLoss(reduction=reduction)
136+
137+
def forward(self, output, target):
138+
dice_loss = self.dice(output, target)
139+
140+
if self.ignore_index is not None:
141+
mask = (target != self.ignore_index).float()
142+
output = output.mul(mask) # can not use inplace for bp
143+
target = target.float().mul(mask)
144+
bce_loss = self.bce(output, target)
145+
loss = self.alpha * bce_loss + (1.0 - self.alpha) * dice_loss
146+
147+
if self.reduction == 'mean':
148+
return loss.mean()
149+
elif self.reduction == 'sum':
150+
return loss.sum()
151+
elif self.reduction == 'none':
152+
return loss
153+
else:
154+
raise Exception('Unexpected reduction {}'.format(self.reduction))
155+
156+
157+
def test():
158+
input = torch.rand((1, 1, 32, 32, 32))
159+
model = nn.Conv3d(1, 1, 3, padding=1)
160+
target = torch.randint(0, 3, (1, 1, 32, 32, 32)).float()
161+
criterion = BCE_DiceLoss(ignore_index=2, reduction='none')
162+
loss = criterion(model(input), target)
163+
loss.backward()
164+
print(loss.item())
165+
166+
# input = torch.zeros((1, 2, 32, 32, 32))
167+
# input[:, 0, ...] = 1
168+
# target = torch.ones((1, 1, 32, 32, 32)).long()
169+
# target_one_hot = make_one_hot(target, num_classes=2)
170+
# # print(target_one_hot.size())
171+
# criterion = DiceLoss()
172+
# loss = criterion(input, target_one_hot)
173+
# print(loss.item())
174+
175+
176+
if __name__ == '__main__':
177+
test()

FocalLoss/FocalLoss_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
from FocalLoss import FocalLoss
7+
from focalloss import FocalLoss
88

99

1010
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Loss_ToolBox
55
This repository include several losses for 3D image segmentation.
66
1. [Focal Loss](https://arxiv.org/abs/1708.02002) (PS:Borrow some code from [c0nn3r/RetinaNet](https://github.com/c0nn3r/RetinaNet))
77
2. [Lovasz-Softmax Loss](https://arxiv.org/abs/1705.08790)(Modify from orinial implementation [LovaszSoftmax](https://github.com/bermanmaxim/LovaszSoftmax))
8-
3. [DiceLoss](https://arxiv.org/abs/1606.04797) (To Be Released)
8+
3. [DiceLoss](https://arxiv.org/abs/1606.04797)
99

0 commit comments

Comments
 (0)