Skip to content

Commit 7ddd2fe

Browse files
authored
[Feature] Support ConvNext (open-mmlab#1216)
* upload original backbone and configs * ConvNext Refactor * ConvNext Refactor * convnext customization refactor with mmseg style * convnext customization refactor with mmseg style * add ade20k_640x640.py * upload files for training * delete dist_optimizer_hook and remove layer_decay_optimizer_constructor * check max(out_indices) < num_stages * add unittest * fix lint error * use MMClassification backbone * fix bugs in base_1k * add mmcls in requirements/mminstall.txt * add mmcls in requirements/mminstall.txt * fix drop_path_rate and layer_scale_init_value * use logger.info instead of print * add mmcls in runtime.txt * fix f string && delete * add doctring in LearningRateDecayOptimizerConstructor and fix mmcls version in requirements * fix typo in LearningRateDecayOptimizerConstructor * use ConvNext models in unit test for LearningRateDecayOptimizerConstructor * add unit test * fix typo * fix typo * add layer_wise and fix redundant backbone.downsample_norm in it * fix unit test * give a ground truth lr_scale and weight_decay * upload models and readme * delete 'backbone.stem_norm' and 'backbone.downsample_norm' in get_num_layer() * fix unit test and use mmcls url * update md2yml.py and metafile * fix typo
1 parent 369a2ee commit 7ddd2fe

19 files changed

+936
-4
lines changed

.dev/md2yml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ def parse_md(md_file):
162162
model_name = fn[:-3]
163163
fps = els[fps_id] if els[fps_id] != '-' and els[
164164
fps_id] != '' else -1
165-
mem = els[mem_id] if els[mem_id] != '-' and els[
166-
mem_id] != '' else -1
165+
mem = els[mem_id].split(
166+
'\\'
167+
)[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
167168
crop_size = els[crop_size_id].split('x')
168169
assert len(crop_size) == 2
169170
method = els[method_id].split()[0].split('-')[-1]

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Supported backbones:
8484
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
8585
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
8686
- [x] [Twins (NeurIPS'2021)](configs/twins)
87+
- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
8788

8889
Supported methods:
8990

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
8383
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
8484
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
8585
- [x] [Twins (NeurIPS'2021)](configs/twins)
86+
- [x] [ConvNeXt (ArXiv'2022)](configs/convnext)
8687

8788
已支持的算法:
8889

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# dataset settings
2+
dataset_type = 'ADE20KDataset'
3+
data_root = 'data/ade/ADEChallengeData2016'
4+
img_norm_cfg = dict(
5+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6+
crop_size = (640, 640)
7+
train_pipeline = [
8+
dict(type='LoadImageFromFile'),
9+
dict(type='LoadAnnotations', reduce_zero_label=True),
10+
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
11+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12+
dict(type='RandomFlip', prob=0.5),
13+
dict(type='PhotoMetricDistortion'),
14+
dict(type='Normalize', **img_norm_cfg),
15+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16+
dict(type='DefaultFormatBundle'),
17+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18+
]
19+
test_pipeline = [
20+
dict(type='LoadImageFromFile'),
21+
dict(
22+
type='MultiScaleFlipAug',
23+
img_scale=(2560, 640),
24+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25+
flip=False,
26+
transforms=[
27+
dict(type='Resize', keep_ratio=True),
28+
dict(type='RandomFlip'),
29+
dict(type='Normalize', **img_norm_cfg),
30+
dict(type='ImageToTensor', keys=['img']),
31+
dict(type='Collect', keys=['img']),
32+
])
33+
]
34+
data = dict(
35+
samples_per_gpu=4,
36+
workers_per_gpu=4,
37+
train=dict(
38+
type=dataset_type,
39+
data_root=data_root,
40+
img_dir='images/training',
41+
ann_dir='annotations/training',
42+
pipeline=train_pipeline),
43+
val=dict(
44+
type=dataset_type,
45+
data_root=data_root,
46+
img_dir='images/validation',
47+
ann_dir='annotations/validation',
48+
pipeline=test_pipeline),
49+
test=dict(
50+
type=dataset_type,
51+
data_root=data_root,
52+
img_dir='images/validation',
53+
ann_dir='annotations/validation',
54+
pipeline=test_pipeline))
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
norm_cfg = dict(type='SyncBN', requires_grad=True)
2+
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
3+
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth' # noqa
4+
model = dict(
5+
type='EncoderDecoder',
6+
pretrained=None,
7+
backbone=dict(
8+
type='mmcls.ConvNeXt',
9+
arch='base',
10+
out_indices=[0, 1, 2, 3],
11+
drop_path_rate=0.4,
12+
layer_scale_init_value=1.0,
13+
gap_before_final_norm=False,
14+
init_cfg=dict(
15+
type='Pretrained', checkpoint=checkpoint_file,
16+
prefix='backbone.')),
17+
decode_head=dict(
18+
type='UPerHead',
19+
in_channels=[128, 256, 512, 1024],
20+
in_index=[0, 1, 2, 3],
21+
pool_scales=(1, 2, 3, 6),
22+
channels=512,
23+
dropout_ratio=0.1,
24+
num_classes=19,
25+
norm_cfg=norm_cfg,
26+
align_corners=False,
27+
loss_decode=dict(
28+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29+
auxiliary_head=dict(
30+
type='FCNHead',
31+
in_channels=384,
32+
in_index=2,
33+
channels=256,
34+
num_convs=1,
35+
concat_input=False,
36+
dropout_ratio=0.1,
37+
num_classes=19,
38+
norm_cfg=norm_cfg,
39+
align_corners=False,
40+
loss_decode=dict(
41+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42+
# model training and testing settings
43+
train_cfg=dict(),
44+
test_cfg=dict(mode='whole'))

configs/convnext/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# ConvNeXt
2+
3+
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)
4+
5+
## Introduction
6+
7+
<!-- [BACKBONE] -->
8+
9+
<a href="https://github.com/facebookresearch/ConvNeXt">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmclassification/blob/v0.20.1/mmcls/models/backbones/convnext.py#L133">Code Snippet</a>
12+
13+
## Abstract
14+
15+
<!-- [ABSTRACT] -->
16+
17+
The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets.
18+
19+
<!-- [IMAGE] -->
20+
<div align=center>
21+
<img src="https://user-images.githubusercontent.com/8370623/148624004-e9581042-ea4d-4e10-b3bd-42c92b02053b.png" width="90%"/>
22+
</div>
23+
24+
```bibtex
25+
@article{liu2022convnet,
26+
title={A ConvNet for the 2020s},
27+
author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
28+
journal={arXiv preprint arXiv:2201.03545},
29+
year={2022}
30+
}
31+
```
32+
33+
### Usage
34+
35+
- This backbone need to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
36+
37+
```shell
38+
pip install mmcls>=0.20.1
39+
```
40+
41+
### Pre-trained Models
42+
43+
The pre-trained models on ImageNet-1k or ImageNet-21k are used to fine-tune on the downstream tasks.
44+
45+
| Model | Training Data | Params(M) | Flops(G) | Download |
46+
|:--------------:|:-------------:|:---------:|:--------:|:--------:|
47+
| ConvNeXt-T\* | ImageNet-1k | 28.59 | 4.46 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth) |
48+
| ConvNeXt-S\* | ImageNet-1k | 50.22 | 8.69 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-small_3rdparty_32xb128-noema_in1k_20220301-303e75e3.pth) |
49+
| ConvNeXt-B\* | ImageNet-1k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_32xb128-noema_in1k_20220301-2a0ee547.pth) |
50+
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-base_3rdparty_in21k_20220301-262fd037.pth) |
51+
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth) |
52+
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-xlarge_3rdparty_in21k_20220301-08aa5ddc.pth) |
53+
54+
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt/tree/main/semantic_segmentation#results-and-fine-tuned-models).*
55+
56+
## Results and models
57+
58+
### ADE20K
59+
60+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
61+
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
62+
| UperNet | ConvNeXt-T | 512x512 | 160000 | 4.23 | 19.90 | 46.11 | 46.62 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553.log.json) |
63+
| UperNet | ConvNeXt-S | 512x512 | 160000 | 5.16 | 15.18 | 48.56 | 49.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208.log.json) |
64+
| UperNet | ConvNeXt-B | 512x512 | 160000 | 6.33 | 14.41 | 48.71 | 49.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227.log.json) |
65+
| UperNet | ConvNeXt-B |640x640 | 160000 | 8.53 | 10.88 | 52.13 | 52.66 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859.log.json) |
66+
| UperNet | ConvNeXt-L |640x640 | 160000 | 12.08 | 7.69 | 53.16 | 53.38 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532.log.json) |
67+
| UperNet | ConvNeXt-XL |640x640 | 160000 | 26.16\* | 6.33 | 53.58 | 54.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344.log.json) |
68+
69+
Note:
70+
71+
- `Mem (GB)` with \* is collected when `cudnn_benchmark=True`, and hardware is V100.

configs/convnext/convnext.yml

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
Models:
2+
- Name: upernet_convnext_tiny_fp16_512x512_160k_ade20k
3+
In Collection: UperNet
4+
Metadata:
5+
backbone: ConvNeXt-T
6+
crop size: (512,512)
7+
lr schd: 160000
8+
inference time (ms/im):
9+
- value: 50.25
10+
hardware: V100
11+
backend: PyTorch
12+
batch size: 1
13+
mode: FP16
14+
resolution: (512,512)
15+
Training Memory (GB): 4.23
16+
Results:
17+
- Task: Semantic Segmentation
18+
Dataset: ADE20K
19+
Metrics:
20+
mIoU: 46.11
21+
mIoU(ms+flip): 46.62
22+
Config: configs/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k.py
23+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth
24+
- Name: upernet_convnext_small_fp16_512x512_160k_ade20k
25+
In Collection: UperNet
26+
Metadata:
27+
backbone: ConvNeXt-S
28+
crop size: (512,512)
29+
lr schd: 160000
30+
inference time (ms/im):
31+
- value: 65.88
32+
hardware: V100
33+
backend: PyTorch
34+
batch size: 1
35+
mode: FP16
36+
resolution: (512,512)
37+
Training Memory (GB): 5.16
38+
Results:
39+
- Task: Semantic Segmentation
40+
Dataset: ADE20K
41+
Metrics:
42+
mIoU: 48.56
43+
mIoU(ms+flip): 49.02
44+
Config: configs/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k.py
45+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth
46+
- Name: upernet_convnext_base_fp16_512x512_160k_ade20k
47+
In Collection: UperNet
48+
Metadata:
49+
backbone: ConvNeXt-B
50+
crop size: (512,512)
51+
lr schd: 160000
52+
inference time (ms/im):
53+
- value: 69.4
54+
hardware: V100
55+
backend: PyTorch
56+
batch size: 1
57+
mode: FP16
58+
resolution: (512,512)
59+
Training Memory (GB): 6.33
60+
Results:
61+
- Task: Semantic Segmentation
62+
Dataset: ADE20K
63+
Metrics:
64+
mIoU: 48.71
65+
mIoU(ms+flip): 49.54
66+
Config: configs/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k.py
67+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth
68+
- Name: upernet_convnext_base_fp16_640x640_160k_ade20k
69+
In Collection: UperNet
70+
Metadata:
71+
backbone: ConvNeXt-B
72+
crop size: (640,640)
73+
lr schd: 160000
74+
inference time (ms/im):
75+
- value: 91.91
76+
hardware: V100
77+
backend: PyTorch
78+
batch size: 1
79+
mode: FP16
80+
resolution: (640,640)
81+
Training Memory (GB): 8.53
82+
Results:
83+
- Task: Semantic Segmentation
84+
Dataset: ADE20K
85+
Metrics:
86+
mIoU: 52.13
87+
mIoU(ms+flip): 52.66
88+
Config: configs/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k.py
89+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_640x640_160k_ade20k/upernet_convnext_base_fp16_640x640_160k_ade20k_20220227_182859-9280e39b.pth
90+
- Name: upernet_convnext_large_fp16_640x640_160k_ade20k
91+
In Collection: UperNet
92+
Metadata:
93+
backbone: ConvNeXt-L
94+
crop size: (640,640)
95+
lr schd: 160000
96+
inference time (ms/im):
97+
- value: 130.04
98+
hardware: V100
99+
backend: PyTorch
100+
batch size: 1
101+
mode: FP16
102+
resolution: (640,640)
103+
Training Memory (GB): 12.08
104+
Results:
105+
- Task: Semantic Segmentation
106+
Dataset: ADE20K
107+
Metrics:
108+
mIoU: 53.16
109+
mIoU(ms+flip): 53.38
110+
Config: configs/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k.py
111+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth
112+
- Name: upernet_convnext_xlarge_fp16_640x640_160k_ade20k
113+
In Collection: UperNet
114+
Metadata:
115+
backbone: ConvNeXt-XL
116+
crop size: (640,640)
117+
lr schd: 160000
118+
inference time (ms/im):
119+
- value: 157.98
120+
hardware: V100
121+
backend: PyTorch
122+
batch size: 1
123+
mode: FP16
124+
resolution: (640,640)
125+
Training Memory (GB): 26.16
126+
Results:
127+
- Task: Semantic Segmentation
128+
Dataset: ADE20K
129+
Metrics:
130+
mIoU: 53.58
131+
mIoU(ms+flip): 54.11
132+
Config: configs/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k.py
133+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
_base_ = [
2+
'../_base_/models/upernet_convnext.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
crop_size = (512, 512)
6+
model = dict(
7+
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
8+
auxiliary_head=dict(in_channels=512, num_classes=150),
9+
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341)),
10+
)
11+
12+
optimizer = dict(
13+
constructor='LearningRateDecayOptimizerConstructor',
14+
_delete_=True,
15+
type='AdamW',
16+
lr=0.0001,
17+
betas=(0.9, 0.999),
18+
weight_decay=0.05,
19+
paramwise_cfg={
20+
'decay_rate': 0.9,
21+
'decay_type': 'stage_wise',
22+
'num_layers': 12
23+
})
24+
25+
lr_config = dict(
26+
_delete_=True,
27+
policy='poly',
28+
warmup='linear',
29+
warmup_iters=1500,
30+
warmup_ratio=1e-6,
31+
power=1.0,
32+
min_lr=0.0,
33+
by_epoch=False)
34+
35+
# By default, models are trained on 8 GPUs with 2 images per GPU
36+
data = dict(samples_per_gpu=2)
37+
# fp16 settings
38+
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
39+
# fp16 placeholder
40+
fp16 = dict()

0 commit comments

Comments
 (0)