@@ -186,9 +186,8 @@ def trainEpoch(epoch):
186186 # shuffle mini batch order
187187 batchOrder = torch .randperm (len (trainData ))
188188
189- total_loss , report_loss = 0 , 0
190- total_words , report_tgt_words , report_src_words = 0 , 0 , 0
191- total_num_correct = 0
189+ total_loss , total_words , total_num_correct = 0
190+ report_loss , report_tgt_words , report_src_words , report_num_correct = 0
192191 start = time .time ()
193192 for i in range (len (trainData )):
194193
@@ -206,23 +205,24 @@ def trainEpoch(epoch):
206205 # update the parameters
207206 optim .step ()
208207
209- report_loss += loss
210- total_num_correct += num_correct
211- total_loss += loss
212208 num_words = targets .data .ne (onmt .Constants .PAD ).sum ()
213- total_words += num_words
209+ report_loss += loss
210+ report_num_correct += num_correct
214211 report_tgt_words += num_words
215212 report_src_words += batch [0 ].data .ne (onmt .Constants .PAD ).sum ()
213+ total_loss += loss
214+ total_num_correct += num_correct
215+ total_words += num_words
216216 if i % opt .log_interval == - 1 % opt .log_interval :
217217 print ("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
218218 (epoch , i , len (trainData ),
219- num_correct / num_words * 100 ,
219+ report_num_correct / report_tgt_words * 100 ,
220220 math .exp (report_loss / report_tgt_words ),
221221 report_src_words / (time .time ()- start ),
222222 report_tgt_words / (time .time ()- start ),
223223 time .time ()- start_time ))
224224
225- report_loss = report_tgt_words = report_src_words = 0
225+ report_loss = report_tgt_words = report_src_words = report_num_correct = 0
226226 start = time .time ()
227227
228228 return total_loss / total_words , total_num_correct / total_words
0 commit comments