|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | | -from typing import Dict, List, Optional, Tuple |
| 2 | +from typing import Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import mmcv |
5 | 5 | import numpy as np |
|
9 | 9 |
|
10 | 10 | from mmseg.registry import VISUALIZERS |
11 | 11 | from mmseg.structures import SegDataSample |
| 12 | +from mmseg.utils import get_classes, get_palette |
12 | 13 |
|
13 | 14 |
|
14 | 15 | @VISUALIZERS.register_module() |
@@ -55,14 +56,23 @@ def __init__(self, |
55 | 56 | image: Optional[np.ndarray] = None, |
56 | 57 | vis_backends: Optional[Dict] = None, |
57 | 58 | save_dir: Optional[str] = None, |
| 59 | + palette: Optional[Union[str, List]] = None, |
| 60 | + classes: Optional[Union[str, List]] = None, |
| 61 | + dataset_name: Optional[str] = None, |
58 | 62 | alpha: float = 0.8, |
59 | 63 | **kwargs): |
60 | 64 | super().__init__(name, image, vis_backends, save_dir, **kwargs) |
61 | | - self.alpha = alpha |
| 65 | + self.alpha: float = alpha |
62 | 66 | # Set default value. When calling |
63 | 67 | # `SegLocalVisualizer().dataset_meta=xxx`, |
64 | 68 | # it will override the default value. |
65 | | - self.dataset_meta = {} |
| 69 | + if dataset_name is None: |
| 70 | + dataset_name = 'cityscapes' |
| 71 | + classes = classes if classes else get_classes(dataset_name) |
| 72 | + palette = palette if palette else get_palette(dataset_name) |
| 73 | + assert len(classes) == len( |
| 74 | + palette), 'The length of classes should be equal to palette' |
| 75 | + self.dataset_meta: dict = {'classes': classes, 'palette': palette} |
66 | 76 |
|
67 | 77 | def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, |
68 | 78 | classes: Optional[Tuple[str]], |
|
0 commit comments