@@ -206,38 +206,40 @@ def Train(self):
206206 Train network in n epochs, n is defined in params.num_epoch
207207 """
208208 self .init_epoch = self .epoch
209- assert self .epoch <= self .params .num_epoch , 'Current epoch should not be larger than num_epoch'
210- for _ in range (self .epoch , self .params .num_epoch ):
211- self .epoch += 1
212- print ('-' * 20 + 'Epoch.' + str (self .epoch ) + '-' * 20 )
209+ if self .epoch >= self .params .num_epoch :
210+ WARNING ('Num_epoch should be smaller than current epoch. Skip training......\n ' )
211+ else :
212+ for _ in range (self .epoch , self .params .num_epoch ):
213+ self .epoch += 1
214+ print ('-' * 20 + 'Epoch.' + str (self .epoch ) + '-' * 20 )
213215
214- # train one epoch
215- self .train_one_epoch ()
216+ # train one epoch
217+ self .train_one_epoch ()
216218
217- # should display
218- if self .epoch % self .params .display == 0 :
219- print ('\t Train loss: %.4f' % self .train_loss [- 1 ])
219+ # should display
220+ if self .epoch % self .params .display == 0 :
221+ print ('\t Train loss: %.4f' % self .train_loss [- 1 ])
220222
221- # should save
222- if self .params .should_save :
223- if self .epoch % self .params .save_every == 0 :
224- self .save_checkpoint ()
223+ # should save
224+ if self .params .should_save :
225+ if self .epoch % self .params .save_every == 0 :
226+ self .save_checkpoint ()
225227
226- # test every params.test_every epoch
227- if self .params .should_val :
228- if self .epoch % self .params .val_every == 0 :
229- self .val_one_epoch ()
230- print ('\t Val loss: %.4f' % self .val_loss [- 1 ])
228+ # test every params.test_every epoch
229+ if self .params .should_val :
230+ if self .epoch % self .params .val_every == 0 :
231+ self .val_one_epoch ()
232+ print ('\t Val loss: %.4f' % self .val_loss [- 1 ])
231233
232- # adjust learning rate
233- self .adjust_lr ()
234+ # adjust learning rate
235+ self .adjust_lr ()
234236
235- # save the last network state
236- if self .params .should_save :
237- self .save_checkpoint ()
237+ # save the last network state
238+ if self .params .should_save :
239+ self .save_checkpoint ()
238240
239- # train visualization
240- self .plot_curve ()
241+ # train visualization
242+ self .plot_curve ()
241243
242244 def Test (self ):
243245 """
@@ -276,8 +278,8 @@ def Test(self):
276278 image_orig = image [i ].numpy ().transpose (1 , 2 , 0 )
277279 image_orig = image_orig * 255
278280 image_orig = image_orig .astype (np .uint8 )
279- self .summary_writer .add_image ('img_%d/orig' % idx , image_orig , idx )
280- self .summary_writer .add_image ('img_%d/seg' % idx , color_map , idx )
281+ self .summary_writer .add_image ('test/ img_%d/orig' % idx , image_orig , idx )
282+ self .summary_writer .add_image ('test/ img_%d/seg' % idx , color_map , idx )
281283
282284 """##########################"""
283285 """# Model Save and Restore #"""
@@ -367,9 +369,11 @@ def plot_curve(self):
367369 """
368370 Plot train/val loss curve
369371 """
370- x = np .arange (self .init_epoch , self .params .num_epoch + 1 , dtype = np .int ).tolist ()
371- plt .plot (x , self .train_loss , label = 'train_loss' )
372- plt .plot (x , self .val_loss , label = 'val_loss' )
372+ x1 = np .arange (self .init_epoch , self .params .num_epoch + 1 , dtype = np .int ).tolist ()
373+ x2 = np .linspace (self .init_epoch , self .epoch ,
374+ num = (self .epoch - self .init_epoch )// self .params .val_every + 1 , dtype = np .int64 )
375+ plt .plot (x1 , self .train_loss , label = 'train_loss' )
376+ plt .plot (x2 , self .val_loss , label = 'val_loss' )
373377 plt .legend (loc = 'best' )
374378 plt .title ('Train/Val loss' )
375379 plt .grid ()
0 commit comments