Skip to content

Commit 37af545

Browse files
[Fix] Fix inference api and support setting palette to SegLocalVisualizer (open-mmlab#2475)
as title Co-authored-by: MengzhangLI <[email protected]>
1 parent 7fc8ca0 commit 37af545

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

mmseg/apis/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ def init_model(config: Union[str, Path, Config],
9393
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
9494

9595
cfg = model.cfg
96-
if dict(type='LoadAnnotations') in cfg.test_pipeline:
97-
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
96+
for t in cfg.test_pipeline:
97+
if t.get('type') == 'LoadAnnotations':
98+
cfg.test_pipeline.remove(t)
9899

99100
is_batch = True
100101
if not isinstance(imgs, (list, tuple)):

mmseg/visualization/local_visualizer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Dict, List, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple, Union
33

44
import mmcv
55
import numpy as np
@@ -9,6 +9,7 @@
99

1010
from mmseg.registry import VISUALIZERS
1111
from mmseg.structures import SegDataSample
12+
from mmseg.utils import get_classes, get_palette
1213

1314

1415
@VISUALIZERS.register_module()
@@ -55,14 +56,23 @@ def __init__(self,
5556
image: Optional[np.ndarray] = None,
5657
vis_backends: Optional[Dict] = None,
5758
save_dir: Optional[str] = None,
59+
palette: Optional[Union[str, List]] = None,
60+
classes: Optional[Union[str, List]] = None,
61+
dataset_name: Optional[str] = None,
5862
alpha: float = 0.8,
5963
**kwargs):
6064
super().__init__(name, image, vis_backends, save_dir, **kwargs)
61-
self.alpha = alpha
65+
self.alpha: float = alpha
6266
# Set default value. When calling
6367
# `SegLocalVisualizer().dataset_meta=xxx`,
6468
# it will override the default value.
65-
self.dataset_meta = {}
69+
if dataset_name is None:
70+
dataset_name = 'cityscapes'
71+
classes = classes if classes else get_classes(dataset_name)
72+
palette = palette if palette else get_palette(dataset_name)
73+
assert len(classes) == len(
74+
palette), 'The length of classes should be equal to palette'
75+
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
6676

6777
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
6878
classes: Optional[Tuple[str]],

0 commit comments

Comments
 (0)