Skip to content

Commit b5418c2

Browse files
luxiinhellock
authored andcommitted
Code for "Grid R-CNN" (open-mmlab#810)
* Grid R-CNN * add grid_rcnn_res50fpn2x config * add assertion that grid_head should exist * fix bugs and remove SharedFCBBoxHeadGrid * remove the property with_grid * format fixes for grad_head and add config dir * move random_jitter to grid_head and some refactoring * simplify the calculation of num_edges * refactoring * refactoring * rename config files and add x101 config * bug fix for inference * remove random_jitter_single * add readme of grid rcnn * add bibtex of grid rcnn plus * update work_dir
1 parent 466926e commit b5418c2

File tree

8 files changed

+966
-4
lines changed

8 files changed

+966
-4
lines changed

configs/grid_rcnn/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Grid R-CNN
2+
3+
## Introduction
4+
5+
```
6+
@inproceedings{lu2019grid,
7+
title={Grid r-cnn},
8+
author={Lu, Xin and Li, Buyu and Yue, Yuxin and Li, Quanquan and Yan, Junjie},
9+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
10+
year={2019}
11+
}
12+
13+
@article{lu2019grid,
14+
title={Grid R-CNN Plus: Faster and Better},
15+
author={Lu, Xin and Li, Buyu and Yue, Yuxin and Li, Quanquan and Yan, Junjie},
16+
journal={arXiv preprint arXiv:1906.05688},
17+
year={2019}
18+
}
19+
```
20+
21+
## Results and Models
22+
23+
| Backbone | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
24+
|:-----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
25+
| R-50 | 2x | 4.8 | | | 40.3 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/grid_rcnn/grid_rcnn_gn_head_r50_fpn_2x_20190619-5b29cf9d.pth) |
26+
| R-101 | 2x | 6.7 | | | 41.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/grid_rcnn/grid_rcnn_gn_head_r101_fpn_2x_20190619-a4b61645.pth) |
27+
| X-101-32x4d | 2x | 8.0 | | | 43.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/grid_rcnn/grid_rcnn_gn_head_x101_32x4d_fpn_2x_20190619-0bbfd87a.pth) |
28+
| X-101-64x4d | 2x | 10.9 | | | 43.1 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/grid_rcnn/grid_rcnn_gn_head_x101_64x4d_fpn_2x_20190619-8f4e20bb.pth) |
29+
30+
**Notes:**
31+
- All models are trained with 8 GPUs instead of 32 GPUs in the original paper.
32+
- The warming up lasts for 1 epoch and `2x` here indicates 25 epochs.
33+
- The training speed is about 3 times slower than Faster R-CNN.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# model settings
2+
model = dict(
3+
type='GridRCNN',
4+
pretrained='modelzoo://resnet50',
5+
backbone=dict(
6+
type='ResNet',
7+
depth=50,
8+
num_stages=4,
9+
out_indices=(0, 1, 2, 3),
10+
frozen_stages=1,
11+
style='pytorch'),
12+
neck=dict(
13+
type='FPN',
14+
in_channels=[256, 512, 1024, 2048],
15+
out_channels=256,
16+
num_outs=5),
17+
rpn_head=dict(
18+
type='RPNHead',
19+
in_channels=256,
20+
feat_channels=256,
21+
anchor_scales=[8],
22+
anchor_ratios=[0.5, 1.0, 2.0],
23+
anchor_strides=[4, 8, 16, 32, 64],
24+
target_means=[.0, .0, .0, .0],
25+
target_stds=[1.0, 1.0, 1.0, 1.0],
26+
loss_cls=dict(
27+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
28+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
29+
bbox_roi_extractor=dict(
30+
type='SingleRoIExtractor',
31+
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
32+
out_channels=256,
33+
featmap_strides=[4, 8, 16, 32]),
34+
bbox_head=dict(
35+
type='SharedFCBBoxHead',
36+
with_reg=False,
37+
num_fcs=2,
38+
in_channels=256,
39+
fc_out_channels=1024,
40+
roi_feat_size=7,
41+
num_classes=81,
42+
target_means=[0., 0., 0., 0.],
43+
target_stds=[0.1, 0.1, 0.2, 0.2],
44+
reg_class_agnostic=False),
45+
grid_roi_extractor=dict(
46+
type='SingleRoIExtractor',
47+
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
48+
out_channels=256,
49+
featmap_strides=[4, 8, 16, 32]),
50+
grid_head=dict(
51+
type='GridHead',
52+
grid_points=9,
53+
num_convs=8,
54+
in_channels=256,
55+
point_feat_channels=64,
56+
norm_cfg=dict(type='GN', num_groups=36),
57+
loss_grid=dict(
58+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15)))
59+
# model training and testing settings
60+
train_cfg = dict(
61+
rpn=dict(
62+
assigner=dict(
63+
type='MaxIoUAssigner',
64+
pos_iou_thr=0.7,
65+
neg_iou_thr=0.3,
66+
min_pos_iou=0.3,
67+
ignore_iof_thr=-1),
68+
sampler=dict(
69+
type='RandomSampler',
70+
num=256,
71+
pos_fraction=0.5,
72+
neg_pos_ub=-1,
73+
add_gt_as_proposals=False),
74+
allowed_border=0,
75+
pos_weight=-1,
76+
debug=False),
77+
rpn_proposal=dict(
78+
nms_across_levels=False,
79+
nms_pre=2000,
80+
nms_post=2000,
81+
max_num=2000,
82+
nms_thr=0.7,
83+
min_bbox_size=0),
84+
rcnn=dict(
85+
assigner=dict(
86+
type='MaxIoUAssigner',
87+
pos_iou_thr=0.5,
88+
neg_iou_thr=0.5,
89+
min_pos_iou=0.5,
90+
ignore_iof_thr=-1),
91+
sampler=dict(
92+
type='RandomSampler',
93+
num=512,
94+
pos_fraction=0.25,
95+
neg_pos_ub=-1,
96+
add_gt_as_proposals=True),
97+
pos_radius=1,
98+
pos_weight=-1,
99+
max_num_grid=192,
100+
debug=False))
101+
test_cfg = dict(
102+
rpn=dict(
103+
nms_across_levels=False,
104+
nms_pre=1000,
105+
nms_post=1000,
106+
max_num=1000,
107+
nms_thr=0.7,
108+
min_bbox_size=0),
109+
rcnn=dict(
110+
score_thr=0.03, nms=dict(type='nms', iou_thr=0.3), max_per_img=100))
111+
# dataset settings
112+
dataset_type = 'CocoDataset'
113+
data_root = 'data/coco/'
114+
img_norm_cfg = dict(
115+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
116+
data = dict(
117+
imgs_per_gpu=2,
118+
workers_per_gpu=2,
119+
train=dict(
120+
type=dataset_type,
121+
ann_file=data_root + 'annotations/instances_train2017.json',
122+
img_prefix=data_root + 'train2017/',
123+
img_scale=(1333, 800),
124+
img_norm_cfg=img_norm_cfg,
125+
size_divisor=32,
126+
flip_ratio=0.5,
127+
with_mask=True,
128+
with_crowd=True,
129+
with_label=True),
130+
val=dict(
131+
type=dataset_type,
132+
ann_file=data_root + 'annotations/instances_val2017.json',
133+
img_prefix=data_root + 'val2017/',
134+
img_scale=(1333, 800),
135+
img_norm_cfg=img_norm_cfg,
136+
size_divisor=32,
137+
flip_ratio=0,
138+
with_mask=True,
139+
with_crowd=True,
140+
with_label=True),
141+
test=dict(
142+
type=dataset_type,
143+
ann_file=data_root + 'annotations/instances_val2017.json',
144+
img_prefix=data_root + 'val2017/',
145+
img_scale=(1333, 800),
146+
img_norm_cfg=img_norm_cfg,
147+
size_divisor=32,
148+
flip_ratio=0,
149+
with_mask=False,
150+
with_label=False,
151+
test_mode=True))
152+
# optimizer
153+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
154+
optimizer_config = dict(grad_clip=None)
155+
# learning policy
156+
lr_config = dict(
157+
policy='step',
158+
warmup='linear',
159+
warmup_iters=3665,
160+
warmup_ratio=1.0 / 80,
161+
step=[17, 23])
162+
checkpoint_config = dict(interval=1)
163+
# yapf:disable
164+
log_config = dict(
165+
interval=50,
166+
hooks=[
167+
dict(type='TextLoggerHook'),
168+
# dict(type='TensorboardLoggerHook')
169+
])
170+
# yapf:enable
171+
# runtime settings
172+
total_epochs = 25
173+
dist_params = dict(backend='nccl')
174+
log_level = 'INFO'
175+
work_dir = './work_dirs/grid_rcnn_gn_head_r50_fpn_2x'
176+
load_from = None
177+
resume_from = None
178+
workflow = [('train', 1)]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# model settings
2+
model = dict(
3+
type='GridRCNN',
4+
pretrained='open-mmlab://resnext101_32x4d',
5+
backbone=dict(
6+
type='ResNeXt',
7+
depth=101,
8+
groups=32,
9+
base_width=4,
10+
num_stages=4,
11+
out_indices=(0, 1, 2, 3),
12+
frozen_stages=1,
13+
style='pytorch'),
14+
neck=dict(
15+
type='FPN',
16+
in_channels=[256, 512, 1024, 2048],
17+
out_channels=256,
18+
num_outs=5),
19+
rpn_head=dict(
20+
type='RPNHead',
21+
in_channels=256,
22+
feat_channels=256,
23+
anchor_scales=[8],
24+
anchor_ratios=[0.5, 1.0, 2.0],
25+
anchor_strides=[4, 8, 16, 32, 64],
26+
target_means=[.0, .0, .0, .0],
27+
target_stds=[1.0, 1.0, 1.0, 1.0],
28+
loss_cls=dict(
29+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
30+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
31+
bbox_roi_extractor=dict(
32+
type='SingleRoIExtractor',
33+
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
34+
out_channels=256,
35+
featmap_strides=[4, 8, 16, 32]),
36+
bbox_head=dict(
37+
type='SharedFCBBoxHead',
38+
with_reg=False,
39+
num_fcs=2,
40+
in_channels=256,
41+
fc_out_channels=1024,
42+
roi_feat_size=7,
43+
num_classes=81,
44+
target_means=[0., 0., 0., 0.],
45+
target_stds=[0.1, 0.1, 0.2, 0.2],
46+
reg_class_agnostic=False),
47+
grid_roi_extractor=dict(
48+
type='SingleRoIExtractor',
49+
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
50+
out_channels=256,
51+
featmap_strides=[4, 8, 16, 32]),
52+
grid_head=dict(
53+
type='GridHead',
54+
grid_points=9,
55+
num_convs=8,
56+
in_channels=256,
57+
point_feat_channels=64,
58+
norm_cfg=dict(type='GN', num_groups=36),
59+
loss_grid=dict(
60+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15)))
61+
# model training and testing settings
62+
train_cfg = dict(
63+
rpn=dict(
64+
assigner=dict(
65+
type='MaxIoUAssigner',
66+
pos_iou_thr=0.7,
67+
neg_iou_thr=0.3,
68+
min_pos_iou=0.3,
69+
ignore_iof_thr=-1),
70+
sampler=dict(
71+
type='RandomSampler',
72+
num=256,
73+
pos_fraction=0.5,
74+
neg_pos_ub=-1,
75+
add_gt_as_proposals=False),
76+
allowed_border=0,
77+
pos_weight=-1,
78+
debug=False),
79+
rpn_proposal=dict(
80+
nms_across_levels=False,
81+
nms_pre=2000,
82+
nms_post=2000,
83+
max_num=2000,
84+
nms_thr=0.7,
85+
min_bbox_size=0),
86+
rcnn=dict(
87+
assigner=dict(
88+
type='MaxIoUAssigner',
89+
pos_iou_thr=0.5,
90+
neg_iou_thr=0.5,
91+
min_pos_iou=0.5,
92+
ignore_iof_thr=-1),
93+
sampler=dict(
94+
type='RandomSampler',
95+
num=512,
96+
pos_fraction=0.25,
97+
neg_pos_ub=-1,
98+
add_gt_as_proposals=True),
99+
pos_radius=1,
100+
pos_weight=-1,
101+
max_num_grid=192,
102+
debug=False))
103+
test_cfg = dict(
104+
rpn=dict(
105+
nms_across_levels=False,
106+
nms_pre=1000,
107+
nms_post=1000,
108+
max_num=1000,
109+
nms_thr=0.7,
110+
min_bbox_size=0),
111+
rcnn=dict(
112+
score_thr=0.03, nms=dict(type='nms', iou_thr=0.3), max_per_img=100))
113+
# dataset settings
114+
dataset_type = 'CocoDataset'
115+
data_root = 'data/coco/'
116+
img_norm_cfg = dict(
117+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
118+
data = dict(
119+
imgs_per_gpu=2,
120+
workers_per_gpu=2,
121+
train=dict(
122+
type=dataset_type,
123+
ann_file=data_root + 'annotations/instances_train2017.json',
124+
img_prefix=data_root + 'train2017/',
125+
img_scale=(1333, 800),
126+
img_norm_cfg=img_norm_cfg,
127+
size_divisor=32,
128+
flip_ratio=0.5,
129+
with_mask=True,
130+
with_crowd=True,
131+
with_label=True),
132+
val=dict(
133+
type=dataset_type,
134+
ann_file=data_root + 'annotations/instances_val2017.json',
135+
img_prefix=data_root + 'val2017/',
136+
img_scale=(1333, 800),
137+
img_norm_cfg=img_norm_cfg,
138+
size_divisor=32,
139+
flip_ratio=0,
140+
with_mask=True,
141+
with_crowd=True,
142+
with_label=True),
143+
test=dict(
144+
type=dataset_type,
145+
ann_file=data_root + 'annotations/instances_val2017.json',
146+
img_prefix=data_root + 'val2017/',
147+
img_scale=(1333, 800),
148+
img_norm_cfg=img_norm_cfg,
149+
size_divisor=32,
150+
flip_ratio=0,
151+
with_mask=False,
152+
with_label=False,
153+
test_mode=True))
154+
# optimizer
155+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
156+
optimizer_config = dict(grad_clip=None)
157+
# learning policy
158+
lr_config = dict(
159+
policy='step',
160+
warmup='linear',
161+
warmup_iters=3665,
162+
warmup_ratio=1.0 / 80,
163+
step=[17, 23])
164+
checkpoint_config = dict(interval=1)
165+
# yapf:disable
166+
log_config = dict(
167+
interval=50,
168+
hooks=[
169+
dict(type='TextLoggerHook'),
170+
# dict(type='TensorboardLoggerHook')
171+
])
172+
# yapf:enable
173+
# runtime settings
174+
total_epochs = 25
175+
dist_params = dict(backend='nccl')
176+
log_level = 'INFO'
177+
work_dir = './work_dirs/grid_rcnn_gn_head_x101_32x4d_fpn_2x'
178+
load_from = None
179+
resume_from = None
180+
workflow = [('train', 1)]

0 commit comments

Comments
 (0)