|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# -------------------------------------------------------- |
| 4 | +# Tensorflow Faster R-CNN |
| 5 | +# Licensed under The MIT License [see LICENSE for details] |
| 6 | +# Written by Xinlei Chen, based on code from Ross Girshick |
| 7 | +# Edited by Matthew Seals |
| 8 | +# -------------------------------------------------------- |
| 9 | + |
| 10 | +""" |
| 11 | +Demo script showing detections in sample images. |
| 12 | +
|
| 13 | +See README.md for installation instructions before running. |
| 14 | +""" |
| 15 | +from __future__ import absolute_import |
| 16 | +from __future__ import division |
| 17 | +from __future__ import print_function |
| 18 | + |
| 19 | +import _init_paths |
| 20 | +from model.config import cfg |
| 21 | +from model.test import im_detect |
| 22 | +from model.nms_wrapper import nms |
| 23 | + |
| 24 | +from utils.timer import Timer |
| 25 | +import matplotlib.pyplot as plt |
| 26 | +import numpy as np |
| 27 | +import os |
| 28 | +import cv2 |
| 29 | +import argparse |
| 30 | +from matplotlib import cm |
| 31 | + |
| 32 | +from nets.vgg16 import vgg16 |
| 33 | +from nets.resnet_v1 import resnetv1 |
| 34 | + |
| 35 | +import torch |
| 36 | + |
| 37 | +CLASSES = ('__background__', |
| 38 | + 'aeroplane', 'bicycle', 'bird', 'boat', |
| 39 | + 'bottle', 'bus', 'car', 'cat', 'chair', |
| 40 | + 'cow', 'diningtable', 'dog', 'horse', |
| 41 | + 'motorbike', 'person', 'pottedplant', |
| 42 | + 'sheep', 'sofa', 'train', 'tvmonitor') |
| 43 | + |
| 44 | +NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',), 'res101': ('res101_faster_rcnn_iter_%d.pth',)} |
| 45 | +DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)} |
| 46 | + |
| 47 | +COLORS = [cm.tab10(i) for i in np.linspace(0., 1., 10)] |
| 48 | + |
| 49 | + |
| 50 | +def demo(net, image_name): |
| 51 | + """Detect object classes in an image using pre-computed object proposals.""" |
| 52 | + |
| 53 | + # Load the demo image |
| 54 | + im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name) |
| 55 | + im = cv2.imread(im_file) |
| 56 | + |
| 57 | + # Detect all object classes and regress object bounds |
| 58 | + timer = Timer() |
| 59 | + timer.tic() |
| 60 | + scores, boxes = im_detect(net, im) |
| 61 | + timer.toc() |
| 62 | + print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0])) |
| 63 | + |
| 64 | + # Visualize detections for each class |
| 65 | + thresh = 0.8 # CONF_THRESH |
| 66 | + NMS_THRESH = 0.3 |
| 67 | + |
| 68 | + im = im[:, :, (2, 1, 0)] |
| 69 | + fig, ax = plt.subplots(figsize=(12, 12)) |
| 70 | + ax.imshow(im, aspect='equal') |
| 71 | + cntr = -1 |
| 72 | + |
| 73 | + for cls_ind, cls in enumerate(CLASSES[1:]): |
| 74 | + cls_ind += 1 # because we skipped background |
| 75 | + cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] |
| 76 | + cls_scores = scores[:, cls_ind] |
| 77 | + dets = np.hstack((cls_boxes, |
| 78 | + cls_scores[:, np.newaxis])).astype(np.float32) |
| 79 | + keep = nms(torch.from_numpy(dets), NMS_THRESH) |
| 80 | + dets = dets[keep.numpy(), :] |
| 81 | + inds = np.where(dets[:, -1] >= thresh)[0] |
| 82 | + if len(inds) == 0: |
| 83 | + continue |
| 84 | + else: |
| 85 | + cntr += 1 |
| 86 | + |
| 87 | + for i in inds: |
| 88 | + bbox = dets[i, :4] |
| 89 | + score = dets[i, -1] |
| 90 | + |
| 91 | + ax.add_patch( |
| 92 | + plt.Rectangle((bbox[0], bbox[1]), |
| 93 | + bbox[2] - bbox[0], |
| 94 | + bbox[3] - bbox[1], fill=False, |
| 95 | + edgecolor=COLORS[cntr % len(COLORS)], linewidth=3.5) |
| 96 | + ) |
| 97 | + ax.text(bbox[0], bbox[1] - 2, |
| 98 | + '{:s} {:.3f}'.format(cls, score), |
| 99 | + bbox=dict(facecolor='blue', alpha=0.5), |
| 100 | + fontsize=14, color='white') |
| 101 | + |
| 102 | + ax.set_title('All detections with threshold >= {:.1f}'.format(thresh), fontsize=14) |
| 103 | + |
| 104 | + plt.axis('off') |
| 105 | + plt.tight_layout() |
| 106 | + plt.savefig('demo_' + image_name) |
| 107 | + print('Saved to `{}`'.format(os.path.join(os.getcwd(), 'demo_' + image_name))) |
| 108 | + |
| 109 | + |
| 110 | +def parse_args(): |
| 111 | + """Parse input arguments.""" |
| 112 | + parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo') |
| 113 | + parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', |
| 114 | + choices=NETS.keys(), default='res101') |
| 115 | + parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]', |
| 116 | + choices=DATASETS.keys(), default='pascal_voc_0712') |
| 117 | + args = parser.parse_args() |
| 118 | + |
| 119 | + return args |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == '__main__': |
| 123 | + cfg.TEST.HAS_RPN = True # Use RPN for proposals |
| 124 | + args = parse_args() |
| 125 | + |
| 126 | + # model path |
| 127 | + demonet = args.demo_net |
| 128 | + dataset = args.dataset |
| 129 | + saved_model = os.path.join('output', demonet, DATASETS[dataset][0], 'default', |
| 130 | + NETS[demonet][0] % (70000 if dataset == 'pascal_voc' else 110000)) |
| 131 | + |
| 132 | + if not os.path.isfile(saved_model): |
| 133 | + raise IOError(('{:s} not found.\nDid you download the proper networks from ' |
| 134 | + 'our server and place them properly?').format(saved_model)) |
| 135 | + |
| 136 | + # load network |
| 137 | + if demonet == 'vgg16': |
| 138 | + net = vgg16() |
| 139 | + elif demonet == 'res101': |
| 140 | + net = resnetv1(num_layers=101) |
| 141 | + else: |
| 142 | + raise NotImplementedError |
| 143 | + net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32]) |
| 144 | + |
| 145 | + net.load_state_dict(torch.load(saved_model)) |
| 146 | + |
| 147 | + net.eval() |
| 148 | + net.cuda() |
| 149 | + |
| 150 | + print('Loaded network {:s}'.format(saved_model)) |
| 151 | + |
| 152 | + im_names = [i for i in os.listdir('data/demo/') # Pull in all jpgs |
| 153 | + if i.lower().endswith(".jpg")] |
| 154 | + |
| 155 | + for im_name in im_names: |
| 156 | + print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') |
| 157 | + print('Demo for data/demo/{}'.format(im_name)) |
| 158 | + demo(net, im_name) |
| 159 | + |
| 160 | + plt.show() |
0 commit comments