Skip to content

Commit 8e5bfd8

Browse files
authored
Merge pull request open-mmlab#14 from OceanPang/dev
update fast rcnn configs
2 parents d13997c + 15e538f commit 8e5bfd8

File tree

4 files changed

+252
-4
lines changed

4 files changed

+252
-4
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# model settings
2+
model = dict(
3+
type='FastRCNN',
4+
pretrained='modelzoo://resnet50',
5+
backbone=dict(
6+
type='ResNet',
7+
depth=50,
8+
num_stages=4,
9+
out_indices=(0, 1, 2, 3),
10+
frozen_stages=1,
11+
style='pytorch'),
12+
neck=dict(
13+
type='FPN',
14+
in_channels=[256, 512, 1024, 2048],
15+
out_channels=256,
16+
num_outs=5),
17+
bbox_roi_extractor=dict(
18+
type='SingleRoIExtractor',
19+
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
20+
out_channels=256,
21+
featmap_strides=[4, 8, 16, 32]),
22+
bbox_head=dict(
23+
type='SharedFCRoIHead',
24+
num_fcs=2,
25+
in_channels=256,
26+
fc_out_channels=1024,
27+
roi_feat_size=7,
28+
num_classes=81,
29+
target_means=[0., 0., 0., 0.],
30+
target_stds=[0.1, 0.1, 0.2, 0.2],
31+
reg_class_agnostic=False),
32+
mask_roi_extractor=dict(
33+
type='SingleRoIExtractor',
34+
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
35+
out_channels=256,
36+
featmap_strides=[4, 8, 16, 32]),
37+
mask_head=dict(
38+
type='FCNMaskHead',
39+
num_convs=4,
40+
in_channels=256,
41+
conv_out_channels=256,
42+
num_classes=81))
43+
# model training and testing settings
44+
train_cfg = dict(
45+
rcnn=dict(
46+
mask_size=28,
47+
pos_iou_thr=0.5,
48+
neg_iou_thr=0.5,
49+
crowd_thr=1.1,
50+
roi_batch_size=512,
51+
add_gt_as_proposals=True,
52+
pos_fraction=0.25,
53+
pos_balance_sampling=False,
54+
neg_pos_ub=512,
55+
neg_balance_thr=0,
56+
min_pos_iou=0.5,
57+
pos_weight=-1,
58+
debug=False))
59+
test_cfg = dict(
60+
rcnn=dict(
61+
score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
62+
# dataset settings
63+
dataset_type = 'CocoDataset'
64+
data_root = 'data/coco/'
65+
img_norm_cfg = dict(
66+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
67+
data = dict(
68+
imgs_per_gpu=2,
69+
workers_per_gpu=2,
70+
train=dict(
71+
type=dataset_type,
72+
ann_file=data_root + 'annotations/instances_train2017.json',
73+
img_prefix=data_root + 'train2017/',
74+
img_scale=(1333, 800),
75+
img_norm_cfg=img_norm_cfg,
76+
size_divisor=32,
77+
proposal_file=data_root + 'proposals/train2017_r50_fpn_rpn_1x.pkl',
78+
flip_ratio=0.5,
79+
with_mask=True,
80+
with_crowd=True,
81+
with_label=True),
82+
val=dict(
83+
type=dataset_type,
84+
ann_file=data_root + 'annotations/instances_val2017.json',
85+
img_prefix=data_root + 'val2017/',
86+
img_scale=(1333, 800),
87+
img_norm_cfg=img_norm_cfg,
88+
proposal_file=data_root + 'proposals/val2017_r50_fpn_rpn_1x.pkl',
89+
size_divisor=32,
90+
flip_ratio=0,
91+
with_mask=True,
92+
with_crowd=True,
93+
with_label=True),
94+
test=dict(
95+
type=dataset_type,
96+
ann_file=data_root + 'annotations/instances_val2017.json',
97+
img_prefix=data_root + 'val2017/',
98+
img_scale=(1333, 800),
99+
img_norm_cfg=img_norm_cfg,
100+
proposal_file=data_root + 'proposals/val2017_r50_fpn_rpn_1x.pkl',
101+
size_divisor=32,
102+
flip_ratio=0,
103+
with_mask=False,
104+
with_label=False,
105+
test_mode=True))
106+
# optimizer
107+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
108+
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
109+
# learning policy
110+
lr_config = dict(
111+
policy='step',
112+
warmup='linear',
113+
warmup_iters=500,
114+
warmup_ratio=1.0 / 3,
115+
step=[8, 11])
116+
checkpoint_config = dict(interval=1)
117+
# yapf:disable
118+
log_config = dict(
119+
interval=50,
120+
hooks=[
121+
dict(type='TextLoggerHook'),
122+
# dict(type='TensorboardLoggerHook')
123+
])
124+
# yapf:enable
125+
# runtime settings
126+
total_epochs = 12
127+
dist_params = dict(backend='nccl')
128+
log_level = 'INFO'
129+
work_dir = './work_dirs/fast_mask_rcnn_r50_fpn_1x'
130+
load_from = None
131+
resume_from = None
132+
workflow = [('train', 1)]

configs/fast_rcnn_r50_fpn_1x.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# model settings
2+
model = dict(
3+
type='FastRCNN',
4+
pretrained='modelzoo://resnet50',
5+
backbone=dict(
6+
type='ResNet',
7+
depth=50,
8+
num_stages=4,
9+
out_indices=(0, 1, 2, 3),
10+
frozen_stages=1,
11+
style='pytorch'),
12+
neck=dict(
13+
type='FPN',
14+
in_channels=[256, 512, 1024, 2048],
15+
out_channels=256,
16+
num_outs=5),
17+
bbox_roi_extractor=dict(
18+
type='SingleRoIExtractor',
19+
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
20+
out_channels=256,
21+
featmap_strides=[4, 8, 16, 32]),
22+
bbox_head=dict(
23+
type='SharedFCRoIHead',
24+
num_fcs=2,
25+
in_channels=256,
26+
fc_out_channels=1024,
27+
roi_feat_size=7,
28+
num_classes=81,
29+
target_means=[0., 0., 0., 0.],
30+
target_stds=[0.1, 0.1, 0.2, 0.2],
31+
reg_class_agnostic=False))
32+
# model training and testing settings
33+
train_cfg = dict(
34+
rcnn=dict(
35+
pos_iou_thr=0.5,
36+
neg_iou_thr=0.5,
37+
crowd_thr=1.1,
38+
roi_batch_size=512,
39+
add_gt_as_proposals=True,
40+
pos_fraction=0.25,
41+
pos_balance_sampling=False,
42+
neg_pos_ub=512,
43+
neg_balance_thr=0,
44+
min_pos_iou=0.5,
45+
pos_weight=-1,
46+
debug=False))
47+
test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
48+
# dataset settings
49+
dataset_type = 'CocoDataset'
50+
data_root = 'data/coco/'
51+
img_norm_cfg = dict(
52+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
53+
data = dict(
54+
imgs_per_gpu=2,
55+
workers_per_gpu=2,
56+
train=dict(
57+
type=dataset_type,
58+
ann_file=data_root + 'annotations/instances_train2017.json',
59+
img_prefix=data_root + 'train2017/',
60+
img_scale=(1333, 800),
61+
img_norm_cfg=img_norm_cfg,
62+
size_divisor=32,
63+
proposal_file=data_root + 'proposals/train2017_r50_fpn_rpn_1x.pkl',
64+
flip_ratio=0.5,
65+
with_mask=False,
66+
with_crowd=True,
67+
with_label=True),
68+
val=dict(
69+
type=dataset_type,
70+
ann_file=data_root + 'annotations/instances_val2017.json',
71+
img_prefix=data_root + 'val2017/',
72+
img_scale=(1333, 800),
73+
img_norm_cfg=img_norm_cfg,
74+
proposal_file=data_root + 'proposals/val2017_r50_fpn_rpn_1x.pkl',
75+
size_divisor=32,
76+
flip_ratio=0,
77+
with_mask=False,
78+
with_crowd=True,
79+
with_label=True),
80+
test=dict(
81+
type=dataset_type,
82+
ann_file=data_root + 'annotations/instances_val2017.json',
83+
img_prefix=data_root + 'val2017/',
84+
img_scale=(1333, 800),
85+
img_norm_cfg=img_norm_cfg,
86+
proposal_file=data_root + 'proposals/val2017_r50_fpn_rpn_1x.pkl',
87+
size_divisor=32,
88+
flip_ratio=0,
89+
with_mask=False,
90+
with_label=False,
91+
test_mode=True))
92+
# optimizer
93+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
94+
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
95+
# learning policy
96+
lr_config = dict(
97+
policy='step',
98+
warmup='linear',
99+
warmup_iters=500,
100+
warmup_ratio=1.0 / 3,
101+
step=[8, 11])
102+
checkpoint_config = dict(interval=1)
103+
# yapf:disable
104+
log_config = dict(
105+
interval=50,
106+
hooks=[
107+
dict(type='TextLoggerHook'),
108+
# dict(type='TensorboardLoggerHook')
109+
])
110+
# yapf:enable
111+
# runtime settings
112+
total_epochs = 12
113+
dist_params = dict(backend='nccl')
114+
log_level = 'INFO'
115+
work_dir = './work_dirs/fast_rcnn_r50_fpn_1x'
116+
load_from = None
117+
resume_from = None
118+
workflow = [('train', 1)]

configs/faster_rcnn_r50_fpn_1x.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
pos_balance_sampling=False,
6666
neg_pos_ub=512,
6767
neg_balance_thr=0,
68-
min_pos_iou=1.1,
68+
min_pos_iou=0.5,
6969
pos_weight=-1,
7070
debug=False))
7171
test_cfg = dict(
@@ -139,7 +139,6 @@
139139
# yapf:enable
140140
# runtime settings
141141
total_epochs = 12
142-
device_ids = range(8)
143142
dist_params = dict(backend='nccl')
144143
log_level = 'INFO'
145144
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x'

configs/mask_rcnn_r50_fpn_1x.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
pos_balance_sampling=False,
7878
neg_pos_ub=512,
7979
neg_balance_thr=0,
80-
min_pos_iou=1.1,
80+
min_pos_iou=0.5,
8181
pos_weight=-1,
8282
debug=False))
8383
test_cfg = dict(
@@ -152,7 +152,6 @@
152152
# yapf:enable
153153
# runtime settings
154154
total_epochs = 12
155-
device_ids = range(8)
156155
dist_params = dict(backend='nccl')
157156
log_level = 'INFO'
158157
work_dir = './work_dirs/mask_rcnn_r50_fpn_1x'

0 commit comments

Comments
 (0)