|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +def make_one_hot(input, num_classes=None): |
| 8 | + """Convert class index tensor to one hot encoding tensor. |
| 9 | +
|
| 10 | + Args: |
| 11 | + input: A tensor of shape [N, 1, *] |
| 12 | + num_classes: An int of number of class |
| 13 | + Shapes: |
| 14 | + predict: A tensor of shape [N, *] without sigmoid activation function applied |
| 15 | + target: A tensor of shape same with predict |
| 16 | + Returns: |
| 17 | + A tensor of shape [N, num_classes, *] |
| 18 | + """ |
| 19 | + if num_classes is None: |
| 20 | + num_classes = input.max() + 1 |
| 21 | + shape = np.array(input.shape) |
| 22 | + shape[1] = num_classes |
| 23 | + shape = tuple(shape) |
| 24 | + result = torch.zeros(shape) |
| 25 | + result = result.scatter_(1, input.cpu(), 1) |
| 26 | + |
| 27 | + return result |
| 28 | + |
| 29 | + |
| 30 | +class BinaryDiceLoss(nn.Module): |
| 31 | + """Dice loss of binary class |
| 32 | + Args: |
| 33 | + ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient |
| 34 | + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' |
| 35 | + Shapes: |
| 36 | + output: A tensor of shape [N, *] without sigmoid activation function applied |
| 37 | + target: A tensor of shape same with output |
| 38 | + Returns: |
| 39 | + Loss tensor according to arg reduction |
| 40 | + Raise: |
| 41 | + Exception if unexpected reduction |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, ignore_index=None, reduction='mean'): |
| 45 | + super(BinaryDiceLoss, self).__init__() |
| 46 | + self.smooth = 1 |
| 47 | + self.ignore_index = ignore_index |
| 48 | + self.reduction = reduction |
| 49 | + |
| 50 | + def forward(self, output, target): |
| 51 | + assert output.shape[0] == target.shape[0], "output & target batch size don't match" |
| 52 | + output = torch.sigmoid(output) |
| 53 | + |
| 54 | + if self.ignore_index is not None: |
| 55 | + validmask = (target != self.ignore_index).float() |
| 56 | + output = output.mul(validmask) # can not use inplace for bp |
| 57 | + target = target.float().mul(validmask) |
| 58 | + |
| 59 | + output = output.contiguous().view(output.shape[0], -1) |
| 60 | + target = target.contiguous().view(target.shape[0], -1).float() |
| 61 | + |
| 62 | + num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth |
| 63 | + den = torch.sum(output.pow(2) + target.pow(2), dim=1) + self.smooth |
| 64 | + |
| 65 | + loss = 1 - (num / den) |
| 66 | + |
| 67 | + if self.reduction == 'mean': |
| 68 | + return loss.mean() |
| 69 | + elif self.reduction == 'sum': |
| 70 | + return loss.sum() |
| 71 | + elif self.reduction == 'none': |
| 72 | + return loss |
| 73 | + else: |
| 74 | + raise Exception('Unexpected reduction {}'.format(self.reduction)) |
| 75 | + |
| 76 | + |
| 77 | +class DiceLoss(nn.Module): |
| 78 | + """Dice loss, need one hot encode input |
| 79 | + Args: |
| 80 | + weight: An array of shape [num_classes,] |
| 81 | + ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient |
| 82 | + output: A tensor of shape [N, C, *] |
| 83 | + target: A tensor of same shape with output |
| 84 | + other args pass to BinaryDiceLoss |
| 85 | + Return: |
| 86 | + same as BinaryDiceLoss |
| 87 | + """ |
| 88 | + |
| 89 | + def __init__(self, weight=None, ignore_index=[], **kwargs): |
| 90 | + super(DiceLoss, self).__init__() |
| 91 | + self.kwargs = kwargs |
| 92 | + self.weight = weight |
| 93 | + if isinstance(ignore_index, int): |
| 94 | + self.ignore_index = [ignore_index] |
| 95 | + elif ignore_index is None: |
| 96 | + self.ignore_index = [] |
| 97 | + self.ignore_index = ignore_index |
| 98 | + |
| 99 | + def forward(self, output, target): |
| 100 | + assert output.shape == target.shape, 'output & target shape do not match' |
| 101 | + dice = BinaryDiceLoss(**self.kwargs) |
| 102 | + total_loss = 0 |
| 103 | + output = F.softmax(output, dim=1) |
| 104 | + for i in range(target.shape[1]): |
| 105 | + if i not in self.ignore_index: |
| 106 | + dice_loss = dice(output[:, i], target[:, i]) |
| 107 | + if self.weight is not None: |
| 108 | + assert self.weight.shape[0] == target.shape[1], \ |
| 109 | + 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) |
| 110 | + dice_loss *= self.weights[i] |
| 111 | + total_loss += (dice_loss) |
| 112 | + loss = total_loss / (target.size(1) - len(self.ignore_index)) |
| 113 | + return loss |
| 114 | + |
| 115 | + |
| 116 | +class BCE_DiceLoss(nn.Module): |
| 117 | + def __init__(self, alpha=0.5, ignore_index=None, reduction='mean'): |
| 118 | + """ |
| 119 | + combination of Binary Cross Entropy and Binary Dice Loss |
| 120 | + Args: |
| 121 | + @param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient |
| 122 | + @param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' |
| 123 | + @param alpha: weight between BCE('Binary Cross Entropy') and binary dice |
| 124 | + Shapes: |
| 125 | + output: A tensor of shape [N, *] without sigmoid activation function applied |
| 126 | + target: A tensor of shape same with output |
| 127 | + """ |
| 128 | + super(BCE_DiceLoss, self).__init__() |
| 129 | + assert reduction in ['none', 'mean', 'sum'] |
| 130 | + assert (alpha >= 0 and alpha <= 1), '`alpha` should in [0,1]' |
| 131 | + self.alpha = alpha |
| 132 | + self.ignore_index = ignore_index |
| 133 | + self.reduction = reduction |
| 134 | + self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction) |
| 135 | + self.bce = nn.BCEWithLogitsLoss(reduction=reduction) |
| 136 | + |
| 137 | + def forward(self, output, target): |
| 138 | + dice_loss = self.dice(output, target) |
| 139 | + |
| 140 | + if self.ignore_index is not None: |
| 141 | + mask = (target != self.ignore_index).float() |
| 142 | + output = output.mul(mask) # can not use inplace for bp |
| 143 | + target = target.float().mul(mask) |
| 144 | + bce_loss = self.bce(output, target) |
| 145 | + loss = self.alpha * bce_loss + (1.0 - self.alpha) * dice_loss |
| 146 | + |
| 147 | + if self.reduction == 'mean': |
| 148 | + return loss.mean() |
| 149 | + elif self.reduction == 'sum': |
| 150 | + return loss.sum() |
| 151 | + elif self.reduction == 'none': |
| 152 | + return loss |
| 153 | + else: |
| 154 | + raise Exception('Unexpected reduction {}'.format(self.reduction)) |
| 155 | + |
| 156 | + |
| 157 | +def test(): |
| 158 | + input = torch.rand((1, 1, 32, 32, 32)) |
| 159 | + model = nn.Conv3d(1, 1, 3, padding=1) |
| 160 | + target = torch.randint(0, 3, (1, 1, 32, 32, 32)).float() |
| 161 | + criterion = BCE_DiceLoss(ignore_index=2, reduction='none') |
| 162 | + loss = criterion(model(input), target) |
| 163 | + loss.backward() |
| 164 | + print(loss.item()) |
| 165 | + |
| 166 | + # input = torch.zeros((1, 2, 32, 32, 32)) |
| 167 | + # input[:, 0, ...] = 1 |
| 168 | + # target = torch.ones((1, 1, 32, 32, 32)).long() |
| 169 | + # target_one_hot = make_one_hot(target, num_classes=2) |
| 170 | + # # print(target_one_hot.size()) |
| 171 | + # criterion = DiceLoss() |
| 172 | + # loss = criterion(input, target_one_hot) |
| 173 | + # print(loss.item()) |
| 174 | + |
| 175 | + |
| 176 | +if __name__ == '__main__': |
| 177 | + test() |
0 commit comments