@@ -39,7 +39,13 @@ def parse_args():
3939 parser = argparse .ArgumentParser (description = 'MMDet test detector' )
4040 parser .add_argument ('config' , help = 'test config file path' )
4141 parser .add_argument ('checkpoint' , help = 'checkpoint file' )
42- parser .add_argument ('--gpus' , default = 1 , type = int )
42+ parser .add_argument (
43+ '--gpus' , default = 1 , type = int , help = 'GPU number used for testing' )
44+ parser .add_argument (
45+ '--proc_per_gpu' ,
46+ default = 1 ,
47+ type = int ,
48+ help = 'Number of processes per GPU' )
4349 parser .add_argument ('--out' , help = 'output result file' )
4450 parser .add_argument (
4551 '--eval' ,
@@ -55,6 +61,9 @@ def parse_args():
5561def main ():
5662 args = parse_args ()
5763
64+ if args .out is not None and not args .out .endswith (('.pkl' , '.pickle' )):
65+ raise ValueError ('The output file must be a pkl file.' )
66+
5867 cfg = mmcv .Config .fromfile (args .config )
5968 cfg .model .pretrained = None
6069 cfg .data .test .test_mode = True
@@ -78,15 +87,27 @@ def main():
7887 model_args = cfg .model .copy ()
7988 model_args .update (train_cfg = None , test_cfg = cfg .test_cfg )
8089 model_type = getattr (detectors , model_args .pop ('type' ))
81- outputs = parallel_test (model_type , model_args , args .checkpoint ,
82- dataset , _data_func , range (args .gpus ))
90+ outputs = parallel_test (
91+ model_type ,
92+ model_args ,
93+ args .checkpoint ,
94+ dataset ,
95+ _data_func ,
96+ range (args .gpus ),
97+ workers_per_gpu = args .proc_per_gpu )
8398
8499 if args .out :
100+ print ('writing results to {}' .format (args .out ))
85101 mmcv .dump (outputs , args .out )
86- if args .eval :
87- json_file = args .out + '.json'
88- results2json (dataset , outputs , json_file )
89- coco_eval (json_file , args .eval , dataset .coco )
102+ eval_types = args .eval
103+ if eval_types :
104+ print ('Starting evaluate {}' .format (' and ' .join (eval_types )))
105+ if eval_types == ['proposal_fast' ]:
106+ result_file = args .out
107+ else :
108+ result_file = args .out + '.json'
109+ results2json (dataset , outputs , result_file )
110+ coco_eval (result_file , eval_types , dataset .coco )
90111
91112
92113if __name__ == '__main__' :
0 commit comments