@@ -60,6 +60,10 @@ class CustomDataset(Dataset):
60
60
Default: False
61
61
classes (str | Sequence[str], optional): Specify classes to load.
62
62
If is None, ``cls.CLASSES`` will be used. Default: None.
63
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
64
+ The palette of segmentation map. If None is given, and
65
+ self.PALETTE is None, random palette will be generated.
66
+ Default: None
63
67
"""
64
68
65
69
CLASSES = None
@@ -77,7 +81,8 @@ def __init__(self,
77
81
test_mode = False ,
78
82
ignore_index = 255 ,
79
83
reduce_zero_label = False ,
80
- classes = None ):
84
+ classes = None ,
85
+ palette = None ):
81
86
self .pipeline = Compose (pipeline )
82
87
self .img_dir = img_dir
83
88
self .img_suffix = img_suffix
@@ -89,7 +94,8 @@ def __init__(self,
89
94
self .ignore_index = ignore_index
90
95
self .reduce_zero_label = reduce_zero_label
91
96
self .label_map = None
92
- self .CLASSES , self .PALETTE = self .get_classes_and_palette (classes )
97
+ self .CLASSES , self .PALETTE = self .get_classes_and_palette (
98
+ classes , palette )
93
99
94
100
# join paths if data_root is specified
95
101
if self .data_root is not None :
@@ -241,7 +247,7 @@ def get_gt_seg_maps(self):
241
247
242
248
return gt_seg_maps
243
249
244
- def get_classes_and_palette (self , classes = None ):
250
+ def get_classes_and_palette (self , classes = None , palette = None ):
245
251
"""Get class names of current dataset.
246
252
247
253
Args:
@@ -250,6 +256,9 @@ def get_classes_and_palette(self, classes=None):
250
256
string, take it as a file name. The file contains the name of
251
257
classes where each line contains one class name. If classes is
252
258
a tuple or list, override the CLASSES defined by the dataset.
259
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
260
+ The palette of segmentation map. If None is given, random
261
+ palette will be generated. Default: None
253
262
"""
254
263
if classes is None :
255
264
self .custom_classes = False
@@ -278,11 +287,11 @@ def get_classes_and_palette(self, classes=None):
278
287
else :
279
288
self .label_map [i ] = classes .index (c )
280
289
281
- palette = self .get_palette_for_custom_classes ()
290
+ palette = self .get_palette_for_custom_classes (class_names , palette )
282
291
283
292
return class_names , palette
284
293
285
- def get_palette_for_custom_classes (self ):
294
+ def get_palette_for_custom_classes (self , class_names , palette = None ):
286
295
287
296
if self .label_map is not None :
288
297
# return subset of palette
@@ -293,8 +302,11 @@ def get_palette_for_custom_classes(self):
293
302
palette .append (self .PALETTE [old_id ])
294
303
palette = type (self .PALETTE )(palette )
295
304
296
- else :
297
- palette = self .PALETTE
305
+ elif palette is None :
306
+ if self .PALETTE is None :
307
+ palette = np .random .randint (0 , 255 , size = (len (class_names ), 3 ))
308
+ else :
309
+ palette = self .PALETTE
298
310
299
311
return palette
300
312
0 commit comments