|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). |
| 3 | +
|
| 4 | +requirement: pip install grad-cam |
| 5 | +""" |
| 6 | + |
| 7 | +from argparse import ArgumentParser |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +import torch.nn.functional as F |
| 12 | +from mmengine.model import revert_sync_batchnorm |
| 13 | +from PIL import Image |
| 14 | +from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM |
| 15 | +from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image |
| 16 | + |
| 17 | +from mmengine import Config |
| 18 | +from mmseg.apis import inference_model, init_model, show_result_pyplot |
| 19 | +from mmseg.utils import register_all_modules |
| 20 | + |
| 21 | + |
| 22 | +class SemanticSegmentationTarget: |
| 23 | + """wrap the model. |
| 24 | +
|
| 25 | + requirement: pip install grad-cam |
| 26 | +
|
| 27 | + Args: |
| 28 | + category (int): Visualization class. |
| 29 | + mask (ndarray): Mask of class. |
| 30 | + size (tuple): Image size. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, category, mask, size): |
| 34 | + self.category = category |
| 35 | + self.mask = torch.from_numpy(mask) |
| 36 | + self.size = size |
| 37 | + if torch.cuda.is_available(): |
| 38 | + self.mask = self.mask.cuda() |
| 39 | + |
| 40 | + def __call__(self, model_output): |
| 41 | + model_output = torch.unsqueeze(model_output, dim=0) |
| 42 | + model_output = F.interpolate( |
| 43 | + model_output, size=self.size, mode='bilinear') |
| 44 | + model_output = torch.squeeze(model_output, dim=0) |
| 45 | + |
| 46 | + return (model_output[self.category, :, :] * self.mask).sum() |
| 47 | + |
| 48 | + |
| 49 | +def main(): |
| 50 | + parser = ArgumentParser() |
| 51 | + parser.add_argument('img', help='Image file') |
| 52 | + parser.add_argument('config', help='Config file') |
| 53 | + parser.add_argument('checkpoint', help='Checkpoint file') |
| 54 | + parser.add_argument( |
| 55 | + '--out-file', |
| 56 | + default='prediction.png', |
| 57 | + help='Path to output prediction file') |
| 58 | + parser.add_argument( |
| 59 | + '--cam-file', |
| 60 | + default='vis_cam.png', |
| 61 | + help='Path to output cam file') |
| 62 | + parser.add_argument( |
| 63 | + '--target-layers', |
| 64 | + default='backbone.layer4[2]', |
| 65 | + help='Target layers to visualize CAM') |
| 66 | + parser.add_argument( |
| 67 | + '--category-index', |
| 68 | + default='7', |
| 69 | + help='Category to visualize CAM') |
| 70 | + parser.add_argument( |
| 71 | + '--device', |
| 72 | + default='cuda:0', |
| 73 | + help='Device used for inference') |
| 74 | + args = parser.parse_args() |
| 75 | + |
| 76 | + # build the model from a config file and a checkpoint file |
| 77 | + register_all_modules() |
| 78 | + model = init_model(args.config, args.checkpoint, device=args.device) |
| 79 | + if args.device == 'cpu': |
| 80 | + model = revert_sync_batchnorm(model) |
| 81 | + |
| 82 | + # test a single image |
| 83 | + result = inference_model(model, args.img) |
| 84 | + |
| 85 | + # show the results |
| 86 | + show_result_pyplot( |
| 87 | + model, |
| 88 | + args.img, |
| 89 | + result, |
| 90 | + draw_gt=False, |
| 91 | + show=False if args.out_file is not None else True, |
| 92 | + out_file=args.out_file) |
| 93 | + |
| 94 | + # result data conversion |
| 95 | + prediction_data = result.pred_sem_seg.data |
| 96 | + pre_np_data = prediction_data.cpu().numpy().squeeze(0) |
| 97 | + |
| 98 | + target_layers = args.target_layers |
| 99 | + target_layers = [eval(f'model.{target_layers}')] |
| 100 | + |
| 101 | + category = int(args.category_index) |
| 102 | + mask_float = np.float32(pre_np_data == category) |
| 103 | + |
| 104 | + # data processing |
| 105 | + image = np.array(Image.open(args.img).convert('RGB')) |
| 106 | + height, width = image.shape[0], image.shape[1] |
| 107 | + rgb_img = np.float32(image) / 255 |
| 108 | + config = Config.fromfile(args.config) |
| 109 | + image_mean = config.data_preprocessor['mean'] |
| 110 | + image_std = config.data_preprocessor['std'] |
| 111 | + input_tensor = preprocess_image( |
| 112 | + rgb_img, |
| 113 | + mean=[x / 255 for x in image_mean], |
| 114 | + std=[x / 255 for x in image_std]) |
| 115 | + |
| 116 | + # Grad CAM(Class Activation Maps) |
| 117 | + # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM |
| 118 | + targets = [ |
| 119 | + SemanticSegmentationTarget(category, mask_float, |
| 120 | + (height, width)) |
| 121 | + ] |
| 122 | + with GradCAM( |
| 123 | + model=model, |
| 124 | + target_layers=target_layers, |
| 125 | + use_cuda=torch.cuda.is_available()) as cam: |
| 126 | + grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] |
| 127 | + cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) |
| 128 | + |
| 129 | + # save cam file |
| 130 | + Image.fromarray(cam_image).save(args.cam_file) |
| 131 | + |
| 132 | + |
| 133 | +if __name__ == '__main__': |
| 134 | + main() |
0 commit comments