Skip to content

Commit 1cbc88e

Browse files
authored
Code for CVPR 2019 paper "Hybrid Task Cascade for Instance Segmentation" (open-mmlab#478)
* add HybridTaskCascade * add configs for other backbones * add keep_ratio argument for segmap transform * add readme for HTC * fix linting errors * split assign and sampling as in Cascade R-CNN * remove unused imports * add a large model * update htc
1 parent a9e21cf commit 1cbc88e

17 files changed

+2373
-4
lines changed

MODEL_ZOO.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
154154
- The `20e` schedule in Cascade (Mask) R-CNN indicates decreasing the lr at 16 and 19 epochs, with a total of 20 epochs.
155155
- Cascade Mask R-CNN with X-101-64x4d-FPN was trained using 16 GPU with a batch size of 16 (1 images per GPU).
156156

157+
### Hybrid Task Cascade (HTC)
158+
159+
Please refer to [HTC](configs/htc/README.md) for details.
160+
157161
### SSD
158162

159163
| Backbone | Size | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Results and models are available in the [Model zoo](MODEL_ZOO.md).
9191
| Cascade Mask R-CNN |||||
9292
| SSD |||||
9393
| RetinaNet |||||
94+
| Hybrid Task Cascade|||||
9495

9596
Other features
9697
- [x] DCNv2

configs/htc/README.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Hybrid Task Cascade for Instance Segmentation
2+
3+
## Introduction
4+
5+
We provide config files to reproduce the results in the CVPR 2019 paper for [Hybrid Task Cascade](https://arxiv.org/abs/1901.07518).
6+
7+
```
8+
@inproceedings{chen2019hybrid,
9+
title={Hybrid task cascade for instance segmentation},
10+
author={Chen, Kai and Pang, Jiangmiao and Wang, Jiaqi and Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and Liu, Ziwei and Shi, Jianping and Ouyang, Wanli and Chen Change Loy and Dahua Lin},
11+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
12+
year={2019}
13+
}
14+
```
15+
16+
## Dataset
17+
18+
HTC requires COCO and COCO-stuff dataset for training. You need to download and extract it in the COCO dataset path.
19+
The directory should be like this.
20+
21+
```
22+
mmdetection
23+
├── mmdet
24+
├── tools
25+
├── configs
26+
├── data
27+
│ ├── coco
28+
│ │ ├── annotations
29+
│ │ ├── train2017
30+
│ │ ├── val2017
31+
│ │ ├── test2017
32+
| | ├── stuffthingmaps
33+
```
34+
35+
## Results and Models
36+
37+
The results on COCO 2017val is shown in the below table. (results on test-dev are usually slightly higher than val)
38+
39+
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
40+
|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:|
41+
| R-50-FPN | pytorch | 1x | | | | 42.2 | 37.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_r50_fpn_1x_20190408-878c1712.pth) |
42+
| R-50-FPN | pytorch | 20e | | | | 43.2 | 38.0 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_r50_fpn_20e_20190408-c03b7015.pth) |
43+
| R-101-FPN | pytorch | 20e | | | | 44.9 | 39.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_r101_fpn_20e_20190408-a2e586db.pth) |
44+
| X-101-32x4d-FPN | pytorch |20e| | | | 46.1 | 40.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_x101_32x4d_fpn_20e_20190408-9eae4d0b.pth) |
45+
| X-101-64x4d-FPN | pytorch |20e| | | | 47.0 | 40.9 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_x101_64x4d_fpn_20e_20190408-497f2561.pth) |
46+
47+
- In the HTC paper and COCO 2018 Challenge, `score_thr` is set to 0.001 for both baselines and HTC.
48+
- We use 8 GPUs with 2 images/GPU for R-50 and R-101 models, and 16 GPUs with 1 image/GPU for X-101 models.
49+
If you would like to train X-101 HTC with 8 GPUs, you need to change the lr from 0.02 to 0.01.
50+
51+
We also provide a powerful HTC with DCN and multi-scale training model. No testing augmentation is used.
52+
53+
| Backbone | Style | DCN | training scales | Lr schd | box AP | mask AP | Download |
54+
|:----------------:|:-------:|:-----:|:---------------:|:-------:|:------:|:-------:|:--------:|
55+
| X-101-64x4d-FPN | pytorch | c3-c5 | 400~1400 | 20e | 50.7 | 43.9 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e_20190408-0e50669c.pth) |
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# model settings
2+
model = dict(
3+
type='HybridTaskCascade',
4+
num_stages=3,
5+
pretrained='open-mmlab://resnext101_64x4d',
6+
interleaved=True,
7+
mask_info_flow=True,
8+
backbone=dict(
9+
type='ResNeXt',
10+
depth=101,
11+
groups=64,
12+
base_width=4,
13+
num_stages=4,
14+
out_indices=(0, 1, 2, 3),
15+
frozen_stages=1,
16+
style='pytorch',
17+
dcn=dict(
18+
modulated=False,
19+
groups=64,
20+
deformable_groups=1,
21+
fallback_on_stride=False),
22+
stage_with_dcn=(False, True, True, True)),
23+
neck=dict(
24+
type='FPN',
25+
in_channels=[256, 512, 1024, 2048],
26+
out_channels=256,
27+
num_outs=5),
28+
rpn_head=dict(
29+
type='RPNHead',
30+
in_channels=256,
31+
feat_channels=256,
32+
anchor_scales=[8],
33+
anchor_ratios=[0.5, 1.0, 2.0],
34+
anchor_strides=[4, 8, 16, 32, 64],
35+
target_means=[.0, .0, .0, .0],
36+
target_stds=[1.0, 1.0, 1.0, 1.0],
37+
use_sigmoid_cls=True),
38+
bbox_roi_extractor=dict(
39+
type='SingleRoIExtractor',
40+
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
41+
out_channels=256,
42+
featmap_strides=[4, 8, 16, 32]),
43+
bbox_head=[
44+
dict(
45+
type='SharedFCBBoxHead',
46+
num_fcs=2,
47+
in_channels=256,
48+
fc_out_channels=1024,
49+
roi_feat_size=7,
50+
num_classes=81,
51+
target_means=[0., 0., 0., 0.],
52+
target_stds=[0.1, 0.1, 0.2, 0.2],
53+
reg_class_agnostic=True),
54+
dict(
55+
type='SharedFCBBoxHead',
56+
num_fcs=2,
57+
in_channels=256,
58+
fc_out_channels=1024,
59+
roi_feat_size=7,
60+
num_classes=81,
61+
target_means=[0., 0., 0., 0.],
62+
target_stds=[0.05, 0.05, 0.1, 0.1],
63+
reg_class_agnostic=True),
64+
dict(
65+
type='SharedFCBBoxHead',
66+
num_fcs=2,
67+
in_channels=256,
68+
fc_out_channels=1024,
69+
roi_feat_size=7,
70+
num_classes=81,
71+
target_means=[0., 0., 0., 0.],
72+
target_stds=[0.033, 0.033, 0.067, 0.067],
73+
reg_class_agnostic=True)
74+
],
75+
mask_roi_extractor=dict(
76+
type='SingleRoIExtractor',
77+
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
78+
out_channels=256,
79+
featmap_strides=[4, 8, 16, 32]),
80+
mask_head=dict(
81+
type='HTCMaskHead',
82+
num_convs=4,
83+
in_channels=256,
84+
conv_out_channels=256,
85+
num_classes=81),
86+
semantic_roi_extractor=dict(
87+
type='SingleRoIExtractor',
88+
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
89+
out_channels=256,
90+
featmap_strides=[8]),
91+
semantic_head=dict(
92+
type='FusedSemanticHead',
93+
num_ins=5,
94+
fusion_level=1,
95+
num_convs=4,
96+
in_channels=256,
97+
conv_out_channels=256,
98+
num_classes=183,
99+
ignore_label=255,
100+
loss_weight=0.2))
101+
# model training and testing settings
102+
train_cfg = dict(
103+
rpn=dict(
104+
assigner=dict(
105+
type='MaxIoUAssigner',
106+
pos_iou_thr=0.7,
107+
neg_iou_thr=0.3,
108+
min_pos_iou=0.3,
109+
ignore_iof_thr=-1),
110+
sampler=dict(
111+
type='RandomSampler',
112+
num=256,
113+
pos_fraction=0.5,
114+
neg_pos_ub=-1,
115+
add_gt_as_proposals=False),
116+
allowed_border=0,
117+
pos_weight=-1,
118+
smoothl1_beta=1 / 9.0,
119+
debug=False),
120+
rcnn=[
121+
dict(
122+
assigner=dict(
123+
type='MaxIoUAssigner',
124+
pos_iou_thr=0.5,
125+
neg_iou_thr=0.5,
126+
min_pos_iou=0.5,
127+
ignore_iof_thr=-1),
128+
sampler=dict(
129+
type='RandomSampler',
130+
num=512,
131+
pos_fraction=0.25,
132+
neg_pos_ub=-1,
133+
add_gt_as_proposals=True),
134+
mask_size=28,
135+
pos_weight=-1,
136+
debug=False),
137+
dict(
138+
assigner=dict(
139+
type='MaxIoUAssigner',
140+
pos_iou_thr=0.6,
141+
neg_iou_thr=0.6,
142+
min_pos_iou=0.6,
143+
ignore_iof_thr=-1),
144+
sampler=dict(
145+
type='RandomSampler',
146+
num=512,
147+
pos_fraction=0.25,
148+
neg_pos_ub=-1,
149+
add_gt_as_proposals=True),
150+
mask_size=28,
151+
pos_weight=-1,
152+
debug=False),
153+
dict(
154+
assigner=dict(
155+
type='MaxIoUAssigner',
156+
pos_iou_thr=0.7,
157+
neg_iou_thr=0.7,
158+
min_pos_iou=0.7,
159+
ignore_iof_thr=-1),
160+
sampler=dict(
161+
type='RandomSampler',
162+
num=512,
163+
pos_fraction=0.25,
164+
neg_pos_ub=-1,
165+
add_gt_as_proposals=True),
166+
mask_size=28,
167+
pos_weight=-1,
168+
debug=False)
169+
],
170+
stage_loss_weights=[1, 0.5, 0.25])
171+
test_cfg = dict(
172+
rpn=dict(
173+
nms_across_levels=False,
174+
nms_pre=2000,
175+
nms_post=2000,
176+
max_num=2000,
177+
nms_thr=0.7,
178+
min_bbox_size=0),
179+
rcnn=dict(
180+
score_thr=0.001,
181+
nms=dict(type='nms', iou_thr=0.5),
182+
max_per_img=100,
183+
mask_thr_binary=0.5),
184+
keep_all_stages=False)
185+
# dataset settings
186+
dataset_type = 'CocoDataset'
187+
data_root = 'data/coco/'
188+
img_norm_cfg = dict(
189+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
190+
data = dict(
191+
imgs_per_gpu=1,
192+
workers_per_gpu=1,
193+
train=dict(
194+
type=dataset_type,
195+
ann_file=data_root + 'annotations/instances_train2017.json',
196+
img_prefix=data_root + 'train2017/',
197+
img_scale=[(1600, 400), (1600, 1400)],
198+
multiscale_mode='range',
199+
img_norm_cfg=img_norm_cfg,
200+
size_divisor=32,
201+
flip_ratio=0.5,
202+
seg_prefix=data_root + 'stuffthingmaps/train2017/',
203+
seg_scale_factor=1 / 8,
204+
with_mask=True,
205+
with_crowd=True,
206+
with_label=True,
207+
with_semantic_seg=True),
208+
val=dict(
209+
type=dataset_type,
210+
ann_file=data_root + 'annotations/instances_val2017.json',
211+
img_prefix=data_root + 'val2017/',
212+
img_scale=(1333, 800),
213+
img_norm_cfg=img_norm_cfg,
214+
size_divisor=32,
215+
flip_ratio=0,
216+
with_mask=True,
217+
with_crowd=True,
218+
with_label=True),
219+
test=dict(
220+
type=dataset_type,
221+
ann_file=data_root + 'annotations/instances_val2017.json',
222+
img_prefix=data_root + 'val2017/',
223+
img_scale=(1333, 800),
224+
img_norm_cfg=img_norm_cfg,
225+
size_divisor=32,
226+
flip_ratio=0,
227+
with_mask=True,
228+
with_label=False,
229+
test_mode=True))
230+
# optimizer
231+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
232+
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
233+
# learning policy
234+
lr_config = dict(
235+
policy='step',
236+
warmup='linear',
237+
warmup_iters=500,
238+
warmup_ratio=1.0 / 3,
239+
step=[16, 19])
240+
checkpoint_config = dict(interval=1)
241+
# yapf:disable
242+
log_config = dict(
243+
interval=50,
244+
hooks=[
245+
dict(type='TextLoggerHook'),
246+
# dict(type='TensorboardLoggerHook')
247+
])
248+
# yapf:enable
249+
# runtime settings
250+
total_epochs = 20
251+
dist_params = dict(backend='nccl')
252+
log_level = 'INFO'
253+
work_dir = './work_dirs/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e'
254+
load_from = None
255+
resume_from = None
256+
workflow = [('train', 1)]

0 commit comments

Comments
 (0)