Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit 9896f95

Browse files
committed
code format
1 parent 9d9de99 commit 9896f95

File tree

17 files changed

+69
-93
lines changed

17 files changed

+69
-93
lines changed

DiceLoss/__init__.py

Whitespace-only changes.

FocalLoss/README.md

Lines changed: 0 additions & 12 deletions
This file was deleted.

FocalLoss/__init__.py

Whitespace-only changes.

TverskyLoss/__init__.py

Whitespace-only changes.

TverskyLoss/multitverskyloss.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

seg_loss/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .dice_loss import *
2+
from focal_loss import *
3+
from lovasz_loss import *
4+
from tverskyloss import *

DiceLoss/dice_loss.py renamed to seg_loss/dice_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ class BinaryDiceLoss(nn.Module):
4040
Exception if unexpected reduction
4141
"""
4242

43-
def __init__(self, ignore_index=None, reduction='mean',**kwargs):
43+
def __init__(self, ignore_index=None, reduction='mean', **kwargs):
4444
super(BinaryDiceLoss, self).__init__()
45-
self.smooth = 1 # suggest set a large number when target area is large,like '10|100'
45+
self.smooth = 1 # suggest set a large number when target area is large,like '10|100'
4646
self.ignore_index = ignore_index
4747
self.reduction = reduction
48-
self.batch_dice = False # treat a large map when True
48+
self.batch_dice = False # treat a large map when True
4949
if 'batch_loss' in kwargs.keys():
5050
self.batch_dice = kwargs['batch_loss']
5151

@@ -59,7 +59,7 @@ def forward(self, output, target, use_sigmoid=True):
5959
output = output.mul(validmask) # can not use inplace for bp
6060
target = target.float().mul(validmask)
6161

62-
dim0= output.shape[0]
62+
dim0 = output.shape[0]
6363
if self.batch_dice:
6464
dim0 = 1
6565

@@ -212,7 +212,7 @@ def test():
212212
model = nn.Conv3d(1, 4, 3, padding=1)
213213
target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float()
214214
target = make_one_hot(target, num_classes=4)
215-
criterion = DiceLoss(ignore_index=[2,3], reduction='mean')
215+
criterion = DiceLoss(ignore_index=[2, 3], reduction='mean')
216216
loss = criterion(model(input), target)
217217
loss.backward()
218218
print(loss.item())
File renamed without changes.
File renamed without changes.

TverskyLoss/binarytverskyloss.py renamed to seg_loss/tverskyloss.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import torch
22
import torch.nn as nn
3-
import torch.nn.functional as F
4-
from torch.autograd import Function
53

6-
from torch.autograd import Variable
4+
from torch.autograd import Function
75

86

97
class FocalBinaryTverskyLoss(Function):
108

11-
129
@staticmethod
1310
def forward(ctx, input, target):
1411
_alpha = 0.5
@@ -142,3 +139,46 @@ def forward(self, output, target, mask=None):
142139
else:
143140
loss = torch.mean(loss)
144141
return loss
142+
143+
144+
class MultiTverskyLoss(nn.Module):
145+
"""
146+
Tversky Loss for segmentation adaptive with multi class segmentation
147+
"""
148+
149+
def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, weights=None):
150+
"""
151+
:param alpha (Tensor, float, optional): controls the penalty for false positives.
152+
:param beta (Tensor, float, optional): controls the penalty for false negative.
153+
:param gamma (Tensor, float, optional): focal coefficient
154+
:param weights (Tensor, optional): a manual rescaling weight given to each
155+
class. If given, it has to be a Tensor of size `C`
156+
157+
"""
158+
super(MultiTverskyLoss, self).__init__()
159+
self.alpha = alpha
160+
self.beta = beta
161+
self.gamma = gamma
162+
self.weights = weights
163+
164+
def forward(self, inputs, targets):
165+
166+
num_class = inputs.size(1)
167+
weight_losses = 0.0
168+
if self.weights is not None:
169+
assert len(self.weights) == num_class, 'number of classes should be equal to length of weights '
170+
weights = self.weights
171+
else:
172+
weights = [1.0 / num_class] * num_class
173+
input_slices = torch.split(inputs, [1] * num_class, dim=1)
174+
for idx in range(num_class):
175+
input_idx = input_slices[idx]
176+
input_idx = torch.cat((1 - input_idx, input_idx), dim=1)
177+
target_idx = (targets == idx) * 1
178+
loss_func = FocalBinaryTverskyLoss(self.alpha, self.beta, self.gamma)
179+
loss_idx = loss_func.apply(input_idx, target_idx)
180+
weight_losses += loss_idx * weights[idx]
181+
# loss = torch.Tensor(weight_losses)
182+
# loss = loss.to(inputs.device)
183+
# loss = torch.sum(loss)
184+
return weight_losses

0 commit comments

Comments
 (0)