Skip to content

Commit f6fb0ac

Browse files
mseals1ruotianluo
mseals1
authored andcommitted
Output image containing all bboxes
Write all bounding boxes into a single image and output. The original demo script writes one bounding box type per output image, i.e. one image for 2 detected persons, one image for 1 detected horse, etc. This script draws all of the bounding boxes for each class onto one image and saves it.
1 parent d63004e commit f6fb0ac

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

tools/demo_all_bboxes.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python
2+
3+
# --------------------------------------------------------
4+
# Tensorflow Faster R-CNN
5+
# Licensed under The MIT License [see LICENSE for details]
6+
# Written by Xinlei Chen, based on code from Ross Girshick
7+
# Edited by Matthew Seals
8+
# --------------------------------------------------------
9+
10+
"""
11+
Demo script showing detections in sample images.
12+
13+
See README.md for installation instructions before running.
14+
"""
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import _init_paths
20+
from model.config import cfg
21+
from model.test import im_detect
22+
from model.nms_wrapper import nms
23+
24+
from utils.timer import Timer
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
import os
28+
import cv2
29+
import argparse
30+
from matplotlib import cm
31+
32+
from nets.vgg16 import vgg16
33+
from nets.resnet_v1 import resnetv1
34+
35+
import torch
36+
37+
CLASSES = ('__background__',
38+
'aeroplane', 'bicycle', 'bird', 'boat',
39+
'bottle', 'bus', 'car', 'cat', 'chair',
40+
'cow', 'diningtable', 'dog', 'horse',
41+
'motorbike', 'person', 'pottedplant',
42+
'sheep', 'sofa', 'train', 'tvmonitor')
43+
44+
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',), 'res101': ('res101_faster_rcnn_iter_%d.pth',)}
45+
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
46+
47+
COLORS = [cm.tab10(i) for i in np.linspace(0., 1., 10)]
48+
49+
50+
def demo(net, image_name):
51+
"""Detect object classes in an image using pre-computed object proposals."""
52+
53+
# Load the demo image
54+
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
55+
im = cv2.imread(im_file)
56+
57+
# Detect all object classes and regress object bounds
58+
timer = Timer()
59+
timer.tic()
60+
scores, boxes = im_detect(net, im)
61+
timer.toc()
62+
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))
63+
64+
# Visualize detections for each class
65+
thresh = 0.8 # CONF_THRESH
66+
NMS_THRESH = 0.3
67+
68+
im = im[:, :, (2, 1, 0)]
69+
fig, ax = plt.subplots(figsize=(12, 12))
70+
ax.imshow(im, aspect='equal')
71+
cntr = -1
72+
73+
for cls_ind, cls in enumerate(CLASSES[1:]):
74+
cls_ind += 1 # because we skipped background
75+
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
76+
cls_scores = scores[:, cls_ind]
77+
dets = np.hstack((cls_boxes,
78+
cls_scores[:, np.newaxis])).astype(np.float32)
79+
keep = nms(torch.from_numpy(dets), NMS_THRESH)
80+
dets = dets[keep.numpy(), :]
81+
inds = np.where(dets[:, -1] >= thresh)[0]
82+
if len(inds) == 0:
83+
continue
84+
else:
85+
cntr += 1
86+
87+
for i in inds:
88+
bbox = dets[i, :4]
89+
score = dets[i, -1]
90+
91+
ax.add_patch(
92+
plt.Rectangle((bbox[0], bbox[1]),
93+
bbox[2] - bbox[0],
94+
bbox[3] - bbox[1], fill=False,
95+
edgecolor=COLORS[cntr % len(COLORS)], linewidth=3.5)
96+
)
97+
ax.text(bbox[0], bbox[1] - 2,
98+
'{:s} {:.3f}'.format(cls, score),
99+
bbox=dict(facecolor='blue', alpha=0.5),
100+
fontsize=14, color='white')
101+
102+
ax.set_title('All detections with threshold >= {:.1f}'.format(thresh), fontsize=14)
103+
104+
plt.axis('off')
105+
plt.tight_layout()
106+
plt.savefig('demo_' + image_name)
107+
print('Saved to `{}`'.format(os.path.join(os.getcwd(), 'demo_' + image_name)))
108+
109+
110+
def parse_args():
111+
"""Parse input arguments."""
112+
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
113+
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
114+
choices=NETS.keys(), default='res101')
115+
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
116+
choices=DATASETS.keys(), default='pascal_voc_0712')
117+
args = parser.parse_args()
118+
119+
return args
120+
121+
122+
if __name__ == '__main__':
123+
cfg.TEST.HAS_RPN = True # Use RPN for proposals
124+
args = parse_args()
125+
126+
# model path
127+
demonet = args.demo_net
128+
dataset = args.dataset
129+
saved_model = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
130+
NETS[demonet][0] % (70000 if dataset == 'pascal_voc' else 110000))
131+
132+
if not os.path.isfile(saved_model):
133+
raise IOError(('{:s} not found.\nDid you download the proper networks from '
134+
'our server and place them properly?').format(saved_model))
135+
136+
# load network
137+
if demonet == 'vgg16':
138+
net = vgg16()
139+
elif demonet == 'res101':
140+
net = resnetv1(num_layers=101)
141+
else:
142+
raise NotImplementedError
143+
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])
144+
145+
net.load_state_dict(torch.load(saved_model))
146+
147+
net.eval()
148+
net.cuda()
149+
150+
print('Loaded network {:s}'.format(saved_model))
151+
152+
im_names = [i for i in os.listdir('data/demo/') # Pull in all jpgs
153+
if i.lower().endswith(".jpg")]
154+
155+
for im_name in im_names:
156+
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
157+
print('Demo for data/demo/{}'.format(im_name))
158+
demo(net, im_name)
159+
160+
plt.show()

0 commit comments

Comments
 (0)