@@ -1686,3 +1686,122 @@ def __repr__(self):
16861686 repr_str += 'different_sigma_per_axis=' \
16871687 f'{ self .different_sigma_per_axis } )'
16881688 return repr_str
1689+
1690+
1691+ @TRANSFORMS .register_module ()
1692+ class BioMedicalRandomGamma (BaseTransform ):
1693+ """Using random gamma correction to process the biomedical image.
1694+
1695+ Modified from
1696+ https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
1697+ With licence: Apache 2.0
1698+
1699+ Required Keys:
1700+
1701+ - img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
1702+ N is the number of modalities, and data type is float32.
1703+
1704+ Modified Keys:
1705+ - img
1706+
1707+ Args:
1708+ prob (float): The probability to perform this transform. Default: 0.5.
1709+ gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
1710+ invert_image (bool): Whether invert the image before applying gamma
1711+ augmentation. Default: False.
1712+ per_channel (bool): Whether perform the transform each channel
1713+ individually. Default: False
1714+ retain_stats (bool): Gamma transformation will alter the mean and std
1715+ of the data in the patch. If retain_stats=True, the data will be
1716+ transformed to match the mean and standard deviation before gamma
1717+ augmentation. Default: False.
1718+ """
1719+
1720+ def __init__ (self ,
1721+ prob : float = 0.5 ,
1722+ gamma_range : Tuple [float ] = (0.5 , 2 ),
1723+ invert_image : bool = False ,
1724+ per_channel : bool = False ,
1725+ retain_stats : bool = False ):
1726+ assert 0 <= prob and prob <= 1
1727+ assert isinstance (gamma_range , tuple ) and len (gamma_range ) == 2
1728+ assert isinstance (invert_image , bool )
1729+ assert isinstance (per_channel , bool )
1730+ assert isinstance (retain_stats , bool )
1731+ self .prob = prob
1732+ self .gamma_range = gamma_range
1733+ self .invert_image = invert_image
1734+ self .per_channel = per_channel
1735+ self .retain_stats = retain_stats
1736+
1737+ @cache_randomness
1738+ def _do_gamma (self ):
1739+ """Whether do adjust gamma for image."""
1740+ return np .random .rand () < self .prob
1741+
1742+ def _adjust_gamma (self , img : np .array ):
1743+ """Gamma adjustment for image.
1744+
1745+ Args:
1746+ img (np.array): Input image before gamma adjust.
1747+
1748+ Returns:
1749+ np.arrays: Image after gamma adjust.
1750+ """
1751+
1752+ if self .invert_image :
1753+ img = - img
1754+
1755+ def _do_adjust (img ):
1756+ if retain_stats_here :
1757+ img_mean = img .mean ()
1758+ img_std = img .std ()
1759+ if np .random .random () < 0.5 and self .gamma_range [0 ] < 1 :
1760+ gamma = np .random .uniform (self .gamma_range [0 ], 1 )
1761+ else :
1762+ gamma = np .random .uniform (
1763+ max (self .gamma_range [0 ], 1 ), self .gamma_range [1 ])
1764+ img_min = img .min ()
1765+ img_range = img .max () - img_min # range
1766+ img = np .power (((img - img_min ) / float (img_range + 1e-7 )),
1767+ gamma ) * img_range + img_min
1768+ if retain_stats_here :
1769+ img = img - img .mean ()
1770+ img = img / (img .std () + 1e-8 ) * img_std
1771+ img = img + img_mean
1772+ return img
1773+
1774+ if not self .per_channel :
1775+ retain_stats_here = self .retain_stats
1776+ img = _do_adjust (img )
1777+ else :
1778+ for c in range (img .shape [0 ]):
1779+ img [c ] = _do_adjust (img [c ])
1780+ if self .invert_image :
1781+ img = - img
1782+ return img
1783+
1784+ def transform (self , results : dict ) -> dict :
1785+ """Call function to perform random gamma correction
1786+ Args:
1787+ results (dict): Result dict from loading pipeline.
1788+
1789+ Returns:
1790+ dict: Result dict with random gamma correction performed.
1791+ """
1792+ do_gamma = self ._do_gamma ()
1793+
1794+ if do_gamma :
1795+ results ['img' ] = self ._adjust_gamma (results ['img' ])
1796+ else :
1797+ pass
1798+ return results
1799+
1800+ def __repr__ (self ):
1801+ repr_str = self .__class__ .__name__
1802+ repr_str += f'(prob={ self .prob } , '
1803+ repr_str += f'gamma_range={ self .gamma_range } ,'
1804+ repr_str += f'invert_image={ self .invert_image } ,'
1805+ repr_str += f'per_channel={ self .per_channel } ,'
1806+ repr_str += f'retain_stats={ self .retain_stats } '
1807+ return repr_str
0 commit comments