Skip to content

Commit 2825efe

Browse files
谢昕辰Junjun2016
谢昕辰
andauthored
[Feature] add DPT head (open-mmlab#605)
* add DPT head * [fix] fix init error * use mmcv function * delete code * remove transpose clas * support NLC output shape * Delete post_process_layer.py * add unittest and docstring * rename variables * fix project error and add unittest * match dpt weights * add configs * fix vit pos_embed bug and dpt feature fusion bug * match vit output * fix gelu * minor change * update unitest * fix configs error * inference test * remove auxilary * use local pretrain * update training results * update yml * update fps and memory test * update doc * update readme * add yml * update doc * remove with_cp * update config * update docstring * remove dpt-l * add init_cfg and modify readme.md * Update dpt_vit-b16.py * zh-n README * use constructor instead of build function * prevent tensor being modified by ConvModule * fix unittest Co-authored-by: Junjun2016 <[email protected]>
1 parent 5753f41 commit 2825efe

File tree

8 files changed

+482
-1
lines changed

8 files changed

+482
-1
lines changed

configs/_base_/models/dpt_vit-b16.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
norm_cfg = dict(type='SyncBN', requires_grad=True)
2+
model = dict(
3+
type='EncoderDecoder',
4+
pretrained='pretrain/vit-b16_p16_224-80ecf9dd.pth', # noqa
5+
backbone=dict(
6+
type='VisionTransformer',
7+
img_size=224,
8+
embed_dims=768,
9+
num_layers=12,
10+
num_heads=12,
11+
out_indices=(2, 5, 8, 11),
12+
final_norm=False,
13+
with_cls_token=True,
14+
output_cls_token=True),
15+
decode_head=dict(
16+
type='DPTHead',
17+
in_channels=(768, 768, 768, 768),
18+
channels=256,
19+
embed_dims=768,
20+
post_process_channels=[96, 192, 384, 768],
21+
num_classes=150,
22+
readout_type='project',
23+
input_transform='multiple_select',
24+
in_index=(0, 1, 2, 3),
25+
norm_cfg=norm_cfg,
26+
loss_decode=dict(
27+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
28+
auxiliary_head=None,
29+
# model training and testing settings
30+
train_cfg=dict(),
31+
test_cfg=dict(mode='whole')) # yapf: disable

configs/dpt/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Vision Transformer for Dense Prediction
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
```latex
8+
@article{dosoViTskiy2020,
9+
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
10+
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
11+
journal={arXiv preprint arXiv:2010.11929},
12+
year={2020}
13+
}
14+
15+
@article{Ranftl2021,
16+
author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
17+
title = {Vision Transformers for Dense Prediction},
18+
journal = {ArXiv preprint},
19+
year = {2021},
20+
}
21+
```
22+
23+
## Usage
24+
25+
To use other repositories' pre-trained models, it is necessary to convert keys.
26+
27+
We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.
28+
29+
```shell
30+
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
31+
```
32+
33+
E.g.
34+
35+
```shell
36+
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
37+
```
38+
39+
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
40+
41+
## Results and models
42+
43+
### ADE20K
44+
45+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
46+
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
47+
| DPT | ViT-B | 512x512 | 160000 | 8.09 | 10.41 | 46.97 | 48.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-20210809_172025.log.json) |

configs/dpt/dpt.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Collections:
2+
- Metadata:
3+
Training Data:
4+
- ADE20K
5+
Name: dpt
6+
Models:
7+
- Config: configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
8+
In Collection: dpt
9+
Metadata:
10+
backbone: ViT-B
11+
crop size: (512,512)
12+
inference time (ms/im):
13+
- backend: PyTorch
14+
batch size: 1
15+
hardware: V100
16+
mode: FP32
17+
resolution: (512,512)
18+
value: 96.06
19+
lr schd: 160000
20+
memory (GB): 8.09
21+
Name: dpt_vit-b16_512x512_160k_ade20k
22+
Results:
23+
Dataset: ADE20K
24+
Metrics:
25+
mIoU: 46.97
26+
mIoU(ms+flip): 48.34
27+
Task: Semantic Segmentation
28+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
_base_ = [
2+
'../_base_/models/dpt_vit-b16.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
6+
# AdamW optimizer, no weight decay for position embedding & layer norm
7+
# in backbone
8+
optimizer = dict(
9+
_delete_=True,
10+
type='AdamW',
11+
lr=0.00006,
12+
betas=(0.9, 0.999),
13+
weight_decay=0.01,
14+
paramwise_cfg=dict(
15+
custom_keys={
16+
'pos_embed': dict(decay_mult=0.),
17+
'cls_token': dict(decay_mult=0.),
18+
'norm': dict(decay_mult=0.)
19+
}))
20+
21+
lr_config = dict(
22+
_delete_=True,
23+
policy='poly',
24+
warmup='linear',
25+
warmup_iters=1500,
26+
warmup_ratio=1e-6,
27+
power=1.0,
28+
min_lr=0.0,
29+
by_epoch=False)
30+
31+
# By default, models are trained on 8 GPUs with 2 images per GPU
32+
data = dict(samples_per_gpu=2, workers_per_gpu=2)

mmseg/models/decode_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .da_head import DAHead
77
from .dm_head import DMHead
88
from .dnl_head import DNLHead
9+
from .dpt_head import DPTHead
910
from .ema_head import EMAHead
1011
from .enc_head import EncHead
1112
from .fcn_head import FCNHead
@@ -29,5 +30,5 @@
2930
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
3031
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
3132
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
32-
'SETRMLAHead', 'SegformerHead'
33+
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead'
3334
]

0 commit comments

Comments
 (0)