77
88import os
99from datasets .imdb import imdb
10- import xml .dom . minidom as minidom
10+ import xml .etree . ElementTree as ET
1111import numpy as np
1212import scipy .sparse
1313import scipy .io as sio
1414import utils .cython_bbox
1515import cPickle
1616import subprocess
17+ import uuid
18+ from voc_eval import voc_eval
1719from fast_rcnn .config import cfg
1820
1921class pascal_voc (imdb ):
@@ -35,13 +37,16 @@ def __init__(self, image_set, year, devkit_path=None):
3537 self ._image_index = self ._load_image_set_index ()
3638 # Default to roidb handler
3739 self ._roidb_handler = self .selective_search_roidb
40+ self ._salt = str (uuid .uuid4 ())
41+ self ._comp_id = 'comp4'
3842
3943 # PASCAL specific config options
40- self .config = {'cleanup' : True ,
41- 'use_salt' : True ,
42- 'top_k' : 2000 ,
43- 'use_diff' : False ,
44- 'rpn_file' : None }
44+ self .config = {'cleanup' : True ,
45+ 'use_salt' : True ,
46+ 'top_k' : 2000 ,
47+ 'use_diff' : False ,
48+ 'matlab_eval' : False ,
49+ 'rpn_file' : None }
4550
4651 assert os .path .exists (self ._devkit_path ), \
4752 'VOCdevkit path does not exist: {}' .format (self ._devkit_path )
@@ -172,21 +177,15 @@ def _load_pascal_annotation(self, index):
172177 format.
173178 """
174179 filename = os .path .join (self ._data_path , 'Annotations' , index + '.xml' )
175- # print 'Loading: {}'.format(filename)
176- def get_data_from_tag (node , tag ):
177- return node .getElementsByTagName (tag )[0 ].childNodes [0 ].data
178-
179- with open (filename ) as f :
180- data = minidom .parseString (f .read ())
181-
182- objs = data .getElementsByTagName ('object' )
180+ tree = ET .parse (filename )
181+ objs = tree .findall ('object' )
183182 if not self .config ['use_diff' ]:
184183 # Exclude the samples labeled as difficult
185- non_diff_objs = [obj for obj in objs
186- if int (get_data_from_tag ( obj , 'difficult' )) == 0 ]
184+ non_diff_objs = [
185+ obj for obj in objs if int (obj . find ( 'difficult' ). text ) == 0 ]
187186 if len (non_diff_objs ) != len (objs ):
188- print 'Removed {} difficult objects' \
189- . format ( len (objs ) - len (non_diff_objs ))
187+ print 'Removed {} difficult objects' . format (
188+ len (objs ) - len (non_diff_objs ))
190189 objs = non_diff_objs
191190 num_objs = len (objs )
192191
@@ -196,13 +195,13 @@ def get_data_from_tag(node, tag):
196195
197196 # Load object bounding boxes into a data frame.
198197 for ix , obj in enumerate (objs ):
198+ bbox = obj .find ('bndbox' )
199199 # Make pixel indexes 0-based
200- x1 = float (get_data_from_tag (obj , 'xmin' )) - 1
201- y1 = float (get_data_from_tag (obj , 'ymin' )) - 1
202- x2 = float (get_data_from_tag (obj , 'xmax' )) - 1
203- y2 = float (get_data_from_tag (obj , 'ymax' )) - 1
204- cls = self ._class_to_ind [
205- str (get_data_from_tag (obj , "name" )).lower ().strip ()]
200+ x1 = float (bbox .find ('xmin' ).text ) - 1
201+ y1 = float (bbox .find ('ymin' ).text ) - 1
202+ x2 = float (bbox .find ('xmax' ).text ) - 1
203+ y2 = float (bbox .find ('ymax' ).text ) - 1
204+ cls = self ._class_to_ind [obj .find ('name' ).text .lower ().strip ()]
206205 boxes [ix , :] = [x1 , y1 , x2 , y2 ]
207206 gt_classes [ix ] = cls
208207 overlaps [ix , cls ] = 1.0
@@ -214,20 +213,28 @@ def get_data_from_tag(node, tag):
214213 'gt_overlaps' : overlaps ,
215214 'flipped' : False }
216215
216+ def _get_comp_id (self ):
217+ comp_id = (self ._comp_id + '_' + self ._salt if self .config ['use_salt' ]
218+ else self ._comp_id )
219+ return comp_id
220+
221+ def _get_voc_results_file_template (self ):
222+ # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
223+ filename = self ._get_comp_id () + '_det_' + self ._image_set + '_{:s}.txt'
224+ path = os .path .join (
225+ self ._devkit_path ,
226+ 'results' ,
227+ 'VOC' + self ._year ,
228+ 'Main' ,
229+ filename )
230+ return path
231+
217232 def _write_voc_results_file (self , all_boxes ):
218- use_salt = self .config ['use_salt' ]
219- comp_id = 'comp4'
220- if use_salt :
221- comp_id += '-{}' .format (os .getpid ())
222-
223- # VOCdevkit/results/VOC2007/Main/comp4-44503_det_test_aeroplane.txt
224- path = os .path .join (self ._devkit_path , 'results' , 'VOC' + self ._year ,
225- 'Main' , comp_id + '_' )
226233 for cls_ind , cls in enumerate (self .classes ):
227234 if cls == '__background__' :
228235 continue
229236 print 'Writing {} VOC results file' .format (cls )
230- filename = path + 'det_' + self ._image_set + '_' + cls + '.txt'
237+ filename = self ._get_voc_results_file_template (). format ( cls )
231238 with open (filename , 'wt' ) as f :
232239 for im_ind , index in enumerate (self .image_index ):
233240 dets = all_boxes [cls_ind ][im_ind ]
@@ -239,25 +246,76 @@ def _write_voc_results_file(self, all_boxes):
239246 format (index , dets [k , - 1 ],
240247 dets [k , 0 ] + 1 , dets [k , 1 ] + 1 ,
241248 dets [k , 2 ] + 1 , dets [k , 3 ] + 1 ))
242- return comp_id
243-
244- def _do_matlab_eval (self , comp_id , output_dir = 'output' ):
245- rm_results = self .config ['cleanup' ]
246249
250+ def _do_python_eval (self , output_dir = 'output' ):
251+ print '--------------------------------------------------------------'
252+ print 'Computing results with **unofficial** Python eval code.'
253+ print 'Results should be very close to the official MATLAB eval code.'
254+ print 'Recompute with `./tools/reval.py --matlab ...` for your paper.'
255+ print '--------------------------------------------------------------'
256+ annopath = os .path .join (
257+ self ._devkit_path ,
258+ 'VOC' + self ._year ,
259+ 'Annotations' ,
260+ '{:s}.xml' )
261+ imagesetfile = os .path .join (
262+ self ._devkit_path ,
263+ 'VOC' + self ._year ,
264+ 'ImageSets' ,
265+ 'Main' ,
266+ self ._image_set + '.txt' )
267+ cachedir = os .path .join (self ._devkit_path , 'annotations_cache' )
268+ aps = []
269+ # The PASCAL VOC metric changed in 2010
270+ use_07_metric = True if int (self ._year ) < 2010 else False
271+ print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No' )
272+ if not os .path .isdir (output_dir ):
273+ os .mkdir (output_dir )
274+ for i , cls in enumerate (self ._classes ):
275+ if cls == '__background__' :
276+ continue
277+ filename = self ._get_voc_results_file_template ().format (cls )
278+ rec , prec , ap = voc_eval (
279+ filename , annopath , imagesetfile , cls , cachedir , ovthresh = 0.5 ,
280+ use_07_metric = use_07_metric )
281+ aps += [ap ]
282+ print ('AP for {} = {:.4f}' .format (cls , ap ))
283+ with open (os .path .join (output_dir , cls + '_pr.pkl' ), 'w' ) as f :
284+ cPickle .dump ({'rec' : rec , 'prec' : prec , 'ap' : ap }, f )
285+ print ('Mean AP = {:.4f}' .format (np .mean (aps )))
286+ print ('~~~~~~~~' )
287+ print ('Results:' )
288+ for ap in aps :
289+ print ('{:.3f}' .format (ap ))
290+ print ('{:.3f}' .format (np .mean (aps )))
291+ print ('~~~~~~~~' )
292+
293+ def _do_matlab_eval (self , output_dir = 'output' ):
294+ print '-----------------------------------------------------'
295+ print 'Computing results with the official MATLAB eval code.'
296+ print '-----------------------------------------------------'
247297 path = os .path .join (cfg .ROOT_DIR , 'lib' , 'datasets' ,
248298 'VOCdevkit-matlab-wrapper' )
249299 cmd = 'cd {} && ' .format (path )
250300 cmd += '{:s} -nodisplay -nodesktop ' .format (cfg .MATLAB )
251301 cmd += '-r "dbstop if error; '
252- cmd += 'voc_eval(\' {:s}\' ,\' {:s}\' ,\' {:s}\' ,\' {:s}\' ,{:d} ); quit;"' \
253- .format (self ._devkit_path , comp_id ,
254- self ._image_set , output_dir , int ( rm_results ) )
302+ cmd += 'voc_eval(\' {:s}\' ,\' {:s}\' ,\' {:s}\' ,\' {:s}\' ); quit;"' \
303+ .format (self ._devkit_path , self . _get_comp_id () ,
304+ self ._image_set , output_dir )
255305 print ('Running:\n {}' .format (cmd ))
256306 status = subprocess .call (cmd , shell = True )
257307
258308 def evaluate_detections (self , all_boxes , output_dir ):
259- comp_id = self ._write_voc_results_file (all_boxes )
260- self ._do_matlab_eval (comp_id , output_dir )
309+ self ._write_voc_results_file (all_boxes )
310+ self ._do_python_eval (output_dir )
311+ if self .config ['matlab_eval' ]:
312+ self ._do_matlab_eval (output_dir )
313+ if self .config ['cleanup' ]:
314+ for cls in self ._classes :
315+ if cls == '__background__' :
316+ continue
317+ filename = self ._get_voc_results_file_template ().format (cls )
318+ os .remove (filename )
261319
262320 def competition_mode (self , on ):
263321 if on :
0 commit comments