Skip to content

Commit 608e319

Browse files
angiecaoCastleDreamyeedragYang-ChanghuiSheffieldCao
authored
[Feature] Support Side Adapter Network (open-mmlab#3232)
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <[email protected]> Co-authored-by: yeedrag <[email protected]> Co-authored-by: Yang-ChangHui <[email protected]> Co-authored-by: Xu CAO <[email protected]> Co-authored-by: xiexinch <[email protected]> Co-authored-by: 小飞猪 <[email protected]>
1 parent 1471d1e commit 608e319

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4114
-29
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
4+
data_preprocessor = dict(
5+
type='SegDataPreProcessor',
6+
mean=[122.7709, 116.7460, 104.0937],
7+
std=[68.5005, 66.6322, 70.3232],
8+
bgr_to_rgb=True,
9+
pad_val=0,
10+
seg_pad_val=255,
11+
size_divisor=640,
12+
test_cfg=dict(size_divisor=32))
13+
14+
num_classes = 171
15+
model = dict(
16+
type='MultimodalEncoderDecoder',
17+
data_preprocessor=data_preprocessor,
18+
pretrained='pretrain/clip_vit_base_patch16_224.pth',
19+
asymetric_input=True,
20+
encoder_resolution=0.5,
21+
image_encoder=dict(
22+
type='VisionTransformer',
23+
img_size=(224, 224),
24+
patch_size=16,
25+
patch_pad=0,
26+
in_channels=3,
27+
embed_dims=768,
28+
num_layers=9,
29+
num_heads=12,
30+
mlp_ratio=4,
31+
out_origin=True,
32+
out_indices=(2, 5, 8),
33+
qkv_bias=True,
34+
drop_rate=0.0,
35+
attn_drop_rate=0.0,
36+
drop_path_rate=0.0,
37+
with_cls_token=True,
38+
output_cls_token=True,
39+
patch_bias=False,
40+
pre_norm=True,
41+
norm_cfg=dict(type='LN', eps=1e-5),
42+
act_cfg=dict(type='QuickGELU'),
43+
norm_eval=False,
44+
interpolate_mode='bicubic',
45+
frozen_exclude=['pos_embed']),
46+
text_encoder=dict(
47+
type='CLIPTextEncoder',
48+
dataset_name=None,
49+
templates='vild',
50+
embed_dims=512,
51+
num_layers=12,
52+
num_heads=8,
53+
mlp_ratio=4,
54+
output_dims=512,
55+
cache_feature=True,
56+
cat_bg=True,
57+
norm_cfg=dict(type='LN', eps=1e-5)
58+
),
59+
decode_head=dict(
60+
type='SideAdapterCLIPHead',
61+
num_classes=num_classes,
62+
deep_supervision_idxs=[7],
63+
san_cfg=dict(
64+
in_channels=3,
65+
clip_channels=768,
66+
embed_dims=240,
67+
patch_size=16,
68+
patch_bias=True,
69+
num_queries=100,
70+
cfg_encoder=dict(
71+
num_encode_layer=8,
72+
num_heads=6,
73+
mlp_ratio=4
74+
),
75+
fusion_index=[0, 1, 2, 3],
76+
cfg_decoder=dict(
77+
num_heads=12,
78+
num_layers=1,
79+
embed_channels=256,
80+
mlp_channels=256,
81+
num_mlp=3,
82+
rescale=True),
83+
norm_cfg=dict(type='LN', eps=1e-6),
84+
),
85+
maskgen_cfg=dict(
86+
sos_token_format='cls_token',
87+
sos_token_num=100,
88+
cross_attn=False,
89+
num_layers=3,
90+
embed_dims=768,
91+
num_heads=12,
92+
mlp_ratio=4,
93+
qkv_bias=True,
94+
out_dims=512,
95+
final_norm=True,
96+
act_cfg=dict(type='QuickGELU'),
97+
norm_cfg=dict(type='LN', eps=1e-5),
98+
frozen_exclude=[]
99+
),
100+
align_corners=False,
101+
train_cfg=dict(
102+
num_points=12544,
103+
oversample_ratio=3.0,
104+
importance_sample_ratio=0.75,
105+
assigner=dict(
106+
type='HungarianAssigner',
107+
match_costs=[
108+
dict(type='ClassificationCost', weight=2.0),
109+
dict(
110+
type='CrossEntropyLossCost',
111+
weight=5.0,
112+
use_sigmoid=True),
113+
dict(
114+
type='DiceCost',
115+
weight=5.0,
116+
pred_act=True,
117+
eps=1.0)
118+
])),
119+
loss_decode=[dict(type='CrossEntropyLoss',
120+
loss_name='loss_cls_ce',
121+
loss_weight=2.0,
122+
class_weight=[1.0] * num_classes + [0.1]),
123+
dict(type='CrossEntropyLoss',
124+
use_sigmoid=True,
125+
loss_name='loss_mask_ce',
126+
loss_weight=5.0),
127+
dict(type='DiceLoss',
128+
ignore_index=None,
129+
naive_dice=True,
130+
eps=1,
131+
loss_name='loss_mask_dice',
132+
loss_weight=5.0)
133+
]),
134+
135+
# model training and testing settings
136+
train_cfg=dict(),
137+
test_cfg=dict(mode='whole')) # yapf: disable

configs/san/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SAN
2+
3+
> [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)
4+
5+
## Introduction
6+
7+
<!-- [ALGORITHM] -->
8+
9+
<a href="https://github.com/MendelXu/SAN">Official Repo</a>
10+
11+
## Abstract
12+
13+
<!-- [ABSTRACT] -->
14+
15+
This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. We hope our approach will serve as a solid baseline and help ease future research in open-vocabulary semantic segmentation.
16+
17+
<!-- [IMAGE] -->
18+
19+
<div align=center>
20+
<img src="https://github.com/MendelXu/SAN/blob/main/resources/arch.png" width="800"/>
21+
</div>
22+
23+
## Results and models
24+
25+
### COCO-Stuff164k
26+
27+
| Method | Backbone | Pretrained | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
28+
| ------ | -------- | ------------ | --------- | ------- | -------- | -------------- | ------ | ----- | ------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
29+
| SAN | ViT-B_16 | CLIP_ViT-B16 | 640x640 | 60000 | 12.61 | - | V100 | 41.93 | 41.77 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906-fd0a7684.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906.log) |
30+
| SAN | ViT-L_14 | CLIP_ViT-L14 | 640x640 | 60000 | 22.84 | - | V100 | 45.78 | 43.99 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907-a11e098f.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907.log) |
31+
32+
## Notes
33+
34+
git push
35+
The pretrained weights in config files are converted from open_clip models using tools/model_converters/clip2mmseg.py.
36+
37+
## Citation
38+
39+
```bibtex
40+
@inproceedings{xu2023side,
41+
title={Side adapter network for open-vocabulary semantic segmentation},
42+
author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang},
43+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
44+
pages={2945--2954},
45+
year={2023}
46+
}
47+
```
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
_base_ = [
2+
'../_base_/models/san_vit-b16.py', '../_base_/datasets/coco-stuff164k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
crop_size = (640, 640)
6+
train_pipeline = [
7+
dict(type='LoadImageFromFile'),
8+
dict(type='LoadAnnotations'),
9+
dict(
10+
type='RandomChoiceResize',
11+
scales=[int(640 * x * 0.1) for x in range(5, 16)],
12+
resize_type='ResizeShortestEdge',
13+
max_size=2560),
14+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.0),
15+
dict(type='PhotoMetricDistortion'),
16+
dict(type='RandomFlip', prob=0.5),
17+
dict(type='PackSegInputs')
18+
]
19+
20+
test_pipeline = [
21+
dict(type='LoadImageFromFile'),
22+
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
23+
dict(type='LoadAnnotations'),
24+
dict(type='PackSegInputs')
25+
]
26+
27+
# By default, models are trained on 4 GPUs with 8 images per GPU
28+
train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline))
29+
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
30+
test_dataloader = val_dataloader
31+
32+
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-base-patch16-224_3rdparty-d08f8887.pth' # noqa
33+
data_preprocessor = dict(
34+
mean=[122.7709, 116.7460, 104.0937],
35+
std=[68.5005, 66.6322, 70.3232],
36+
size_divisor=640,
37+
test_cfg=dict(size_divisor=32))
38+
model = dict(
39+
pretrained=pretrained,
40+
text_encoder=dict(dataset_name='coco-stuff164k'),
41+
decode_head=dict(num_classes=171))
42+
43+
# training schedule for 60k
44+
train_cfg = dict(
45+
type='IterBasedTrainLoop',
46+
max_iters=60000,
47+
val_interval=500,
48+
val_begin=55000)
49+
default_hooks = dict(
50+
checkpoint=dict(
51+
type='CheckpointHook',
52+
by_epoch=False,
53+
interval=10000,
54+
save_best='mIoU'))
55+
56+
# AdamW optimizer, no weight decay for position embedding & layer norm
57+
# in backbone
58+
optim_wrapper = dict(
59+
_delete_=True,
60+
type='AmpOptimWrapper',
61+
optimizer=dict(
62+
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001),
63+
paramwise_cfg=dict(
64+
custom_keys={
65+
'img_encoder': dict(lr_mult=0.1, decay_mult=1.0),
66+
'pos_embed': dict(decay_mult=0.),
67+
'cls_token': dict(decay_mult=0.),
68+
'norm': dict(decay_mult=0.)
69+
}),
70+
loss_scale='dynamic',
71+
clip_grad=dict(max_norm=0.01, norm_type=2))
72+
73+
param_scheduler = [
74+
dict(
75+
type='PolyLR',
76+
eta_min=0.0,
77+
power=1.0,
78+
begin=0,
79+
end=60000,
80+
by_epoch=False,
81+
)
82+
]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
_base_ = [
2+
'../_base_/models/san_vit-b16.py',
3+
'../_base_/datasets/pascal_context_59.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_160k.py'
5+
]
6+
crop_size = (640, 640)
7+
8+
test_pipeline = [
9+
dict(type='LoadImageFromFile'),
10+
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
11+
dict(type='LoadAnnotations'),
12+
dict(type='PackSegInputs')
13+
]
14+
15+
# By default, models are trained on 8 GPUs with 2 images per GPU
16+
train_dataloader = dict(batch_size=2)
17+
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
18+
test_dataloader = val_dataloader
19+
20+
data_preprocessor = dict(
21+
mean=[122.7709, 116.7460, 104.0937],
22+
std=[68.5005, 66.6322, 70.3232],
23+
size_divisor=640,
24+
test_cfg=dict(size_divisor=32))
25+
model = dict(
26+
data_preprocessor=data_preprocessor,
27+
pretrained='pretrain/vit_base_patch16_224.pth',
28+
text_encoder=dict(dataset_name='pascal_context'),
29+
decode_head=dict(num_classes=59))
30+
31+
# AdamW optimizer, no weight decay for position embedding & layer norm
32+
# in backbone
33+
optim_wrapper = dict(
34+
_delete_=True,
35+
type='OptimWrapper',
36+
optimizer=dict(
37+
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
38+
paramwise_cfg=dict(
39+
custom_keys={
40+
'pos_embed': dict(decay_mult=0.),
41+
'cls_token': dict(decay_mult=0.),
42+
'norm': dict(decay_mult=0.)
43+
}))
44+
45+
param_scheduler = [
46+
dict(
47+
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
48+
dict(
49+
type='PolyLR',
50+
eta_min=0.0,
51+
power=1.0,
52+
begin=1500,
53+
end=160000,
54+
by_epoch=False,
55+
)
56+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
_base_ = [
2+
'../_base_/models/san_vit-b16.py',
3+
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_160k.py'
5+
]
6+
crop_size = (640, 640)
7+
8+
metainfo = dict(
9+
classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
10+
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
11+
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
12+
palette=[[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
13+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
14+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
15+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
16+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]])
17+
test_pipeline = [
18+
dict(type='LoadImageFromFile'),
19+
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
20+
dict(type='LoadAnnotations'),
21+
dict(type='PackSegInputs')
22+
]
23+
# By default, models are trained on 8 GPUs with 2 images per GPU
24+
train_dataloader = dict(batch_size=2)
25+
val_dataloader = dict(
26+
batch_size=1, dataset=dict(metainfo=metainfo, pipeline=test_pipeline))
27+
test_dataloader = val_dataloader
28+
29+
data_preprocessor = dict(
30+
mean=[122.7709, 116.7460, 104.0937],
31+
std=[68.5005, 66.6322, 70.3232],
32+
size_divisor=640,
33+
test_cfg=dict(size_divisor=32))
34+
model = dict(
35+
data_preprocessor=data_preprocessor,
36+
pretrained='pretrain/vit_base_patch16_224.pth',
37+
text_encoder=dict(dataset_name='voc'),
38+
decode_head=dict(num_classes=20))
39+
40+
# AdamW optimizer, no weight decay for position embedding & layer norm
41+
# in backbone
42+
optim_wrapper = dict(
43+
_delete_=True,
44+
type='OptimWrapper',
45+
optimizer=dict(
46+
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
47+
paramwise_cfg=dict(
48+
custom_keys={
49+
'pos_embed': dict(decay_mult=0.),
50+
'cls_token': dict(decay_mult=0.),
51+
'norm': dict(decay_mult=0.)
52+
}))
53+
54+
param_scheduler = [
55+
dict(
56+
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
57+
dict(
58+
type='PolyLR',
59+
eta_min=0.0,
60+
power=1.0,
61+
begin=1500,
62+
end=160000,
63+
by_epoch=False,
64+
)
65+
]

0 commit comments

Comments
 (0)