Skip to content

Commit 24f1563

Browse files
[Feature] Add BEiT backbone (open-mmlab#1404)
* [Feature] Add BEiT backbone * fix * fix * fix * fix * add readme * fix * fix * fix * fix * fix * add link * fix memory * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix test_beit.py * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
1 parent 30864ea commit 24f1563

20 files changed

+1345
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Supported backbones:
8585
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
8686
- [x] [Twins (NeurIPS'2021)](configs/twins)
8787
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
88+
- [x] [BEiT (ICLR'2022)](configs/beit)
8889

8990
Supported methods:
9091

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
8484
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
8585
- [x] [Twins (NeurIPS'2021)](configs/twins)
8686
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
87+
- [x] [BEiT (ICLR'2022)](configs/beit)
8788

8889
已支持的算法:
8990

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
norm_cfg = dict(type='SyncBN', requires_grad=True)
2+
model = dict(
3+
type='EncoderDecoder',
4+
pretrained=None,
5+
backbone=dict(
6+
type='BEiT',
7+
img_size=(640, 640),
8+
patch_size=16,
9+
in_channels=3,
10+
embed_dims=768,
11+
num_layers=12,
12+
num_heads=12,
13+
mlp_ratio=4,
14+
out_indices=(3, 5, 7, 11),
15+
qv_bias=True,
16+
attn_drop_rate=0.0,
17+
drop_path_rate=0.1,
18+
norm_cfg=dict(type='LN', eps=1e-6),
19+
act_cfg=dict(type='GELU'),
20+
norm_eval=False,
21+
init_values=0.1),
22+
neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
23+
decode_head=dict(
24+
type='UPerHead',
25+
in_channels=[768, 768, 768, 768],
26+
in_index=[0, 1, 2, 3],
27+
pool_scales=(1, 2, 3, 6),
28+
channels=768,
29+
dropout_ratio=0.1,
30+
num_classes=150,
31+
norm_cfg=norm_cfg,
32+
align_corners=False,
33+
loss_decode=dict(
34+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
35+
auxiliary_head=dict(
36+
type='FCNHead',
37+
in_channels=768,
38+
in_index=2,
39+
channels=256,
40+
num_convs=1,
41+
concat_input=False,
42+
dropout_ratio=0.1,
43+
num_classes=150,
44+
norm_cfg=norm_cfg,
45+
align_corners=False,
46+
loss_decode=dict(
47+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
48+
# model training and testing settings
49+
train_cfg=dict(),
50+
test_cfg=dict(mode='whole'))

configs/beit/README.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# BEiT
2+
3+
[BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
4+
5+
## Introduction
6+
7+
<!-- [BACKBONE] -->
8+
9+
<a href="https://github.com/microsoft/unilm/tree/master/beit">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/backbones/beit.py#1404">Code Snippet</a>
12+
13+
## Abstract
14+
15+
<!-- [ABSTRACT] -->
16+
17+
We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%). The code and pretrained models are available at [this https URL](https://github.com/microsoft/unilm/tree/master/beit).
18+
19+
<!-- [IMAGE] -->
20+
<div align=center>
21+
<img src="https://user-images.githubusercontent.com/93248678/160155758-781c9a45-b1d7-4530-9015-88eca6645006.png" width="70%"/>
22+
</div>
23+
24+
## Citation
25+
26+
```bibtex
27+
@inproceedings{beit,
28+
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
29+
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
30+
booktitle={International Conference on Learning Representations},
31+
year={2022},
32+
url={https://openreview.net/forum?id=p-BhZSz59o4}
33+
}
34+
```
35+
36+
## Usage
37+
38+
To use other repositories' pre-trained models, it is necessary to convert keys.
39+
40+
We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/microsoft/unilm/tree/master/beit/semantic_segmentation) to MMSegmentation style.
41+
42+
```shell
43+
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
44+
```
45+
46+
E.g.
47+
48+
```shell
49+
python tools/model_converters/beit2mmseg.py https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth pretrain/beit_base_patch16_224_pt22k_ft22k.pth
50+
```
51+
52+
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
53+
54+
In our default setting, pretrained models could be defined below:
55+
56+
| pretrained models | original models |
57+
| ------ | -------- |
58+
|BEiT_base.pth | ['BEiT_base'](https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth) |
59+
|BEiT_large.pth | ['BEiT_large'](https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth) |
60+
61+
Verify the single-scale results of the model:
62+
63+
```shell
64+
sh tools/dist_test.sh \
65+
configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py \
66+
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
67+
```
68+
69+
Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=640, that is, the shortest edge is 640. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
70+
71+
```shell
72+
sh tools/dist_test.sh \
73+
configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py \
74+
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
75+
```
76+
77+
## Results and models
78+
79+
### ADE20K
80+
81+
| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
82+
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
83+
| UperNet | BEiT-B | 640x640 | ImageNet-22K | 224x224 | 16 | 160000 | 15.88 | 2.00 | 53.08 | 53.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k.log.json) |
84+
| UperNet | BEiT-L | 640x640 | ImageNet-22K | 224x224 | 8 | 320000 | 22.64 | 0.96 | 56.33 | 56.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.log.json) |

configs/beit/beit.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
Models:
2+
- Name: upernet_beit-base_8x2_640x640_160k_ade20k
3+
In Collection: UperNet
4+
Metadata:
5+
backbone: BEiT-B
6+
crop size: (640,640)
7+
lr schd: 160000
8+
inference time (ms/im):
9+
- value: 500.0
10+
hardware: V100
11+
backend: PyTorch
12+
batch size: 1
13+
mode: FP32
14+
resolution: (640,640)
15+
Training Memory (GB): 15.88
16+
Results:
17+
- Task: Semantic Segmentation
18+
Dataset: ADE20K
19+
Metrics:
20+
mIoU: 53.08
21+
mIoU(ms+flip): 53.84
22+
Config: configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
23+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth
24+
- Name: upernet_beit-large_fp16_8x1_640x640_160k_ade20k
25+
In Collection: UperNet
26+
Metadata:
27+
backbone: BEiT-L
28+
crop size: (640,640)
29+
lr schd: 320000
30+
inference time (ms/im):
31+
- value: 1041.67
32+
hardware: V100
33+
backend: PyTorch
34+
batch size: 1
35+
mode: FP16
36+
resolution: (640,640)
37+
Training Memory (GB): 22.64
38+
Results:
39+
- Task: Semantic Segmentation
40+
Dataset: ADE20K
41+
Metrics:
42+
mIoU: 56.33
43+
mIoU(ms+flip): 56.84
44+
Config: configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
45+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
_base_ = './upernet_beit-base_8x2_640x640_160k_ade20k.py'
2+
3+
img_norm_cfg = dict(
4+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5+
6+
test_pipeline = [
7+
dict(type='LoadImageFromFile'),
8+
dict(
9+
type='MultiScaleFlipAug',
10+
img_scale=(2560, 640),
11+
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
12+
flip=True,
13+
transforms=[
14+
dict(type='Resize', keep_ratio=True, min_size=640),
15+
dict(type='RandomFlip'),
16+
dict(type='Normalize', **img_norm_cfg),
17+
dict(type='ImageToTensor', keys=['img']),
18+
dict(type='Collect', keys=['img']),
19+
])
20+
]
21+
data = dict(
22+
val=dict(pipeline=test_pipeline),
23+
test=dict(pipeline=test_pipeline),
24+
samples_per_gpu=2)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
_base_ = [
2+
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
6+
model = dict(
7+
pretrained='pretrain/beit_base_patch16_224_pt22k_ft22k.pth',
8+
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
9+
10+
optimizer = dict(
11+
_delete_=True,
12+
type='AdamW',
13+
lr=3e-5,
14+
betas=(0.9, 0.999),
15+
weight_decay=0.05,
16+
constructor='LayerDecayOptimizerConstructor',
17+
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
18+
19+
lr_config = dict(
20+
_delete_=True,
21+
policy='poly',
22+
warmup='linear',
23+
warmup_iters=1500,
24+
warmup_ratio=1e-6,
25+
power=1.0,
26+
min_lr=0.0,
27+
by_epoch=False)
28+
29+
# By default, models are trained on 8 GPUs with 2 images per GPU
30+
data = dict(samples_per_gpu=2)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
_base_ = './upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py'
2+
3+
img_norm_cfg = dict(
4+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5+
6+
test_pipeline = [
7+
dict(type='LoadImageFromFile'),
8+
dict(
9+
type='MultiScaleFlipAug',
10+
img_scale=(2560, 640),
11+
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
12+
flip=True,
13+
transforms=[
14+
dict(type='Resize', keep_ratio=True, min_size=640),
15+
dict(type='RandomFlip'),
16+
dict(type='Normalize', **img_norm_cfg),
17+
dict(type='ImageToTensor', keys=['img']),
18+
dict(type='Collect', keys=['img']),
19+
])
20+
]
21+
data = dict(
22+
val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline))
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
_base_ = [
2+
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_320k.py'
4+
]
5+
6+
model = dict(
7+
pretrained='pretrain/beit_large_patch16_224_pt22k_ft22k.pth',
8+
backbone=dict(
9+
type='BEiT',
10+
embed_dims=1024,
11+
num_layers=24,
12+
num_heads=16,
13+
mlp_ratio=4,
14+
qv_bias=True,
15+
init_values=1e-6,
16+
drop_path_rate=0.2,
17+
out_indices=[7, 11, 15, 23]),
18+
neck=dict(embed_dim=1024, rescales=[4, 2, 1, 0.5]),
19+
decode_head=dict(
20+
in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024),
21+
auxiliary_head=dict(in_channels=1024, num_classes=150),
22+
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
23+
24+
optimizer = dict(
25+
_delete_=True,
26+
type='AdamW',
27+
lr=2e-5,
28+
betas=(0.9, 0.999),
29+
weight_decay=0.05,
30+
constructor='LayerDecayOptimizerConstructor',
31+
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))
32+
33+
lr_config = dict(
34+
_delete_=True,
35+
policy='poly',
36+
warmup='linear',
37+
warmup_iters=3000,
38+
warmup_ratio=1e-6,
39+
power=1.0,
40+
min_lr=0.0,
41+
by_epoch=False)
42+
43+
data = dict(samples_per_gpu=1)
44+
optimizer_config = dict(
45+
type='GradientCumulativeFp16OptimizerHook', cumulative_iters=2)
46+
47+
fp16 = dict()

mmseg/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .evaluation import * # noqa: F401, F403
3+
from .layer_decay_optimizer_constructor import \
4+
LayerDecayOptimizerConstructor # noqa: F401
35
from .seg import * # noqa: F401, F403
46
from .utils import * # noqa: F401, F403

0 commit comments

Comments
 (0)