Skip to content

Commit 3d18775

Browse files
xvjiaruihkzhang95
andauthored
[Feature] Add RandomRotate transform (open-mmlab#215)
* add RandomRotate for transforms * change rotation function to mmcv.imrotate * refactor * add unittest * fixed test * fixed docstring * fixed test * add more test * fixed repr * rename to prob * fixed unittest Co-authored-by: hkzhang95 <[email protected]>
1 parent 0d10921 commit 3d18775

File tree

8 files changed

+143
-18
lines changed

8 files changed

+143
-18
lines changed

configs/_base_/datasets/ade20k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
dict(type='LoadAnnotations', reduce_zero_label=True),
1010
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
1111
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12-
dict(type='RandomFlip', flip_ratio=0.5),
12+
dict(type='RandomFlip', prob=0.5),
1313
dict(type='PhotoMetricDistortion'),
1414
dict(type='Normalize', **img_norm_cfg),
1515
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

configs/_base_/datasets/cityscapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
dict(type='LoadAnnotations'),
1010
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
1111
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12-
dict(type='RandomFlip', flip_ratio=0.5),
12+
dict(type='RandomFlip', prob=0.5),
1313
dict(type='PhotoMetricDistortion'),
1414
dict(type='Normalize', **img_norm_cfg),
1515
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

configs/_base_/datasets/cityscapes_769x769.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
dict(type='LoadAnnotations'),
88
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
99
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
10-
dict(type='RandomFlip', flip_ratio=0.5),
10+
dict(type='RandomFlip', prob=0.5),
1111
dict(type='PhotoMetricDistortion'),
1212
dict(type='Normalize', **img_norm_cfg),
1313
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

configs/_base_/datasets/pascal_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
dict(type='LoadAnnotations'),
1313
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
1414
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
15-
dict(type='RandomFlip', flip_ratio=0.5),
15+
dict(type='RandomFlip', prob=0.5),
1616
dict(type='PhotoMetricDistortion'),
1717
dict(type='Normalize', **img_norm_cfg),
1818
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

configs/_base_/datasets/pascal_voc12.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
dict(type='LoadAnnotations'),
1010
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
1111
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12-
dict(type='RandomFlip', flip_ratio=0.5),
12+
dict(type='RandomFlip', prob=0.5),
1313
dict(type='PhotoMetricDistortion'),
1414
dict(type='Normalize', **img_norm_cfg),
1515
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

mmseg/datasets/pipelines/transforms.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import mmcv
22
import numpy as np
3+
from mmcv.utils import deprecated_api_warning
34
from numpy import random
45

56
from ..builder import PIPELINES
@@ -232,16 +233,17 @@ class RandomFlip(object):
232233
method.
233234
234235
Args:
235-
flip_ratio (float, optional): The flipping probability. Default: None.
236+
prob (float, optional): The flipping probability. Default: None.
236237
direction(str, optional): The flipping direction. Options are
237238
'horizontal' and 'vertical'. Default: 'horizontal'.
238239
"""
239240

240-
def __init__(self, flip_ratio=None, direction='horizontal'):
241-
self.flip_ratio = flip_ratio
241+
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
242+
def __init__(self, prob=None, direction='horizontal'):
243+
self.prob = prob
242244
self.direction = direction
243-
if flip_ratio is not None:
244-
assert flip_ratio >= 0 and flip_ratio <= 1
245+
if prob is not None:
246+
assert prob >= 0 and prob <= 1
245247
assert direction in ['horizontal', 'vertical']
246248

247249
def __call__(self, results):
@@ -257,7 +259,7 @@ def __call__(self, results):
257259
"""
258260

259261
if 'flip' not in results:
260-
flip = True if np.random.rand() < self.flip_ratio else False
262+
flip = True if np.random.rand() < self.prob else False
261263
results['flip'] = flip
262264
if 'flip_direction' not in results:
263265
results['flip_direction'] = self.direction
@@ -274,7 +276,7 @@ def __call__(self, results):
274276
return results
275277

276278
def __repr__(self):
277-
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
279+
return self.__class__.__name__ + f'(prob={self.prob})'
278280

279281

280282
@PIPELINES.register_module()
@@ -463,6 +465,89 @@ def __repr__(self):
463465
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
464466

465467

468+
@PIPELINES.register_module()
469+
class RandomRotate(object):
470+
"""Rotate the image & seg.
471+
472+
Args:
473+
prob (float): The rotation probability.
474+
degree (float, tuple[float]): Range of degrees to select from. If
475+
degree is a number instead of tuple like (min, max),
476+
the range of degree will be (``-degree``, ``+degree``)
477+
pad_val (float, optional): Padding value of image. Default: 0.
478+
seg_pad_val (float, optional): Padding value of segmentation map.
479+
Default: 255.
480+
center (tuple[float], optional): Center point (w, h) of the rotation in
481+
the source image. If not specified, the center of the image will be
482+
used. Default: None.
483+
auto_bound (bool): Whether to adjust the image size to cover the whole
484+
rotated image. Default: False
485+
"""
486+
487+
def __init__(self,
488+
prob,
489+
degree,
490+
pad_val=0,
491+
seg_pad_val=255,
492+
center=None,
493+
auto_bound=False):
494+
self.prob = prob
495+
assert prob >= 0 and prob <= 1
496+
if isinstance(degree, (float, int)):
497+
assert degree > 0, f'degree {degree} should be positive'
498+
self.degree = (-degree, degree)
499+
else:
500+
self.degree = degree
501+
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
502+
f'tuple of (min, max)'
503+
self.pal_val = pad_val
504+
self.seg_pad_val = seg_pad_val
505+
self.center = center
506+
self.auto_bound = auto_bound
507+
508+
def __call__(self, results):
509+
"""Call function to rotate image, semantic segmentation maps.
510+
511+
Args:
512+
results (dict): Result dict from loading pipeline.
513+
514+
Returns:
515+
dict: Rotated results.
516+
"""
517+
518+
rotate = True if np.random.rand() < self.prob else False
519+
degree = np.random.uniform(min(*self.degree), max(*self.degree))
520+
if rotate:
521+
# rotate image
522+
results['img'] = mmcv.imrotate(
523+
results['img'],
524+
angle=degree,
525+
border_value=self.pal_val,
526+
center=self.center,
527+
auto_bound=self.auto_bound)
528+
529+
# rotate segs
530+
for key in results.get('seg_fields', []):
531+
results[key] = mmcv.imrotate(
532+
results[key],
533+
angle=degree,
534+
border_value=self.seg_pad_val,
535+
center=self.center,
536+
auto_bound=self.auto_bound,
537+
interpolation='nearest')
538+
return results
539+
540+
def __repr__(self):
541+
repr_str = self.__class__.__name__
542+
repr_str += f'(prob={self.prob}, ' \
543+
f'degree={self.degree}, ' \
544+
f'pad_val={self.pal_val}, ' \
545+
f'seg_pad_val={self.seg_pad_val}, ' \
546+
f'center={self.center}, ' \
547+
f'auto_bound={self.auto_bound})'
548+
return repr_str
549+
550+
466551
@PIPELINES.register_module()
467552
class SegRescale(object):
468553
"""Rescale semantic segmentation maps.

tests/test_data/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_custom_dataset():
6969
dict(type='LoadAnnotations'),
7070
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
7171
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
72-
dict(type='RandomFlip', flip_ratio=0.5),
72+
dict(type='RandomFlip', prob=0.5),
7373
dict(type='PhotoMetricDistortion'),
7474
dict(type='Normalize', **img_norm_cfg),
7575
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),

tests/test_data/test_transform.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,17 @@ def test_resize():
9494

9595

9696
def test_flip():
97-
# test assertion for invalid flip_ratio
97+
# test assertion for invalid prob
9898
with pytest.raises(AssertionError):
99-
transform = dict(type='RandomFlip', flip_ratio=1.5)
99+
transform = dict(type='RandomFlip', prob=1.5)
100100
build_from_cfg(transform, PIPELINES)
101101

102102
# test assertion for invalid direction
103103
with pytest.raises(AssertionError):
104-
transform = dict(
105-
type='RandomFlip', flip_ratio=1, direction='horizonta')
104+
transform = dict(type='RandomFlip', prob=1, direction='horizonta')
106105
build_from_cfg(transform, PIPELINES)
107106

108-
transform = dict(type='RandomFlip', flip_ratio=1)
107+
transform = dict(type='RandomFlip', prob=1)
109108
flip_module = build_from_cfg(transform, PIPELINES)
110109

111110
results = dict()
@@ -197,6 +196,47 @@ def test_pad():
197196
assert img_shape[1] % 32 == 0
198197

199198

199+
def test_rotate():
200+
# test assertion degree should be tuple[float] or float
201+
with pytest.raises(AssertionError):
202+
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
203+
build_from_cfg(transform, PIPELINES)
204+
# test assertion degree should be tuple[float] or float
205+
with pytest.raises(AssertionError):
206+
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
207+
build_from_cfg(transform, PIPELINES)
208+
209+
transform = dict(type='RandomRotate', degree=10., prob=1.)
210+
transform = build_from_cfg(transform, PIPELINES)
211+
212+
assert str(transform) == f'RandomRotate(' \
213+
f'prob={1.}, ' \
214+
f'degree=({-10.}, {10.}), ' \
215+
f'pad_val={0}, ' \
216+
f'seg_pad_val={255}, ' \
217+
f'center={None}, ' \
218+
f'auto_bound={False})'
219+
220+
results = dict()
221+
img = mmcv.imread(
222+
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
223+
h, w, _ = img.shape
224+
seg = np.array(
225+
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
226+
results['img'] = img
227+
results['gt_semantic_seg'] = seg
228+
results['seg_fields'] = ['gt_semantic_seg']
229+
results['img_shape'] = img.shape
230+
results['ori_shape'] = img.shape
231+
# Set initial values for default meta_keys
232+
results['pad_shape'] = img.shape
233+
results['scale_factor'] = 1.0
234+
235+
results = transform(results)
236+
assert results['img'].shape[:2] == (h, w)
237+
assert results['gt_semantic_seg'].shape[:2] == (h, w)
238+
239+
200240
def test_normalize():
201241
img_norm_cfg = dict(
202242
mean=[123.675, 116.28, 103.53],

0 commit comments

Comments
 (0)