|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +# Modified from https://github.com/open-mmlab/mmdetection |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss |
| 7 | + |
| 8 | +from ..builder import LOSSES |
| 9 | +from .utils import weight_reduce_loss |
| 10 | + |
| 11 | + |
| 12 | +# This method is used when cuda is not available |
| 13 | +def py_sigmoid_focal_loss(pred, |
| 14 | + target, |
| 15 | + one_hot_target=None, |
| 16 | + weight=None, |
| 17 | + gamma=2.0, |
| 18 | + alpha=0.5, |
| 19 | + class_weight=None, |
| 20 | + valid_mask=None, |
| 21 | + reduction='mean', |
| 22 | + avg_factor=None): |
| 23 | + """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. |
| 24 | +
|
| 25 | + Args: |
| 26 | + pred (torch.Tensor): The prediction with shape (N, C), C is the |
| 27 | + number of classes |
| 28 | + target (torch.Tensor): The learning label of the prediction with |
| 29 | + shape (N, C) |
| 30 | + one_hot_target (None): Placeholder. It should be None. |
| 31 | + weight (torch.Tensor, optional): Sample-wise loss weight. |
| 32 | + gamma (float, optional): The gamma for calculating the modulating |
| 33 | + factor. Defaults to 2.0. |
| 34 | + alpha (float | list[float], optional): A balanced form for Focal Loss. |
| 35 | + Defaults to 0.5. |
| 36 | + class_weight (list[float], optional): Weight of each class. |
| 37 | + Defaults to None. |
| 38 | + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid |
| 39 | + samples and uses 0 to mark the ignored samples. Default: None. |
| 40 | + reduction (str, optional): The method used to reduce the loss into |
| 41 | + a scalar. Defaults to 'mean'. |
| 42 | + avg_factor (int, optional): Average factor that is used to average |
| 43 | + the loss. Defaults to None. |
| 44 | + """ |
| 45 | + if isinstance(alpha, list): |
| 46 | + alpha = pred.new_tensor(alpha) |
| 47 | + pred_sigmoid = pred.sigmoid() |
| 48 | + target = target.type_as(pred) |
| 49 | + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) |
| 50 | + focal_weight = (alpha * target + (1 - alpha) * |
| 51 | + (1 - target)) * one_minus_pt.pow(gamma) |
| 52 | + |
| 53 | + loss = F.binary_cross_entropy_with_logits( |
| 54 | + pred, target, reduction='none') * focal_weight |
| 55 | + final_weight = torch.ones(1, pred.size(1)).type_as(loss) |
| 56 | + if weight is not None: |
| 57 | + if weight.shape != loss.shape and weight.size(0) == loss.size(0): |
| 58 | + # For most cases, weight is of shape (N, ), |
| 59 | + # which means it does not have the second axis num_class |
| 60 | + weight = weight.view(-1, 1) |
| 61 | + assert weight.dim() == loss.dim() |
| 62 | + final_weight = final_weight * weight |
| 63 | + if class_weight is not None: |
| 64 | + final_weight = final_weight * pred.new_tensor(class_weight) |
| 65 | + if valid_mask is not None: |
| 66 | + final_weight = final_weight * valid_mask |
| 67 | + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) |
| 68 | + return loss |
| 69 | + |
| 70 | + |
| 71 | +def sigmoid_focal_loss(pred, |
| 72 | + target, |
| 73 | + one_hot_target, |
| 74 | + weight=None, |
| 75 | + gamma=2.0, |
| 76 | + alpha=0.5, |
| 77 | + class_weight=None, |
| 78 | + valid_mask=None, |
| 79 | + reduction='mean', |
| 80 | + avg_factor=None): |
| 81 | + r"""A warpper of cuda version `Focal Loss |
| 82 | + <https://arxiv.org/abs/1708.02002>`_. |
| 83 | + Args: |
| 84 | + pred (torch.Tensor): The prediction with shape (N, C), C is the number |
| 85 | + of classes. |
| 86 | + target (torch.Tensor): The learning label of the prediction. It's shape |
| 87 | + should be (N, ) |
| 88 | + one_hot_target (torch.Tensor): The learning label with shape (N, C) |
| 89 | + weight (torch.Tensor, optional): Sample-wise loss weight. |
| 90 | + gamma (float, optional): The gamma for calculating the modulating |
| 91 | + factor. Defaults to 2.0. |
| 92 | + alpha (float | list[float], optional): A balanced form for Focal Loss. |
| 93 | + Defaults to 0.5. |
| 94 | + class_weight (list[float], optional): Weight of each class. |
| 95 | + Defaults to None. |
| 96 | + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid |
| 97 | + samples and uses 0 to mark the ignored samples. Default: None. |
| 98 | + reduction (str, optional): The method used to reduce the loss into |
| 99 | + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". |
| 100 | + avg_factor (int, optional): Average factor that is used to average |
| 101 | + the loss. Defaults to None. |
| 102 | + """ |
| 103 | + # Function.apply does not accept keyword arguments, so the decorator |
| 104 | + # "weighted_loss" is not applicable |
| 105 | + final_weight = torch.ones(1, pred.size(1)).type_as(pred) |
| 106 | + if isinstance(alpha, list): |
| 107 | + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if |
| 108 | + # a list is given, we set the input alpha as 0.5. This means setting |
| 109 | + # equal weight for foreground class and background class. By |
| 110 | + # multiplying the loss by 2, the effect of setting alpha as 0.5 is |
| 111 | + # undone. The alpha of type list is used to regulate the loss in the |
| 112 | + # post-processing process. |
| 113 | + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), |
| 114 | + gamma, 0.5, None, 'none') * 2 |
| 115 | + alpha = pred.new_tensor(alpha) |
| 116 | + final_weight = final_weight * ( |
| 117 | + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) |
| 118 | + else: |
| 119 | + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), |
| 120 | + gamma, alpha, None, 'none') |
| 121 | + if weight is not None: |
| 122 | + if weight.shape != loss.shape and weight.size(0) == loss.size(0): |
| 123 | + # For most cases, weight is of shape (N, ), |
| 124 | + # which means it does not have the second axis num_class |
| 125 | + weight = weight.view(-1, 1) |
| 126 | + assert weight.dim() == loss.dim() |
| 127 | + final_weight = final_weight * weight |
| 128 | + if class_weight is not None: |
| 129 | + final_weight = final_weight * pred.new_tensor(class_weight) |
| 130 | + if valid_mask is not None: |
| 131 | + final_weight = final_weight * valid_mask |
| 132 | + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) |
| 133 | + return loss |
| 134 | + |
| 135 | + |
| 136 | +@LOSSES.register_module() |
| 137 | +class FocalLoss(nn.Module): |
| 138 | + |
| 139 | + def __init__(self, |
| 140 | + use_sigmoid=True, |
| 141 | + gamma=2.0, |
| 142 | + alpha=0.5, |
| 143 | + reduction='mean', |
| 144 | + class_weight=None, |
| 145 | + loss_weight=1.0, |
| 146 | + loss_name='loss_focal'): |
| 147 | + """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ |
| 148 | + Args: |
| 149 | + use_sigmoid (bool, optional): Whether to the prediction is |
| 150 | + used for sigmoid or softmax. Defaults to True. |
| 151 | + gamma (float, optional): The gamma for calculating the modulating |
| 152 | + factor. Defaults to 2.0. |
| 153 | + alpha (float | list[float], optional): A balanced form for Focal |
| 154 | + Loss. Defaults to 0.5. When a list is provided, the length |
| 155 | + of the list should be equal to the number of classes. |
| 156 | + Please be careful that this parameter is not the |
| 157 | + class-wise weight but the weight of a binary classification |
| 158 | + problem. This binary classification problem regards the |
| 159 | + pixels which belong to one class as the foreground |
| 160 | + and the other pixels as the background, each element in |
| 161 | + the list is the weight of the corresponding foreground class. |
| 162 | + The value of alpha or each element of alpha should be a float |
| 163 | + in the interval [0, 1]. If you want to specify the class-wise |
| 164 | + weight, please use `class_weight` parameter. |
| 165 | + reduction (str, optional): The method used to reduce the loss into |
| 166 | + a scalar. Defaults to 'mean'. Options are "none", "mean" and |
| 167 | + "sum". |
| 168 | + class_weight (list[float], optional): Weight of each class. |
| 169 | + Defaults to None. |
| 170 | + loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
| 171 | + loss_name (str, optional): Name of the loss item. If you want this |
| 172 | + loss item to be included into the backward graph, `loss_` must |
| 173 | + be the prefix of the name. Defaults to 'loss_focal'. |
| 174 | + """ |
| 175 | + super(FocalLoss, self).__init__() |
| 176 | + assert use_sigmoid is True, \ |
| 177 | + 'AssertionError: Only sigmoid focal loss supported now.' |
| 178 | + assert reduction in ('none', 'mean', 'sum'), \ |
| 179 | + "AssertionError: reduction should be 'none', 'mean' or " \ |
| 180 | + "'sum'" |
| 181 | + assert isinstance(alpha, (float, list)), \ |
| 182 | + 'AssertionError: alpha should be of type float' |
| 183 | + assert isinstance(gamma, float), \ |
| 184 | + 'AssertionError: gamma should be of type float' |
| 185 | + assert isinstance(loss_weight, float), \ |
| 186 | + 'AssertionError: loss_weight should be of type float' |
| 187 | + assert isinstance(loss_name, str), \ |
| 188 | + 'AssertionError: loss_name should be of type str' |
| 189 | + assert isinstance(class_weight, list) or class_weight is None, \ |
| 190 | + 'AssertionError: class_weight must be None or of type list' |
| 191 | + self.use_sigmoid = use_sigmoid |
| 192 | + self.gamma = gamma |
| 193 | + self.alpha = alpha |
| 194 | + self.reduction = reduction |
| 195 | + self.class_weight = class_weight |
| 196 | + self.loss_weight = loss_weight |
| 197 | + self._loss_name = loss_name |
| 198 | + |
| 199 | + def forward(self, |
| 200 | + pred, |
| 201 | + target, |
| 202 | + weight=None, |
| 203 | + avg_factor=None, |
| 204 | + reduction_override=None, |
| 205 | + ignore_index=255, |
| 206 | + **kwargs): |
| 207 | + """Forward function. |
| 208 | +
|
| 209 | + Args: |
| 210 | + pred (torch.Tensor): The prediction with shape |
| 211 | + (N, C) where C = number of classes, or |
| 212 | + (N, C, d_1, d_2, ..., d_K) with K≥1 in the |
| 213 | + case of K-dimensional loss. |
| 214 | + target (torch.Tensor): The ground truth. If containing class |
| 215 | + indices, shape (N) where each value is 0≤targets[i]≤C−1, |
| 216 | + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of |
| 217 | + K-dimensional loss. If containing class probabilities, |
| 218 | + same shape as the input. |
| 219 | + weight (torch.Tensor, optional): The weight of loss for each |
| 220 | + prediction. Defaults to None. |
| 221 | + avg_factor (int, optional): Average factor that is used to |
| 222 | + average the loss. Defaults to None. |
| 223 | + reduction_override (str, optional): The reduction method used |
| 224 | + to override the original reduction method of the loss. |
| 225 | + Options are "none", "mean" and "sum". |
| 226 | + ignore_index (int, optional): The label index to be ignored. |
| 227 | + Default: 255 |
| 228 | + Returns: |
| 229 | + torch.Tensor: The calculated loss |
| 230 | + """ |
| 231 | + assert isinstance(ignore_index, int), \ |
| 232 | + 'ignore_index must be of type int' |
| 233 | + assert reduction_override in (None, 'none', 'mean', 'sum'), \ |
| 234 | + "AssertionError: reduction should be 'none', 'mean' or " \ |
| 235 | + "'sum'" |
| 236 | + assert pred.shape == target.shape or \ |
| 237 | + (pred.size(0) == target.size(0) and |
| 238 | + pred.shape[2:] == target.shape[1:]), \ |
| 239 | + "The shape of pred doesn't match the shape of target" |
| 240 | + |
| 241 | + original_shape = pred.shape |
| 242 | + |
| 243 | + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] |
| 244 | + pred = pred.transpose(0, 1) |
| 245 | + # [C, B, d_1, d_2, ..., d_k] -> [C, N] |
| 246 | + pred = pred.reshape(pred.size(0), -1) |
| 247 | + # [C, N] -> [N, C] |
| 248 | + pred = pred.transpose(0, 1).contiguous() |
| 249 | + |
| 250 | + if original_shape == target.shape: |
| 251 | + # target with shape [B, C, d_1, d_2, ...] |
| 252 | + # transform it's shape into [N, C] |
| 253 | + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] |
| 254 | + target = target.transpose(0, 1) |
| 255 | + # [C, B, d_1, d_2, ..., d_k] -> [C, N] |
| 256 | + target = target.reshape(target.size(0), -1) |
| 257 | + # [C, N] -> [N, C] |
| 258 | + target = target.transpose(0, 1).contiguous() |
| 259 | + else: |
| 260 | + # target with shape [B, d_1, d_2, ...] |
| 261 | + # transform it's shape into [N, ] |
| 262 | + target = target.view(-1).contiguous() |
| 263 | + valid_mask = (target != ignore_index).view(-1, 1) |
| 264 | + # avoid raising error when using F.one_hot() |
| 265 | + target = torch.where(target == ignore_index, target.new_tensor(0), |
| 266 | + target) |
| 267 | + |
| 268 | + reduction = ( |
| 269 | + reduction_override if reduction_override else self.reduction) |
| 270 | + if self.use_sigmoid: |
| 271 | + num_classes = pred.size(1) |
| 272 | + if torch.cuda.is_available() and pred.is_cuda: |
| 273 | + if target.dim() == 1: |
| 274 | + one_hot_target = F.one_hot(target, num_classes=num_classes) |
| 275 | + else: |
| 276 | + one_hot_target = target |
| 277 | + target = target.argmax(dim=1) |
| 278 | + valid_mask = (target != ignore_index).view(-1, 1) |
| 279 | + calculate_loss_func = sigmoid_focal_loss |
| 280 | + else: |
| 281 | + one_hot_target = None |
| 282 | + if target.dim() == 1: |
| 283 | + target = F.one_hot(target, num_classes=num_classes) |
| 284 | + else: |
| 285 | + valid_mask = (target.argmax(dim=1) != ignore_index).view( |
| 286 | + -1, 1) |
| 287 | + calculate_loss_func = py_sigmoid_focal_loss |
| 288 | + |
| 289 | + loss_cls = self.loss_weight * calculate_loss_func( |
| 290 | + pred, |
| 291 | + target, |
| 292 | + one_hot_target, |
| 293 | + weight, |
| 294 | + gamma=self.gamma, |
| 295 | + alpha=self.alpha, |
| 296 | + class_weight=self.class_weight, |
| 297 | + valid_mask=valid_mask, |
| 298 | + reduction=reduction, |
| 299 | + avg_factor=avg_factor) |
| 300 | + |
| 301 | + if reduction == 'none': |
| 302 | + # [N, C] -> [C, N] |
| 303 | + loss_cls = loss_cls.transpose(0, 1) |
| 304 | + # [C, N] -> [C, B, d1, d2, ...] |
| 305 | + # original_shape: [B, C, d1, d2, ...] |
| 306 | + loss_cls = loss_cls.reshape(original_shape[1], |
| 307 | + original_shape[0], |
| 308 | + *original_shape[2:]) |
| 309 | + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] |
| 310 | + loss_cls = loss_cls.transpose(0, 1).contiguous() |
| 311 | + else: |
| 312 | + raise NotImplementedError |
| 313 | + return loss_cls |
| 314 | + |
| 315 | + @property |
| 316 | + def loss_name(self): |
| 317 | + """Loss Name. |
| 318 | +
|
| 319 | + This function must be implemented and will return the name of this |
| 320 | + loss function. This name will be used to combine different loss items |
| 321 | + by simple sum operation. In addition, if you want this loss item to be |
| 322 | + included into the backward graph, `loss_` must be the prefix of the |
| 323 | + name. |
| 324 | + Returns: |
| 325 | + str: The name of this loss item. |
| 326 | + """ |
| 327 | + return self._loss_name |
0 commit comments