Skip to content

Commit 78a6ff6

Browse files
authored
[Feature] Add Cutout transform (open-mmlab#1022)
* Fix typo in usage example * [Feature] Add CutOut transform * CutOut repr covered by unittests * Cutout ignore index, test * ignore_index -> seg_fill_in, defualt is None * seg_fill_in is added to repr * test is modified for seg_fill_in is None * seg_fill_in (int), 0-255 * add seg_fill_in test * doc string for seg_fill_in * rename CutOut to RandomCutOut, add prob * Add unittest when cutout is False
1 parent 08272b6 commit 78a6ff6

File tree

3 files changed

+213
-3
lines changed

3 files changed

+213
-3
lines changed

mmseg/datasets/pipelines/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from .loading import LoadAnnotations, LoadImageFromFile
66
from .test_time_aug import MultiScaleFlipAug
77
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
8-
PhotoMetricDistortion, RandomCrop, RandomFlip,
9-
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
8+
PhotoMetricDistortion, RandomCrop, RandomCutOut,
9+
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
10+
SegRescale)
1011

1112
__all__ = [
1213
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
1314
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
1415
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
1516
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
16-
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
17+
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut'
1718
]

mmseg/datasets/pipelines/transforms.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,3 +948,95 @@ def __repr__(self):
948948
f'{self.saturation_upper}), '
949949
f'hue_delta={self.hue_delta})')
950950
return repr_str
951+
952+
953+
@PIPELINES.register_module()
954+
class RandomCutOut(object):
955+
"""CutOut operation.
956+
957+
Randomly drop some regions of image used in
958+
`Cutout <https://arxiv.org/abs/1708.04552>`_.
959+
Args:
960+
prob (float): cutout probability.
961+
n_holes (int | tuple[int, int]): Number of regions to be dropped.
962+
If it is given as a list, number of holes will be randomly
963+
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
964+
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
965+
shape of dropped regions. It can be `tuple[int, int]` to use a
966+
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
967+
shape from the list.
968+
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
969+
candidate ratio of dropped regions. It can be `tuple[float, float]`
970+
to use a fixed ratio or `list[tuple[float, float]]` to randomly
971+
choose ratio from the list. Please note that `cutout_shape`
972+
and `cutout_ratio` cannot be both given at the same time.
973+
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
974+
of pixel to fill in the dropped regions. Default: (0, 0, 0).
975+
seg_fill_in (int): The labels of pixel to fill in the dropped regions.
976+
If seg_fill_in is None, skip. Default: None.
977+
"""
978+
979+
def __init__(self,
980+
prob,
981+
n_holes,
982+
cutout_shape=None,
983+
cutout_ratio=None,
984+
fill_in=(0, 0, 0),
985+
seg_fill_in=None):
986+
987+
assert 0 <= prob and prob <= 1
988+
assert (cutout_shape is None) ^ (cutout_ratio is None), \
989+
'Either cutout_shape or cutout_ratio should be specified.'
990+
assert (isinstance(cutout_shape, (list, tuple))
991+
or isinstance(cutout_ratio, (list, tuple)))
992+
if isinstance(n_holes, tuple):
993+
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
994+
else:
995+
n_holes = (n_holes, n_holes)
996+
if seg_fill_in is not None:
997+
assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in
998+
and seg_fill_in <= 255)
999+
self.prob = prob
1000+
self.n_holes = n_holes
1001+
self.fill_in = fill_in
1002+
self.seg_fill_in = seg_fill_in
1003+
self.with_ratio = cutout_ratio is not None
1004+
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
1005+
if not isinstance(self.candidates, list):
1006+
self.candidates = [self.candidates]
1007+
1008+
def __call__(self, results):
1009+
"""Call function to drop some regions of image."""
1010+
cutout = True if np.random.rand() < self.prob else False
1011+
if cutout:
1012+
h, w, c = results['img'].shape
1013+
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
1014+
for _ in range(n_holes):
1015+
x1 = np.random.randint(0, w)
1016+
y1 = np.random.randint(0, h)
1017+
index = np.random.randint(0, len(self.candidates))
1018+
if not self.with_ratio:
1019+
cutout_w, cutout_h = self.candidates[index]
1020+
else:
1021+
cutout_w = int(self.candidates[index][0] * w)
1022+
cutout_h = int(self.candidates[index][1] * h)
1023+
1024+
x2 = np.clip(x1 + cutout_w, 0, w)
1025+
y2 = np.clip(y1 + cutout_h, 0, h)
1026+
results['img'][y1:y2, x1:x2, :] = self.fill_in
1027+
1028+
if self.seg_fill_in is not None:
1029+
for key in results.get('seg_fields', []):
1030+
results[key][y1:y2, x1:x2] = self.seg_fill_in
1031+
1032+
return results
1033+
1034+
def __repr__(self):
1035+
repr_str = self.__class__.__name__
1036+
repr_str += f'(prob={self.prob}, '
1037+
repr_str += f'n_holes={self.n_holes}, '
1038+
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
1039+
else f'cutout_shape={self.candidates}, ')
1040+
repr_str += f'fill_in={self.fill_in}, '
1041+
repr_str += f'seg_fill_in={self.seg_fill_in})'
1042+
return repr_str

tests/test_data/test_transform.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,120 @@ def test_seg_rescale():
497497
rescale_module = build_from_cfg(transform, PIPELINES)
498498
rescale_results = rescale_module(results.copy())
499499
assert rescale_results['gt_semantic_seg'].shape == (h, w)
500+
501+
502+
def test_cutout():
503+
# test prob
504+
with pytest.raises(AssertionError):
505+
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
506+
build_from_cfg(transform, PIPELINES)
507+
# test n_holes
508+
with pytest.raises(AssertionError):
509+
transform = dict(
510+
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
511+
build_from_cfg(transform, PIPELINES)
512+
with pytest.raises(AssertionError):
513+
transform = dict(
514+
type='RandomCutOut',
515+
prob=0.5,
516+
n_holes=(3, 4, 5),
517+
cutout_shape=(8, 8))
518+
build_from_cfg(transform, PIPELINES)
519+
# test cutout_shape and cutout_ratio
520+
with pytest.raises(AssertionError):
521+
transform = dict(
522+
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
523+
build_from_cfg(transform, PIPELINES)
524+
with pytest.raises(AssertionError):
525+
transform = dict(
526+
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
527+
build_from_cfg(transform, PIPELINES)
528+
# either of cutout_shape and cutout_ratio should be given
529+
with pytest.raises(AssertionError):
530+
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
531+
build_from_cfg(transform, PIPELINES)
532+
with pytest.raises(AssertionError):
533+
transform = dict(
534+
type='RandomCutOut',
535+
prob=0.5,
536+
n_holes=1,
537+
cutout_shape=(2, 2),
538+
cutout_ratio=(0.4, 0.4))
539+
build_from_cfg(transform, PIPELINES)
540+
# test seg_fill_in
541+
with pytest.raises(AssertionError):
542+
transform = dict(
543+
type='RandomCutOut',
544+
prob=0.5,
545+
n_holes=1,
546+
cutout_shape=(8, 8),
547+
seg_fill_in='a')
548+
build_from_cfg(transform, PIPELINES)
549+
with pytest.raises(AssertionError):
550+
transform = dict(
551+
type='RandomCutOut',
552+
prob=0.5,
553+
n_holes=1,
554+
cutout_shape=(8, 8),
555+
seg_fill_in=256)
556+
build_from_cfg(transform, PIPELINES)
557+
558+
results = dict()
559+
img = mmcv.imread(
560+
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
561+
562+
seg = np.array(
563+
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
564+
565+
results['img'] = img
566+
results['gt_semantic_seg'] = seg
567+
results['seg_fields'] = ['gt_semantic_seg']
568+
results['img_shape'] = img.shape
569+
results['ori_shape'] = img.shape
570+
results['pad_shape'] = img.shape
571+
results['img_fields'] = ['img']
572+
573+
transform = dict(
574+
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
575+
cutout_module = build_from_cfg(transform, PIPELINES)
576+
assert 'cutout_shape' in repr(cutout_module)
577+
cutout_result = cutout_module(copy.deepcopy(results))
578+
assert cutout_result['img'].sum() < img.sum()
579+
580+
transform = dict(
581+
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
582+
cutout_module = build_from_cfg(transform, PIPELINES)
583+
assert 'cutout_ratio' in repr(cutout_module)
584+
cutout_result = cutout_module(copy.deepcopy(results))
585+
assert cutout_result['img'].sum() < img.sum()
586+
587+
transform = dict(
588+
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
589+
cutout_module = build_from_cfg(transform, PIPELINES)
590+
cutout_result = cutout_module(copy.deepcopy(results))
591+
assert cutout_result['img'].sum() == img.sum()
592+
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
593+
594+
transform = dict(
595+
type='RandomCutOut',
596+
prob=1,
597+
n_holes=(2, 4),
598+
cutout_shape=[(10, 10), (15, 15)],
599+
fill_in=(255, 255, 255),
600+
seg_fill_in=None)
601+
cutout_module = build_from_cfg(transform, PIPELINES)
602+
cutout_result = cutout_module(copy.deepcopy(results))
603+
assert cutout_result['img'].sum() > img.sum()
604+
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
605+
606+
transform = dict(
607+
type='RandomCutOut',
608+
prob=1,
609+
n_holes=1,
610+
cutout_ratio=(0.8, 0.8),
611+
fill_in=(255, 255, 255),
612+
seg_fill_in=255)
613+
cutout_module = build_from_cfg(transform, PIPELINES)
614+
cutout_result = cutout_module(copy.deepcopy(results))
615+
assert cutout_result['img'].sum() > img.sum()
616+
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()

0 commit comments

Comments
 (0)