Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit 4ade65e

Browse files
committed
add dice focal
1 parent 1806792 commit 4ade65e

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

seg_loss/dice_loss.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch.nn.functional as F
44
import numpy as np
55

6+
from .focal_loss import BinaryFocalLoss
7+
68

79
def make_one_hot(input, num_classes=None):
810
"""Convert class index tensor to one hot encoding tensor.
@@ -207,6 +209,19 @@ def forward(self, output, target):
207209
return loss
208210

209211

212+
class Binary_Focal_Dice(nn.Module):
213+
def __init__(self, **kwargs):
214+
super(Binary_Focal_Dice, self).__init__()
215+
self.dice = BinaryDiceLoss(**kwargs)
216+
self.focal = BinaryFocalLoss(**kwargs)
217+
218+
def forward(self, logits, target):
219+
dice_loss = self.dice(logits, target)
220+
focal_loss = self.focal(logits, target)
221+
loss = dice_loss + focal_loss
222+
return loss, (dice_loss.detach(), focal_loss.detach())
223+
224+
210225
def test():
211226
input = torch.rand((3, 1, 32, 32, 32))
212227
model = nn.Conv3d(1, 4, 3, padding=1)

0 commit comments

Comments
 (0)