Skip to content

Commit 4dc809a

Browse files
authored
[Feature] add AdjustGamma transform (open-mmlab#232)
* add AdjustGamma transform * restore * change cv2 to mmcv * simplify AdjustGamma * fix syntax error * modify * fix syntax error * change mmcv version to 1.3.0 * fix lut function name error * fix syntax error * fix range
1 parent 993be25 commit 4dc809a

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed

mmseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .version import __version__, version_info
44

55
MMCV_MIN = '1.1.4'
6-
MMCV_MAX = '1.2.0'
6+
MMCV_MAX = '1.3.0'
77

88

99
def digit_version(version_str):

mmseg/datasets/pipelines/transforms.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,42 @@ def __repr__(self):
650650
return repr_str
651651

652652

653+
@PIPELINES.register_module()
654+
class AdjustGamma(object):
655+
"""Using gamma correction to process the image.
656+
657+
Args:
658+
gamma (float or int): Gamma value used in gamma correction.
659+
Default: 1.0.
660+
"""
661+
662+
def __init__(self, gamma=1.0):
663+
assert isinstance(gamma, float) or isinstance(gamma, int)
664+
assert gamma > 0
665+
self.gamma = gamma
666+
inv_gamma = 1.0 / gamma
667+
self.table = np.array([(i / 255.0)**inv_gamma * 255
668+
for i in np.arange(256)]).astype('uint8')
669+
670+
def __call__(self, results):
671+
"""Call function to process the image with gamma correction.
672+
673+
Args:
674+
results (dict): Result dict from loading pipeline.
675+
676+
Returns:
677+
dict: Processed results.
678+
"""
679+
680+
results['img'] = mmcv.lut_transform(
681+
np.array(results['img'], dtype=np.uint8), self.table)
682+
683+
return results
684+
685+
def __repr__(self):
686+
return self.__class__.__name__ + f'(gamma={self.gamma})'
687+
688+
653689
@PIPELINES.register_module()
654690
class SegRescale(object):
655691
"""Rescale semantic segmentation maps.

tests/test_data/test_transform.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,42 @@ def test_rgb2gray():
330330
assert results['ori_shape'] == (h, w, c)
331331

332332

333+
def test_adjust_gamma():
334+
# test assertion if gamma <= 0
335+
with pytest.raises(AssertionError):
336+
transform = dict(type='AdjustGamma', gamma=0)
337+
build_from_cfg(transform, PIPELINES)
338+
339+
# test assertion if gamma is list
340+
with pytest.raises(AssertionError):
341+
transform = dict(type='AdjustGamma', gamma=[1.2])
342+
build_from_cfg(transform, PIPELINES)
343+
344+
# test with gamma = 1.2
345+
transform = dict(type='AdjustGamma', gamma=1.2)
346+
transform = build_from_cfg(transform, PIPELINES)
347+
results = dict()
348+
img = mmcv.imread(
349+
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
350+
original_img = copy.deepcopy(img)
351+
results['img'] = img
352+
results['img_shape'] = img.shape
353+
results['ori_shape'] = img.shape
354+
# Set initial values for default meta_keys
355+
results['pad_shape'] = img.shape
356+
results['scale_factor'] = 1.0
357+
358+
results = transform(results)
359+
360+
inv_gamma = 1.0 / 1.2
361+
table = np.array([((i / 255.0)**inv_gamma) * 255
362+
for i in np.arange(0, 256)]).astype('uint8')
363+
converted_img = mmcv.lut_transform(
364+
np.array(original_img, dtype=np.uint8), table)
365+
assert np.allclose(results['img'], converted_img)
366+
assert str(transform) == f'AdjustGamma(gamma={1.2})'
367+
368+
333369
def test_rerange():
334370
# test assertion if min_value or max_value is illegal
335371
with pytest.raises(AssertionError):

0 commit comments

Comments
 (0)