Skip to content

Commit b6090a1

Browse files
[CodeCamp2023-608] Add Adabins model (open-mmlab#3257)
1 parent c46cc85 commit b6090a1

File tree

11 files changed

+483
-0
lines changed

11 files changed

+483
-0
lines changed

demo/classroom__rgb_00283.jpg

49.1 KB
Loading

projects/Adabins/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# AdaBins: Depth Estimation Using Adaptive Bins
2+
3+
## Reference
4+
5+
> [AdaBins: Depth Estimation Using Adaptive Bins](https://arxiv.org/abs/2011.14141)
6+
7+
## Introduction
8+
9+
<a href="https://github.com/shariqfarooq123/AdaBins">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/Adabins">Code Snippet</a>
12+
13+
## <img src="https://user-images.githubusercontent.com/34859558/190043857-bfbdaf8b-d2dc-4fff-81c7-e0aac50851f9.png" width="25"/> Abstract
14+
15+
We address the problem of estimating a high quality dense depth map from a single RGB input image. We start out with a baseline encoder-decoder convolutional neural network architecture and pose the question of how the global processing of information can help improve overall depth estimation. To this end, we propose a transformer-based architecture block that divides the depth range into bins whose center value is estimated adaptively per image. The final depth values are estimated as linear combinations of the bin centers. We call our new building block AdaBins. Our results show a decisive improvement over the state-of-the-art on several popular depth datasets across all metrics.We also validate the effectiveness of the proposed block with an ablation study and provide the code and corresponding pre-trained weights of the new state-of-the-art model.
16+
17+
Our main contributions are the following:
18+
19+
- We propose an architecture building block that performs global processing of the scene’s information.We propose to divide the predicted depth range into bins where the bin widths change per image. The final depth estimation is a linear combination of the bin center values.
20+
- We show a decisive improvement for supervised single image depth estimation across all metrics for the two most popular datasets, NYU and KITTI.
21+
- We analyze our findings and investigate different modifications on the proposed AdaBins block and study their effect on the accuracy of the depth estimation.
22+
23+
<div align="center">
24+
<img src="https://github.com/open-mmlab/mmsegmentation/assets/15952744/915bcd5a-9dc2-4602-a6e7-055ff5d4889f" width = "1000" />
25+
</div>
26+
27+
## <img src="https://user-images.githubusercontent.com/34859558/190044217-8f6befc2-7f20-473d-b356-148e06265205.png" width="25"/> Performance
28+
29+
### NYU and KITTI
30+
31+
| Model | Encoder | Training epoch | Batchsize | Train Resolution | δ1 | δ2 | δ3 | REL | RMS | RMS log | params(M) | Links |
32+
| ------------- | --------------- | -------------- | --------- | ---------------- | ----- | ----- | ----- | ----- | ----- | ------- | --------- | ----------------------------------------------------------------------------------------------------------------------- |
33+
| AdaBins_nyu | EfficientNet-B5 | 25 | 16 | 416x544 | 0.903 | 0.984 | 0.997 | 0.103 | 0.364 | 0.044 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient_b5_nyu_third-party-f68d6bd3.pth) |
34+
| AdaBins_kitti | EfficientNet-B5 | 25 | 16 | 352x764 | 0.964 | 0.995 | 0.999 | 0.058 | 2.360 | 0.088 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient-b5_kitty_third-party-a1aa6f36.pth) |
35+
36+
## Citation
37+
38+
```bibtex
39+
@article{10.1109/cvpr46437.2021.00400,
40+
author = {Bhat, S. A. and Alhashim, I. and Wonka, P.},
41+
title = {Adabins: depth estimation using adaptive bins},
42+
journal = {2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
43+
year = {2021},
44+
doi = {10.1109/cvpr46437.2021.00400}
45+
}
46+
```
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .adabins_backbone import AdabinsBackbone
3+
4+
__all__ = ['AdabinsBackbone']
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import timm
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from mmcv.cnn import ConvModule, build_conv_layer
6+
from mmengine.model import BaseModule
7+
8+
from mmseg.registry import MODELS
9+
10+
11+
class UpSampleBN(nn.Module):
12+
""" UpSample module
13+
Args:
14+
skip_input (int): the input feature
15+
output_features (int): the output feature
16+
norm_cfg (dict, optional): Config dict for normalization layer.
17+
Default: dict(type='BN', requires_grad=True).
18+
act_cfg (dict, optional): The activation layer of AAM:
19+
Aggregate Attention Module.
20+
"""
21+
22+
def __init__(self,
23+
skip_input,
24+
output_features,
25+
norm_cfg=dict(type='BN'),
26+
act_cfg=dict(type='LeakyReLU')):
27+
super().__init__()
28+
29+
self._net = nn.Sequential(
30+
ConvModule(
31+
in_channels=skip_input,
32+
out_channels=output_features,
33+
kernel_size=3,
34+
stride=1,
35+
padding=1,
36+
bias=True,
37+
norm_cfg=norm_cfg,
38+
act_cfg=act_cfg,
39+
),
40+
ConvModule(
41+
in_channels=output_features,
42+
out_channels=output_features,
43+
kernel_size=3,
44+
stride=1,
45+
padding=1,
46+
bias=True,
47+
norm_cfg=norm_cfg,
48+
act_cfg=act_cfg,
49+
))
50+
51+
def forward(self, x, concat_with):
52+
up_x = F.interpolate(
53+
x,
54+
size=[concat_with.size(2),
55+
concat_with.size(3)],
56+
mode='bilinear',
57+
align_corners=True)
58+
f = torch.cat([up_x, concat_with], dim=1)
59+
return self._net(f)
60+
61+
62+
class Encoder(nn.Module):
63+
""" the efficientnet_b5 model
64+
Args:
65+
basemodel_name (str): the name of base model
66+
"""
67+
68+
def __init__(self, basemodel_name):
69+
super().__init__()
70+
self.original_model = timm.create_model(
71+
basemodel_name, pretrained=True)
72+
# Remove last layer
73+
self.original_model.global_pool = nn.Identity()
74+
self.original_model.classifier = nn.Identity()
75+
76+
def forward(self, x):
77+
features = [x]
78+
for k, v in self.original_model._modules.items():
79+
if k == 'blocks':
80+
for ki, vi in v._modules.items():
81+
features.append(vi(features[-1]))
82+
else:
83+
features.append(v(features[-1]))
84+
return features
85+
86+
87+
@MODELS.register_module()
88+
class AdabinsBackbone(BaseModule):
89+
""" the backbone of the adabins
90+
Args:
91+
basemodel_name (str):the name of base model
92+
num_features (int): the middle feature
93+
num_classes (int): the classes number
94+
bottleneck_features (int): the bottleneck features
95+
conv_cfg (dict): Config dict for convolution layer.
96+
"""
97+
98+
def __init__(self,
99+
basemodel_name,
100+
num_features=2048,
101+
num_classes=128,
102+
bottleneck_features=2048,
103+
conv_cfg=dict(type='Conv')):
104+
super().__init__()
105+
self.encoder = Encoder(basemodel_name)
106+
features = int(num_features)
107+
self.conv2 = build_conv_layer(
108+
conv_cfg,
109+
bottleneck_features,
110+
features,
111+
kernel_size=1,
112+
stride=1,
113+
padding=1)
114+
self.up1 = UpSampleBN(
115+
skip_input=features // 1 + 112 + 64, output_features=features // 2)
116+
self.up2 = UpSampleBN(
117+
skip_input=features // 2 + 40 + 24, output_features=features // 4)
118+
self.up3 = UpSampleBN(
119+
skip_input=features // 4 + 24 + 16, output_features=features // 8)
120+
self.up4 = UpSampleBN(
121+
skip_input=features // 8 + 16 + 8, output_features=features // 16)
122+
123+
self.conv3 = build_conv_layer(
124+
conv_cfg,
125+
features // 16,
126+
num_classes,
127+
kernel_size=3,
128+
stride=1,
129+
padding=1)
130+
131+
def forward(self, x):
132+
features = self.encoder(x)
133+
x_block0, x_block1, x_block2, x_block3, x_block4 = features[
134+
3], features[4], features[5], features[7], features[10]
135+
x_d0 = self.conv2(x_block4)
136+
x_d1 = self.up1(x_d0, x_block3)
137+
x_d2 = self.up2(x_d1, x_block2)
138+
x_d3 = self.up3(x_d2, x_block1)
139+
x_d4 = self.up4(x_d3, x_block0)
140+
out = self.conv3(x_d4)
141+
return out
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
dataset_type = 'NYUDataset'
2+
data_root = 'data/nyu'
3+
4+
test_pipeline = [
5+
dict(dict(type='LoadImageFromFile', to_float32=True)),
6+
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
7+
dict(
8+
type='PackSegInputs',
9+
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
10+
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
11+
'category_id'))
12+
]
13+
14+
val_dataloader = dict(
15+
batch_size=1,
16+
num_workers=4,
17+
persistent_workers=True,
18+
sampler=dict(type='DefaultSampler', shuffle=False),
19+
dataset=dict(
20+
type=dataset_type,
21+
data_root=data_root,
22+
test_mode=True,
23+
data_prefix=dict(
24+
img_path='images/test', depth_map_path='annotations/test'),
25+
pipeline=test_pipeline))
26+
test_dataloader = val_dataloader
27+
28+
val_evaluator = dict(
29+
type='DepthMetric', max_depth_eval=10.0, crop_type='nyu_crop')
30+
test_evaluator = val_evaluator
31+
val_cfg = dict(type='ValLoop')
32+
test_cfg = dict(type='TestLoop')
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
default_scope = 'mmseg'
2+
env_cfg = dict(
3+
cudnn_benchmark=True,
4+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
5+
dist_cfg=dict(backend='nccl'),
6+
)
7+
vis_backends = [dict(type='LocalVisBackend')]
8+
visualizer = dict(
9+
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
10+
log_processor = dict(by_epoch=False)
11+
log_level = 'INFO'
12+
load_from = None
13+
resume = False
14+
15+
tta_model = dict(type='SegTTAModel')
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
data_preprocessor = dict(
4+
type='SegDataPreProcessor',
5+
mean=[123.675, 116.28, 103.53],
6+
std=[58.395, 57.12, 57.375],
7+
bgr_to_rgb=True,
8+
pad_val=0,
9+
seg_pad_val=255)
10+
model = dict(
11+
type='DepthEstimator',
12+
data_preprocessor=data_preprocessor,
13+
# pretrained='open-mmlab://resnet50_v1c',
14+
backbone=dict(
15+
type='AdabinsBackbone',
16+
basemodel_name='tf_efficientnet_b5_ap',
17+
num_features=2048,
18+
num_classes=128,
19+
bottleneck_features=2048,
20+
),
21+
decode_head=dict(
22+
type='AdabinsHead',
23+
in_channels=128,
24+
n_query_channels=128,
25+
patch_size=16,
26+
embedding_dim=128,
27+
num_heads=4,
28+
n_bins=256,
29+
min_val=0.001,
30+
max_val=10,
31+
norm='linear'),
32+
33+
# model training and testing settings
34+
train_cfg=dict(),
35+
test_cfg=dict(mode='whole'))
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
_base_ = [
2+
'../_base_/models/Adabins.py', '../_base_/datasets/nyu.py',
3+
'../_base_/default_runtime.py'
4+
]
5+
custom_imports = dict(
6+
imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
7+
allow_failed_imports=False)
8+
crop_size = (416, 544)
9+
data_preprocessor = dict(size=crop_size)
10+
norm_cfg = dict(type='SyncBN', requires_grad=True)
11+
model = dict(
12+
data_preprocessor=data_preprocessor,
13+
backbone=dict(),
14+
decode_head=dict(),
15+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = ['../_base_/models/Adabins.py']
2+
custom_imports = dict(
3+
imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
4+
allow_failed_imports=False)
5+
crop_size = (352, 704)
6+
data_preprocessor = dict(size=crop_size)
7+
norm_cfg = dict(type='SyncBN', requires_grad=True)
8+
model = dict(
9+
data_preprocessor=data_preprocessor,
10+
backbone=dict(),
11+
decode_head=dict(min_val=0.001, max_val=80),
12+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .adabins_head import AdabinsHead
3+
4+
__all__ = ['AdabinsHead']

0 commit comments

Comments
 (0)