Skip to content

Commit d13997c

Browse files
authored
Merge pull request open-mmlab#12 from hellock/dev
Bug fix for recall evaluation
2 parents 14a7dfb + 7c2b814 commit d13997c

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

mmdet/core/evaluation/coco_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
1616
coco = COCO(coco)
1717
assert isinstance(coco, COCO)
1818

19-
if res_type == 'proposal_fast':
20-
ar = fast_eval_recall(result_file, coco, max_dets)
19+
if result_types == ['proposal_fast']:
20+
ar = fast_eval_recall(result_file, coco, np.array(max_dets))
2121
for i, num in enumerate(max_dets):
2222
print('AR@{}\t= {:.4f}'.format(num, ar[i]))
2323
return

tools/test.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@ def parse_args():
3939
parser = argparse.ArgumentParser(description='MMDet test detector')
4040
parser.add_argument('config', help='test config file path')
4141
parser.add_argument('checkpoint', help='checkpoint file')
42-
parser.add_argument('--gpus', default=1, type=int)
42+
parser.add_argument(
43+
'--gpus', default=1, type=int, help='GPU number used for testing')
44+
parser.add_argument(
45+
'--proc_per_gpu',
46+
default=1,
47+
type=int,
48+
help='Number of processes per GPU')
4349
parser.add_argument('--out', help='output result file')
4450
parser.add_argument(
4551
'--eval',
@@ -55,6 +61,9 @@ def parse_args():
5561
def main():
5662
args = parse_args()
5763

64+
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
65+
raise ValueError('The output file must be a pkl file.')
66+
5867
cfg = mmcv.Config.fromfile(args.config)
5968
cfg.model.pretrained = None
6069
cfg.data.test.test_mode = True
@@ -78,15 +87,27 @@ def main():
7887
model_args = cfg.model.copy()
7988
model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
8089
model_type = getattr(detectors, model_args.pop('type'))
81-
outputs = parallel_test(model_type, model_args, args.checkpoint,
82-
dataset, _data_func, range(args.gpus))
90+
outputs = parallel_test(
91+
model_type,
92+
model_args,
93+
args.checkpoint,
94+
dataset,
95+
_data_func,
96+
range(args.gpus),
97+
workers_per_gpu=args.proc_per_gpu)
8398

8499
if args.out:
100+
print('writing results to {}'.format(args.out))
85101
mmcv.dump(outputs, args.out)
86-
if args.eval:
87-
json_file = args.out + '.json'
88-
results2json(dataset, outputs, json_file)
89-
coco_eval(json_file, args.eval, dataset.coco)
102+
eval_types = args.eval
103+
if eval_types:
104+
print('Starting evaluate {}'.format(' and '.join(eval_types)))
105+
if eval_types == ['proposal_fast']:
106+
result_file = args.out
107+
else:
108+
result_file = args.out + '.json'
109+
results2json(dataset, outputs, result_file)
110+
coco_eval(result_file, eval_types, dataset.coco)
90111

91112

92113
if __name__ == '__main__':

0 commit comments

Comments
 (0)