Skip to content

Commit 6f43f4d

Browse files
authored
[Enchance] support infererence with padding (open-mmlab#1607)
* [Enchance] support infererence with padding * limite pad after flip when inference * add test code
1 parent 2dede04 commit 6f43f4d

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

mmseg/datasets/pipelines/test_time_aug.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def __init__(self,
5757
img_ratios=None,
5858
flip=False,
5959
flip_direction='horizontal'):
60+
if flip:
61+
trans_index = {
62+
key['type']: index
63+
for index, key in enumerate(transforms)
64+
}
65+
if 'RandomFlip' in trans_index and 'Pad' in trans_index:
66+
assert trans_index['RandomFlip'] < trans_index['Pad'], \
67+
'Pad must be executed after RandomFlip when flip is True'
6068
self.transforms = Compose(transforms)
6169
if img_ratios is not None:
6270
img_ratios = img_ratios if isinstance(img_ratios,

mmseg/models/segmentors/encoder_decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def slide_inference(self, img, img_meta, rescale):
189189
count_mat.cpu().detach().numpy()).to(device=img.device)
190190
preds = preds / count_mat
191191
if rescale:
192+
# remove padding area
193+
resize_shape = img_meta[0]['img_shape'][:2]
194+
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
192195
preds = resize(
193196
preds,
194197
size=img_meta[0]['ori_shape'][:2],
@@ -206,6 +209,9 @@ def whole_inference(self, img, img_meta, rescale):
206209
if torch.onnx.is_in_onnx_export():
207210
size = img.shape[2:]
208211
else:
212+
# remove padding area
213+
resize_shape = img_meta[0]['img_shape'][:2]
214+
seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]]
209215
size = img_meta[0]['ori_shape'][:2]
210216
seg_logit = resize(
211217
seg_logit,

tests/test_data/test_tta.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,41 @@ def test_multi_scale_flip_aug():
149149
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
150150
(512, 512), (1024, 1024), (1024, 1024)]
151151
assert tta_results['flip'] == [False, True, False, True, False, True]
152+
153+
# test assertion if flip is True and Pad executed before RandomFlip
154+
with pytest.raises(AssertionError):
155+
tta_transform = dict(
156+
type='MultiScaleFlipAug',
157+
img_scale=[(256, 256), (512, 512), (1024, 1024)],
158+
img_ratios=None,
159+
flip=True,
160+
transforms=[
161+
dict(type='Resize', keep_ratio=False),
162+
dict(type='Pad', size_divisor=32),
163+
dict(type='RandomFlip'),
164+
])
165+
tta_module = build_from_cfg(tta_transform, PIPELINES)
166+
167+
tta_transform = dict(
168+
type='MultiScaleFlipAug',
169+
img_scale=[(256, 256), (512, 512), (1024, 1024)],
170+
img_ratios=None,
171+
flip=True,
172+
transforms=[
173+
dict(type='Resize', keep_ratio=True),
174+
dict(type='RandomFlip'),
175+
dict(type='Pad', size_divisor=32),
176+
])
177+
tta_module = build_from_cfg(tta_transform, PIPELINES)
178+
tta_results = tta_module(results.copy())
179+
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
180+
(512, 512), (1024, 1024), (1024, 1024)]
181+
assert tta_results['flip'] == [False, True, False, True, False, True]
182+
assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3),
183+
(288, 512, 3), (288, 512, 3),
184+
(576, 1024, 3), (576, 1024, 3)]
185+
assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3),
186+
(288, 512, 3), (288, 512, 3),
187+
(576, 1024, 3), (576, 1024, 3)]
188+
for i in range(len(tta_results['img'])):
189+
assert tta_results['img'][i].shape == tta_results['pad_shape'][i]

0 commit comments

Comments
 (0)