Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit f9ce2bd

Browse files
committed
first pass fix at autograd fix
1 parent 0cee7e8 commit f9ce2bd

File tree

2 files changed

+23
-34
lines changed

2 files changed

+23
-34
lines changed

TverskyLoss/binarytverskyloss.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,15 @@
88

99
class FocalBinaryTverskyLoss(Function):
1010

11-
def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean'):
12-
"""
13-
14-
:param alpha: controls the penalty for false positives.
15-
:param beta: penalty for false negative.
16-
:param gamma : focal coefficient range[1,3]
17-
:param reduction: return mode
18-
Notes:
19-
alpha = beta = 0.5 -> dice coeff
20-
alpha = beta = 1 -> tanimoto coeff
21-
alpha + beta = 1 -> F beta coeff
22-
add focal index -> loss=(1-T_index)**(1/gamma)
23-
24-
"""
25-
ctx.alpha = alpha
26-
ctx.beta = beta
27-
ctx.epsilon = 1e-6
28-
ctx.reduction = reduction
29-
ctx.gamma = gamma
30-
s = ctx.beta + ctx.alpha
31-
if sum != 1:
32-
ctx.beta = ctx.beta / s
33-
ctx.alpha = ctx.alpha / s
3411

35-
# @staticmethod
12+
@staticmethod
3613
def forward(ctx, input, target):
14+
_alpha = 0.5
15+
_beta = 0.5
16+
_gamma = 1.0
17+
_epsilon = 1e-6
18+
_reduction = 'mean'
19+
3720
batch_size = input.size(0)
3821
_, input_label = input.max(1)
3922

@@ -49,13 +32,13 @@ def forward(ctx, input, target):
4932
ctx.P_NG = torch.sum(input_label * (1 - target_label), 1) # FP
5033
ctx.NP_G = torch.sum((1 - input_label) * target_label, 1) # FN
5134

52-
index = ctx.P_G / (ctx.P_G + ctx.alpha * ctx.P_NG + ctx.beta * ctx.NP_G + ctx.epsilon)
53-
loss = torch.pow((1 - index), 1 / ctx.gamma)
35+
index = ctx.P_G / (ctx.P_G + _alpha * ctx.P_NG + _beta * ctx.NP_G + _epsilon)
36+
loss = torch.pow((1 - index), 1 / _gamma)
5437
# target_area = torch.sum(target_label, 1)
5538
# loss[target_area == 0] = 0
56-
if ctx.reduction == 'none':
39+
if _reduction == 'none':
5740
loss = loss
58-
elif ctx.reduction == 'sum':
41+
elif _reduction == 'sum':
5942
loss = torch.sum(loss)
6043
else:
6144
loss = torch.mean(loss)
@@ -72,24 +55,30 @@ def backward(ctx, grad_out):
7255
= 2*P_G
7356
(dT_loss/d_p0)=
7457
"""
58+
_alpha = 0.5
59+
_beta = 0.5
60+
_gamma = 1.0
61+
_reduction = 'mean'
62+
_epsilon = 1e-6
63+
7564
inputs, target = ctx.saved_tensors
7665
inputs = inputs.float()
7766
target = target.float()
7867
batch_size = inputs.size(0)
79-
sum = ctx.P_G + ctx.alpha * ctx.P_NG + ctx.beta * ctx.NP_G + ctx.epsilon
68+
sum = ctx.P_G + _alpha * ctx.P_NG + _beta * ctx.NP_G + _epsilon
8069
P_G = ctx.P_G.view(batch_size, 1, 1, 1, 1)
8170
if inputs.dim() == 5:
8271
sum = sum.view(batch_size, 1, 1, 1, 1)
8372
elif inputs.dim() == 4:
8473
sum = sum.view(batch_size, 1, 1, 1)
8574
P_G = ctx.P_G.view(batch_size, 1, 1, 1)
86-
sub = (ctx.alpha * (1 - target) + target) * P_G
75+
sub = (_alpha * (1 - target) + target) * P_G
8776

88-
dL_dT = (1 / ctx.gamma) * torch.pow((P_G / sum), (1 / ctx.gamma - 1))
77+
dL_dT = (1 / _gamma) * torch.pow((P_G / sum), (1 / _gamma - 1))
8978
dT_dp0 = -2 * (target / sum - sub / sum / sum)
9079
dL_dp0 = dL_dT * dT_dp0
9180

92-
dT_dp1 = ctx.beta * (1 - target) * P_G / sum / sum
81+
dT_dp1 = _beta * (1 - target) * P_G / sum / sum
9382
dL_dp1 = dL_dT * dT_dp1
9483
grad_input = torch.cat((dL_dp1, dL_dp0), dim=1)
9584
# grad_input = torch.cat((grad_out.item() * dL_dp0, dL_dp0 * grad_out.item()), dim=1)
@@ -152,4 +141,4 @@ def forward(self, output, target, mask=None):
152141
loss = torch.sum(loss)
153142
else:
154143
loss = torch.mean(loss)
155-
return loss
144+
return loss

TverskyLoss/multitverskyloss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def forward(self, inputs, targets):
3838
input_idx = torch.cat((1 - input_idx, input_idx), dim=1)
3939
target_idx = (targets == idx) * 1
4040
loss_func = FocalBinaryTverskyLoss(self.alpha, self.beta, self.gamma)
41-
loss_idx = loss_func(input_idx, target_idx)
41+
loss_idx = loss_func.apply(input_idx, target_idx)
4242
weight_losses+=loss_idx * weights[idx]
4343
# loss = torch.Tensor(weight_losses)
4444
# loss = loss.to(inputs.device)

0 commit comments

Comments
 (0)