Skip to content

Commit 7933c79

Browse files
authored
Merge pull request utkuozbulak#102 from PengtaoJiang/master
LayerCam added.
2 parents 16eddfa + b176231 commit 7933c79

14 files changed

+144
-0
lines changed

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This repository contains a number of convolutional neural network visualization
1414
* [Gradient-weighted class activation mapping](#gradient-visualization) [3] (Generalization of [2])
1515
* [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3]
1616
* [Score-weighted class activation mapping](#gradient-visualization) [15] (Gradient-free generalization of [2])
17+
* [Element-wise gradient-weighted class activation mapping](#hierarchical-gradient-visualization) [16] (Visualization of any CNN layer)
1718
* [Smooth grad](#smooth-grad) [8]
1819
* [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]
1920
* [Inverted image representations](#inverted-image-representations) [5]
@@ -163,6 +164,45 @@ If you find the code in this repository useful for your research consider citing
163164
</tbody>
164165
</table>
165166

167+
## Hierarchical Gradient Visualization
168+
**Element-wise gradient-weighted class activation mapping**: LayerCAM [16] is a simple modification of Grad-CAM [3], which can generate reliable class activation maps from different layers. For this example I used a pre-trained **VGG16**.
169+
170+
<table border=0 width="50px" >
171+
<tbody>
172+
<tr>
173+
<td> </td>
174+
<td align="center"> Class Activation Map </td>
175+
<td align="center"> Class Activation HeatMap </td>
176+
<td align="center"> Class Activation HeatMap on Image</td>
177+
</tr>
178+
<tr>
179+
<td width="19%" align="center"> LayerCAM <br /> (Layer 9)</td>
180+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool2_Grayscale.png"> </td>
181+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool2_Heatmap.png"> </td>
182+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool2_On_Image.png"> </td>
183+
</tr>
184+
<tr>
185+
<td width="19%" align="center"> LayerCAM <br /> (Layer 16)</td>
186+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool3_Grayscale.png"> </td>
187+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool3_Heatmap.png"> </td>
188+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool3_On_Image.png"> </td>
189+
</tr>
190+
<tr>
191+
<td width="19%" align="center"> LayerCAM <br /> (Layer 23)</td>
192+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool4_Grayscale.png"> </td>
193+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool4_Heatmap.png"> </td>
194+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool4_On_Image.png"> </td>
195+
</tr>
196+
<tr>
197+
<td width="19%" align="center"> LayerCAM <br /> (Layer 30)</td>
198+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool5_Grayscale.png"> </td>
199+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool5_Heatmap.png"> </td>
200+
<td width="27%" align="center"> <img src="results/hierarchical_gradient_visualization/snake_LayerCam_pool5_On_Image.png"> </td>
201+
</tr>
202+
</tbody>
203+
</table>
204+
205+
166206
## Grad Times Image
167207
Another technique that is proposed is simply multiplying the gradients with the image itself. Results obtained with the usage of multiple gradient techniques are below.
168208

@@ -401,3 +441,5 @@ PIL >= 1.1.7
401441
[14] J. Yosinski, J. Clune, A. Nguyen, T. Fuchs, Hod Lipson, *Understanding Neural Networks Through Deep Visualization* https://arxiv.org/abs/1506.06579
402442

403443
[15] H. Wang, Z. Wang, M. Du, F. Yang, Z. Zhang, S. Ding, P. Mardziel, X. Hu. *Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks* https://arxiv.org/abs/1910.01279
444+
445+
[16] P. Jiang, C. Zhang, Q. Hou, M. Cheng, Y. Wei. LayerCAM: *Exploring Hierarchical Class Activation Maps for Localization* http://mmcheng.net/mftp/Papers/21TIP_LayerCAM.pdf
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

src/layercam.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Created on Mon Jul 5 12:39:11 2021
3+
4+
@author: Peng-Tao Jiang - github.com/PengtaoJiang
5+
"""
6+
from PIL import Image
7+
import numpy as np
8+
import torch
9+
10+
from misc_functions import get_example_params, save_class_activation_images
11+
12+
13+
class CamExtractor():
14+
"""
15+
Extracts cam features from the model
16+
"""
17+
def __init__(self, model, target_layer):
18+
self.model = model
19+
self.target_layer = target_layer
20+
self.gradients = None
21+
22+
def save_gradient(self, grad):
23+
self.gradients = grad
24+
25+
def forward_pass_on_convolutions(self, x):
26+
"""
27+
Does a forward pass on convolutions, hooks the function at given layer
28+
"""
29+
conv_output = None
30+
for module_pos, module in self.model.features._modules.items():
31+
x = module(x) # Forward
32+
if int(module_pos) == self.target_layer:
33+
x.register_hook(self.save_gradient)
34+
conv_output = x # Save the convolution output on that layer
35+
return conv_output, x
36+
37+
def forward_pass(self, x):
38+
"""
39+
Does a full forward pass on the model
40+
"""
41+
# Forward pass on the convolutions
42+
conv_output, x = self.forward_pass_on_convolutions(x)
43+
x = x.view(x.size(0), -1) # Flatten
44+
# Forward pass on the classifier
45+
x = self.model.classifier(x)
46+
return conv_output, x
47+
48+
49+
class LayerCam():
50+
"""
51+
Produces class activation map
52+
"""
53+
def __init__(self, model, target_layer):
54+
self.model = model
55+
self.model.eval()
56+
# Define extractor
57+
self.extractor = CamExtractor(self.model, target_layer)
58+
59+
def generate_cam(self, input_image, target_class=None):
60+
# Full forward pass
61+
# conv_output is the output of convolutions at specified layer
62+
# model_output is the final output of the model (1, 1000)
63+
conv_output, model_output = self.extractor.forward_pass(input_image)
64+
if target_class is None:
65+
target_class = np.argmax(model_output.data.numpy())
66+
# Target for backprop
67+
one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
68+
one_hot_output[0][target_class] = 1
69+
# Zero grads
70+
self.model.features.zero_grad()
71+
self.model.classifier.zero_grad()
72+
# Backward pass with specified target
73+
model_output.backward(gradient=one_hot_output, retain_graph=True)
74+
# Get hooked gradients
75+
guided_gradients = self.extractor.gradients.data.numpy()[0]
76+
# Get convolution outputs
77+
target = conv_output.data.numpy()[0]
78+
# Get weights from gradients
79+
weights = guided_gradients
80+
weights[weights < 0] = 0 # discard negative gradients
81+
# Element-wise multiply the weight with its conv output and then, sum
82+
cam = np.sum(weights * target, axis=0)
83+
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize between 0-1
84+
cam = np.uint8(cam * 255) # Scale between 0-255 to visualize
85+
cam = np.uint8(Image.fromarray(cam).resize((input_image.shape[2],
86+
input_image.shape[3]), Image.ANTIALIAS))/255
87+
88+
return cam
89+
90+
91+
if __name__ == '__main__':
92+
# Get params
93+
target_example = 0 # Snake
94+
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\
95+
get_example_params(target_example)
96+
# Layer cam
97+
layer_cam = LayerCam(pretrained_model, target_layer=9)
98+
# Generate cam mask
99+
cam = layer_cam.generate_cam(prep_img, target_class)
100+
# Save mask
101+
save_class_activation_images(original_image, cam, file_name_to_export)
102+
print('Layer cam completed')

0 commit comments

Comments
 (0)