Skip to content

Commit 3ca690b

Browse files
[Feature] Add BioMedicalRandomGamma (open-mmlab#2406)
Add the random gamma correction transform for biomedical images, which follows the design of the nnUNet.
1 parent 26f3df7 commit 3ca690b

File tree

5 files changed

+195
-8
lines changed

5 files changed

+195
-8
lines changed

docs/zh_cn/advanced_guides/datasets.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 数据集
1+
# 数据集
22

33
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
44

mmseg/datasets/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from .pascal_context import PascalContextDataset, PascalContextDataset59
1919
from .potsdam import PotsdamDataset
2020
from .stare import STAREDataset
21+
# yapf: disable
2122
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
2223
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
23-
GenerateEdge, LoadAnnotations,
24+
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
2425
LoadBiomedicalAnnotation, LoadBiomedicalData,
2526
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
2627
PackSegInputs, PhotoMetricDistortion, RandomCrop,
@@ -30,7 +31,6 @@
3031
from .voc import PascalVOCDataset
3132

3233
# yapf: enable
33-
3434
__all__ = [
3535
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
3636
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
@@ -44,5 +44,6 @@
4444
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
4545
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
4646
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
47-
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
47+
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
48+
'BioMedicalRandomGamma'
4849
]

mmseg/datasets/transforms/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
# yapf: disable
77
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
88
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
9-
GenerateEdge, PhotoMetricDistortion, RandomCrop,
10-
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
9+
BioMedicalRandomGamma, GenerateEdge,
10+
PhotoMetricDistortion, RandomCrop, RandomCutOut,
11+
RandomMosaic, RandomRotate, Rerange,
1112
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
1213
SegRescale)
1314

@@ -18,5 +19,6 @@
1819
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
1920
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
2021
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
21-
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
22+
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
23+
'BioMedicalRandomGamma'
2224
]

mmseg/datasets/transforms/transforms.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,122 @@ def __repr__(self):
16861686
repr_str += 'different_sigma_per_axis='\
16871687
f'{self.different_sigma_per_axis})'
16881688
return repr_str
1689+
1690+
1691+
@TRANSFORMS.register_module()
1692+
class BioMedicalRandomGamma(BaseTransform):
1693+
"""Using random gamma correction to process the biomedical image.
1694+
1695+
Modified from
1696+
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
1697+
With licence: Apache 2.0
1698+
1699+
Required Keys:
1700+
1701+
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
1702+
N is the number of modalities, and data type is float32.
1703+
1704+
Modified Keys:
1705+
- img
1706+
1707+
Args:
1708+
prob (float): The probability to perform this transform. Default: 0.5.
1709+
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
1710+
invert_image (bool): Whether invert the image before applying gamma
1711+
augmentation. Default: False.
1712+
per_channel (bool): Whether perform the transform each channel
1713+
individually. Default: False
1714+
retain_stats (bool): Gamma transformation will alter the mean and std
1715+
of the data in the patch. If retain_stats=True, the data will be
1716+
transformed to match the mean and standard deviation before gamma
1717+
augmentation. Default: False.
1718+
"""
1719+
1720+
def __init__(self,
1721+
prob: float = 0.5,
1722+
gamma_range: Tuple[float] = (0.5, 2),
1723+
invert_image: bool = False,
1724+
per_channel: bool = False,
1725+
retain_stats: bool = False):
1726+
assert 0 <= prob and prob <= 1
1727+
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
1728+
assert isinstance(invert_image, bool)
1729+
assert isinstance(per_channel, bool)
1730+
assert isinstance(retain_stats, bool)
1731+
self.prob = prob
1732+
self.gamma_range = gamma_range
1733+
self.invert_image = invert_image
1734+
self.per_channel = per_channel
1735+
self.retain_stats = retain_stats
1736+
1737+
@cache_randomness
1738+
def _do_gamma(self):
1739+
"""Whether do adjust gamma for image."""
1740+
return np.random.rand() < self.prob
1741+
1742+
def _adjust_gamma(self, img: np.array):
1743+
"""Gamma adjustment for image.
1744+
1745+
Args:
1746+
img (np.array): Input image before gamma adjust.
1747+
1748+
Returns:
1749+
np.arrays: Image after gamma adjust.
1750+
"""
1751+
1752+
if self.invert_image:
1753+
img = -img
1754+
1755+
def _do_adjust(img):
1756+
if retain_stats_here:
1757+
img_mean = img.mean()
1758+
img_std = img.std()
1759+
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
1760+
gamma = np.random.uniform(self.gamma_range[0], 1)
1761+
else:
1762+
gamma = np.random.uniform(
1763+
max(self.gamma_range[0], 1), self.gamma_range[1])
1764+
img_min = img.min()
1765+
img_range = img.max() - img_min # range
1766+
img = np.power(((img - img_min) / float(img_range + 1e-7)),
1767+
gamma) * img_range + img_min
1768+
if retain_stats_here:
1769+
img = img - img.mean()
1770+
img = img / (img.std() + 1e-8) * img_std
1771+
img = img + img_mean
1772+
return img
1773+
1774+
if not self.per_channel:
1775+
retain_stats_here = self.retain_stats
1776+
img = _do_adjust(img)
1777+
else:
1778+
for c in range(img.shape[0]):
1779+
img[c] = _do_adjust(img[c])
1780+
if self.invert_image:
1781+
img = -img
1782+
return img
1783+
1784+
def transform(self, results: dict) -> dict:
1785+
"""Call function to perform random gamma correction
1786+
Args:
1787+
results (dict): Result dict from loading pipeline.
1788+
1789+
Returns:
1790+
dict: Result dict with random gamma correction performed.
1791+
"""
1792+
do_gamma = self._do_gamma()
1793+
1794+
if do_gamma:
1795+
results['img'] = self._adjust_gamma(results['img'])
1796+
else:
1797+
pass
1798+
return results
1799+
1800+
def __repr__(self):
1801+
repr_str = self.__class__.__name__
1802+
repr_str += f'(prob={self.prob}, '
1803+
repr_str += f'gamma_range={self.gamma_range},'
1804+
repr_str += f'invert_image={self.invert_image},'
1805+
repr_str += f'per_channel={self.per_channel},'
1806+
repr_str += f'retain_stats={self.retain_stats}'
1807+
return repr_str

tests/test_datasets/test_transform.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from PIL import Image
99

1010
from mmseg.datasets.transforms import * # noqa
11-
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
11+
from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
12+
PhotoMetricDistortion, RandomCrop)
1213
from mmseg.registry import TRANSFORMS
1314
from mmseg.utils import register_all_modules
1415

@@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur():
886887
# the max value in the smoothed image should be less than the original one
887888
assert original_img.max() >= results['img'].max()
888889
assert original_img.min() <= results['img'].min()
890+
891+
892+
def test_BioMedicalRandomGamma():
893+
894+
with pytest.raises(AssertionError):
895+
transform = dict(
896+
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
897+
TRANSFORMS.build(transform)
898+
899+
with pytest.raises(AssertionError):
900+
transform = dict(
901+
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
902+
TRANSFORMS.build(transform)
903+
904+
with pytest.raises(AssertionError):
905+
transform = dict(
906+
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
907+
TRANSFORMS.build(transform)
908+
909+
with pytest.raises(AssertionError):
910+
transform = dict(
911+
type='BioMedicalRandomGamma',
912+
prob=1.0,
913+
gamma_range=(0.7, 0.2, 0.3))
914+
TRANSFORMS.build(transform)
915+
916+
with pytest.raises(AssertionError):
917+
transform = dict(
918+
type='BioMedicalRandomGamma',
919+
prob=1.0,
920+
gamma_range=(0.7, 2),
921+
invert_image=1)
922+
TRANSFORMS.build(transform)
923+
924+
with pytest.raises(AssertionError):
925+
transform = dict(
926+
type='BioMedicalRandomGamma',
927+
prob=1.0,
928+
gamma_range=(0.7, 2),
929+
per_channel=1)
930+
TRANSFORMS.build(transform)
931+
932+
with pytest.raises(AssertionError):
933+
transform = dict(
934+
type='BioMedicalRandomGamma',
935+
prob=1.0,
936+
gamma_range=(0.7, 2),
937+
retain_stats=1)
938+
TRANSFORMS.build(transform)
939+
940+
test_img = 'tests/data/biomedical.nii.gz'
941+
results = dict(img_path=test_img)
942+
transform = LoadBiomedicalImageFromFile()
943+
results = transform(copy.deepcopy(results))
944+
origin_img = results['img']
945+
transform2 = dict(
946+
type='BioMedicalRandomGamma',
947+
prob=1.0,
948+
gamma_range=(0.7, 2),
949+
)
950+
transform2 = TRANSFORMS.build(transform2)
951+
results = transform2(results)
952+
transformed_img = results['img']
953+
assert origin_img.shape == transformed_img.shape

0 commit comments

Comments
 (0)