Skip to content

Commit cef8a4f

Browse files
authored
[Enhance] Support reading class_weight from file in loss functions to help MMDet3D (open-mmlab#513)
* support reading class_weight from file in loss function * add unit test of loss with class_weight from file * minor fix * move get_class_weight to utils
1 parent 768f704 commit cef8a4f

File tree

7 files changed

+120
-12
lines changed

7 files changed

+120
-12
lines changed

mmseg/models/losses/cross_entropy_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44

55
from ..builder import LOSSES
6-
from .utils import weight_reduce_loss
6+
from .utils import get_class_weight, weight_reduce_loss
77

88

99
def cross_entropy(pred,
@@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
146146
Defaults to False.
147147
reduction (str, optional): . Defaults to 'mean'.
148148
Options are "none", "mean" and "sum".
149-
class_weight (list[float], optional): Weight of each class.
150-
Defaults to None.
149+
class_weight (list[float] | str, optional): Weight of each class. If in
150+
str format, read them from a file. Defaults to None.
151151
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
152152
"""
153153

@@ -163,7 +163,7 @@ def __init__(self,
163163
self.use_mask = use_mask
164164
self.reduction = reduction
165165
self.loss_weight = loss_weight
166-
self.class_weight = class_weight
166+
self.class_weight = get_class_weight(class_weight)
167167

168168
if self.use_sigmoid:
169169
self.cls_criterion = binary_cross_entropy

mmseg/models/losses/dice_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as F
66

77
from ..builder import LOSSES
8-
from .utils import weighted_loss
8+
from .utils import get_class_weight, weighted_loss
99

1010

1111
@weighted_loss
@@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
6363
reduction (str, optional): The method used to reduce the loss. Options
6464
are "none", "mean" and "sum". This parameter only works when
6565
per_image is True. Default: 'mean'.
66-
class_weight (list[float], optional): The weight for each class.
67-
Default: None.
66+
class_weight (list[float] | str, optional): Weight of each class. If in
67+
str format, read them from a file. Defaults to None.
6868
loss_weight (float, optional): Weight of the loss. Default to 1.0.
6969
ignore_index (int | None): The label index to be ignored. Default: 255.
7070
"""
@@ -81,7 +81,7 @@ def __init__(self,
8181
self.smooth = smooth
8282
self.exponent = exponent
8383
self.reduction = reduction
84-
self.class_weight = class_weight
84+
self.class_weight = get_class_weight(class_weight)
8585
self.loss_weight = loss_weight
8686
self.ignore_index = ignore_index
8787

mmseg/models/losses/lovasz_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99

1010
from ..builder import LOSSES
11-
from .utils import weight_reduce_loss
11+
from .utils import get_class_weight, weight_reduce_loss
1212

1313

1414
def lovasz_grad(gt_sorted):
@@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
240240
reduction (str, optional): The method used to reduce the loss. Options
241241
are "none", "mean" and "sum". This parameter only works when
242242
per_image is True. Default: 'mean'.
243-
class_weight (list[float], optional): The weight for each class.
244-
Default: None.
243+
class_weight (list[float] | str, optional): Weight of each class. If in
244+
str format, read them from a file. Defaults to None.
245245
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
246246
"""
247247

@@ -269,7 +269,7 @@ def __init__(self,
269269
self.per_image = per_image
270270
self.reduction = reduction
271271
self.loss_weight = loss_weight
272-
self.class_weight = class_weight
272+
self.class_weight = get_class_weight(class_weight)
273273

274274
def forward(self,
275275
cls_score,

mmseg/models/losses/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
import functools
22

3+
import mmcv
4+
import numpy as np
35
import torch.nn.functional as F
46

57

8+
def get_class_weight(class_weight):
9+
"""Get class weight for loss function.
10+
11+
Args:
12+
class_weight (list[float] | str | None): If class_weight is a str,
13+
take it as a file name and read from it.
14+
"""
15+
if isinstance(class_weight, str):
16+
# take it as a file path
17+
if class_weight.endswith('.npy'):
18+
class_weight = np.load(class_weight)
19+
else:
20+
# pkl, json or yaml
21+
class_weight = mmcv.load(class_weight)
22+
23+
return class_weight
24+
25+
626
def reduce_loss(loss, reduction):
727
"""Reduce loss as specified.
828

tests/test_models/test_losses/test_ce_loss.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,34 @@ def test_ce_loss():
2525
fake_label = torch.Tensor([1]).long()
2626
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
2727

28+
# test loss with class weights from file
29+
import os
30+
import tempfile
31+
import mmcv
32+
import numpy as np
33+
tmp_file = tempfile.NamedTemporaryFile()
34+
35+
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
36+
loss_cls_cfg = dict(
37+
type='CrossEntropyLoss',
38+
use_sigmoid=False,
39+
class_weight=f'{tmp_file.name}.pkl',
40+
loss_weight=1.0)
41+
loss_cls = build_loss(loss_cls_cfg)
42+
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
43+
44+
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
45+
loss_cls_cfg = dict(
46+
type='CrossEntropyLoss',
47+
use_sigmoid=False,
48+
class_weight=f'{tmp_file.name}.npy',
49+
loss_weight=1.0)
50+
loss_cls = build_loss(loss_cls_cfg)
51+
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
52+
tmp_file.close()
53+
os.remove(f'{tmp_file.name}.pkl')
54+
os.remove(f'{tmp_file.name}.npy')
55+
2856
loss_cls_cfg = dict(
2957
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
3058
loss_cls = build_loss(loss_cls_cfg)

tests/test_models/test_losses/test_dice_loss.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,36 @@ def test_dice_lose():
1616
labels = (torch.rand(8, 4, 4) * 3).long()
1717
dice_loss(logits, labels)
1818

19+
# test loss with class weights from file
20+
import os
21+
import tempfile
22+
import mmcv
23+
import numpy as np
24+
tmp_file = tempfile.NamedTemporaryFile()
25+
26+
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
27+
loss_cfg = dict(
28+
type='DiceLoss',
29+
reduction='none',
30+
class_weight=f'{tmp_file.name}.pkl',
31+
loss_weight=1.0,
32+
ignore_index=1)
33+
dice_loss = build_loss(loss_cfg)
34+
dice_loss(logits, labels, ignore_index=None)
35+
36+
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
37+
loss_cfg = dict(
38+
type='DiceLoss',
39+
reduction='none',
40+
class_weight=f'{tmp_file.name}.pkl',
41+
loss_weight=1.0,
42+
ignore_index=1)
43+
dice_loss = build_loss(loss_cfg)
44+
dice_loss(logits, labels, ignore_index=None)
45+
tmp_file.close()
46+
os.remove(f'{tmp_file.name}.pkl')
47+
os.remove(f'{tmp_file.name}.npy')
48+
1949
# test dice loss with loss_type = 'binary'
2050
loss_cfg = dict(
2151
type='DiceLoss',

tests/test_models/test_losses/test_lovasz_loss.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,36 @@ def test_lovasz_loss():
3838
labels = (torch.rand(1, 4, 4) * 2).long()
3939
lovasz_loss(logits, labels, ignore_index=None)
4040

41+
# test loss with class weights from file
42+
import os
43+
import tempfile
44+
import mmcv
45+
import numpy as np
46+
tmp_file = tempfile.NamedTemporaryFile()
47+
48+
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
49+
loss_cfg = dict(
50+
type='LovaszLoss',
51+
per_image=True,
52+
reduction='mean',
53+
class_weight=f'{tmp_file.name}.pkl',
54+
loss_weight=1.0)
55+
lovasz_loss = build_loss(loss_cfg)
56+
lovasz_loss(logits, labels, ignore_index=None)
57+
58+
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
59+
loss_cfg = dict(
60+
type='LovaszLoss',
61+
per_image=True,
62+
reduction='mean',
63+
class_weight=f'{tmp_file.name}.npy',
64+
loss_weight=1.0)
65+
lovasz_loss = build_loss(loss_cfg)
66+
lovasz_loss(logits, labels, ignore_index=None)
67+
tmp_file.close()
68+
os.remove(f'{tmp_file.name}.pkl')
69+
os.remove(f'{tmp_file.name}.npy')
70+
4171
# test lovasz loss with loss_type = 'binary' and per_image = False
4272
loss_cfg = dict(
4373
type='LovaszLoss',

0 commit comments

Comments
 (0)