Skip to content

Commit c923f4d

Browse files
authored
[Project] Medical semantic seg dataset: Crass (open-mmlab#2690)
1 parent c1de52a commit c923f4d

8 files changed

+369
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Chest Radiograph Anatomical Structure Segmentation (CRASS)
2+
3+
## Description
4+
5+
This project supports **`Chest Radiograph Anatomical Structure Segmentation (CRASS) `**, which can be downloaded from [here](https://crass.grand-challenge.org/).
6+
7+
### Dataset Overview
8+
9+
A set of consecutively obtained posterior-anterior chest radiograph were selected from a database containing images acquired at two sites in sub Saharan Africa with a high tuberculosis incidence. All subjects were 15 years or older. Images from digital chest radiography units were used (Delft Imaging Systems, The Netherlands) of varying resolutions, with a typical resolution of 1800--2000 pixels, the pixel size was 250 lm isotropic. From the total set of images, 225 were considered to be normal by an expert radiologist, while 333 of the images contained abnormalities. Of the abnormal images, 220 contained abnormalities in the upper area of the lung where the clavicle is located. The data was divided into a training and a test set. The training set consisted of 299 images, the test set of 249 images.
10+
The current data is still incomplete and to be added later.
11+
12+
### Information Statistics
13+
14+
| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
15+
| ------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------- |
16+
| [crass](https://crass.grand-challenge.org/) | pulmonary | segmentation | x_ray | 2 | 299/-/234 | yes/-/no | 2021 | [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/) |
17+
18+
| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
19+
| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
20+
| background | 299 | 98.38 | - | - | - | - |
21+
| clavicles | 299 | 1.62 | - | - | - | - |
22+
23+
Note:
24+
25+
- `Pct` means percentage of pixels in this category in all pixels.
26+
27+
### Visualization
28+
29+
![crass](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/crass/crass_dataset.png?raw=true)
30+
31+
### Dataset Citation
32+
33+
```
34+
@article{HOGEWEG20121490,
35+
title={Clavicle segmentation in chest radiographs},
36+
journal={Medical Image Analysis},
37+
volume={16},
38+
number={8},
39+
pages={1490-1502},
40+
year={2012}
41+
}
42+
```
43+
44+
### Prerequisites
45+
46+
- Python v3.8
47+
- PyTorch v1.10.0
48+
- pillow(PIL) v9.3.0
49+
- scikit-learn(sklearn) v1.2.0
50+
- [MIM](https://github.com/open-mmlab/mim) v0.3.4
51+
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
52+
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
53+
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
54+
55+
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `crass/` root directory, run the following line to add the current directory to `PYTHONPATH`:
56+
57+
```shell
58+
export PYTHONPATH=`pwd`:$PYTHONPATH
59+
```
60+
61+
### Dataset Preparing
62+
63+
- download dataset from [here](https://crass.grand-challenge.org/) and decompress data to path `'data/'`.
64+
- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
65+
- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
66+
67+
```none
68+
mmsegmentation
69+
├── mmseg
70+
├── projects
71+
│ ├── medical
72+
│ │ ├── 2d_image
73+
│ │ │ ├── x_ray
74+
│ │ │ │ ├── crass
75+
│ │ │ │ │ ├── configs
76+
│ │ │ │ │ ├── datasets
77+
│ │ │ │ │ ├── tools
78+
│ │ │ │ │ ├── data
79+
│ │ │ │ │ │ ├── train.txt
80+
│ │ │ │ │ │ ├── val.txt
81+
│ │ │ │ │ │ ├── images
82+
│ │ │ │ │ │ │ ├── train
83+
│ │ │ │ | │ │ │ ├── xxx.png
84+
│ │ │ │ | │ │ │ ├── ...
85+
│ │ │ │ | │ │ │ └── xxx.png
86+
│ │ │ │ │ │ ├── masks
87+
│ │ │ │ │ │ │ ├── train
88+
│ │ │ │ | │ │ │ ├── xxx.png
89+
│ │ │ │ | │ │ │ ├── ...
90+
│ │ │ │ | │ │ │ └── xxx.png
91+
```
92+
93+
### Divided Dataset Information
94+
95+
***Note: The table information below is divided by ourselves.***
96+
97+
| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
98+
| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
99+
| background | 227 | 98.38 | 57 | 98.39 | - | - |
100+
| clavicles | 227 | 1.62 | 57 | 1.61 | - | - |
101+
102+
### Training commands
103+
104+
To train models on a single server with one GPU. (default)
105+
106+
```shell
107+
mim train mmseg ./configs/${CONFIG_FILE}
108+
```
109+
110+
### Testing commands
111+
112+
To test models on a single server with one GPU. (default)
113+
114+
```shell
115+
mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
116+
```
117+
118+
<!-- List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models)
119+
120+
You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. -->
121+
122+
## Checklist
123+
124+
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
125+
126+
- [x] Finish the code
127+
- [x] Basic docstrings & proper citation
128+
- [ ] Test-time correctness
129+
- [x] A full README
130+
131+
- [x] Milestone 2: Indicates a successful model implementation.
132+
133+
- [x] Training-time correctness
134+
135+
- [ ] Milestone 3: Good to be a part of our core package!
136+
137+
- [ ] Type hints and docstrings
138+
- [ ] Unit tests
139+
- [ ] Code polishing
140+
- [ ] Metafile.yml
141+
142+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
143+
144+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
dataset_type = 'CRASSDataset'
2+
data_root = 'data/'
3+
img_scale = (512, 512)
4+
train_pipeline = [
5+
dict(type='LoadImageFromFile'),
6+
dict(type='LoadAnnotations'),
7+
dict(type='Resize', scale=img_scale, keep_ratio=False),
8+
dict(type='RandomFlip', prob=0.5),
9+
dict(type='PhotoMetricDistortion'),
10+
dict(type='PackSegInputs')
11+
]
12+
test_pipeline = [
13+
dict(type='LoadImageFromFile'),
14+
dict(type='Resize', scale=img_scale, keep_ratio=False),
15+
dict(type='LoadAnnotations'),
16+
dict(type='PackSegInputs')
17+
]
18+
train_dataloader = dict(
19+
batch_size=16,
20+
num_workers=4,
21+
persistent_workers=True,
22+
sampler=dict(type='InfiniteSampler', shuffle=True),
23+
dataset=dict(
24+
type=dataset_type,
25+
data_root=data_root,
26+
ann_file='train.txt',
27+
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
28+
pipeline=train_pipeline))
29+
val_dataloader = dict(
30+
batch_size=1,
31+
num_workers=4,
32+
persistent_workers=True,
33+
sampler=dict(type='DefaultSampler', shuffle=False),
34+
dataset=dict(
35+
type=dataset_type,
36+
data_root=data_root,
37+
ann_file='tval.txt',
38+
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
39+
pipeline=test_pipeline))
40+
test_dataloader = val_dataloader
41+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
42+
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_base_ = [
2+
'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
3+
'mmseg::_base_/default_runtime.py',
4+
'mmseg::_base_/schedules/schedule_20k.py'
5+
]
6+
custom_imports = dict(imports='datasets.crass_dataset')
7+
img_scale = (512, 512)
8+
data_preprocessor = dict(size=img_scale)
9+
optimizer = dict(lr=0.0001)
10+
optim_wrapper = dict(optimizer=optimizer)
11+
model = dict(
12+
data_preprocessor=data_preprocessor,
13+
decode_head=dict(num_classes=2),
14+
auxiliary_head=None,
15+
test_cfg=dict(mode='whole', _delete_=True))
16+
vis_backends = None
17+
visualizer = dict(vis_backends=vis_backends)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_base_ = [
2+
'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
3+
'mmseg::_base_/default_runtime.py',
4+
'mmseg::_base_/schedules/schedule_20k.py'
5+
]
6+
custom_imports = dict(imports='datasets.crass_dataset')
7+
img_scale = (512, 512)
8+
data_preprocessor = dict(size=img_scale)
9+
optimizer = dict(lr=0.001)
10+
optim_wrapper = dict(optimizer=optimizer)
11+
model = dict(
12+
data_preprocessor=data_preprocessor,
13+
decode_head=dict(num_classes=2),
14+
auxiliary_head=None,
15+
test_cfg=dict(mode='whole', _delete_=True))
16+
vis_backends = None
17+
visualizer = dict(vis_backends=vis_backends)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_base_ = [
2+
'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
3+
'mmseg::_base_/default_runtime.py',
4+
'mmseg::_base_/schedules/schedule_20k.py'
5+
]
6+
custom_imports = dict(imports='datasets.crass_dataset')
7+
img_scale = (512, 512)
8+
data_preprocessor = dict(size=img_scale)
9+
optimizer = dict(lr=0.01)
10+
optim_wrapper = dict(optimizer=optimizer)
11+
model = dict(
12+
data_preprocessor=data_preprocessor,
13+
decode_head=dict(num_classes=2),
14+
auxiliary_head=None,
15+
test_cfg=dict(mode='whole', _delete_=True))
16+
vis_backends = None
17+
visualizer = dict(vis_backends=vis_backends)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
3+
'mmseg::_base_/default_runtime.py',
4+
'mmseg::_base_/schedules/schedule_20k.py'
5+
]
6+
custom_imports = dict(imports='datasets.crass_dataset')
7+
img_scale = (512, 512)
8+
data_preprocessor = dict(size=img_scale)
9+
optimizer = dict(lr=0.01)
10+
optim_wrapper = dict(optimizer=optimizer)
11+
model = dict(
12+
data_preprocessor=data_preprocessor,
13+
decode_head=dict(
14+
num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
15+
auxiliary_head=None,
16+
test_cfg=dict(mode='whole', _delete_=True))
17+
vis_backends = None
18+
visualizer = dict(vis_backends=vis_backends)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from mmseg.datasets import BaseSegDataset
2+
from mmseg.registry import DATASETS
3+
4+
5+
@DATASETS.register_module()
6+
class CRASSDataset(BaseSegDataset):
7+
"""CRASSDataset dataset.
8+
9+
In segmentation map annotation for CRASSDataset, 0 stands for background,
10+
which is included in 2 categories. ``reduce_zero_label`` is fixed to
11+
False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is
12+
fixed to '.png'.
13+
Args:
14+
img_suffix (str): Suffix of images. Default: '.png'
15+
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
16+
reduce_zero_label (bool): Whether to mark label zero as ignored.
17+
Default to False..
18+
"""
19+
METAINFO = dict(classes=('background', 'clavicles'))
20+
21+
def __init__(self,
22+
img_suffix='.png',
23+
seg_map_suffix='.png',
24+
reduce_zero_label=False,
25+
**kwargs) -> None:
26+
super().__init__(
27+
img_suffix=img_suffix,
28+
seg_map_suffix=seg_map_suffix,
29+
reduce_zero_label=reduce_zero_label,
30+
**kwargs)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import glob
2+
import os
3+
4+
import cv2
5+
import SimpleITK as sitk
6+
from PIL import Image
7+
8+
root_path = 'data/'
9+
img_suffix = '.tif'
10+
seg_map_suffix = '.png'
11+
save_img_suffix = '.png'
12+
save_seg_map_suffix = '.png'
13+
14+
src_img_train_dir = os.path.join(root_path, 'CRASS/data_train')
15+
src_mask_train_dir = os.path.join(root_path, 'CRASS/mask_mhd')
16+
src_img_test_dir = os.path.join(root_path, 'CRASS/data_test')
17+
18+
tgt_img_train_dir = os.path.join(root_path, 'images/train/')
19+
tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
20+
tgt_img_test_dir = os.path.join(root_path, 'images/test/')
21+
os.system('mkdir -p ' + tgt_img_train_dir)
22+
os.system('mkdir -p ' + tgt_mask_train_dir)
23+
os.system('mkdir -p ' + tgt_img_test_dir)
24+
25+
26+
def filter_suffix_recursive(src_dir, suffix):
27+
suffix = '.' + suffix if '.' not in suffix else suffix
28+
file_paths = glob(
29+
os.path.join(src_dir, '**', '*' + suffix), recursive=True)
30+
file_names = [_.split('/')[-1] for _ in file_paths]
31+
return sorted(file_paths), sorted(file_names)
32+
33+
34+
def read_single_array_from_med(path):
35+
return sitk.GetArrayFromImage(sitk.ReadImage(path)).squeeze()
36+
37+
38+
def convert_meds_into_pngs(src_dir,
39+
tgt_dir,
40+
suffix='.dcm',
41+
norm_min=0,
42+
norm_max=255,
43+
convert='RGB'):
44+
if not os.path.exists(tgt_dir):
45+
os.makedirs(tgt_dir)
46+
47+
src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
48+
num = len(src_paths)
49+
for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
50+
tgt_name = src_name.replace(suffix, '.png')
51+
tgt_path = os.path.join(tgt_dir, tgt_name)
52+
53+
img = read_single_array_from_med(src_path)
54+
if norm_min is not None and norm_max is not None:
55+
img = cv2.normalize(img, None, norm_min, norm_max, cv2.NORM_MINMAX,
56+
cv2.CV_8U)
57+
pil = Image.fromarray(img).convert(convert)
58+
pil.save(tgt_path)
59+
print(f'processed {i+1}/{num}.')
60+
61+
62+
convert_meds_into_pngs(
63+
src_img_train_dir,
64+
tgt_img_train_dir,
65+
suffix='.mhd',
66+
norm_min=0,
67+
norm_max=255,
68+
convert='RGB')
69+
70+
convert_meds_into_pngs(
71+
src_img_test_dir,
72+
tgt_img_test_dir,
73+
suffix='.mhd',
74+
norm_min=0,
75+
norm_max=255,
76+
convert='RGB')
77+
78+
convert_meds_into_pngs(
79+
src_mask_train_dir,
80+
tgt_mask_train_dir,
81+
suffix='.mhd',
82+
norm_min=0,
83+
norm_max=1,
84+
convert='L')

0 commit comments

Comments
 (0)