Skip to content

Commit d84c25b

Browse files
committed
add robin dataset
1 parent e64548f commit d84c25b

File tree

7 files changed

+2036
-185
lines changed

7 files changed

+2036
-185
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
dataset_type = 'RobinDataset'
2+
data_root = 'data/robin/'

demo/MMSegmentation_Tutorial.ipynb

Lines changed: 1754 additions & 185 deletions
Large diffs are not rendered by default.
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
norm_cfg = dict(type='SyncBN', requires_grad=True)
2+
data_preprocessor = dict(
3+
type='SegDataPreProcessor',
4+
mean=[123.675, 116.28, 103.53],
5+
std=[58.395, 57.12, 57.375],
6+
bgr_to_rgb=True,
7+
pad_val=0,
8+
seg_pad_val=255,
9+
size=(512, 1024))
10+
model = dict(
11+
type='EncoderDecoder',
12+
data_preprocessor=dict(
13+
type='SegDataPreProcessor',
14+
mean=[123.675, 116.28, 103.53],
15+
std=[58.395, 57.12, 57.375],
16+
bgr_to_rgb=True,
17+
pad_val=0,
18+
seg_pad_val=255,
19+
size=(512, 1024)),
20+
pretrained='open-mmlab://resnet50_v1c',
21+
backbone=dict(
22+
type='ResNetV1c',
23+
depth=50,
24+
num_stages=4,
25+
out_indices=(0, 1, 2, 3),
26+
dilations=(1, 1, 2, 4),
27+
strides=(1, 2, 1, 1),
28+
norm_cfg=dict(type='SyncBN', requires_grad=True),
29+
norm_eval=False,
30+
style='pytorch',
31+
contract_dilation=True),
32+
decode_head=dict(
33+
type='PSPHead',
34+
in_channels=2048,
35+
in_index=3,
36+
channels=512,
37+
pool_scales=(1, 2, 3, 6),
38+
dropout_ratio=0.1,
39+
num_classes=19,
40+
norm_cfg=dict(type='SyncBN', requires_grad=True),
41+
align_corners=False,
42+
loss_decode=dict(
43+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
44+
auxiliary_head=dict(
45+
type='FCNHead',
46+
in_channels=1024,
47+
in_index=2,
48+
channels=256,
49+
num_convs=1,
50+
concat_input=False,
51+
dropout_ratio=0.1,
52+
num_classes=19,
53+
norm_cfg=dict(type='SyncBN', requires_grad=True),
54+
align_corners=False,
55+
loss_decode=dict(
56+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
57+
train_cfg=dict(),
58+
test_cfg=dict(mode='whole'))
59+
dataset_type = 'CityscapesDataset'
60+
data_root = 'data/cityscapes/'
61+
crop_size = (512, 1024)
62+
train_pipeline = [
63+
dict(type='LoadImageFromFile'),
64+
dict(type='LoadAnnotations'),
65+
dict(
66+
type='RandomResize',
67+
scale=(2048, 1024),
68+
ratio_range=(0.5, 2.0),
69+
keep_ratio=True),
70+
dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75),
71+
dict(type='RandomFlip', prob=0.5),
72+
dict(type='PhotoMetricDistortion'),
73+
dict(type='PackSegInputs')
74+
]
75+
test_pipeline = [
76+
dict(type='LoadImageFromFile'),
77+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
78+
dict(type='LoadAnnotations'),
79+
dict(type='PackSegInputs')
80+
]
81+
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
82+
tta_pipeline = [
83+
dict(type='LoadImageFromFile', backend_args=None),
84+
dict(
85+
type='TestTimeAug',
86+
transforms=[[{
87+
'type': 'Resize',
88+
'scale_factor': 0.5,
89+
'keep_ratio': True
90+
}, {
91+
'type': 'Resize',
92+
'scale_factor': 0.75,
93+
'keep_ratio': True
94+
}, {
95+
'type': 'Resize',
96+
'scale_factor': 1.0,
97+
'keep_ratio': True
98+
}, {
99+
'type': 'Resize',
100+
'scale_factor': 1.25,
101+
'keep_ratio': True
102+
}, {
103+
'type': 'Resize',
104+
'scale_factor': 1.5,
105+
'keep_ratio': True
106+
}, {
107+
'type': 'Resize',
108+
'scale_factor': 1.75,
109+
'keep_ratio': True
110+
}],
111+
[{
112+
'type': 'RandomFlip',
113+
'prob': 0.0,
114+
'direction': 'horizontal'
115+
}, {
116+
'type': 'RandomFlip',
117+
'prob': 1.0,
118+
'direction': 'horizontal'
119+
}], [{
120+
'type': 'LoadAnnotations'
121+
}], [{
122+
'type': 'PackSegInputs'
123+
}]])
124+
]
125+
train_dataloader = dict(
126+
batch_size=2,
127+
num_workers=2,
128+
persistent_workers=True,
129+
sampler=dict(type='InfiniteSampler', shuffle=True),
130+
dataset=dict(
131+
type='CityscapesDataset',
132+
data_root='data/cityscapes/',
133+
data_prefix=dict(
134+
img_path='leftImg8bit/train', seg_map_path='gtFine/train'),
135+
pipeline=[
136+
dict(type='LoadImageFromFile'),
137+
dict(type='LoadAnnotations'),
138+
dict(
139+
type='RandomResize',
140+
scale=(2048, 1024),
141+
ratio_range=(0.5, 2.0),
142+
keep_ratio=True),
143+
dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75),
144+
dict(type='RandomFlip', prob=0.5),
145+
dict(type='PhotoMetricDistortion'),
146+
dict(type='PackSegInputs')
147+
]))
148+
val_dataloader = dict(
149+
batch_size=1,
150+
num_workers=4,
151+
persistent_workers=True,
152+
sampler=dict(type='DefaultSampler', shuffle=False),
153+
dataset=dict(
154+
type='CityscapesDataset',
155+
data_root='data/cityscapes/',
156+
data_prefix=dict(
157+
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
158+
pipeline=[
159+
dict(type='LoadImageFromFile'),
160+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
161+
dict(type='LoadAnnotations'),
162+
dict(type='PackSegInputs')
163+
]))
164+
test_dataloader = dict(
165+
batch_size=1,
166+
num_workers=4,
167+
persistent_workers=True,
168+
sampler=dict(type='DefaultSampler', shuffle=False),
169+
dataset=dict(
170+
type='CityscapesDataset',
171+
data_root='data/cityscapes/',
172+
data_prefix=dict(
173+
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
174+
pipeline=[
175+
dict(type='LoadImageFromFile'),
176+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
177+
dict(type='LoadAnnotations'),
178+
dict(type='PackSegInputs')
179+
]))
180+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
181+
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
182+
default_scope = 'mmseg'
183+
env_cfg = dict(
184+
cudnn_benchmark=True,
185+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
186+
dist_cfg=dict(backend='nccl'))
187+
vis_backends = [dict(type='LocalVisBackend')]
188+
visualizer = dict(
189+
type='SegLocalVisualizer',
190+
vis_backends=[dict(type='LocalVisBackend')],
191+
name='visualizer')
192+
log_processor = dict(by_epoch=False)
193+
log_level = 'INFO'
194+
load_from = None
195+
resume = False
196+
tta_model = dict(type='SegTTAModel')
197+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
198+
optim_wrapper = dict(
199+
type='OptimWrapper',
200+
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005),
201+
clip_grad=None)
202+
param_scheduler = [
203+
dict(
204+
type='PolyLR',
205+
eta_min=0.0001,
206+
power=0.9,
207+
begin=0,
208+
end=40000,
209+
by_epoch=False)
210+
]
211+
train_cfg = dict(type='IterBasedTrainLoop', max_iters=40000, val_interval=4000)
212+
val_cfg = dict(type='ValLoop')
213+
test_cfg = dict(type='TestLoop')
214+
default_hooks = dict(
215+
timer=dict(type='IterTimerHook'),
216+
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
217+
param_scheduler=dict(type='ParamSchedulerHook'),
218+
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
219+
sampler_seed=dict(type='DistSamplerSeedHook'),
220+
visualization=dict(type='SegVisualizationHook'))

demo/robin_dataset_demo.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from argparse import ArgumentParser
3+
4+
from mmengine.model import revert_sync_batchnorm
5+
6+
from mmseg.apis import inference_model, init_model, show_result_pyplot
7+
8+
from mmengine.registry import init_default_scope
9+
from mmseg.datasets import RobinDataset
10+
11+
12+
def main():
13+
print("start")
14+
init_default_scope('mmseg')
15+
16+
data_root = 'mmsegmentation/data/robin/'
17+
18+
data_prefix=dict(img_path='img_dir/train', seg_map_path='ann_dir/train')
19+
dataset = RobinDataset(data_root=data_root, data_prefix=data_prefix,
20+
pipeline=[],
21+
img_suffix = '.png',
22+
ann_suffix = '.png'
23+
)
24+
25+
print(f"len(robin_dataset): {len(dataset)}")
26+
27+
28+
if __name__ == '__main__':
29+
main()

mmseg/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from .refuge import REFUGEDataset
2222
from .stare import STAREDataset
2323
from .synapse import SynapseDataset
24+
from .robin import RobinDataset
25+
2426
# yapf: disable
2527
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
2628
BioMedical3DRandomCrop, BioMedical3DRandomFlip,

mmseg/datasets/robin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from mmseg.registry import DATASETS
2+
from .basesegdataset import BaseSegDataset
3+
4+
5+
@DATASETS.register_module()
6+
class RobinDataset(BaseSegDataset):
7+
8+
METAINFO = dict(
9+
classes=('box', 'ice_pack'),
10+
palette=[[20, 20, 255], [255, 20, 20]]
11+
)
12+
13+
14+
def __init__(self, data_root, data_prefix, pipeline=[],
15+
img_suffix = '.png',
16+
ann_suffix = '.png', ann_file="", **kwargs):
17+
super().__init__(data_root=data_root,
18+
data_prefix=data_prefix,
19+
pipeline=pipeline, img_suffix=img_suffix,
20+
seg_map_suffix=ann_suffix, ann_file=ann_file, **kwargs)

mmseg/utils/class_names.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
from mmengine.utils import is_str
33

44

5+
def robin_classes():
6+
"""Robin class names for external use."""
7+
return ['box', 'ice_pack']
8+
9+
def robin_palette():
10+
"""Robin palette for external use."""
11+
return [[20, 20, 255], [255, 20, 20]]
12+
513
def cityscapes_classes():
614
"""Cityscapes class names for external use."""
715
return [
@@ -420,6 +428,7 @@ def lip_palette():
420428

421429

422430
dataset_aliases = {
431+
'robin': ['robin'],
423432
'cityscapes': ['cityscapes'],
424433
'ade': ['ade', 'ade20k'],
425434
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],

0 commit comments

Comments
 (0)