Skip to content

Commit 743171d

Browse files
authored
[Feature] Support inference and visualization of VPD (open-mmlab#3331)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support inference and visualization of VPD ## Modification 1. add a new VPD model that does not generate black border in predictions 2. update `SegLocalVisualizer` to support depth visualization 3. update `MMSegInferencer` to support save predictions of depth estimation in method `postprocess` ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) Run inference with VPD using the this command ```sh python demo/image_demo_with_inferencer.py demo/classroom__rgb_00283.jpg vpd_depth --out-dir vis_results ``` The following image will be saved under `vis_results/vis` ![classroom__rgb_00283](https://github.com/open-mmlab/mmsegmentation/assets/26127467/051e8c4b-8f92-495f-8c3e-f249aac888e3) ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 4. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 5. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 6. The documentation has been modified accordingly, like docstring or example tutorials.
1 parent f1fa61a commit 743171d

File tree

15 files changed

+366
-36
lines changed

15 files changed

+366
-36
lines changed

.readthedocs.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
version: 2
22

3+
build:
4+
os: ubuntu-22.04
5+
tools:
6+
python: "3.7"
7+
38
formats:
49
- epub
510

611
python:
7-
version: 3.7
812
install:
913
- requirements: requirements/docs.txt
1014
- requirements: requirements/readthedocs.txt

configs/_base_/datasets/nyu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
test_pipeline = [
2626
dict(type='LoadImageFromFile'),
27+
dict(type='Resize', scale=(2000, 480), keep_ratio=True),
2728
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
2829
dict(
2930
type='PackSegInputs',
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# dataset settings
2+
dataset_type = 'NYUDataset'
3+
data_root = 'data/nyu'
4+
5+
train_pipeline = [
6+
dict(type='LoadImageFromFile'),
7+
dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3),
8+
dict(type='RandomDepthMix', prob=0.25),
9+
dict(type='RandomFlip', prob=0.5),
10+
dict(
11+
type='RandomResize',
12+
scale=(768, 512),
13+
ratio_range=(0.8, 1.5),
14+
keep_ratio=True),
15+
dict(type='RandomCrop', crop_size=(512, 512)),
16+
dict(
17+
type='Albu',
18+
transforms=[
19+
dict(type='RandomBrightnessContrast'),
20+
dict(type='RandomGamma'),
21+
dict(type='HueSaturationValue'),
22+
]),
23+
dict(
24+
type='PackSegInputs',
25+
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
26+
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
27+
'category_id')),
28+
]
29+
30+
test_pipeline = [
31+
dict(type='LoadImageFromFile'),
32+
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
33+
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
34+
dict(
35+
type='PackSegInputs',
36+
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
37+
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
38+
'category_id'))
39+
]
40+
41+
train_dataloader = dict(
42+
batch_size=8,
43+
num_workers=8,
44+
persistent_workers=True,
45+
sampler=dict(type='InfiniteSampler', shuffle=True),
46+
dataset=dict(
47+
type=dataset_type,
48+
data_root=data_root,
49+
data_prefix=dict(
50+
img_path='images/train', depth_map_path='annotations/train'),
51+
pipeline=train_pipeline))
52+
53+
val_dataloader = dict(
54+
batch_size=1,
55+
num_workers=4,
56+
persistent_workers=True,
57+
sampler=dict(type='DefaultSampler', shuffle=False),
58+
dataset=dict(
59+
type=dataset_type,
60+
data_root=data_root,
61+
test_mode=True,
62+
data_prefix=dict(
63+
img_path='images/test', depth_map_path='annotations/test'),
64+
pipeline=test_pipeline))
65+
test_dataloader = val_dataloader
66+
67+
val_evaluator = dict(
68+
type='DepthMetric',
69+
min_depth_eval=0.001,
70+
max_depth_eval=10.0,
71+
crop_type='nyu_crop')
72+
test_evaluator = val_evaluator

configs/_base_/models/vpd_sd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
std=[127.5, 127.5, 127.5],
66
bgr_to_rgb=True,
77
pad_val=0,
8-
seg_pad_val=255)
8+
seg_pad_val=0)
99

1010
# adapted from stable-diffusion/configs/stable-diffusion/v1-inference.yaml
1111
stable_diffusion_cfg = dict(

configs/vpd/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
## Introduction
66

7-
<!-- [BACKBONE] -->
7+
<!-- [ALGORITHM] -->
88

99
<a href = "https://github.com/wl-zhao/VPD">Official Repo</a>
1010

@@ -36,6 +36,7 @@ pip install -r requirements/optional.txt
3636
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | RMSE | d1 | d2 | d3 | REL | log_10 | config | download |
3737
| ------ | --------------------- | --------- | ------- | -------- | -------------- | ------ | ----- | ----- | ----- | ----- | ----- | ------ | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
3838
| VPD | Stable-Diffusion-v1-5 | 480x480 | 25000 | - | - | A100 | 0.253 | 0.964 | 0.995 | 0.999 | 0.069 | 0.030 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json) |
39+
| VPD | Stable-Diffusion-v1-5 | 512x512 | 25000 | - | - | A100 | 0.258 | 0.963 | 0.995 | 0.999 | 0.072 | 0.031 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918-60cefcff.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918.json) |
3940

4041
## Citation
4142

configs/vpd/metafile.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,25 @@ Models:
3232
URL: https://arxiv.org/abs/2112.10752
3333
Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333
3434
Framework: PyTorch
35+
- Name: vpd_sd_4xb8-25k_nyu-512x512
36+
In Collection: VPD
37+
Alias: vpd_depth
38+
Results:
39+
Task: Depth Estimation
40+
Dataset: NYU
41+
Metrics:
42+
RMSE: 0.258
43+
Config: configs/vpd/vpd_sd_4xb8-25k_nyu-512x512.py
44+
Metadata:
45+
Training Data: NYU
46+
Batch Size: 32
47+
Architecture:
48+
- Stable-Diffusion
49+
Training Resources: 8x A100 GPUS
50+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918-60cefcff.pth
51+
Training log: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918.json
52+
Paper:
53+
Title: 'High-Resolution Image Synthesis with Latent Diffusion Models'
54+
URL: https://arxiv.org/abs/2112.10752
55+
Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333
56+
Framework: PyTorch

configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
),
2424
test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(160, 160)))
2525

26-
default_hooks = dict(checkpoint=dict(save_best='rmse', rule='less'))
26+
default_hooks = dict(
27+
checkpoint=dict(save_best='rmse', rule='less', max_keep_ckpts=1))
2728

2829
# custom optimizer
2930
optim_wrapper = dict(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = [
2+
'../_base_/models/vpd_sd.py', '../_base_/datasets/nyu_512x512.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_25k.py'
4+
]
5+
6+
crop_size = (512, 512)
7+
8+
model = dict(
9+
type='DepthEstimator',
10+
data_preprocessor=dict(size=crop_size),
11+
backbone=dict(
12+
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
13+
'v0.5/vpd/nyu_class_embeddings.pth',
14+
class_embed_select=True,
15+
pad_shape=512,
16+
unet_cfg=dict(use_attn=False),
17+
),
18+
decode_head=dict(
19+
type='VPDDepthHead',
20+
in_channels=[320, 640, 1280, 1280],
21+
max_depth=10,
22+
),
23+
test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(128, 128)))
24+
25+
default_hooks = dict(
26+
checkpoint=dict(save_best='rmse', rule='less', max_keep_ckpts=1))
27+
28+
# custom optimizer
29+
optim_wrapper = dict(
30+
constructor='ForceDefaultOptimWrapperConstructor',
31+
paramwise_cfg=dict(
32+
bias_decay_mult=0,
33+
force_default_settings=True,
34+
custom_keys={
35+
'backbone.encoder_vq': dict(lr_mult=0),
36+
'backbone.unet': dict(lr_mult=0.01),
37+
}))

mmseg/apis/mmseg_inferencer.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,28 @@ def postprocess(self,
306306
results_dict['visualization'] = []
307307

308308
for i, pred in enumerate(preds):
309-
pred_data = pred.pred_sem_seg.numpy().data[0]
310-
results_dict['predictions'].append(pred_data)
309+
pred_data = dict()
310+
if 'pred_sem_seg' in pred.keys():
311+
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
312+
elif 'pred_depth_map' in pred.keys():
313+
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
314+
311315
if visualization is not None:
312316
vis = visualization[i]
313317
results_dict['visualization'].append(vis)
314318
if pred_out_dir != '':
315319
mmengine.mkdir_or_exist(pred_out_dir)
316-
img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
317-
img_path = osp.join(pred_out_dir, img_name)
318-
output = Image.fromarray(pred_data.astype(np.uint8))
319-
output.save(img_path)
320+
for key, data in pred_data.items():
321+
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
322+
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
323+
img_path = osp.join(pred_out_dir, img_name)
324+
if key == 'sem_seg':
325+
output = Image.fromarray(data.astype(np.uint8))
326+
output.save(img_path)
327+
else:
328+
np.save(img_path, data)
329+
pred_data = next(iter(pred_data.values()))
330+
results_dict['predictions'].append(pred_data)
320331
self.num_pred_imgs += 1
321332

322333
if len(results_dict['predictions']) == 1:
@@ -344,12 +355,13 @@ def preprocess(self, inputs, batch_size, **kwargs):
344355
"""
345356
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
346357
# Loading annotations is also not applicable
347-
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations')
348-
if idx != -1:
349-
del pipeline_cfg[idx]
358+
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
359+
idx = self._get_transform_idx(pipeline_cfg, transform)
360+
if idx != -1:
361+
del pipeline_cfg[idx]
362+
350363
load_img_idx = self._get_transform_idx(pipeline_cfg,
351364
'LoadImageFromFile')
352-
353365
if load_img_idx == -1:
354366
raise ValueError(
355367
'LoadImageFromFile is not found in the test pipeline')

mmseg/datasets/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
1212
PhotoMetricDistortion, RandomCrop, RandomCutOut,
1313
RandomDepthMix, RandomFlip, RandomMosaic,
14-
RandomRotate, RandomRotFlip, Rerange,
14+
RandomRotate, RandomRotFlip, Rerange, Resize,
1515
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
1616
SegRescale)
1717

@@ -26,5 +26,5 @@
2626
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
2727
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
2828
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
29-
'RandomFlip'
29+
'RandomFlip', 'Resize'
3030
]

0 commit comments

Comments
 (0)