Skip to content

Commit 118fd9d

Browse files
daavooxvjiarui
andauthored
Support custom palette (open-mmlab#157)
* Fix split * Update tests/test_data/test_dataset.py Co-authored-by: Jerry Jiarui XU <[email protected]> Co-authored-by: Jerry Jiarui XU <[email protected]>
1 parent f705071 commit 118fd9d

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

mmseg/datasets/custom.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class CustomDataset(Dataset):
6060
Default: False
6161
classes (str | Sequence[str], optional): Specify classes to load.
6262
If is None, ``cls.CLASSES`` will be used. Default: None.
63+
palette (Sequence[Sequence[int]]] | np.ndarray | None):
64+
The palette of segmentation map. If None is given, and
65+
self.PALETTE is None, random palette will be generated.
66+
Default: None
6367
"""
6468

6569
CLASSES = None
@@ -77,7 +81,8 @@ def __init__(self,
7781
test_mode=False,
7882
ignore_index=255,
7983
reduce_zero_label=False,
80-
classes=None):
84+
classes=None,
85+
palette=None):
8186
self.pipeline = Compose(pipeline)
8287
self.img_dir = img_dir
8388
self.img_suffix = img_suffix
@@ -89,7 +94,8 @@ def __init__(self,
8994
self.ignore_index = ignore_index
9095
self.reduce_zero_label = reduce_zero_label
9196
self.label_map = None
92-
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
97+
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
98+
classes, palette)
9399

94100
# join paths if data_root is specified
95101
if self.data_root is not None:
@@ -241,7 +247,7 @@ def get_gt_seg_maps(self):
241247

242248
return gt_seg_maps
243249

244-
def get_classes_and_palette(self, classes=None):
250+
def get_classes_and_palette(self, classes=None, palette=None):
245251
"""Get class names of current dataset.
246252
247253
Args:
@@ -250,6 +256,9 @@ def get_classes_and_palette(self, classes=None):
250256
string, take it as a file name. The file contains the name of
251257
classes where each line contains one class name. If classes is
252258
a tuple or list, override the CLASSES defined by the dataset.
259+
palette (Sequence[Sequence[int]]] | np.ndarray | None):
260+
The palette of segmentation map. If None is given, random
261+
palette will be generated. Default: None
253262
"""
254263
if classes is None:
255264
self.custom_classes = False
@@ -278,11 +287,11 @@ def get_classes_and_palette(self, classes=None):
278287
else:
279288
self.label_map[i] = classes.index(c)
280289

281-
palette = self.get_palette_for_custom_classes()
290+
palette = self.get_palette_for_custom_classes(class_names, palette)
282291

283292
return class_names, palette
284293

285-
def get_palette_for_custom_classes(self):
294+
def get_palette_for_custom_classes(self, class_names, palette=None):
286295

287296
if self.label_map is not None:
288297
# return subset of palette
@@ -293,8 +302,11 @@ def get_palette_for_custom_classes(self):
293302
palette.append(self.PALETTE[old_id])
294303
palette = type(self.PALETTE)(palette)
295304

296-
else:
297-
palette = self.PALETTE
305+
elif palette is None:
306+
if self.PALETTE is None:
307+
palette = np.random.randint(0, 255, size=(len(class_names), 3))
308+
else:
309+
palette = self.PALETTE
298310

299311
return palette
300312

tests/test_data/test_dataset.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,33 @@ def test_custom_classes_override_default(dataset, classes):
231231
test_mode=True)
232232

233233
assert custom_dataset.CLASSES == original_classes
234+
235+
236+
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
237+
@patch('mmseg.datasets.CustomDataset.__getitem__',
238+
MagicMock(side_effect=lambda idx: idx))
239+
def test_custom_dataset_random_palette_is_generated():
240+
dataset = CustomDataset(
241+
pipeline=[],
242+
img_dir=MagicMock(),
243+
split=MagicMock(),
244+
classes=('bus', 'car'),
245+
test_mode=True)
246+
assert len(dataset.PALETTE) == 2
247+
for class_color in dataset.PALETTE:
248+
assert len(class_color) == 3
249+
assert all(x >= 0 and x <= 255 for x in class_color)
250+
251+
252+
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
253+
@patch('mmseg.datasets.CustomDataset.__getitem__',
254+
MagicMock(side_effect=lambda idx: idx))
255+
def test_custom_dataset_custom_palette():
256+
dataset = CustomDataset(
257+
pipeline=[],
258+
img_dir=MagicMock(),
259+
split=MagicMock(),
260+
classes=('bus', 'car'),
261+
palette=[[100, 100, 100], [200, 200, 200]],
262+
test_mode=True)
263+
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])

0 commit comments

Comments
 (0)