Skip to content

Commit d33af52

Browse files
committed
fix ut
1 parent d0b35cd commit d33af52

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

mmseg/models/data_preprocessor.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ class SegDataPreProcessor(BaseDataPreprocessor):
4848
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
4949
Defaults to False.
5050
batch_augments (list[dict], optional): Batch-level augmentations
51-
train_cfg (dict, optional): The padding size config in training, if
52-
not specify, will use `size` and `size_divisor` params as default.
53-
Defaults to None, only supports keys `size` or `size_divisor`.
5451
test_cfg (dict, optional): The padding size config in testing, if not
5552
specify, will use `size` and `size_divisor` params as default.
5653
Defaults to None, only supports keys `size` or `size_divisor`.
@@ -67,7 +64,6 @@ def __init__(
6764
bgr_to_rgb: bool = False,
6865
rgb_to_bgr: bool = False,
6966
batch_augments: Optional[List[dict]] = None,
70-
train_cfg: dict = None,
7167
test_cfg: dict = None,
7268
):
7369
super().__init__()
@@ -96,10 +92,8 @@ def __init__(
9692
# TODO: support batch augmentations.
9793
self.batch_augments = batch_augments
9894

99-
# Support different padding methods in training and testing
100-
default_size_cfg = dict(size=size, size_divisor=size_divisor)
101-
self.train_cfg = train_cfg if train_cfg else default_size_cfg
102-
self.test_cfg = test_cfg if test_cfg else default_size_cfg
95+
# Support different padding methods in testing
96+
self.test_cfg = test_cfg
10397

10498
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
10599
"""Perform normalization、padding and bgr2rgb conversion based on
@@ -126,24 +120,31 @@ def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
126120
if training:
127121
assert data_samples is not None, ('During training, ',
128122
'`data_samples` must be define.')
123+
inputs, data_samples = stack_batch(
124+
inputs=inputs,
125+
data_samples=data_samples,
126+
size=self.size,
127+
size_divisor=self.size_divisor,
128+
pad_val=self.pad_val,
129+
seg_pad_val=self.seg_pad_val)
130+
131+
if self.batch_augments is not None:
132+
inputs, data_samples = self.batch_augments(
133+
inputs, data_samples)
129134
else:
130135
assert len(inputs) == 1, (
131136
'Batch inference is not support currently, '
132137
'as the image size might be different in a batch')
133-
134-
size_cfg = self.train_cfg if training else self.test_cfg
135-
size = size_cfg.get('size', None)
136-
size_divisor = size_cfg.get('size_divisor', None)
137-
138-
inputs, data_samples = stack_batch(
139-
inputs=inputs,
140-
data_samples=data_samples,
141-
size=size,
142-
size_divisor=size_divisor,
143-
pad_val=self.pad_val,
144-
seg_pad_val=self.seg_pad_val)
145-
146-
if self.batch_augments is not None:
147-
inputs, data_samples = self.batch_augments(inputs, data_samples)
138+
# pad images when testing
139+
if self.test_cfg:
140+
inputs, data_samples = stack_batch(
141+
inputs=inputs,
142+
data_samples=data_samples,
143+
size=self.test_cfg.get('size', None),
144+
size_divisor=self.test_cfg.get('size_divisor', None),
145+
pad_val=self.pad_val,
146+
seg_pad_val=self.seg_pad_val)
147+
else:
148+
inputs = torch.stack(inputs, dim=0)
148149

149150
return dict(inputs=inputs, data_samples=data_samples)

mmseg/models/segmentors/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def postprocess_result(self,
192192
'pred_sem_seg':
193193
PixelData(**{'data': i_seg_pred}),
194194
'gt_sem_seg':
195-
PixelData(**{'data': i_gt_sem_seg})
195+
PixelData() if only_prediction else PixelData(
196+
**{'data': i_gt_sem_seg})
196197
})
197198

198199
return data_samples

0 commit comments

Comments
 (0)