Skip to content

Commit 5d49918

Browse files
authored
[Fix] Switch order of reduce_zero_label and applying label_map (open-mmlab#2500)
## Motivation I want to fix a bug through this PR. The bug occurs when two options -- `reduce_zero_label=True`, and custom classes are used. `reduce_zero_label` remaps the GT seg labels by remapping the zero-class to 255 which is ignored. Conceptually, this should occur *before* the `label_map` is applied, which maps *already reduced labels*. However, currently, the `label_map` is applied before the zero label is reduced. ## Modification The modification is simple: - I've just interchanged the order of the two operations by moving 4 lines from bottom to top. - I've added a test that passes when the fix is introduced, and fails on the original `master` branch. ## BC-breaking (Optional) I do not anticipate this change braking any backward-compatibility. ## Checklist - [x] Pre-commit or other linting tools are used to fix the potential lint issues. - _I've fixed all linting/pre-commit errors._ - [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. - _I've added a unit test._ - [x] If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. - _I don't think this change affects MMDet or MMDet3D._ - [x] The documentation has been modified accordingly, like docstring or example tutorials. - _This change fixes an existing bug and doesn't require modifying any documentation/docstring._
1 parent 6cb7fe0 commit 5d49918

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

mmseg/core/evaluation/metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ def intersect_and_union(pred_label,
6363
else:
6464
label = torch.from_numpy(label)
6565

66-
if label_map is not None:
67-
label_copy = label.clone()
68-
for old_id, new_id in label_map.items():
69-
label[label_copy == old_id] = new_id
7066
if reduce_zero_label:
7167
label[label == 0] = 255
7268
label = label - 1
7369
label[label == 254] = 255
70+
if label_map is not None:
71+
label_copy = label.clone()
72+
for old_id, new_id in label_map.items():
73+
label[label_copy == old_id] = new_id
7474

7575
mask = (label != ignore_index)
7676
pred_label = pred_label[mask]

mmseg/datasets/pipelines/loading.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def __call__(self, results):
133133
gt_semantic_seg = mmcv.imfrombytes(
134134
img_bytes, flag='unchanged',
135135
backend=self.imdecode_backend).squeeze().astype(np.uint8)
136+
# reduce zero_label
137+
if self.reduce_zero_label:
138+
# avoid using underflow conversion
139+
gt_semantic_seg[gt_semantic_seg == 0] = 255
140+
gt_semantic_seg = gt_semantic_seg - 1
141+
gt_semantic_seg[gt_semantic_seg == 254] = 255
136142
# modify if custom classes
137143
if results.get('label_map', None) is not None:
138144
# Add deep copy to solve bug of repeatedly
@@ -141,12 +147,6 @@ def __call__(self, results):
141147
gt_semantic_seg_copy = gt_semantic_seg.copy()
142148
for old_id, new_id in results['label_map'].items():
143149
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
144-
# reduce zero_label
145-
if self.reduce_zero_label:
146-
# avoid using underflow conversion
147-
gt_semantic_seg[gt_semantic_seg == 0] = 255
148-
gt_semantic_seg = gt_semantic_seg - 1
149-
gt_semantic_seg[gt_semantic_seg == 254] = 255
150150
results['gt_semantic_seg'] = gt_semantic_seg
151151
results['seg_fields'].append('gt_semantic_seg')
152152
return results

tests/test_data/test_loading.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,61 @@ def test_load_seg_custom_classes(self):
177177
assert gt_array.dtype == np.uint8
178178
np.testing.assert_array_equal(gt_array, true_mask)
179179

180+
# test with removing a class and reducing zero label simultaneously
181+
results = dict(
182+
img_info=dict(filename=img_path),
183+
ann_info=dict(seg_map=gt_path),
184+
# since reduce_zero_label is True, there are only 4 real classes.
185+
# if the full set of classes is ["A", "B", "C", "D"], the
186+
# following label map simulates the dataset option
187+
# classes=["A", "C", "D"] which removes class "B".
188+
label_map={
189+
0: 0,
190+
1: -1, # simulate removing class 1
191+
2: 1,
192+
3: 2
193+
},
194+
seg_fields=[])
195+
196+
load_imgs = LoadImageFromFile()
197+
results = load_imgs(copy.deepcopy(results))
198+
199+
# reduce zero label
200+
load_anns = LoadAnnotations(reduce_zero_label=True)
201+
results = load_anns(copy.deepcopy(results))
202+
203+
gt_array = results['gt_semantic_seg']
204+
205+
true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255
206+
true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0
207+
true_mask[2:4, 6:8] = -1 # 2s are reduced to class 1 which is removed
208+
true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1
209+
true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2
210+
211+
assert results['seg_fields'] == ['gt_semantic_seg']
212+
assert gt_array.shape == (10, 10)
213+
assert gt_array.dtype == np.uint8
214+
np.testing.assert_array_equal(gt_array, true_mask)
215+
216+
# test no custom classes
217+
results = dict(
218+
img_info=dict(filename=img_path),
219+
ann_info=dict(seg_map=gt_path),
220+
seg_fields=[])
221+
222+
load_imgs = LoadImageFromFile()
223+
results = load_imgs(copy.deepcopy(results))
224+
225+
load_anns = LoadAnnotations()
226+
results = load_anns(copy.deepcopy(results))
227+
228+
gt_array = results['gt_semantic_seg']
229+
230+
assert results['seg_fields'] == ['gt_semantic_seg']
231+
assert gt_array.shape == (10, 10)
232+
assert gt_array.dtype == np.uint8
233+
np.testing.assert_array_equal(gt_array, test_gt)
234+
180235
# test no custom classes
181236
results = dict(
182237
img_info=dict(filename=img_path),

0 commit comments

Comments
 (0)