Skip to content

Commit 796d5ed

Browse files
authored
[Fix] Fix bug when loading class name form file in custom dataset (open-mmlab#923)
* [Fix] open-mmlab#916 expection string type classes * add unittests for string path classes * fix double quote string in test_dataset.py * move the import to the top of the file * fix isort lint error fix isort lint error when move the import to the top of the file
1 parent 8d49dd3 commit 796d5ed

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

mmseg/datasets/custom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def get_classes_and_palette(self, classes=None, palette=None):
319319
raise ValueError(f'Unsupported type {type(classes)} of classes.')
320320

321321
if self.CLASSES:
322-
if not set(classes).issubset(self.CLASSES):
322+
if not set(class_names).issubset(self.CLASSES):
323323
raise ValueError('classes is not a subset of CLASSES.')
324324

325325
# dictionary, its keys are the old label ids and its values
@@ -330,7 +330,7 @@ def get_classes_and_palette(self, classes=None, palette=None):
330330
if c not in class_names:
331331
self.label_map[i] = -1
332332
else:
333-
self.label_map[i] = classes.index(c)
333+
self.label_map[i] = class_names.index(c)
334334

335335
palette = self.get_palette_for_custom_classes(class_names, palette)
336336

tests/test_data/test_dataset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
import os.path as osp
34
import shutil
5+
import tempfile
46
from typing import Generator
57
from unittest.mock import MagicMock, patch
68

@@ -26,6 +28,37 @@ def test_classes():
2628
get_classes('unsupported')
2729

2830

31+
def test_classes_file_path():
32+
tmp_file = tempfile.NamedTemporaryFile()
33+
classes_path = f'{tmp_file.name}.txt'
34+
train_pipeline = [dict(type='LoadImageFromFile')]
35+
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
36+
37+
# classes.txt with full categories
38+
categories = get_classes('cityscapes')
39+
with open(classes_path, 'w') as f:
40+
f.write('\n'.join(categories))
41+
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
42+
43+
# classes.txt with sub categories
44+
categories = ['road', 'sidewalk', 'building']
45+
with open(classes_path, 'w') as f:
46+
f.write('\n'.join(categories))
47+
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
48+
49+
# classes.txt with unknown categories
50+
categories = ['road', 'sidewalk', 'unknown']
51+
with open(classes_path, 'w') as f:
52+
f.write('\n'.join(categories))
53+
54+
with pytest.raises(ValueError):
55+
CityscapesDataset(**kwargs)
56+
57+
tmp_file.close()
58+
os.remove(classes_path)
59+
assert not osp.exists(classes_path)
60+
61+
2962
def test_palette():
3063
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
3164
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(

0 commit comments

Comments
 (0)