Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit fd6f752

Browse files
committed
update binarytvserskylossv2
1 parent dc9d27d commit fd6f752

File tree

1 file changed

+63
-6
lines changed

1 file changed

+63
-6
lines changed

TverskyLoss/binarytverskyloss.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean'):
1616
:param gamma : focal coefficient range[1,3]
1717
:param reduction: return mode
1818
Notes:
19-
alpha = beta = 0.5 => dice coeff
20-
alpha = beta = 1 => tanimoto coeff
21-
alpha + beta = 1 => F beta coeff
19+
alpha = beta = 0.5 -> dice coeff
20+
alpha = beta = 1 -> tanimoto coeff
21+
alpha + beta = 1 -> F beta coeff
2222
add focal index -> loss=(1-T_index)**(1/gamma)
2323
2424
"""
@@ -27,10 +27,10 @@ def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean'):
2727
ctx.epsilon = 1e-6
2828
ctx.reduction = reduction
2929
ctx.gamma = gamma
30-
sum = ctx.beta + ctx.alpha
30+
s = ctx.beta + ctx.alpha
3131
if sum != 1:
32-
ctx.beta = ctx.beta / sum
33-
ctx.alpha = ctx.alpha / sum
32+
ctx.beta = ctx.beta / s
33+
ctx.alpha = ctx.alpha / s
3434

3535
# @staticmethod
3636
def forward(ctx, input, target):
@@ -94,3 +94,60 @@ def backward(ctx, grad_out):
9494
grad_input = torch.cat((dL_dp1, dL_dp0), dim=1)
9595
# grad_input = torch.cat((grad_out.item() * dL_dp0, dL_dp0 * grad_out.item()), dim=1)
9696
return grad_input, None
97+
98+
99+
class BinaryTverskyLossV2(nn.Module):
100+
101+
def __init__(self, alpha=0.3, beta=0.7, ignore_index=None, reduction='mean'):
102+
"""Dice loss of binary class
103+
Args:
104+
alpha: controls the penalty for false positives.
105+
beta: penalty for false negative.
106+
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
107+
reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
108+
Shapes:
109+
output: A tensor of shape [N, 1,(d,) h, w] without sigmoid activation function applied
110+
target: A tensor of shape same with output
111+
Returns:
112+
Loss tensor according to arg reduction
113+
Raise:
114+
Exception if unexpected reduction
115+
"""
116+
super(BinaryTverskyLossV2, self).__init__()
117+
self.alpha = alpha
118+
self.beta = beta
119+
self.ignore_index = ignore_index
120+
self.epsilon = 1e-6
121+
self.reduction = reduction
122+
s = self.beta + self.alpha
123+
if sum != 1:
124+
self.beta = self.beta / s
125+
self.alpha = self.alpha / s
126+
127+
def forward(self, output, target):
128+
batch_size = output.size(0)
129+
130+
if self.ignore_index is not None:
131+
valid_mask = (target != self.ignore_index).float()
132+
output = output.float().mul(valid_mask) # can not use inplace for bp
133+
target = target.float().mul(valid_mask)
134+
135+
output = torch.sigmoid(output).view(batch_size, -1)
136+
target = target.view(batch_size, -1)
137+
138+
P_G = torch.sum(output * target, 1) # TP
139+
P_NG = torch.sum(output * (1 - target), 1) # FP
140+
NP_G = torch.sum((1 - output) * target, 1) # FN
141+
142+
tversky_index = P_G / (P_G + self.alpha * P_NG + self.beta * NP_G + self.epsilon)
143+
144+
loss = 1. - tversky_index
145+
# target_area = torch.sum(target_label, 1)
146+
# loss[target_area == 0] = 0
147+
if self.reduction == 'none':
148+
loss = loss
149+
elif self.reduction == 'sum':
150+
loss = torch.sum(loss)
151+
else:
152+
loss = torch.mean(loss)
153+
return loss

0 commit comments

Comments
 (0)