Skip to content

Commit 25604a1

Browse files
authored
[Feature] Support PoolFormer in MMSegmentation 2.0 (open-mmlab#2191)
* [Feature] 2.0 PoolFormer * fix mmcls version * fix ut error * fix ut * fix ut
1 parent e8af7a0 commit 25604a1

12 files changed

+360
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Supported backbones:
102102
- [x] [BEiT (ICLR'2022)](configs/beit)
103103
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
104104
- [x] [MAE (CVPR'2022)](configs/mae)
105+
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
105106

106107
Supported methods:
107108

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
9696
- [x] [BEiT (ICLR'2022)](configs/beit)
9797
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
9898
- [x] [MAE (CVPR'2022)](configs/mae)
99+
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
99100

100101
已支持的算法:
101102

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
4+
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
5+
data_preprocessor = dict(
6+
type='SegDataPreProcessor',
7+
mean=[123.675, 116.28, 103.53],
8+
std=[58.395, 57.12, 57.375],
9+
bgr_to_rgb=True,
10+
pad_val=0,
11+
seg_pad_val=255)
12+
model = dict(
13+
type='EncoderDecoder',
14+
data_preprocessor=data_preprocessor,
15+
backbone=dict(
16+
type='mmcls.PoolFormer',
17+
arch='s12',
18+
init_cfg=dict(
19+
type='Pretrained', checkpoint=checkpoint_file, prefix='backbone.'),
20+
in_patch_size=7,
21+
in_stride=4,
22+
in_pad=2,
23+
down_patch_size=3,
24+
down_stride=2,
25+
down_pad=1,
26+
drop_rate=0.,
27+
drop_path_rate=0.,
28+
out_indices=(0, 2, 4, 6),
29+
frozen_stages=0,
30+
),
31+
neck=dict(
32+
type='FPN',
33+
in_channels=[256, 512, 1024, 2048],
34+
out_channels=256,
35+
num_outs=4),
36+
decode_head=dict(
37+
type='FPNHead',
38+
in_channels=[256, 256, 256, 256],
39+
in_index=[0, 1, 2, 3],
40+
feature_strides=[4, 8, 16, 32],
41+
channels=128,
42+
dropout_ratio=0.1,
43+
num_classes=19,
44+
norm_cfg=norm_cfg,
45+
align_corners=False,
46+
loss_decode=dict(
47+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
48+
# model training and testing settings
49+
train_cfg=dict(),
50+
test_cfg=dict(mode='whole'))

configs/poolformer/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# PoolFormer
2+
3+
[MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418)
4+
5+
## Introduction
6+
7+
<!-- [BACKBONE] -->
8+
9+
<a href="https://github.com/sail-sg/poolformer/tree/main/segmentation">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/poolformer.py#L198">Code Snippet</a>
12+
13+
## Abstract
14+
15+
<!-- [ABSTRACT] -->
16+
17+
Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design. Code is available at [this https URL](https://github.com/sail-sg/poolformer)
18+
19+
<!-- [IMAGE] -->
20+
21+
<div align=center>
22+
<img src="https://user-images.githubusercontent.com/15921929/144710761-1635f59a-abde-4946-984c-a2c3f22a19d2.png" width="70%"/>
23+
</div>
24+
25+
## Citation
26+
27+
```bibtex
28+
@inproceedings{yu2022metaformer,
29+
title={Metaformer is actually what you need for vision},
30+
author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
31+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
32+
pages={10819--10829},
33+
year={2022}
34+
}
35+
```
36+
37+
### Usage
38+
39+
- PoolFormer backbone needs to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
40+
41+
```shell
42+
pip install "mmcls>=1.0.0rc0"
43+
```
44+
45+
- The pretrained models could also be downloaded from [PoolFormer config of MMClassification](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer).
46+
47+
## Results and models
48+
49+
### ADE20K
50+
51+
| Method | Backbone | Crop Size | pretrain | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | mIoU\* | mIoU\*(ms+flip) | config | download |
52+
| ------ | -------------- | --------- | ----------- | ---------- | ------- | -------- | -------------- | ----- | ------------: | ------ | --------------: | ------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
53+
| FPN | PoolFormer-S12 | 512x512 | ImageNet-1K | 32 | 40000 | 4.17 | 23.48 | 36.68 | - | 37.07 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154.log.json) |
54+
| FPN | PoolFormer-S24 | 512x512 | ImageNet-1K | 32 | 40000 | 5.47 | 15.74 | 40.12 | - | 40.36 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s24_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049.log.json) |
55+
| FPN | PoolFormer-S36 | 512x512 | ImageNet-1K | 32 | 40000 | 6.77 | 11.34 | 41.61 | - | 41.81 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_s36_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122-b47e607d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122.log.json) |
56+
| FPN | PoolFormer-M36 | 512x512 | ImageNet-1K | 32 | 40000 | 8.59 | 8.97 | 41.95 | - | 42.35 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_m36_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230.log.json) |
57+
| FPN | PoolFormer-M48 | 512x512 | ImageNet-1K | 32 | 40000 | 10.48 | 6.69 | 42.43 | - | 42.76 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/poolformer/fpn_poolformer_m48_8xb4-40k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923.log.json) |
58+
59+
Note:
60+
61+
- We replace `AlignedResize` in original PoolFormer implementation to `Resize + ResizeToMultiple`.
62+
63+
- `mIoU` with * is collected when `Resize + ResizeToMultiple` is adopted in `test_pipeline`, so do `mIoU` in logs.
64+
65+
- The Test Time Augmentation i.e., "ms+flip" in MMSegmentation v1.x is developing, stay tuned!
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m36_3rdparty_32xb128_in1k_20220414-c55e0949.pth' # noqa
3+
4+
# model settings
5+
model = dict(
6+
backbone=dict(
7+
arch='m36',
8+
init_cfg=dict(
9+
type='Pretrained', checkpoint=checkpoint_file,
10+
prefix='backbone.')),
11+
neck=dict(in_channels=[96, 192, 384, 768]))
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth' # noqa
3+
4+
# model settings
5+
model = dict(
6+
backbone=dict(
7+
arch='m48',
8+
init_cfg=dict(
9+
type='Pretrained', checkpoint=checkpoint_file,
10+
prefix='backbone.')),
11+
neck=dict(in_channels=[96, 192, 384, 768]))
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
_base_ = [
2+
'../_base_/models/fpn_poolformer_s12.py', '../_base_/default_runtime.py',
3+
'../_base_/schedules/schedule_40k.py'
4+
]
5+
6+
# dataset settings
7+
dataset_type = 'ADE20KDataset'
8+
data_root = 'data/ade/ADEChallengeData2016'
9+
img_norm_cfg = dict(
10+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
11+
crop_size = (512, 512)
12+
data_preprocessor = dict(size=crop_size)
13+
train_pipeline = [
14+
dict(type='LoadImageFromFile'),
15+
dict(type='LoadAnnotations', reduce_zero_label=True),
16+
dict(
17+
type='RandomResize',
18+
scale=(2048, 512),
19+
ratio_range=(0.5, 2.0),
20+
keep_ratio=True),
21+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
22+
dict(type='RandomFlip', prob=0.5),
23+
dict(type='PhotoMetricDistortion'),
24+
dict(type='PackSegInputs')
25+
]
26+
test_pipeline = [
27+
dict(type='LoadImageFromFile'),
28+
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
29+
dict(type='ResizeToMultiple', size_divisor=32),
30+
# add loading annotation after ``Resize`` because ground truth
31+
# does not need to do resize data transform
32+
dict(type='LoadAnnotations', reduce_zero_label=True),
33+
dict(type='PackSegInputs')
34+
]
35+
36+
train_dataloader = dict(
37+
batch_size=4,
38+
num_workers=4,
39+
persistent_workers=True,
40+
sampler=dict(type='InfiniteSampler', shuffle=True),
41+
dataset=dict(
42+
type='RepeatDataset',
43+
times=50,
44+
dataset=dict(
45+
type=dataset_type,
46+
data_root=data_root,
47+
data_prefix=dict(
48+
img_path='images/training',
49+
seg_map_path='annotations/training'),
50+
pipeline=train_pipeline)))
51+
val_dataloader = dict(
52+
batch_size=1,
53+
num_workers=4,
54+
persistent_workers=True,
55+
sampler=dict(type='DefaultSampler', shuffle=False),
56+
dataset=dict(
57+
type=dataset_type,
58+
data_root=data_root,
59+
data_prefix=dict(
60+
img_path='images/validation',
61+
seg_map_path='annotations/validation'),
62+
pipeline=test_pipeline))
63+
test_dataloader = val_dataloader
64+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
65+
test_evaluator = val_evaluator
66+
67+
# model settings
68+
model = dict(
69+
data_preprocessor=data_preprocessor,
70+
neck=dict(in_channels=[64, 128, 320, 512]),
71+
decode_head=dict(num_classes=150))
72+
73+
# optimizer
74+
# optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.0001)
75+
# optimizer_config = dict()
76+
# # learning policy
77+
# lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False)
78+
optim_wrapper = dict(
79+
_delete_=True,
80+
type='AmpOptimWrapper',
81+
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001))
82+
param_scheduler = [
83+
dict(
84+
type='PolyLR',
85+
power=0.9,
86+
begin=0,
87+
end=40000,
88+
eta_min=0.0,
89+
by_epoch=False,
90+
)
91+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s24_3rdparty_32xb128_in1k_20220414-d7055904.pth' # noqa
3+
# model settings
4+
model = dict(
5+
backbone=dict(
6+
arch='s24',
7+
init_cfg=dict(
8+
type='Pretrained', checkpoint=checkpoint_file,
9+
prefix='backbone.')))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = './fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s36_3rdparty_32xb128_in1k_20220414-d78ff3e8.pth' # noqa
3+
4+
# model settings
5+
model = dict(
6+
backbone=dict(
7+
arch='s36',
8+
init_cfg=dict(
9+
type='Pretrained', checkpoint=checkpoint_file,
10+
prefix='backbone.')))

configs/poolformer/poolformer.yml

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
Models:
2+
- Name: fpn_poolformer_s12_8xb4-40k_ade20k-512x512
3+
In Collection: FPN
4+
Metadata:
5+
backbone: PoolFormer-S12
6+
crop size: (512,512)
7+
lr schd: 40000
8+
inference time (ms/im):
9+
- value: 42.59
10+
hardware: V100
11+
backend: PyTorch
12+
batch size: 1
13+
mode: FP32
14+
resolution: (512,512)
15+
Training Memory (GB): 4.17
16+
Results:
17+
- Task: Semantic Segmentation
18+
Dataset: ADE20K
19+
Metrics:
20+
mIoU: 36.68
21+
Config: configs/poolformer/fpn_poolformer_s12_8xb4-40k_ade20k-512x512.py
22+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth
23+
- Name: fpn_poolformer_s24_8xb4-40k_ade20k-512x512
24+
In Collection: FPN
25+
Metadata:
26+
backbone: PoolFormer-S24
27+
crop size: (512,512)
28+
lr schd: 40000
29+
inference time (ms/im):
30+
- value: 63.53
31+
hardware: V100
32+
backend: PyTorch
33+
batch size: 1
34+
mode: FP32
35+
resolution: (512,512)
36+
Training Memory (GB): 5.47
37+
Results:
38+
- Task: Semantic Segmentation
39+
Dataset: ADE20K
40+
Metrics:
41+
mIoU: 40.12
42+
Config: configs/poolformer/fpn_poolformer_s24_8xb4-40k_ade20k-512x512.py
43+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth
44+
- Name: ''
45+
In Collection: FPN
46+
Metadata:
47+
backbone: PoolFormer-S36
48+
crop size: (512,512)
49+
lr schd: 40000
50+
inference time (ms/im):
51+
- value: 88.18
52+
hardware: V100
53+
backend: PyTorch
54+
batch size: 1
55+
mode: FP32
56+
resolution: (512,512)
57+
Training Memory (GB): 6.77
58+
Results:
59+
- Task: Semantic Segmentation
60+
Dataset: ADE20K
61+
Metrics:
62+
mIoU: 41.61
63+
Config: ''
64+
Weights: ''
65+
- Name: fpn_poolformer_m36_8xb4-40k_ade20k-512x512
66+
In Collection: FPN
67+
Metadata:
68+
backbone: PoolFormer-M36
69+
crop size: (512,512)
70+
lr schd: 40000
71+
inference time (ms/im):
72+
- value: 111.48
73+
hardware: V100
74+
backend: PyTorch
75+
batch size: 1
76+
mode: FP32
77+
resolution: (512,512)
78+
Training Memory (GB): 8.59
79+
Results:
80+
- Task: Semantic Segmentation
81+
Dataset: ADE20K
82+
Metrics:
83+
mIoU: 41.95
84+
Config: configs/poolformer/fpn_poolformer_m36_8xb4-40k_ade20k-512x512.py
85+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth
86+
- Name: fpn_poolformer_m48_8xb4-40k_ade20k-512x512
87+
In Collection: FPN
88+
Metadata:
89+
backbone: PoolFormer-M48
90+
crop size: (512,512)
91+
lr schd: 40000
92+
inference time (ms/im):
93+
- value: 149.48
94+
hardware: V100
95+
backend: PyTorch
96+
batch size: 1
97+
mode: FP32
98+
resolution: (512,512)
99+
Training Memory (GB): 10.48
100+
Results:
101+
- Task: Semantic Segmentation
102+
Dataset: ADE20K
103+
Metrics:
104+
mIoU: 42.43
105+
Config: configs/poolformer/fpn_poolformer_m48_8xb4-40k_ade20k-512x512.py
106+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth

0 commit comments

Comments
 (0)