-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Description
Checklist
- I have searched related issues but cannot get the expected help.
- The bug has not been fixed in the latest version.
Describe the bug
When calculating losses in Moudle STDCHead, seg_logit needs to be resized back to the size of seg_label. It has a attribute align_corners to control whether align_corners=True or False in "configs/base/models/stdc.py" line 71:
mmsegmentation/configs/_base_/models/stdc.py
Lines 61 to 72 in 4b905cb
| dict( | |
| type='STDCHead', | |
| in_channels=256, | |
| channels=64, | |
| num_convs=1, | |
| num_classes=2, | |
| boundary_threshold=0.1, | |
| in_index=0, | |
| norm_cfg=norm_cfg, | |
| concat_input=False, | |
| align_corners=False, | |
| loss_decode=[ |
But in code implementation, the align_corners is forced to True, ignoring the setting in configuration file, in line 87.
mmsegmentation/mmseg/models/decode_heads/stdc_head.py
Lines 83 to 89 in 4b905cb
| seg_logit = F.interpolate( | |
| seg_logit, | |
| boundary_targets.shape[2:], | |
| mode='bilinear', | |
| align_corners=True) | |
| loss = super(STDCHead, self).losses(seg_logit, | |
| boudary_targets_pyramid.long()) |
Bug fix
Actually STDCHead.losses calls BaseDecodeHead.losses(line 88-89), and BaseDecodeHead.losses will resize seg_logit to the same size of seg_label using configuration's setting align_corners(line 239).
mmsegmentation/mmseg/models/decode_heads/stdc_head.py
Lines 88 to 89 in 4b905cb
| loss = super(STDCHead, self).losses(seg_logit, | |
| boudary_targets_pyramid.long()) |
mmsegmentation/mmseg/models/decode_heads/decode_head.py
Lines 232 to 239 in 4b905cb
| def losses(self, seg_logit, seg_label): | |
| """Compute segmentation loss.""" | |
| loss = dict() | |
| seg_logit = resize( | |
| input=seg_logit, | |
| size=seg_label.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) |
Therefore, there is no need to do
interpolate in STDCHead.losses again. Just deleting the code block line 83-87 in "mmseg/models/decode_heads/stdc_head.py" fixes the bug.