Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit ce7ac2f

Browse files
committed
fix error in tversky-loss
1 parent d11de27 commit ce7ac2f

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

TverskyLoss/binarytverskyloss.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self, alpha=0.3, beta=0.7, ignore_index=None, reduction='mean'):
102102
"""Dice loss of binary class
103103
Args:
104104
alpha: controls the penalty for false positives.
105-
beta: penalty for false negative.
105+
beta: penalty for false negative. Larger beta weigh recall higher
106106
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
107107
reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
108108
Shapes:
@@ -117,29 +117,31 @@ def __init__(self, alpha=0.3, beta=0.7, ignore_index=None, reduction='mean'):
117117
self.alpha = alpha
118118
self.beta = beta
119119
self.ignore_index = ignore_index
120-
self.epsilon = 1e-6
120+
self.smooth = 10
121121
self.reduction = reduction
122122
s = self.beta + self.alpha
123-
if sum != 1:
123+
if s != 1:
124124
self.beta = self.beta / s
125125
self.alpha = self.alpha / s
126126

127-
def forward(self, output, target):
127+
def forward(self, output, target, mask=None):
128128
batch_size = output.size(0)
129-
129+
bg_target = 1 - target
130130
if self.ignore_index is not None:
131131
valid_mask = (target != self.ignore_index).float()
132132
output = output.float().mul(valid_mask) # can not use inplace for bp
133133
target = target.float().mul(valid_mask)
134+
bg_target = bg_target.float().mul(valid_mask)
134135

135136
output = torch.sigmoid(output).view(batch_size, -1)
136137
target = target.view(batch_size, -1)
138+
bg_target = bg_target.view(batch_size, -1)
137139

138140
P_G = torch.sum(output * target, 1) # TP
139-
P_NG = torch.sum(output * (1 - target), 1) # FP
141+
P_NG = torch.sum(output * bg_target, 1) # FP
140142
NP_G = torch.sum((1 - output) * target, 1) # FN
141143

142-
tversky_index = P_G / (P_G + self.alpha * P_NG + self.beta * NP_G + self.epsilon)
144+
tversky_index = P_G / (P_G + self.alpha * P_NG + self.beta * NP_G + self.smooth)
143145

144146
loss = 1. - tversky_index
145147
# target_area = torch.sum(target_label, 1)

0 commit comments

Comments
 (0)