Skip to content

Commit 8de0050

Browse files
authored
[Refactor] data flow (#1956)
* [WIP] Refactor data flow * model return * [WIP] Refactor data flow * support data_samples is optional * fix benchmark * fix base * minors * rebase * fix api * ut * fix api inference * comments * docstring * docstring * docstring * fix bug of slide inference * add assert c > 1
1 parent 50546da commit 8de0050

File tree

21 files changed

+536
-535
lines changed

21 files changed

+536
-535
lines changed

mmseg/apis/inference.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import warnings
3+
from collections import defaultdict
34
from pathlib import Path
45
from typing import Optional, Sequence, Union
56

@@ -11,9 +12,9 @@
1112
from mmengine.runner import load_checkpoint
1213
from mmengine.utils import mkdir_or_exist
1314

14-
from mmseg.data import SegDataSample
1515
from mmseg.models import BaseSegmentor
1616
from mmseg.registry import MODELS
17+
from mmseg.structures import SegDataSample
1718
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
1819
from mmseg.visualization import SegLocalVisualizer
1920

@@ -50,7 +51,6 @@ def init_model(config: Union[str, Path, Config],
5051
model = MODELS.build(config.model)
5152
if checkpoint is not None:
5253
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
53-
5454
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
5555
# save the dataset_meta in the model for convenience
5656
if 'dataset_meta' in checkpoint.get('meta', {}):
@@ -108,14 +108,15 @@ def _preprare_data(imgs: ImageType, model: BaseSegmentor):
108108
# a pipeline for each inference
109109
pipeline = Compose(cfg.test_pipeline)
110110

111-
data = []
111+
data = defaultdict(list)
112112
for img in imgs:
113113
if isinstance(img, np.ndarray):
114114
data_ = dict(img=img)
115115
else:
116116
data_ = dict(img_path=img)
117117
data_ = pipeline(data_)
118-
data.append(data_)
118+
data['inputs'].append(data_['inputs'])
119+
data['data_samples'].append(data_['data_samples'])
119120

120121
return data, is_batch
121122

@@ -187,11 +188,12 @@ def show_result_pyplot(model: BaseSegmentor,
187188
save_dir=save_dir,
188189
alpha=opacity)
189190
visualizer.dataset_meta = dict(
190-
classes=model.CLASSES, palette=model.PALETTE)
191+
classes=model.dataset_meta['classes'],
192+
palette=model.dataset_meta['palette'])
191193
visualizer.add_datasample(
192194
name=title,
193195
image=image,
194-
pred_sample=result[0],
196+
data_sample=result[0],
195197
draw_gt=draw_gt,
196198
draw_pred=draw_pred,
197199
wait_time=wait_time,

mmseg/datasets/transforms/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def transform(self, results: dict) -> dict:
7878
if key in results:
7979
img_meta[key] = results[key]
8080
data_sample.set_metainfo(img_meta)
81-
packed_results['data_sample'] = data_sample
81+
packed_results['data_samples'] = data_sample
8282

8383
return packed_results
8484

mmseg/engine/hooks/visualization_hook.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def __init__(self,
6666
def _after_iter(self,
6767
runner: Runner,
6868
batch_idx: int,
69-
data_batch: Sequence[dict],
69+
data_batch: dict,
7070
outputs: Sequence[SegDataSample],
7171
mode: str = 'val') -> None:
7272
"""Run after every ``self.interval`` validation iterations.
7373
7474
Args:
7575
runner (:obj:`Runner`): The runner of the validation process.
7676
batch_idx (int): The index of the current batch in the val loop.
77-
data_batch (Sequence[dict]): Data from dataloader.
77+
data_batch (dict): Data from dataloader.
7878
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
7979
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
8080
"""
@@ -85,18 +85,16 @@ def _after_iter(self,
8585
self.file_client = FileClient(**self.file_client_args)
8686

8787
if self.every_n_inner_iters(batch_idx, self.interval):
88-
for input_data, output in zip(data_batch, outputs):
89-
img_path = input_data['data_sample'].img_path
88+
for output in outputs:
89+
img_path = output.img_path
9090
img_bytes = self.file_client.get(img_path)
9191
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
9292
window_name = f'{mode}_{osp.basename(img_path)}'
9393

94-
gt_sample = input_data['data_sample']
9594
self._visualizer.add_datasample(
9695
window_name,
9796
img,
98-
gt_sample=gt_sample,
99-
pred_sample=output,
97+
data_sample=output,
10098
show=self.show,
10199
wait_time=self.wait_time,
102100
step=runner.iter)

mmseg/evaluation/metrics/citys_metric.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,24 @@ def __init__(self,
4949
self.to_label_id = to_label_id
5050
self.suffix = suffix
5151

52-
def process(self, data_batch: Sequence[dict],
53-
predictions: Sequence[dict]) -> None:
54-
"""Process one batch of data and predictions.
52+
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
53+
"""Process one batch of data and data_samples.
5554
5655
The processed results should be stored in ``self.results``, which will
5756
be used to computed the metrics when all batches have been processed.
5857
5958
Args:
60-
data_batch (Sequence[dict]): A batch of data from the dataloader.
61-
predictions (Sequence[dict]): A batch of outputs from the model.
59+
data_batch (dict): A batch of data from the dataloader.
60+
data_samples (Sequence[dict]): A batch of outputs from the model.
6261
"""
6362
mkdir_or_exist(self.suffix)
6463

65-
for pred in predictions:
66-
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
64+
for data_sample in data_samples:
65+
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
6766
# results2img
6867
if self.to_label_id:
6968
pred_label = self._convert_to_label_id(pred_label)
70-
basename = osp.splitext(osp.basename(pred['img_path']))[0]
69+
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
7170
png_filename = osp.join(self.suffix, f'{basename}.png')
7271
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
7372
import cityscapesscripts.helpers.labels as CSLabels

mmseg/evaluation/metrics/iou_metric.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,20 @@ def __init__(self,
4747
self.nan_to_num = nan_to_num
4848
self.beta = beta
4949

50-
def process(self, data_batch: Sequence[dict],
51-
predictions: Sequence[dict]) -> None:
52-
"""Process one batch of data and predictions.
50+
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
51+
"""Process one batch of data and data_samples.
5352
5453
The processed results should be stored in ``self.results``, which will
5554
be used to computed the metrics when all batches have been processed.
5655
5756
Args:
58-
data_batch (Sequence[dict]): A batch of data from the dataloader.
59-
predictions (Sequence[dict]): A batch of outputs from the model.
57+
data_batch (dict): A batch of data from the dataloader.
58+
data_samples (Sequence[dict]): A batch of outputs from the model.
6059
"""
6160
num_classes = len(self.dataset_meta['classes'])
62-
for data, pred in zip(data_batch, predictions):
63-
pred_label = pred['pred_sem_seg']['data'].squeeze()
64-
label = data['data_sample']['gt_sem_seg']['data'].squeeze().to(
65-
pred_label)
61+
for data_sample in data_samples:
62+
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
63+
label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label)
6664
self.results.append(
6765
self.intersect_and_union(pred_label, label, num_classes,
6866
self.ignore_index))

mmseg/models/data_preprocessor.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from numbers import Number
3-
from typing import List, Optional, Sequence, Tuple
3+
from typing import Any, Dict, List, Optional, Sequence
44

55
import torch
66
from mmengine.model import BaseDataPreprocessor
7-
from torch import Tensor
87

98
from mmseg.registry import MODELS
10-
from mmseg.utils import OptSampleList, stack_batch
9+
from mmseg.utils import stack_batch
1110

1211

1312
@MODELS.register_module()
@@ -87,22 +86,20 @@ def __init__(self,
8786
# TODO: support batch augmentations.
8887
self.batch_augments = batch_augments
8988

90-
def forward(self,
91-
data: Sequence[dict],
92-
training: bool = False) -> Tuple[Tensor, OptSampleList]:
89+
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
9390
"""Perform normalization、padding and bgr2rgb conversion based on
9491
``BaseDataPreprocessor``.
9592
9693
Args:
97-
data (Sequence[dict]): data sampled from dataloader.
94+
data (dict): data sampled from dataloader.
9895
training (bool): Whether to enable training time augmentation.
9996
10097
Returns:
101-
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
102-
model input.
98+
Dict: Data in the same format as the model input.
10399
"""
104-
inputs, batch_data_samples = self.collate_data(data)
105-
100+
data = self.cast_data(data) # type: ignore
101+
inputs = data['inputs']
102+
data_samples = data.get('data_samples', None)
106103
# TODO: whether normalize should be after stack_batch
107104
if self.channel_conversion and inputs[0].size(0) == 3:
108105
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
@@ -113,20 +110,23 @@ def forward(self,
113110
inputs = [_input.float() for _input in inputs]
114111

115112
if training:
116-
batch_inputs, batch_data_samples = stack_batch(
113+
assert data_samples is not None, ('During training, ',
114+
'`data_samples` must be define.')
115+
inputs, data_samples = stack_batch(
117116
inputs=inputs,
118-
batch_data_samples=batch_data_samples,
117+
data_samples=data_samples,
119118
size=self.size,
120119
size_divisor=self.size_divisor,
121120
pad_val=self.pad_val,
122121
seg_pad_val=self.seg_pad_val)
123122

124123
if self.batch_augments is not None:
125-
inputs, batch_data_samples = self.batch_augments(
126-
inputs, batch_data_samples)
127-
return batch_inputs, batch_data_samples
124+
inputs, data_samples = self.batch_augments(
125+
inputs, data_samples)
126+
return dict(inputs=inputs, data_samples=data_samples)
128127
else:
129128
assert len(inputs) == 1, (
130129
'Batch inference is not support currently, '
131130
'as the image size might be different in a batch')
132-
return torch.stack(inputs, dim=0), batch_data_samples
131+
return dict(
132+
inputs=torch.stack(inputs, dim=0), data_samples=data_samples)

0 commit comments

Comments
 (0)