Skip to content

Commit 56a40d7

Browse files
authored
Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM) (open-mmlab#3324)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). ## Modification Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). requirement: pip install grad-cam run commad: python tools/analysis_tools/visualization_cam.py ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. The documentation has been modified accordingly, like docstring or example tutorials.
1 parent 743171d commit 56a40d7

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
3+
4+
requirement: pip install grad-cam
5+
"""
6+
7+
from argparse import ArgumentParser
8+
9+
import numpy as np
10+
import torch
11+
import torch.nn.functional as F
12+
from mmengine.model import revert_sync_batchnorm
13+
from PIL import Image
14+
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
15+
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
16+
17+
from mmengine import Config
18+
from mmseg.apis import inference_model, init_model, show_result_pyplot
19+
from mmseg.utils import register_all_modules
20+
21+
22+
class SemanticSegmentationTarget:
23+
"""wrap the model.
24+
25+
requirement: pip install grad-cam
26+
27+
Args:
28+
category (int): Visualization class.
29+
mask (ndarray): Mask of class.
30+
size (tuple): Image size.
31+
"""
32+
33+
def __init__(self, category, mask, size):
34+
self.category = category
35+
self.mask = torch.from_numpy(mask)
36+
self.size = size
37+
if torch.cuda.is_available():
38+
self.mask = self.mask.cuda()
39+
40+
def __call__(self, model_output):
41+
model_output = torch.unsqueeze(model_output, dim=0)
42+
model_output = F.interpolate(
43+
model_output, size=self.size, mode='bilinear')
44+
model_output = torch.squeeze(model_output, dim=0)
45+
46+
return (model_output[self.category, :, :] * self.mask).sum()
47+
48+
49+
def main():
50+
parser = ArgumentParser()
51+
parser.add_argument('img', help='Image file')
52+
parser.add_argument('config', help='Config file')
53+
parser.add_argument('checkpoint', help='Checkpoint file')
54+
parser.add_argument(
55+
'--out-file',
56+
default='prediction.png',
57+
help='Path to output prediction file')
58+
parser.add_argument(
59+
'--cam-file',
60+
default='vis_cam.png',
61+
help='Path to output cam file')
62+
parser.add_argument(
63+
'--target-layers',
64+
default='backbone.layer4[2]',
65+
help='Target layers to visualize CAM')
66+
parser.add_argument(
67+
'--category-index',
68+
default='7',
69+
help='Category to visualize CAM')
70+
parser.add_argument(
71+
'--device',
72+
default='cuda:0',
73+
help='Device used for inference')
74+
args = parser.parse_args()
75+
76+
# build the model from a config file and a checkpoint file
77+
register_all_modules()
78+
model = init_model(args.config, args.checkpoint, device=args.device)
79+
if args.device == 'cpu':
80+
model = revert_sync_batchnorm(model)
81+
82+
# test a single image
83+
result = inference_model(model, args.img)
84+
85+
# show the results
86+
show_result_pyplot(
87+
model,
88+
args.img,
89+
result,
90+
draw_gt=False,
91+
show=False if args.out_file is not None else True,
92+
out_file=args.out_file)
93+
94+
# result data conversion
95+
prediction_data = result.pred_sem_seg.data
96+
pre_np_data = prediction_data.cpu().numpy().squeeze(0)
97+
98+
target_layers = args.target_layers
99+
target_layers = [eval(f'model.{target_layers}')]
100+
101+
category = int(args.category_index)
102+
mask_float = np.float32(pre_np_data == category)
103+
104+
# data processing
105+
image = np.array(Image.open(args.img).convert('RGB'))
106+
height, width = image.shape[0], image.shape[1]
107+
rgb_img = np.float32(image) / 255
108+
config = Config.fromfile(args.config)
109+
image_mean = config.data_preprocessor['mean']
110+
image_std = config.data_preprocessor['std']
111+
input_tensor = preprocess_image(
112+
rgb_img,
113+
mean=[x / 255 for x in image_mean],
114+
std=[x / 255 for x in image_std])
115+
116+
# Grad CAM(Class Activation Maps)
117+
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
118+
targets = [
119+
SemanticSegmentationTarget(category, mask_float,
120+
(height, width))
121+
]
122+
with GradCAM(
123+
model=model,
124+
target_layers=target_layers,
125+
use_cuda=torch.cuda.is_available()) as cam:
126+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
127+
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
128+
129+
# save cam file
130+
Image.fromarray(cam_image).save(args.cam_file)
131+
132+
133+
if __name__ == '__main__':
134+
main()

0 commit comments

Comments
 (0)