Skip to content

Commit 061a295

Browse files
authored
Support resize data augmentation according to original image size (open-mmlab#291)
* Support resize data augmentation according to original image size (img_scale=None and retio_range is tuple) * fix docstring * fix bug * add unittest * img_scale=None in TTA * fix bug * add unittest * fix typos * fix bug
1 parent 7970e0f commit 061a295

File tree

4 files changed

+199
-17
lines changed

4 files changed

+199
-17
lines changed

mmseg/datasets/pipelines/test_time_aug.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class MultiScaleFlipAug(object):
4141
4242
Args:
4343
transforms (list[dict]): Transforms to apply in each augmentation.
44-
img_scale (tuple | list[tuple]): Images scales for resizing.
44+
img_scale (None | tuple | list[tuple]): Images scales for resizing.
4545
img_ratios (float | list[float]): Image ratios for resizing
4646
flip (bool): Whether apply flip augmentation. Default: False.
4747
flip_direction (str | list[str]): Flip augmentation directions,
@@ -58,20 +58,27 @@ def __init__(self,
5858
flip_direction='horizontal'):
5959
self.transforms = Compose(transforms)
6060
if img_ratios is not None:
61-
# mode 1: given a scale and a range of image ratio
6261
img_ratios = img_ratios if isinstance(img_ratios,
6362
list) else [img_ratios]
6463
assert mmcv.is_list_of(img_ratios, float)
65-
assert isinstance(img_scale, tuple) and len(img_scale) == 2
64+
if img_scale is None:
65+
# mode 1: given img_scale=None and a range of image ratio
66+
self.img_scale = None
67+
assert mmcv.is_list_of(img_ratios, float)
68+
elif isinstance(img_scale, tuple) and mmcv.is_list_of(
69+
img_ratios, float):
70+
assert len(img_scale) == 2
71+
# mode 2: given a scale and a range of image ratio
6672
self.img_scale = [(int(img_scale[0] * ratio),
6773
int(img_scale[1] * ratio))
6874
for ratio in img_ratios]
6975
else:
70-
# mode 2: given multiple scales
76+
# mode 3: given multiple scales
7177
self.img_scale = img_scale if isinstance(img_scale,
7278
list) else [img_scale]
73-
assert mmcv.is_list_of(self.img_scale, tuple)
79+
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
7480
self.flip = flip
81+
self.img_ratios = img_ratios
7582
self.flip_direction = flip_direction if isinstance(
7683
flip_direction, list) else [flip_direction]
7784
assert mmcv.is_list_of(self.flip_direction, str)
@@ -95,8 +102,14 @@ def __call__(self, results):
95102
"""
96103

97104
aug_data = []
105+
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
106+
h, w = results['img'].shape[:2]
107+
img_scale = [(int(h * ratio), int(w * ratio))
108+
for ratio in self.img_ratios]
109+
else:
110+
img_scale = self.img_scale
98111
flip_aug = [False, True] if self.flip else [False]
99-
for scale in self.img_scale:
112+
for scale in img_scale:
100113
for flip in flip_aug:
101114
for direction in self.flip_direction:
102115
_results = results.copy()

mmseg/datasets/pipelines/transforms.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@ class Resize(object):
1414
contains the key "scale", then the scale in the input dict is used,
1515
otherwise the specified scale in the init method is used.
1616
17-
``img_scale`` can either be a tuple (single-scale) or a list of tuple
18-
(multi-scale). There are 3 multiscale modes:
17+
``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple
18+
(multi-scale). There are 4 multiscale modes:
1919
20-
- ``ratio_range is not None``: randomly sample a ratio from the ratio range
21-
and multiply it with the image scale.
20+
- ``ratio_range is not None``:
21+
1. When img_scale is None, img_scale is the shape of image in results
22+
(img_scale = results['img'].shape[:2]) and the image is resized based
23+
on the original size. (mode 1)
24+
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
25+
the ratio range and multiply it with the image scale. (mode 2)
2226
2327
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
24-
scale from the a range.
28+
scale from the a range. (mode 3)
2529
2630
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
27-
scale from multiple scales.
31+
scale from multiple scales. (mode 4)
2832
2933
Args:
3034
img_scale (tuple or list[tuple]): Images scales for resizing.
@@ -49,10 +53,11 @@ def __init__(self,
4953
assert mmcv.is_list_of(self.img_scale, tuple)
5054

5155
if ratio_range is not None:
52-
# mode 1: given a scale and a range of image ratio
53-
assert len(self.img_scale) == 1
56+
# mode 1: given img_scale=None and a range of image ratio
57+
# mode 2: given a scale and a range of image ratio
58+
assert self.img_scale is None or len(self.img_scale) == 1
5459
else:
55-
# mode 2: given multiple scales or a range of scales
60+
# mode 3 and 4: given multiple scales or a range of scales
5661
assert multiscale_mode in ['value', 'range']
5762

5863
self.multiscale_mode = multiscale_mode
@@ -150,8 +155,12 @@ def _random_scale(self, results):
150155
"""
151156

152157
if self.ratio_range is not None:
153-
scale, scale_idx = self.random_sample_ratio(
154-
self.img_scale[0], self.ratio_range)
158+
if self.img_scale is None:
159+
scale, scale_idx = self.random_sample_ratio(
160+
results['img'].shape[:2], self.ratio_range)
161+
else:
162+
scale, scale_idx = self.random_sample_ratio(
163+
self.img_scale[0], self.ratio_range)
155164
elif len(self.img_scale) == 1:
156165
scale, scale_idx = self.img_scale[0], 0
157166
elif self.multiscale_mode == 'range':

tests/test_data/test_transform.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_resize():
3838
resize_module = build_from_cfg(transform, PIPELINES)
3939

4040
results = dict()
41+
# (288, 512, 3)
4142
img = mmcv.imread(
4243
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
4344
results['img'] = img
@@ -92,6 +93,15 @@ def test_resize():
9293
resized_results = resize_module(results.copy())
9394
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
9495

96+
# test img_scale=None and ratio_range is tuple.
97+
# img shape: (288, 512, 3)
98+
transform = dict(
99+
type='Resize', img_scale=None, ratio_range=(0.5, 2.0), keep_ratio=True)
100+
resize_module = build_from_cfg(transform, PIPELINES)
101+
resized_results = resize_module(results.copy())
102+
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
103+
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0
104+
95105

96106
def test_flip():
97107
# test assertion for invalid prob

tests/test_data/test_tta.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os.path as osp
2+
3+
import mmcv
4+
import pytest
5+
from mmcv.utils import build_from_cfg
6+
7+
from mmseg.datasets.builder import PIPELINES
8+
9+
10+
def test_multi_scale_flip_aug():
11+
# test assertion if img_scale=None, img_ratios=1 (not float).
12+
with pytest.raises(AssertionError):
13+
tta_transform = dict(
14+
type='MultiScaleFlipAug',
15+
img_scale=None,
16+
img_ratios=1,
17+
transforms=[dict(type='Resize', keep_ratio=False)],
18+
)
19+
build_from_cfg(tta_transform, PIPELINES)
20+
21+
# test assertion if img_scale=None, img_ratios=None.
22+
with pytest.raises(AssertionError):
23+
tta_transform = dict(
24+
type='MultiScaleFlipAug',
25+
img_scale=None,
26+
img_ratios=None,
27+
transforms=[dict(type='Resize', keep_ratio=False)],
28+
)
29+
build_from_cfg(tta_transform, PIPELINES)
30+
31+
# test assertion if img_scale=(512, 512), img_ratios=1 (not float).
32+
with pytest.raises(AssertionError):
33+
tta_transform = dict(
34+
type='MultiScaleFlipAug',
35+
img_scale=(512, 512),
36+
img_ratios=1,
37+
transforms=[dict(type='Resize', keep_ratio=False)],
38+
)
39+
build_from_cfg(tta_transform, PIPELINES)
40+
41+
tta_transform = dict(
42+
type='MultiScaleFlipAug',
43+
img_scale=(512, 512),
44+
img_ratios=[0.5, 1.0, 2.0],
45+
flip=False,
46+
transforms=[dict(type='Resize', keep_ratio=False)],
47+
)
48+
tta_module = build_from_cfg(tta_transform, PIPELINES)
49+
50+
results = dict()
51+
# (288, 512, 3)
52+
img = mmcv.imread(
53+
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
54+
results['img'] = img
55+
results['img_shape'] = img.shape
56+
results['ori_shape'] = img.shape
57+
# Set initial values for default meta_keys
58+
results['pad_shape'] = img.shape
59+
results['scale_factor'] = 1.0
60+
61+
tta_results = tta_module(results.copy())
62+
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
63+
assert tta_results['flip'] == [False, False, False]
64+
65+
tta_transform = dict(
66+
type='MultiScaleFlipAug',
67+
img_scale=(512, 512),
68+
img_ratios=[0.5, 1.0, 2.0],
69+
flip=True,
70+
transforms=[dict(type='Resize', keep_ratio=False)],
71+
)
72+
tta_module = build_from_cfg(tta_transform, PIPELINES)
73+
tta_results = tta_module(results.copy())
74+
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
75+
(512, 512), (1024, 1024), (1024, 1024)]
76+
assert tta_results['flip'] == [False, True, False, True, False, True]
77+
78+
tta_transform = dict(
79+
type='MultiScaleFlipAug',
80+
img_scale=(512, 512),
81+
img_ratios=1.0,
82+
flip=False,
83+
transforms=[dict(type='Resize', keep_ratio=False)],
84+
)
85+
tta_module = build_from_cfg(tta_transform, PIPELINES)
86+
tta_results = tta_module(results.copy())
87+
assert tta_results['scale'] == [(512, 512)]
88+
assert tta_results['flip'] == [False]
89+
90+
tta_transform = dict(
91+
type='MultiScaleFlipAug',
92+
img_scale=(512, 512),
93+
img_ratios=1.0,
94+
flip=True,
95+
transforms=[dict(type='Resize', keep_ratio=False)],
96+
)
97+
tta_module = build_from_cfg(tta_transform, PIPELINES)
98+
tta_results = tta_module(results.copy())
99+
assert tta_results['scale'] == [(512, 512), (512, 512)]
100+
assert tta_results['flip'] == [False, True]
101+
102+
tta_transform = dict(
103+
type='MultiScaleFlipAug',
104+
img_scale=None,
105+
img_ratios=[0.5, 1.0, 2.0],
106+
flip=False,
107+
transforms=[dict(type='Resize', keep_ratio=False)],
108+
)
109+
tta_module = build_from_cfg(tta_transform, PIPELINES)
110+
tta_results = tta_module(results.copy())
111+
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
112+
assert tta_results['flip'] == [False, False, False]
113+
114+
tta_transform = dict(
115+
type='MultiScaleFlipAug',
116+
img_scale=None,
117+
img_ratios=[0.5, 1.0, 2.0],
118+
flip=True,
119+
transforms=[dict(type='Resize', keep_ratio=False)],
120+
)
121+
tta_module = build_from_cfg(tta_transform, PIPELINES)
122+
tta_results = tta_module(results.copy())
123+
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
124+
(288, 512), (576, 1024), (576, 1024)]
125+
assert tta_results['flip'] == [False, True, False, True, False, True]
126+
127+
tta_transform = dict(
128+
type='MultiScaleFlipAug',
129+
img_scale=[(256, 256), (512, 512), (1024, 1024)],
130+
img_ratios=None,
131+
flip=False,
132+
transforms=[dict(type='Resize', keep_ratio=False)],
133+
)
134+
tta_module = build_from_cfg(tta_transform, PIPELINES)
135+
tta_results = tta_module(results.copy())
136+
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
137+
assert tta_results['flip'] == [False, False, False]
138+
139+
tta_transform = dict(
140+
type='MultiScaleFlipAug',
141+
img_scale=[(256, 256), (512, 512), (1024, 1024)],
142+
img_ratios=None,
143+
flip=True,
144+
transforms=[dict(type='Resize', keep_ratio=False)],
145+
)
146+
tta_module = build_from_cfg(tta_transform, PIPELINES)
147+
tta_results = tta_module(results.copy())
148+
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
149+
(512, 512), (1024, 1024), (1024, 1024)]
150+
assert tta_results['flip'] == [False, True, False, True, False, True]

0 commit comments

Comments
 (0)