|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | 2 | import copy |
| 3 | +import warnings |
3 | 4 | from typing import Dict, Sequence, Tuple, Union |
4 | 5 |
|
5 | 6 | import cv2 |
@@ -1310,3 +1311,199 @@ def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: |
1310 | 1311 | def transform(self, results: Dict) -> Dict: |
1311 | 1312 | self.resize.scale = self._get_output_shape(results['img'], self.scale) |
1312 | 1313 | return self.resize(results) |
| 1314 | + |
| 1315 | + |
| 1316 | +@TRANSFORMS.register_module() |
| 1317 | +class BioMedical3DRandomCrop(BaseTransform): |
| 1318 | + """Crop the input patch for medical image & segmentation mask. |
| 1319 | +
|
| 1320 | + Required Keys: |
| 1321 | +
|
| 1322 | + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), |
| 1323 | + N is the number of modalities, and data type is float32. |
| 1324 | + - gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask |
| 1325 | + with shape (Z, Y, X). |
| 1326 | +
|
| 1327 | + Modified Keys: |
| 1328 | +
|
| 1329 | + - img |
| 1330 | + - img_shape |
| 1331 | + - gt_seg_map (optional) |
| 1332 | +
|
| 1333 | + Args: |
| 1334 | + crop_shape (Union[int, Tuple[int, int, int]]): Expected size after |
| 1335 | + cropping with the format of (z, y, x). If set to an integer, |
| 1336 | + then cropping width and height are equal to this integer. |
| 1337 | + keep_foreground (bool): If keep_foreground is True, it will sample a |
| 1338 | + voxel of foreground classes randomly, and will take it as the |
| 1339 | + center of the crop bounding-box. Default to True. |
| 1340 | + """ |
| 1341 | + |
| 1342 | + def __init__(self, |
| 1343 | + crop_shape: Union[int, Tuple[int, int, int]], |
| 1344 | + keep_foreground: bool = True): |
| 1345 | + super().__init__() |
| 1346 | + assert isinstance(crop_shape, int) or ( |
| 1347 | + isinstance(crop_shape, tuple) and len(crop_shape) == 3 |
| 1348 | + ), 'The expected crop_shape is an integer, or a tuple containing ' |
| 1349 | + 'three integers' |
| 1350 | + |
| 1351 | + if isinstance(crop_shape, int): |
| 1352 | + crop_shape = (crop_shape, crop_shape, crop_shape) |
| 1353 | + assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0 |
| 1354 | + self.crop_shape = crop_shape |
| 1355 | + self.keep_foreground = keep_foreground |
| 1356 | + |
| 1357 | + def random_sample_location(self, seg_map: np.ndarray) -> dict: |
| 1358 | + """sample foreground voxel when keep_foreground is True. |
| 1359 | +
|
| 1360 | + Args: |
| 1361 | + seg_map (np.ndarray): gt seg map. |
| 1362 | +
|
| 1363 | + Returns: |
| 1364 | + dict: Coordinates of selected foreground voxel. |
| 1365 | + """ |
| 1366 | + num_samples = 10000 |
| 1367 | + # at least 1% of the class voxels need to be selected, |
| 1368 | + # otherwise it may be too sparse |
| 1369 | + min_percent_coverage = 0.01 |
| 1370 | + class_locs = {} |
| 1371 | + foreground_classes = [] |
| 1372 | + all_classes = np.unique(seg_map) |
| 1373 | + for c in all_classes: |
| 1374 | + if c == 0: |
| 1375 | + # to avoid the segmentation mask full of background 0 |
| 1376 | + # and the class_locs is just void dictionary {} when it return |
| 1377 | + # there add a void list for background 0. |
| 1378 | + class_locs[c] = [] |
| 1379 | + else: |
| 1380 | + all_locs = np.argwhere(seg_map == c) |
| 1381 | + target_num_samples = min(num_samples, len(all_locs)) |
| 1382 | + target_num_samples = max( |
| 1383 | + target_num_samples, |
| 1384 | + int(np.ceil(len(all_locs) * min_percent_coverage))) |
| 1385 | + |
| 1386 | + selected = all_locs[np.random.choice( |
| 1387 | + len(all_locs), target_num_samples, replace=False)] |
| 1388 | + class_locs[c] = selected |
| 1389 | + foreground_classes.append(c) |
| 1390 | + |
| 1391 | + selected_voxel = None |
| 1392 | + if len(foreground_classes) > 0: |
| 1393 | + selected_class = np.random.choice(foreground_classes) |
| 1394 | + voxels_of_that_class = class_locs[selected_class] |
| 1395 | + selected_voxel = voxels_of_that_class[np.random.choice( |
| 1396 | + len(voxels_of_that_class))] |
| 1397 | + |
| 1398 | + return selected_voxel |
| 1399 | + |
| 1400 | + def random_generate_crop_bbox(self, margin_z: int, margin_y: int, |
| 1401 | + margin_x: int) -> tuple: |
| 1402 | + """Randomly get a crop bounding box. |
| 1403 | +
|
| 1404 | + Args: |
| 1405 | + seg_map (np.ndarray): Ground truth segmentation map. |
| 1406 | +
|
| 1407 | + Returns: |
| 1408 | + tuple: Coordinates of the cropped image. |
| 1409 | + """ |
| 1410 | + offset_z = np.random.randint(0, margin_z + 1) |
| 1411 | + offset_y = np.random.randint(0, margin_y + 1) |
| 1412 | + offset_x = np.random.randint(0, margin_x + 1) |
| 1413 | + crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0] |
| 1414 | + crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1] |
| 1415 | + crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2] |
| 1416 | + |
| 1417 | + return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 |
| 1418 | + |
| 1419 | + def generate_margin(self, results: dict) -> tuple: |
| 1420 | + """Generate margin of crop bounding-box. |
| 1421 | +
|
| 1422 | + If keep_foreground is True, it will sample a voxel of foreground |
| 1423 | + classes randomly, and will take it as the center of the bounding-box, |
| 1424 | + and return the margin between of the bounding-box and image. |
| 1425 | + If keep_foreground is False, it will return the difference from crop |
| 1426 | + shape and image shape. |
| 1427 | +
|
| 1428 | + Args: |
| 1429 | + results (dict): Result dict from loading pipeline. |
| 1430 | +
|
| 1431 | + Returns: |
| 1432 | + tuple: The margin for 3 dimensions of crop bounding-box and image. |
| 1433 | + """ |
| 1434 | + |
| 1435 | + seg_map = results['gt_seg_map'] |
| 1436 | + if self.keep_foreground: |
| 1437 | + selected_voxel = self.random_sample_location(seg_map) |
| 1438 | + if selected_voxel is None: |
| 1439 | + # this only happens if some image does not contain |
| 1440 | + # foreground voxels at all |
| 1441 | + warnings.warn(f'case does not contain any foreground classes' |
| 1442 | + f': {results["img_path"]}') |
| 1443 | + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) |
| 1444 | + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) |
| 1445 | + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) |
| 1446 | + else: |
| 1447 | + margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2) |
| 1448 | + margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2) |
| 1449 | + margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2) |
| 1450 | + margin_z = max( |
| 1451 | + 0, min(seg_map.shape[0] - self.crop_shape[0], margin_z)) |
| 1452 | + margin_y = max( |
| 1453 | + 0, min(seg_map.shape[1] - self.crop_shape[1], margin_y)) |
| 1454 | + margin_x = max( |
| 1455 | + 0, min(seg_map.shape[2] - self.crop_shape[2], margin_x)) |
| 1456 | + else: |
| 1457 | + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) |
| 1458 | + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) |
| 1459 | + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) |
| 1460 | + |
| 1461 | + return margin_z, margin_y, margin_x |
| 1462 | + |
| 1463 | + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: |
| 1464 | + """Crop from ``img`` |
| 1465 | +
|
| 1466 | + Args: |
| 1467 | + img (np.ndarray): Original input image. |
| 1468 | + crop_bbox (tuple): Coordinates of the cropped image. |
| 1469 | +
|
| 1470 | + Returns: |
| 1471 | + np.ndarray: The cropped image. |
| 1472 | + """ |
| 1473 | + crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox |
| 1474 | + if len(img.shape) == 3: |
| 1475 | + # crop seg map |
| 1476 | + img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] |
| 1477 | + else: |
| 1478 | + # crop image |
| 1479 | + assert len(img.shape) == 4 |
| 1480 | + img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] |
| 1481 | + return img |
| 1482 | + |
| 1483 | + def transform(self, results: dict) -> dict: |
| 1484 | + """Transform function to randomly crop images, semantic segmentation |
| 1485 | + maps. |
| 1486 | +
|
| 1487 | + Args: |
| 1488 | + results (dict): Result dict from loading pipeline. |
| 1489 | +
|
| 1490 | + Returns: |
| 1491 | + dict: Randomly cropped results, 'img_shape' key in result dict is |
| 1492 | + updated according to crop size. |
| 1493 | + """ |
| 1494 | + margin = self.generate_margin(results) |
| 1495 | + crop_bbox = self.random_generate_crop_bbox(*margin) |
| 1496 | + |
| 1497 | + # crop the image |
| 1498 | + img = results['img'] |
| 1499 | + results['img'] = self.crop(img, crop_bbox) |
| 1500 | + results['img_shape'] = results['img'].shape[1:] |
| 1501 | + |
| 1502 | + # crop semantic seg |
| 1503 | + seg_map = results['gt_seg_map'] |
| 1504 | + results['gt_seg_map'] = self.crop(seg_map, crop_bbox) |
| 1505 | + |
| 1506 | + return results |
| 1507 | + |
| 1508 | + def __repr__(self): |
| 1509 | + return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' |
0 commit comments