@@ -497,3 +497,120 @@ def test_seg_rescale():
497497 rescale_module = build_from_cfg (transform , PIPELINES )
498498 rescale_results = rescale_module (results .copy ())
499499 assert rescale_results ['gt_semantic_seg' ].shape == (h , w )
500+
501+
502+ def test_cutout ():
503+ # test prob
504+ with pytest .raises (AssertionError ):
505+ transform = dict (type = 'RandomCutOut' , prob = 1.5 , n_holes = 1 )
506+ build_from_cfg (transform , PIPELINES )
507+ # test n_holes
508+ with pytest .raises (AssertionError ):
509+ transform = dict (
510+ type = 'RandomCutOut' , prob = 0.5 , n_holes = (5 , 3 ), cutout_shape = (8 , 8 ))
511+ build_from_cfg (transform , PIPELINES )
512+ with pytest .raises (AssertionError ):
513+ transform = dict (
514+ type = 'RandomCutOut' ,
515+ prob = 0.5 ,
516+ n_holes = (3 , 4 , 5 ),
517+ cutout_shape = (8 , 8 ))
518+ build_from_cfg (transform , PIPELINES )
519+ # test cutout_shape and cutout_ratio
520+ with pytest .raises (AssertionError ):
521+ transform = dict (
522+ type = 'RandomCutOut' , prob = 0.5 , n_holes = 1 , cutout_shape = 8 )
523+ build_from_cfg (transform , PIPELINES )
524+ with pytest .raises (AssertionError ):
525+ transform = dict (
526+ type = 'RandomCutOut' , prob = 0.5 , n_holes = 1 , cutout_ratio = 0.2 )
527+ build_from_cfg (transform , PIPELINES )
528+ # either of cutout_shape and cutout_ratio should be given
529+ with pytest .raises (AssertionError ):
530+ transform = dict (type = 'RandomCutOut' , prob = 0.5 , n_holes = 1 )
531+ build_from_cfg (transform , PIPELINES )
532+ with pytest .raises (AssertionError ):
533+ transform = dict (
534+ type = 'RandomCutOut' ,
535+ prob = 0.5 ,
536+ n_holes = 1 ,
537+ cutout_shape = (2 , 2 ),
538+ cutout_ratio = (0.4 , 0.4 ))
539+ build_from_cfg (transform , PIPELINES )
540+ # test seg_fill_in
541+ with pytest .raises (AssertionError ):
542+ transform = dict (
543+ type = 'RandomCutOut' ,
544+ prob = 0.5 ,
545+ n_holes = 1 ,
546+ cutout_shape = (8 , 8 ),
547+ seg_fill_in = 'a' )
548+ build_from_cfg (transform , PIPELINES )
549+ with pytest .raises (AssertionError ):
550+ transform = dict (
551+ type = 'RandomCutOut' ,
552+ prob = 0.5 ,
553+ n_holes = 1 ,
554+ cutout_shape = (8 , 8 ),
555+ seg_fill_in = 256 )
556+ build_from_cfg (transform , PIPELINES )
557+
558+ results = dict ()
559+ img = mmcv .imread (
560+ osp .join (osp .dirname (__file__ ), '../data/color.jpg' ), 'color' )
561+
562+ seg = np .array (
563+ Image .open (osp .join (osp .dirname (__file__ ), '../data/seg.png' )))
564+
565+ results ['img' ] = img
566+ results ['gt_semantic_seg' ] = seg
567+ results ['seg_fields' ] = ['gt_semantic_seg' ]
568+ results ['img_shape' ] = img .shape
569+ results ['ori_shape' ] = img .shape
570+ results ['pad_shape' ] = img .shape
571+ results ['img_fields' ] = ['img' ]
572+
573+ transform = dict (
574+ type = 'RandomCutOut' , prob = 1 , n_holes = 1 , cutout_shape = (10 , 10 ))
575+ cutout_module = build_from_cfg (transform , PIPELINES )
576+ assert 'cutout_shape' in repr (cutout_module )
577+ cutout_result = cutout_module (copy .deepcopy (results ))
578+ assert cutout_result ['img' ].sum () < img .sum ()
579+
580+ transform = dict (
581+ type = 'RandomCutOut' , prob = 1 , n_holes = 1 , cutout_ratio = (0.8 , 0.8 ))
582+ cutout_module = build_from_cfg (transform , PIPELINES )
583+ assert 'cutout_ratio' in repr (cutout_module )
584+ cutout_result = cutout_module (copy .deepcopy (results ))
585+ assert cutout_result ['img' ].sum () < img .sum ()
586+
587+ transform = dict (
588+ type = 'RandomCutOut' , prob = 0 , n_holes = 1 , cutout_ratio = (0.8 , 0.8 ))
589+ cutout_module = build_from_cfg (transform , PIPELINES )
590+ cutout_result = cutout_module (copy .deepcopy (results ))
591+ assert cutout_result ['img' ].sum () == img .sum ()
592+ assert cutout_result ['gt_semantic_seg' ].sum () == seg .sum ()
593+
594+ transform = dict (
595+ type = 'RandomCutOut' ,
596+ prob = 1 ,
597+ n_holes = (2 , 4 ),
598+ cutout_shape = [(10 , 10 ), (15 , 15 )],
599+ fill_in = (255 , 255 , 255 ),
600+ seg_fill_in = None )
601+ cutout_module = build_from_cfg (transform , PIPELINES )
602+ cutout_result = cutout_module (copy .deepcopy (results ))
603+ assert cutout_result ['img' ].sum () > img .sum ()
604+ assert cutout_result ['gt_semantic_seg' ].sum () == seg .sum ()
605+
606+ transform = dict (
607+ type = 'RandomCutOut' ,
608+ prob = 1 ,
609+ n_holes = 1 ,
610+ cutout_ratio = (0.8 , 0.8 ),
611+ fill_in = (255 , 255 , 255 ),
612+ seg_fill_in = 255 )
613+ cutout_module = build_from_cfg (transform , PIPELINES )
614+ cutout_result = cutout_module (copy .deepcopy (results ))
615+ assert cutout_result ['img' ].sum () > img .sum ()
616+ assert cutout_result ['gt_semantic_seg' ].sum () > seg .sum ()
0 commit comments