@@ -205,7 +205,8 @@ def update_args(args):
205205 args .test_add_geo_coords = bool (args .test_add_geo_coords )
206206
207207 # set cuda values
208- if args .gpu >= 0 :
208+ # if args.gpu >= 0:
209+ if args .gpu != "-1" :
209210 args .use_GPU , args .use_CUDNN = 1 , 1
210211 else :
211212 args .use_GPU , args .use_CUDNN = 0 , 0
@@ -661,6 +662,7 @@ def yolt_command(framework='yolt2',
661662 else :
662663 gpu_cmd = '-i ' + str (gpu )
663664 # gpu_cmd = '-i ' + str(3-args.gpu) # originally, numbers were reversed
665+ ngpus = len (gpu .split (',' ))
664666
665667 ##########################
666668 # SET VARIABLES ACCORDING TO MODE (SET UNNECCESSARY VALUES TO 0 OR NULL)
@@ -713,6 +715,7 @@ def yolt_command(framework='yolt2',
713715 str (nbands ),
714716 yolt_loss_file ,
715717 str (min_retain_prob ),
718+ str (ngpus ),
716719 suffix
717720 ]
718721
@@ -1732,8 +1735,9 @@ def main():
17321735 help = "object detection framework [yolt2, 'yolt3', ssd, faster_rcnn]" )
17331736 parser .add_argument ('--mode' , type = str , default = 'test' ,
17341737 help = "[compile, test, train, test]" )
1735- parser .add_argument ('--gpu' , type = int , default = 0 ,
1736- help = "GPU number, set < 0 to turn off GPU support" )
1738+ parser .add_argument ('--gpu' , type = str , default = "0" ,
1739+ help = "GPU number, set < 0 to turn off GPU support " \
1740+ "to use multiple, use '0,1'" )
17371741 parser .add_argument ('--single_gpu_machine' , type = int , default = 0 ,
17381742 help = "Switch to use a machine with just one gpu" )
17391743 parser .add_argument ('--nbands' , type = int , default = 3 ,
0 commit comments