11import mmcv
22import numpy as np
3+ from mmcv .utils import deprecated_api_warning
34from numpy import random
45
56from ..builder import PIPELINES
@@ -232,16 +233,17 @@ class RandomFlip(object):
232233 method.
233234
234235 Args:
235- flip_ratio (float, optional): The flipping probability. Default: None.
236+ prob (float, optional): The flipping probability. Default: None.
236237 direction(str, optional): The flipping direction. Options are
237238 'horizontal' and 'vertical'. Default: 'horizontal'.
238239 """
239240
240- def __init__ (self , flip_ratio = None , direction = 'horizontal' ):
241- self .flip_ratio = flip_ratio
241+ @deprecated_api_warning ({'flip_ratio' : 'prob' }, cls_name = 'RandomFlip' )
242+ def __init__ (self , prob = None , direction = 'horizontal' ):
243+ self .prob = prob
242244 self .direction = direction
243- if flip_ratio is not None :
244- assert flip_ratio >= 0 and flip_ratio <= 1
245+ if prob is not None :
246+ assert prob >= 0 and prob <= 1
245247 assert direction in ['horizontal' , 'vertical' ]
246248
247249 def __call__ (self , results ):
@@ -257,7 +259,7 @@ def __call__(self, results):
257259 """
258260
259261 if 'flip' not in results :
260- flip = True if np .random .rand () < self .flip_ratio else False
262+ flip = True if np .random .rand () < self .prob else False
261263 results ['flip' ] = flip
262264 if 'flip_direction' not in results :
263265 results ['flip_direction' ] = self .direction
@@ -274,7 +276,7 @@ def __call__(self, results):
274276 return results
275277
276278 def __repr__ (self ):
277- return self .__class__ .__name__ + f'(flip_ratio ={ self .flip_ratio } )'
279+ return self .__class__ .__name__ + f'(prob ={ self .prob } )'
278280
279281
280282@PIPELINES .register_module ()
@@ -463,6 +465,89 @@ def __repr__(self):
463465 return self .__class__ .__name__ + f'(crop_size={ self .crop_size } )'
464466
465467
468+ @PIPELINES .register_module ()
469+ class RandomRotate (object ):
470+ """Rotate the image & seg.
471+
472+ Args:
473+ prob (float): The rotation probability.
474+ degree (float, tuple[float]): Range of degrees to select from. If
475+ degree is a number instead of tuple like (min, max),
476+ the range of degree will be (``-degree``, ``+degree``)
477+ pad_val (float, optional): Padding value of image. Default: 0.
478+ seg_pad_val (float, optional): Padding value of segmentation map.
479+ Default: 255.
480+ center (tuple[float], optional): Center point (w, h) of the rotation in
481+ the source image. If not specified, the center of the image will be
482+ used. Default: None.
483+ auto_bound (bool): Whether to adjust the image size to cover the whole
484+ rotated image. Default: False
485+ """
486+
487+ def __init__ (self ,
488+ prob ,
489+ degree ,
490+ pad_val = 0 ,
491+ seg_pad_val = 255 ,
492+ center = None ,
493+ auto_bound = False ):
494+ self .prob = prob
495+ assert prob >= 0 and prob <= 1
496+ if isinstance (degree , (float , int )):
497+ assert degree > 0 , f'degree { degree } should be positive'
498+ self .degree = (- degree , degree )
499+ else :
500+ self .degree = degree
501+ assert len (self .degree ) == 2 , f'degree { self .degree } should be a ' \
502+ f'tuple of (min, max)'
503+ self .pal_val = pad_val
504+ self .seg_pad_val = seg_pad_val
505+ self .center = center
506+ self .auto_bound = auto_bound
507+
508+ def __call__ (self , results ):
509+ """Call function to rotate image, semantic segmentation maps.
510+
511+ Args:
512+ results (dict): Result dict from loading pipeline.
513+
514+ Returns:
515+ dict: Rotated results.
516+ """
517+
518+ rotate = True if np .random .rand () < self .prob else False
519+ degree = np .random .uniform (min (* self .degree ), max (* self .degree ))
520+ if rotate :
521+ # rotate image
522+ results ['img' ] = mmcv .imrotate (
523+ results ['img' ],
524+ angle = degree ,
525+ border_value = self .pal_val ,
526+ center = self .center ,
527+ auto_bound = self .auto_bound )
528+
529+ # rotate segs
530+ for key in results .get ('seg_fields' , []):
531+ results [key ] = mmcv .imrotate (
532+ results [key ],
533+ angle = degree ,
534+ border_value = self .seg_pad_val ,
535+ center = self .center ,
536+ auto_bound = self .auto_bound ,
537+ interpolation = 'nearest' )
538+ return results
539+
540+ def __repr__ (self ):
541+ repr_str = self .__class__ .__name__
542+ repr_str += f'(prob={ self .prob } , ' \
543+ f'degree={ self .degree } , ' \
544+ f'pad_val={ self .pal_val } , ' \
545+ f'seg_pad_val={ self .seg_pad_val } , ' \
546+ f'center={ self .center } , ' \
547+ f'auto_bound={ self .auto_bound } )'
548+ return repr_str
549+
550+
466551@PIPELINES .register_module ()
467552class SegRescale (object ):
468553 """Rescale semantic segmentation maps.
0 commit comments