Skip to content

Commit 79e8578

Browse files
Yejin0111MeowZheng
andcommitted
[Feature] Add Biomedical 3D array random crop transform (#2378)
* [Feature] Add Biomedical 3D array random crop transform * fix lint * fix gen crop bbox * fix gen crop bbox * docstring * typo Co-authored-by: MeowZheng <[email protected]>
1 parent ad99ad1 commit 79e8578

File tree

4 files changed

+259
-16
lines changed

4 files changed

+259
-16
lines changed

mmseg/datasets/__init__.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
# yapf: disable
23
from .ade import ADE20KDataset
34
from .basesegdataset import BaseSegDataset
45
from .chase_db1 import ChaseDB1Dataset
@@ -17,7 +18,8 @@
1718
from .pascal_context import PascalContextDataset, PascalContextDataset59
1819
from .potsdam import PotsdamDataset
1920
from .stare import STAREDataset
20-
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
21+
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
22+
GenerateEdge, LoadAnnotations,
2123
LoadBiomedicalAnnotation, LoadBiomedicalData,
2224
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
2325
PackSegInputs, PhotoMetricDistortion, RandomCrop,
@@ -26,15 +28,18 @@
2628
SegRescale)
2729
from .voc import PascalVOCDataset
2830

31+
# yapf: enable
32+
2933
__all__ = [
30-
'BaseSegDataset', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
31-
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
32-
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
33-
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
34-
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
35-
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
36-
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
37-
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
34+
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
35+
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
36+
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
37+
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
38+
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
39+
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations',
40+
'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
41+
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
42+
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
3843
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
3944
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
4045
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'

mmseg/datasets/transforms/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
44
LoadBiomedicalData, LoadBiomedicalImageFromFile,
55
LoadImageFromNDArray)
6-
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
7-
PhotoMetricDistortion, RandomCrop, RandomCutOut,
8-
RandomMosaic, RandomRotate, Rerange,
6+
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
7+
GenerateEdge, PhotoMetricDistortion, RandomCrop,
8+
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
99
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
1010
SegRescale)
1111

1212
__all__ = [
13-
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
14-
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
15-
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
16-
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
13+
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
14+
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
15+
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
16+
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
1717
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
1818
'ResizeShortestEdge'
1919
]

mmseg/datasets/transforms/transforms.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import copy
3+
import warnings
34
from typing import Dict, Sequence, Tuple, Union
45

56
import cv2
@@ -1310,3 +1311,199 @@ def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
13101311
def transform(self, results: Dict) -> Dict:
13111312
self.resize.scale = self._get_output_shape(results['img'], self.scale)
13121313
return self.resize(results)
1314+
1315+
1316+
@TRANSFORMS.register_module()
1317+
class BioMedical3DRandomCrop(BaseTransform):
1318+
"""Crop the input patch for medical image & segmentation mask.
1319+
1320+
Required Keys:
1321+
1322+
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
1323+
N is the number of modalities, and data type is float32.
1324+
- gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask
1325+
with shape (Z, Y, X).
1326+
1327+
Modified Keys:
1328+
1329+
- img
1330+
- img_shape
1331+
- gt_seg_map (optional)
1332+
1333+
Args:
1334+
crop_shape (Union[int, Tuple[int, int, int]]): Expected size after
1335+
cropping with the format of (z, y, x). If set to an integer,
1336+
then cropping width and height are equal to this integer.
1337+
keep_foreground (bool): If keep_foreground is True, it will sample a
1338+
voxel of foreground classes randomly, and will take it as the
1339+
center of the crop bounding-box. Default to True.
1340+
"""
1341+
1342+
def __init__(self,
1343+
crop_shape: Union[int, Tuple[int, int, int]],
1344+
keep_foreground: bool = True):
1345+
super().__init__()
1346+
assert isinstance(crop_shape, int) or (
1347+
isinstance(crop_shape, tuple) and len(crop_shape) == 3
1348+
), 'The expected crop_shape is an integer, or a tuple containing '
1349+
'three integers'
1350+
1351+
if isinstance(crop_shape, int):
1352+
crop_shape = (crop_shape, crop_shape, crop_shape)
1353+
assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0
1354+
self.crop_shape = crop_shape
1355+
self.keep_foreground = keep_foreground
1356+
1357+
def random_sample_location(self, seg_map: np.ndarray) -> dict:
1358+
"""sample foreground voxel when keep_foreground is True.
1359+
1360+
Args:
1361+
seg_map (np.ndarray): gt seg map.
1362+
1363+
Returns:
1364+
dict: Coordinates of selected foreground voxel.
1365+
"""
1366+
num_samples = 10000
1367+
# at least 1% of the class voxels need to be selected,
1368+
# otherwise it may be too sparse
1369+
min_percent_coverage = 0.01
1370+
class_locs = {}
1371+
foreground_classes = []
1372+
all_classes = np.unique(seg_map)
1373+
for c in all_classes:
1374+
if c == 0:
1375+
# to avoid the segmentation mask full of background 0
1376+
# and the class_locs is just void dictionary {} when it return
1377+
# there add a void list for background 0.
1378+
class_locs[c] = []
1379+
else:
1380+
all_locs = np.argwhere(seg_map == c)
1381+
target_num_samples = min(num_samples, len(all_locs))
1382+
target_num_samples = max(
1383+
target_num_samples,
1384+
int(np.ceil(len(all_locs) * min_percent_coverage)))
1385+
1386+
selected = all_locs[np.random.choice(
1387+
len(all_locs), target_num_samples, replace=False)]
1388+
class_locs[c] = selected
1389+
foreground_classes.append(c)
1390+
1391+
selected_voxel = None
1392+
if len(foreground_classes) > 0:
1393+
selected_class = np.random.choice(foreground_classes)
1394+
voxels_of_that_class = class_locs[selected_class]
1395+
selected_voxel = voxels_of_that_class[np.random.choice(
1396+
len(voxels_of_that_class))]
1397+
1398+
return selected_voxel
1399+
1400+
def random_generate_crop_bbox(self, margin_z: int, margin_y: int,
1401+
margin_x: int) -> tuple:
1402+
"""Randomly get a crop bounding box.
1403+
1404+
Args:
1405+
seg_map (np.ndarray): Ground truth segmentation map.
1406+
1407+
Returns:
1408+
tuple: Coordinates of the cropped image.
1409+
"""
1410+
offset_z = np.random.randint(0, margin_z + 1)
1411+
offset_y = np.random.randint(0, margin_y + 1)
1412+
offset_x = np.random.randint(0, margin_x + 1)
1413+
crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0]
1414+
crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1]
1415+
crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2]
1416+
1417+
return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2
1418+
1419+
def generate_margin(self, results: dict) -> tuple:
1420+
"""Generate margin of crop bounding-box.
1421+
1422+
If keep_foreground is True, it will sample a voxel of foreground
1423+
classes randomly, and will take it as the center of the bounding-box,
1424+
and return the margin between of the bounding-box and image.
1425+
If keep_foreground is False, it will return the difference from crop
1426+
shape and image shape.
1427+
1428+
Args:
1429+
results (dict): Result dict from loading pipeline.
1430+
1431+
Returns:
1432+
tuple: The margin for 3 dimensions of crop bounding-box and image.
1433+
"""
1434+
1435+
seg_map = results['gt_seg_map']
1436+
if self.keep_foreground:
1437+
selected_voxel = self.random_sample_location(seg_map)
1438+
if selected_voxel is None:
1439+
# this only happens if some image does not contain
1440+
# foreground voxels at all
1441+
warnings.warn(f'case does not contain any foreground classes'
1442+
f': {results["img_path"]}')
1443+
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
1444+
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
1445+
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
1446+
else:
1447+
margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2)
1448+
margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2)
1449+
margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2)
1450+
margin_z = max(
1451+
0, min(seg_map.shape[0] - self.crop_shape[0], margin_z))
1452+
margin_y = max(
1453+
0, min(seg_map.shape[1] - self.crop_shape[1], margin_y))
1454+
margin_x = max(
1455+
0, min(seg_map.shape[2] - self.crop_shape[2], margin_x))
1456+
else:
1457+
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
1458+
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
1459+
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
1460+
1461+
return margin_z, margin_y, margin_x
1462+
1463+
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
1464+
"""Crop from ``img``
1465+
1466+
Args:
1467+
img (np.ndarray): Original input image.
1468+
crop_bbox (tuple): Coordinates of the cropped image.
1469+
1470+
Returns:
1471+
np.ndarray: The cropped image.
1472+
"""
1473+
crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
1474+
if len(img.shape) == 3:
1475+
# crop seg map
1476+
img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
1477+
else:
1478+
# crop image
1479+
assert len(img.shape) == 4
1480+
img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
1481+
return img
1482+
1483+
def transform(self, results: dict) -> dict:
1484+
"""Transform function to randomly crop images, semantic segmentation
1485+
maps.
1486+
1487+
Args:
1488+
results (dict): Result dict from loading pipeline.
1489+
1490+
Returns:
1491+
dict: Randomly cropped results, 'img_shape' key in result dict is
1492+
updated according to crop size.
1493+
"""
1494+
margin = self.generate_margin(results)
1495+
crop_bbox = self.random_generate_crop_bbox(*margin)
1496+
1497+
# crop the image
1498+
img = results['img']
1499+
results['img'] = self.crop(img, crop_bbox)
1500+
results['img_shape'] = results['img'].shape[1:]
1501+
1502+
# crop semantic seg
1503+
seg_map = results['gt_seg_map']
1504+
results['gt_seg_map'] = self.crop(seg_map, crop_bbox)
1505+
1506+
return results
1507+
1508+
def __repr__(self):
1509+
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'

tests/test_datasets/test_transform.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,3 +737,44 @@ def test_generate_edge():
737737
[1, 1, 0, 0, 0],
738738
[1, 0, 0, 0, 0],
739739
]))
740+
741+
742+
def test_biomedical3d_random_crop():
743+
# test assertion for invalid random crop
744+
with pytest.raises(AssertionError):
745+
transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0))
746+
transform = TRANSFORMS.build(transform)
747+
748+
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
749+
LoadBiomedicalImageFromFile)
750+
results = dict()
751+
results['img_path'] = osp.join(
752+
osp.dirname(__file__), '../data', 'biomedical.nii.gz')
753+
transform = LoadBiomedicalImageFromFile()
754+
results = transform(copy.deepcopy(results))
755+
756+
results['seg_map_path'] = osp.join(
757+
osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz')
758+
transform = LoadBiomedicalAnnotation()
759+
results = transform(copy.deepcopy(results))
760+
761+
d, h, w = results['img_shape']
762+
transform = dict(
763+
type='BioMedical3DRandomCrop',
764+
crop_shape=(d - 20, h - 20, w - 20),
765+
keep_foreground=True)
766+
transform = TRANSFORMS.build(transform)
767+
crop_results = transform(results)
768+
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
769+
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
770+
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
771+
772+
transform = dict(
773+
type='BioMedical3DRandomCrop',
774+
crop_shape=(d - 20, h - 20, w - 20),
775+
keep_foreground=False)
776+
transform = TRANSFORMS.build(transform)
777+
crop_results = transform(results)
778+
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
779+
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
780+
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)

0 commit comments

Comments
 (0)