11# coding=utf-8
22from __future__ import print_function
3+
4+ import logging
35import os
46
57os .environ ["CUDA_VISIBLE_DEVICES" ] = "0"
2426parser .add_argument ('--ngpu' , default = 1 , type = int , help = 'gpus' )
2527parser .add_argument ('--lr' , '--learning-rate' , default = 1e-3 , type = float , help = 'initial learning rate' )
2628parser .add_argument ('--momentum' , default = 0.9 , type = float , help = 'momentum' )
27- parser .add_argument ('--resume_net' , default = None , help = 'resume net for retraining' )
29+ parser .add_argument ('--resume_net' , default = "./weights/FaceBoxes_epoch_295.pth" , help = 'resume net for retraining' )
2830parser .add_argument ('--resume_epoch' , default = 0 , type = int , help = 'resume iter for retraining' )
2931parser .add_argument ('-max' , '--max_epoch' , default = 300 , type = int , help = 'max epoch for retraining' )
3032parser .add_argument ('--weight_decay' , default = 5e-4 , type = float , help = 'Weight decay for SGD' )
7173 cudnn .benchmark = True
7274
7375optimizer = optim .SGD (net .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
74- criterion = MultiBoxLoss (num_classes , 0.35 , True , 0 , True , 7 , 0.35 , False )
76+ criterion = MultiBoxLoss (num_classes , 0.35 , True , 0 , True , 3 , 0.35 , False )
7577
7678priorbox = PriorBox (cfg )
7779with torch .no_grad ():
8183
8284
8385def train ():
86+ prefix = time .strftime ("%Y-%m-%d-%H:%M:%S" )
87+ file_path = "models_{}" .format (prefix )
88+ if not os .path .exists (file_path ):
89+ os .mkdir (file_path )
90+ logging .basicConfig ()
91+ logging .getLogger ().setLevel (logging .INFO )
92+ fh = logging .FileHandler ("{}/train.log" .format (file_path ))
93+ # create formatter#
94+ formatter = logging .Formatter ("%(asctime)s - %(levelname)s - %(message)s" )
95+ # add formatter to ch
96+ fh .setFormatter (formatter )
97+ logging .getLogger ().addHandler (fh )
98+
8499 net .train ()
85100 epoch = 0 + args .resume_epoch
86- print ('Loading Dataset...' )
101+ logging . info ('Loading Dataset...' )
87102
88103 args .training_dataset = os .path .expanduser (args .training_dataset )
89104 dataset = VOCDetection (args .training_dataset , preproc (img_dim , rgb_means ), AnnotationTransform ())
90105
91- print ("len(dataset):" , len (dataset ))
106+ logging . info ("len(dataset): %s " , len (dataset ))
92107 epoch_size = int (math .ceil (len (dataset ) / args .batch_size ))
93- print ("epoch_size:" , epoch_size )
108+ logging . info ("epoch_size: %s " , epoch_size )
94109 max_iter = args .max_epoch * epoch_size
95110
96111 stepvalues = (200 * epoch_size , 250 * epoch_size )
@@ -106,7 +121,7 @@ def train():
106121 # create batch iterator
107122 batch_iterator = iter (data .DataLoader (dataset , batch_size , shuffle = True , num_workers = args .num_workers , collate_fn = detection_collate ))
108123 if (epoch % 10 == 0 and epoch > 0 ) or (epoch % 5 == 0 and epoch > 200 ):
109- torch .save (net .state_dict (), args . save_folder + 'FaceBoxes_epoch_' + repr (epoch ) + '.pth' )
124+ torch .save (net .state_dict (), file_path + 'FaceBoxes_epoch_' + repr (epoch ) + '.pth' )
110125 epoch += 1
111126
112127 load_t0 = time .time ()
@@ -133,10 +148,12 @@ def train():
133148 loss .backward ()
134149 optimizer .step ()
135150 load_t1 = time .time ()
136- print ('Epoch:' + repr (epoch ) + ' || epochiter: ' + repr (iteration % epoch_size ) + '/' + repr (epoch_size ) +
137- '|| Totel iter ' + repr (iteration ) + ' || L: %.4f C: %.4f||' % (cfg ['loc_weight' ] * loss_l .item (), loss_c .item ()) +
138- 'Batch time: %.4f sec. ||' % (load_t1 - load_t0 ) + 'LR: %.8f' % (lr ))
151+ # logging.info ('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) +
152+ # '|| Totel iter ' + repr(iteration) + ' || L: %.4f C: %.4f||' % (cfg['loc_weight'] * loss_l.item(), loss_c.item()) +
153+ # 'Batch time: %.4f sec. ||' % (load_t1 - load_t0) + 'LR: %.8f' % (lr))
139154
155+ logging .info ("epoch %s epochiter %s epoch_size %s iteration %s loss %s l_loss %s c_loss %s" ,
156+ epoch , iteration % epoch_size , epoch_size , iteration , loss .item (), loss_l .item (), loss_c .item ())
140157 torch .save (net .state_dict (), args .save_folder + 'Final_FaceBoxes.pth' )
141158
142159
0 commit comments