Skip to content

Commit 409caf8

Browse files
zhby99MeowZheng
andauthored
[DEST] add DEST model (open-mmlab#2482)
## Motivation We are from NVIDIA and we have developed a simplified and inference-efficient transformer for dense prediction tasks. The method is based on SegFormer with hardware-friendly design choices, resulting in better accuracy and over 2x reduction in inference speed as compared to the baseline. We believe this model would be of particular interests to those who want to deploy an efficient vision transformer for production, and it is easily adaptable to other tasks. Therefore, we would like to contribute our method to mmsegmentation in order to benefit a larger audience. The paper was accepted to [Transformer for Vision workshop](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fsites.google.com%2Fview%2Ft4v-cvpr22%2Fpapers%3Fauthuser%3D0&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=XtSgPQrbVgHxt5L9XkXF%2BGWvc95haB3kKPcHnsVIF3M%3D&reserved=0) at CVPR 2022, here below are some resource links: Paper [https://arxiv.org/pdf/2204.13791.pdf](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Farxiv.org%2Fpdf%2F2204.13791.pdf&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=X%2FCVoa6PFA09EHfClES36QOa5NvbZu%2F6IDfBVwiYywU%3D&reserved=0) (Table 3 shows the semseg results) Code [https://github.com/NVIDIA/DL4AGX/tree/master/DEST](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FNVIDIA%2FDL4AGX%2Ftree%2Fmaster%2FDEST&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=9DLQZpEq1cN75%2FDf%2FniUOOUFS1ABX8FEUH02O6isGVQ%3D&reserved=0) A webinar on its application [https://www.nvidia.com/en-us/on-demand/session/other2022-drivetraining/](https://nam11.safelinks.protection.outlook.com/?url=https%3A%2F%2Fwww.nvidia.com%2Fen-us%2Fon-demand%2Fsession%2Fother2022-drivetraining%2F&data=05%7C01%7Cboyinz%40nvidia.com%7Cbf078d69821449d1f4c908dab5e8c7da%7C43083d15727340c1b7db39efd9ccc17a%7C0%7C0%7C638022308636438546%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=8jrBC%2Bp3jGxiaW4vtSfhh6GozC3tRqGNjNoALM%2FOYxs%3D&reserved=0) ## Modification Add backbone(smit.py) and head(dest_head.py) of DEST ## BC-breaking (Optional) N/A ## Use cases (Optional) N/A --------- Co-authored-by: MeowZheng <[email protected]>
1 parent 486a409 commit 409caf8

14 files changed

+924
-3
lines changed

LICENSES.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.
44

5-
| Feature | Files | License |
6-
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
7-
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
5+
| Feature | Files | License |
6+
| :-------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------: |
7+
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
8+
| DEST | [mmseg/models/backbones/smit.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/smit.py) [mmseg/models/decode_heads/dest_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/dest_head.py) | [NVIDIA License](https://github.com/NVIDIA/DL4AGX/blob/master/DEST/LICENSE) |

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ Supported methods:
182182
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
183183
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
184184
- [x] [K-Net (NeurIPS'2021)](configs/knet)
185+
- [x] [DEST (CVPRW'2022)](projects/dest)
185186

186187
Supported datasets:
187188

projects/dest/README.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# DEST
2+
3+
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
4+
5+
## Description
6+
7+
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
8+
9+
## Usage
10+
11+
### Prerequisites
12+
13+
- Python 3.8.12
14+
- PyTorch 1.11
15+
- mmcv v1.7.0
16+
- Install [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) from source
17+
18+
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the mmsegmentaions directory so that Python can locate the configuration files in mmsegmentation.
19+
20+
### Dataset preparing
21+
22+
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)
23+
24+
### Training commands
25+
26+
```shell
27+
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest
28+
```
29+
30+
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
31+
32+
```shell
33+
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --launcher pytorch --gpus 8
34+
```
35+
36+
### Testing commands
37+
38+
```shell
39+
mim test mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --checkpoint ${CHECKPOINT_PATH} --eval mIoU
40+
```
41+
42+
## Results and models
43+
44+
### Cityscapes
45+
46+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
47+
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
48+
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025.log) |
49+
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.logmmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.log) |
50+
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943.log) |
51+
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800.log) |
52+
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155.log) |
53+
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411.log) |
54+
55+
Note:
56+
57+
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
58+
- Training of DEST is not very stable, which is sensitive to random seeds.
59+
60+
## Citation
61+
62+
```bibtex
63+
@article{YangDEST,
64+
title={Depth Estimation with Simplified Transformer},
65+
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
66+
journal={arXiv preprint arXiv:2204.13791},
67+
year={2022}
68+
}
69+
```
70+
71+
## Checklist
72+
73+
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
74+
75+
- [x] Finish the code
76+
77+
- [x] Basic docstrings & proper citation
78+
79+
- [x] Test-time correctness
80+
81+
- [x] A full README
82+
83+
- [x] Milestone 2: Indicates a successful model implementation.
84+
85+
- [x] Training-time correctness
86+
87+
- [ ] Milestone 3: Good to be a part of our core package!
88+
89+
- [ ] Type hints and docstrings
90+
91+
- [ ] Unit tests
92+
93+
- [ ] Code polishing
94+
95+
- [ ] Metafile.yml
96+
97+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
98+
99+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.

projects/dest/configs/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# DEST
2+
3+
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)
4+
5+
## Introduction
6+
7+
<!-- [ALGORITHM] -->
8+
9+
<a href="https://github.com/NVIDIA/DL4AGX/tree/master/DEST">Official Repo</a>
10+
11+
## Abstract
12+
13+
<!-- [ABSTRACT] -->
14+
15+
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.
16+
17+
<!-- [IMAGE] -->
18+
19+
<div align=center>
20+
<img src="https://user-images.githubusercontent.com/76149310/219313665-49fa89ed-4973-4496-bb33-3256f107e82d.png" width="70%"/>
21+
</div>
22+
23+
## Citation
24+
25+
```bibtex
26+
@article{YangDEST,
27+
title={Depth Estimation with Simplified Transformer},
28+
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
29+
journal={arXiv preprint arXiv:2204.13791},
30+
year={2022}
31+
}
32+
```
33+
34+
## Results and models
35+
36+
### Cityscapes
37+
38+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
39+
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
40+
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) |
41+
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) |
42+
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) |
43+
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) |
44+
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) |
45+
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) |
46+
47+
Note:
48+
49+
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
50+
- Training of DEST is not very stable, which is sensitive to random seeds.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# model settings
2+
embed_dims = [32, 64, 160, 256]
3+
norm_cfg = dict(type='SyncBN', requires_grad=True)
4+
model = dict(
5+
type='EncoderDecoder',
6+
pretrained=None,
7+
backbone=dict(
8+
type='SimplifiedMixTransformer',
9+
in_channels=3,
10+
embed_dims=embed_dims,
11+
num_stages=4,
12+
num_layers=[2, 2, 2, 2],
13+
num_heads=[1, 2, 5, 8],
14+
patch_sizes=[7, 3, 3, 3],
15+
strides=[4, 2, 2, 2],
16+
sr_ratios=[8, 4, 2, 1],
17+
out_indices=(0, 1, 2, 3),
18+
mlp_ratios=[8, 8, 4, 4],
19+
qkv_bias=True,
20+
drop_rate=0.0,
21+
attn_drop_rate=0.0,
22+
drop_path_rate=0.1,
23+
norm_cfg=norm_cfg),
24+
decode_head=dict(
25+
type='DESTHead',
26+
in_channels=[32, 64, 160, 256],
27+
in_index=[0, 1, 2, 3],
28+
channels=32,
29+
dropout_ratio=0.1,
30+
num_classes=19,
31+
norm_cfg=norm_cfg,
32+
align_corners=False,
33+
loss_decode=dict(
34+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
35+
# model training and testing settings
36+
train_cfg=dict(),
37+
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
_base_ = [
2+
'./dest_simpatt-b0.py',
3+
'../../../configs/_base_/datasets/cityscapes_1024x1024.py',
4+
'../../../configs/_base_/default_runtime.py',
5+
'../../../configs/_base_/schedules/schedule_160k.py'
6+
]
7+
8+
custom_imports = dict(imports=['projects.dest.models'])
9+
10+
optimizer = dict(
11+
_delete_=True,
12+
type='AdamW',
13+
lr=0.00006,
14+
betas=(0.9, 0.999),
15+
weight_decay=0.01,
16+
paramwise_cfg=dict(
17+
custom_keys={
18+
'pos_block': dict(decay_mult=0.),
19+
'norm': dict(decay_mult=0.),
20+
'head': dict(lr_mult=10.)
21+
}))
22+
23+
lr_config = dict(
24+
_delete_=True,
25+
policy='poly',
26+
warmup='linear',
27+
warmup_iters=1500,
28+
warmup_ratio=1e-6,
29+
power=1.0,
30+
min_lr=0.0,
31+
by_epoch=False)
32+
33+
data = dict(samples_per_gpu=1, workers_per_gpu=1)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
2+
3+
embed_dims = [64, 128, 250, 320]
4+
5+
model = dict(
6+
type='EncoderDecoder',
7+
pretrained=None,
8+
backbone=dict(embed_dims=embed_dims),
9+
decode_head=dict(in_channels=embed_dims, channels=64))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
2+
3+
embed_dims = [64, 128, 250, 320]
4+
5+
model = dict(
6+
type='EncoderDecoder',
7+
pretrained=None,
8+
backbone=dict(embed_dims=embed_dims, num_layers=[3, 3, 6, 3]),
9+
decode_head=dict(in_channels=embed_dims, channels=64))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
2+
3+
embed_dims = [64, 128, 250, 320]
4+
5+
optimizer = dict(
6+
_delete_=True,
7+
type='AdamW',
8+
lr=0.00006,
9+
betas=(0.9, 0.999),
10+
weight_decay=0.01,
11+
paramwise_cfg=dict(
12+
custom_keys={
13+
'pos_block': dict(decay_mult=0.),
14+
'norm': dict(decay_mult=0.),
15+
'head': dict(lr_mult=1.)
16+
}))
17+
18+
model = dict(
19+
type='EncoderDecoder',
20+
pretrained=None,
21+
backbone=dict(embed_dims=embed_dims, num_layers=[3, 6, 8, 3]),
22+
decode_head=dict(in_channels=embed_dims, channels=64))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']
2+
3+
embed_dims = [64, 128, 250, 320]
4+
5+
optimizer = dict(
6+
_delete_=True,
7+
type='AdamW',
8+
lr=0.00006,
9+
betas=(0.9, 0.999),
10+
weight_decay=0.01,
11+
paramwise_cfg=dict(
12+
custom_keys={
13+
'pos_block': dict(decay_mult=0.),
14+
'norm': dict(decay_mult=0.),
15+
'head': dict(lr_mult=1.)
16+
}))
17+
18+
model = dict(
19+
type='EncoderDecoder',
20+
pretrained=None,
21+
backbone=dict(embed_dims=embed_dims, num_layers=[3, 8, 12, 5]),
22+
decode_head=dict(in_channels=embed_dims, channels=64))

0 commit comments

Comments
 (0)