Skip to content

Commit 9c45a94

Browse files
authored
[Fix] fix import error raised by ldm (open-mmlab#3338)
1 parent 56a40d7 commit 9c45a94

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

mmseg/models/backbones/vpd.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
import torch
1111
import torch.nn as nn
1212
import torch.nn.functional as F
13-
from ldm.modules.diffusionmodules.util import timestep_embedding
14-
from ldm.util import instantiate_from_config
1513
from mmengine.model import BaseModule
1614
from mmengine.runner import CheckpointLoader, load_checkpoint
1715

1816
from mmseg.registry import MODELS
1917
from mmseg.utils import ConfigType, OptConfigType
2018

19+
try:
20+
from ldm.modules.diffusionmodules.util import timestep_embedding
21+
from ldm.util import instantiate_from_config
22+
has_ldm = True
23+
except ImportError:
24+
has_ldm = False
25+
2126

2227
def register_attention_control(model, controller):
2328
"""Registers a control function to manage attention within a model.
@@ -205,6 +210,10 @@ def __init__(self,
205210
max_attn_size=None,
206211
attn_selector='up_cross+down_cross'):
207212
super().__init__()
213+
214+
assert has_ldm, 'To use UNetWrapper, please install required ' \
215+
'packages via `pip install -r requirements/optional.txt`.'
216+
208217
self.unet = unet
209218
self.attention_store = AttentionStore(
210219
base_size=base_size // 8, max_size=max_attn_size)
@@ -321,6 +330,9 @@ def __init__(self,
321330

322331
super().__init__(init_cfg=init_cfg)
323332

333+
assert has_ldm, 'To use VPD model, please install required packages' \
334+
' via `pip install -r requirements/optional.txt`.'
335+
324336
if pad_shape is not None:
325337
if not isinstance(pad_shape, (list, tuple)):
326338
pad_shape = (pad_shape, pad_shape)

tools/analysis_tools/visualization_cam.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import numpy as np
1010
import torch
1111
import torch.nn.functional as F
12+
from mmengine import Config
1213
from mmengine.model import revert_sync_batchnorm
1314
from PIL import Image
14-
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
15+
from pytorch_grad_cam import GradCAM
1516
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
1617

17-
from mmengine import Config
1818
from mmseg.apis import inference_model, init_model, show_result_pyplot
1919
from mmseg.utils import register_all_modules
2020

@@ -56,21 +56,15 @@ def main():
5656
default='prediction.png',
5757
help='Path to output prediction file')
5858
parser.add_argument(
59-
'--cam-file',
60-
default='vis_cam.png',
61-
help='Path to output cam file')
59+
'--cam-file', default='vis_cam.png', help='Path to output cam file')
6260
parser.add_argument(
6361
'--target-layers',
6462
default='backbone.layer4[2]',
6563
help='Target layers to visualize CAM')
6664
parser.add_argument(
67-
'--category-index',
68-
default='7',
69-
help='Category to visualize CAM')
65+
'--category-index', default='7', help='Category to visualize CAM')
7066
parser.add_argument(
71-
'--device',
72-
default='cuda:0',
73-
help='Device used for inference')
67+
'--device', default='cuda:0', help='Device used for inference')
7468
args = parser.parse_args()
7569

7670
# build the model from a config file and a checkpoint file
@@ -116,8 +110,7 @@ def main():
116110
# Grad CAM(Class Activation Maps)
117111
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
118112
targets = [
119-
SemanticSegmentationTarget(category, mask_float,
120-
(height, width))
113+
SemanticSegmentationTarget(category, mask_float, (height, width))
121114
]
122115
with GradCAM(
123116
model=model,

0 commit comments

Comments
 (0)