Skip to content

Commit 0066ce8

Browse files
authored
[Feature]add CLAHE transform (open-mmlab#229)
* add CLAHE transform * fix syntax error * fix syntax error * restore * add a test * modify cv2 to mmcv * add docstring * modify * restore * fix mmcv.clahe error * change mmcv version to 1.3.0 * fix bugs * add all data transformers to __init__ * fix __init__ * fix test_transform
1 parent 4dc809a commit 0066ce8

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

mmseg/datasets/pipelines/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
Transpose, to_tensor)
44
from .loading import LoadAnnotations, LoadImageFromFile
55
from .test_time_aug import MultiScaleFlipAug
6-
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
7-
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
8-
SegRescale)
6+
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
7+
PhotoMetricDistortion, RandomCrop, RandomFlip,
8+
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
99

1010
__all__ = [
1111
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
1212
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
1313
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
1414
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
15-
'Rerange', 'RGB2Gray'
15+
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
1616
]

mmseg/datasets/pipelines/transforms.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import mmcv
22
import numpy as np
3-
from mmcv.utils import deprecated_api_warning
3+
from mmcv.utils import deprecated_api_warning, is_tuple_of
44
from numpy import random
55

66
from ..builder import PIPELINES
@@ -415,7 +415,6 @@ def __call__(self, results):
415415
416416
Args:
417417
results (dict): Result dict from loading pipeline.
418-
419418
Returns:
420419
dict: Reranged results.
421420
"""
@@ -439,6 +438,51 @@ def __repr__(self):
439438
return repr_str
440439

441440

441+
@PIPELINES.register_module()
442+
class CLAHE(object):
443+
"""Use CLAHE method to process the image.
444+
445+
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
446+
Graphics Gems, 1994:474-485.` for more information.
447+
448+
Args:
449+
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
450+
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
451+
Input image will be divided into equally sized rectangular tiles.
452+
It defines the number of tiles in row and column. Default: (8, 8).
453+
"""
454+
455+
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
456+
assert isinstance(clip_limit, (float, int))
457+
self.clip_limit = clip_limit
458+
assert is_tuple_of(tile_grid_size, int)
459+
assert len(tile_grid_size) == 2
460+
self.tile_grid_size = tile_grid_size
461+
462+
def __call__(self, results):
463+
"""Call function to Use CLAHE method process images.
464+
465+
Args:
466+
results (dict): Result dict from loading pipeline.
467+
468+
Returns:
469+
dict: Processed results.
470+
"""
471+
472+
for i in range(results['img'].shape[2]):
473+
results['img'][:, :, i] = mmcv.clahe(
474+
np.array(results['img'][:, :, i], dtype=np.uint8),
475+
self.clip_limit, self.tile_grid_size)
476+
477+
return results
478+
479+
def __repr__(self):
480+
repr_str = self.__class__.__name__
481+
repr_str += f'(clip_limit={self.clip_limit}, '\
482+
f'tile_grid_size={self.tile_grid_size})'
483+
return repr_str
484+
485+
442486
@PIPELINES.register_module()
443487
class RandomCrop(object):
444488
"""Random crop the image & seg.

tests/test_data/test_transform.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,46 @@ def test_rerange():
409409
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
410410

411411

412+
def test_CLAHE():
413+
# test assertion if clip_limit is None
414+
with pytest.raises(AssertionError):
415+
transform = dict(type='CLAHE', clip_limit=None)
416+
build_from_cfg(transform, PIPELINES)
417+
418+
# test assertion if tile_grid_size is illegal
419+
with pytest.raises(AssertionError):
420+
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
421+
build_from_cfg(transform, PIPELINES)
422+
423+
# test assertion if tile_grid_size is illegal
424+
with pytest.raises(AssertionError):
425+
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
426+
build_from_cfg(transform, PIPELINES)
427+
428+
transform = dict(type='CLAHE', clip_limit=2)
429+
transform = build_from_cfg(transform, PIPELINES)
430+
results = dict()
431+
img = mmcv.imread(
432+
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
433+
original_img = copy.deepcopy(img)
434+
results['img'] = img
435+
results['img_shape'] = img.shape
436+
results['ori_shape'] = img.shape
437+
# Set initial values for default meta_keys
438+
results['pad_shape'] = img.shape
439+
results['scale_factor'] = 1.0
440+
441+
results = transform(results)
442+
443+
converted_img = np.empty(original_img.shape)
444+
for i in range(original_img.shape[2]):
445+
converted_img[:, :, i] = mmcv.clahe(
446+
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
447+
448+
assert np.allclose(results['img'], converted_img)
449+
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
450+
451+
412452
def test_seg_rescale():
413453
results = dict()
414454
seg = np.array(

0 commit comments

Comments
 (0)