Skip to content

Commit 65c8d77

Browse files
authored
[Project] Medical semantic seg dataset: chest_image_pneum (open-mmlab#2727)
1 parent 041f1f0 commit 65c8d77

7 files changed

+343
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Chest Image Dataset for Pneumothorax Segmentation
2+
3+
## Description
4+
5+
This project supports **`Chest Image Dataset for Pneumothorax Segmentation`**, which can be downloaded from [here](https://tianchi.aliyun.com/dataset/83075).
6+
7+
### Dataset Overview
8+
9+
Pneumothorax can be caused by a blunt chest injury, damage from underlying lung disease, or most horrifying—it may occur for no obvious reason at all. On some occasions, a collapsed lung can be a life-threatening event.
10+
Pneumothorax is usually diagnosed by a radiologist on a chest x-ray, and can sometimes be very difficult to confirm. An accurate AI algorithm to detect pneumothorax would be useful in a lot of clinical scenarios. AI could be used to triage chest radiographs for priority interpretation, or to provide a more confident diagnosis for non-radiologists.
11+
12+
The dataset is provided by the Society for Imaging Informatics in Medicine(SIIM), American College of Radiology (ACR),Society of Thoracic Radiology (STR) and MD.ai. You can develop a model to classify (and if present, segment) pneumothorax from a set of chest radiographic images. If successful, you could aid in the early recognition of pneumothoraces and save lives.
13+
14+
### Original Statistic Information
15+
16+
| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
17+
| --------------------------------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------------ |
18+
| [pneumothorax segmentation](https://tianchi.aliyun.com/dataset/83075) | thorax | segmentation | x_ray | 2 | 12089/-/3205 | yes/-/no | - | [CC-BY-SA-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
19+
20+
| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
21+
| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
22+
| normal | 12089 | 99.75 | - | - | - | - |
23+
| pneumothorax area | 2669 | 0.25 | - | - | - | - |
24+
25+
Note:
26+
27+
- `Pct` means percentage of pixels in this category in all pixels.
28+
29+
### Visualization
30+
31+
![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/chest_image_pneum/chest_image_pneum_dataset.png)
32+
33+
### Prerequisites
34+
35+
- Python v3.8
36+
- PyTorch v1.10.0
37+
- [MIM](https://github.com/open-mmlab/mim) v0.3.4
38+
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
39+
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
40+
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
41+
42+
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `chest_image_pneum/` root directory, run the following line to add the current directory to `PYTHONPATH`:
43+
44+
```shell
45+
export PYTHONPATH=`pwd`:$PYTHONPATH
46+
```
47+
48+
### Dataset preparing
49+
50+
- download dataset from [here](https://tianchi.aliyun.com/dataset/83075) and decompress data to path `'data/'`.
51+
- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
52+
- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
53+
54+
```none
55+
mmsegmentation
56+
├── mmseg
57+
├── projects
58+
│ ├── medical
59+
│ │ ├── 2d_image
60+
│ │ │ ├── x_ray
61+
│ │ │ │ ├── chest_image_pneum
62+
│ │ │ │ │ ├── configs
63+
│ │ │ │ │ ├── datasets
64+
│ │ │ │ │ ├── tools
65+
│ │ │ │ │ ├── data
66+
│ │ │ │ │ │ ├── train.txt
67+
│ │ │ │ │ │ ├── test.txt
68+
│ │ │ │ │ │ ├── images
69+
│ │ │ │ │ │ │ ├── train
70+
│ │ │ │ | │ │ │ ├── xxx.png
71+
│ │ │ │ | │ │ │ ├── ...
72+
│ │ │ │ | │ │ │ └── xxx.png
73+
│ │ │ │ │ │ ├── masks
74+
│ │ │ │ │ │ │ ├── train
75+
│ │ │ │ | │ │ │ ├── xxx.png
76+
│ │ │ │ | │ │ │ ├── ...
77+
│ │ │ │ | │ │ │ └── xxx.png
78+
```
79+
80+
### Divided Dataset Information
81+
82+
***Note: The table information below is divided by ourselves.***
83+
84+
| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
85+
| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
86+
| normal | 9637 | 99.75 | 2410 | 99.74 | - | - |
87+
| pneumothorax area | 2137 | 0.25 | 532 | 0.26 | - | - |
88+
89+
### Training commands
90+
91+
Train models on a single server with one GPU.
92+
93+
```shell
94+
mim train mmseg ./configs/${CONFIG_FILE}
95+
```
96+
97+
### Testing commands
98+
99+
Test models on a single server with one GPU.
100+
101+
```shell
102+
mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
103+
```
104+
105+
<!-- List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models)
106+
107+
You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. -->
108+
109+
## Results
110+
111+
### Bactteria detection with darkfield microscopy
112+
113+
| Method | Backbone | Crop Size | lr | mIoU | mDice | config | download |
114+
| :-------------: | :------: | :-------: | :----: | :--: | :---: | :------------------------------------------------------------------------------------: | :----------------------: |
115+
| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
116+
| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
117+
| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
118+
119+
## Checklist
120+
121+
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
122+
123+
- [x] Finish the code
124+
125+
- [x] Basic docstrings & proper citation
126+
127+
- [x] Test-time correctness
128+
129+
- [x] A full README
130+
131+
- [x] Milestone 2: Indicates a successful model implementation.
132+
133+
- [x] Training-time correctness
134+
135+
- [ ] Milestone 3: Good to be a part of our core package!
136+
137+
- [ ] Type hints and docstrings
138+
139+
- [ ] Unit tests
140+
141+
- [ ] Code polishing
142+
143+
- [ ] Metafile.yml
144+
145+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
146+
147+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
dataset_type = 'ChestImagePneumDataset'
2+
data_root = 'data/'
3+
img_scale = (512, 512)
4+
train_pipeline = [
5+
dict(type='LoadImageFromFile'),
6+
dict(type='LoadAnnotations'),
7+
dict(type='Resize', scale=img_scale, keep_ratio=False),
8+
dict(type='RandomFlip', prob=0.5),
9+
dict(type='PhotoMetricDistortion'),
10+
dict(type='PackSegInputs')
11+
]
12+
test_pipeline = [
13+
dict(type='LoadImageFromFile'),
14+
dict(type='Resize', scale=img_scale, keep_ratio=False),
15+
dict(type='LoadAnnotations'),
16+
dict(type='PackSegInputs')
17+
]
18+
train_dataloader = dict(
19+
batch_size=16,
20+
num_workers=4,
21+
persistent_workers=True,
22+
sampler=dict(type='InfiniteSampler', shuffle=True),
23+
dataset=dict(
24+
type=dataset_type,
25+
data_root=data_root,
26+
ann_file='train.txt',
27+
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
28+
pipeline=train_pipeline))
29+
val_dataloader = dict(
30+
batch_size=1,
31+
num_workers=4,
32+
persistent_workers=True,
33+
sampler=dict(type='DefaultSampler', shuffle=False),
34+
dataset=dict(
35+
type=dataset_type,
36+
data_root=data_root,
37+
ann_file='val.txt',
38+
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
39+
pipeline=test_pipeline))
40+
test_dataloader = val_dataloader
41+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
42+
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'./chest-image-pneum_512x512.py',
3+
'mmseg::_base_/models/fcn_unet_s5-d16.py',
4+
'mmseg::_base_/default_runtime.py',
5+
'mmseg::_base_/schedules/schedule_20k.py'
6+
]
7+
custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
8+
img_scale = (512, 512)
9+
data_preprocessor = dict(size=img_scale)
10+
optimizer = dict(lr=0.0001)
11+
optim_wrapper = dict(optimizer=optimizer)
12+
model = dict(
13+
data_preprocessor=data_preprocessor,
14+
decode_head=dict(num_classes=2),
15+
auxiliary_head=None,
16+
test_cfg=dict(mode='whole', _delete_=True))
17+
vis_backends = None
18+
visualizer = dict(vis_backends=vis_backends)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'./chest-image-pneum_512x512.py',
3+
'mmseg::_base_/models/fcn_unet_s5-d16.py',
4+
'mmseg::_base_/default_runtime.py',
5+
'mmseg::_base_/schedules/schedule_20k.py'
6+
]
7+
custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
8+
img_scale = (512, 512)
9+
data_preprocessor = dict(size=img_scale)
10+
optimizer = dict(lr=0.001)
11+
optim_wrapper = dict(optimizer=optimizer)
12+
model = dict(
13+
data_preprocessor=data_preprocessor,
14+
decode_head=dict(num_classes=2),
15+
auxiliary_head=None,
16+
test_cfg=dict(mode='whole', _delete_=True))
17+
vis_backends = None
18+
visualizer = dict(vis_backends=vis_backends)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'./chest-image-pneum_512x512.py',
3+
'mmseg::_base_/models/fcn_unet_s5-d16.py',
4+
'mmseg::_base_/default_runtime.py',
5+
'mmseg::_base_/schedules/schedule_20k.py'
6+
]
7+
custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
8+
img_scale = (512, 512)
9+
data_preprocessor = dict(size=img_scale)
10+
optimizer = dict(lr=0.01)
11+
optim_wrapper = dict(optimizer=optimizer)
12+
model = dict(
13+
data_preprocessor=data_preprocessor,
14+
decode_head=dict(num_classes=2),
15+
auxiliary_head=None,
16+
test_cfg=dict(mode='whole', _delete_=True))
17+
vis_backends = None
18+
visualizer = dict(vis_backends=vis_backends)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from mmseg.datasets import BaseSegDataset
2+
from mmseg.registry import DATASETS
3+
4+
5+
@DATASETS.register_module()
6+
class ChestImagePneumDataset(BaseSegDataset):
7+
"""ChestImagePneumDataset dataset.
8+
9+
In segmentation map annotation for ChestImagePneumDataset,
10+
``reduce_zero_label`` is fixed to False. The ``img_suffix``
11+
is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
12+
13+
Args:
14+
img_suffix (str): Suffix of images. Default: '.png'
15+
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
16+
"""
17+
METAINFO = dict(classes=('normal', 'pneumothorax area'))
18+
19+
def __init__(self,
20+
img_suffix='.png',
21+
seg_map_suffix='.png',
22+
**kwargs) -> None:
23+
super().__init__(
24+
img_suffix=img_suffix,
25+
seg_map_suffix=seg_map_suffix,
26+
reduce_zero_label=False,
27+
**kwargs)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pydicom
6+
from PIL import Image
7+
8+
root_path = 'data/'
9+
img_suffix = '.dcm'
10+
seg_map_suffix = '.png'
11+
save_img_suffix = '.png'
12+
save_seg_map_suffix = '.png'
13+
14+
x_train = []
15+
for fpath, dirname, fnames in os.walk('data/chestimage_train_datasets'):
16+
for fname in fnames:
17+
if fname.endswith('.dcm'):
18+
x_train.append(os.path.join(fpath, fname))
19+
x_test = []
20+
for fpath, dirname, fnames in os.walk('data/chestimage_test_datasets/'):
21+
for fname in fnames:
22+
if fname.endswith('.dcm'):
23+
x_test.append(os.path.join(fpath, fname))
24+
25+
os.system('mkdir -p ' + root_path + 'images/train/')
26+
os.system('mkdir -p ' + root_path + 'images/test/')
27+
os.system('mkdir -p ' + root_path + 'masks/train/')
28+
29+
30+
def rle_decode(rle, width, height):
31+
mask = np.zeros(width * height, dtype=np.uint8)
32+
array = np.asarray([int(x) for x in rle.split()])
33+
starts = array[0::2]
34+
lengths = array[1::2]
35+
36+
current_position = 0
37+
for index, start in enumerate(starts):
38+
current_position += start
39+
mask[current_position:current_position + lengths[index]] = 1
40+
current_position += lengths[index]
41+
42+
return mask.reshape(width, height, order='F')
43+
44+
45+
part_dir_dict = {0: 'train/', 1: 'test/'}
46+
dict_from_csv = pd.read_csv(
47+
root_path + 'chestimage_train-rle_datasets.csv', sep=',',
48+
index_col=0).to_dict()[' EncodedPixels']
49+
50+
for ith, part in enumerate([x_train, x_test]):
51+
part_dir = part_dir_dict[ith]
52+
for img in part:
53+
basename = os.path.basename(img)
54+
img_id = '.'.join(basename.split('.')[:-1])
55+
if ith == 0 and (img_id not in dict_from_csv.keys()):
56+
continue
57+
image = pydicom.read_file(img).pixel_array
58+
save_img_path = root_path + 'images/' + part_dir + '.'.join(
59+
basename.split('.')[:-1]) + save_img_suffix
60+
print(save_img_path)
61+
img_h, img_w = image.shape[:2]
62+
image = Image.fromarray(image)
63+
image.save(save_img_path)
64+
if ith == 1:
65+
continue
66+
if dict_from_csv[img_id] == '-1':
67+
mask = np.zeros((img_h, img_w), dtype=np.uint8)
68+
else:
69+
mask = rle_decode(dict_from_csv[img_id], img_h, img_w)
70+
save_mask_path = root_path + 'masks/' + part_dir + '.'.join(
71+
basename.split('.')[:-1]) + save_seg_map_suffix
72+
mask = Image.fromarray(mask)
73+
mask.save(save_mask_path)

0 commit comments

Comments
 (0)