2727import random
2828import shutil
2929
30+ import paddle
3031import numpy as np
3132import paddle .fluid as fluid
3233from paddle .fluid import profiler
@@ -158,6 +159,15 @@ def load_checkpoint(exe, program):
158159 return begin_epoch
159160
160161
162+ def save_infer_program (test_program , ckpt_dir ):
163+ _test_program = test_program .clone ()
164+ _test_program .desc .flush ()
165+ _test_program .desc ._set_version ()
166+ paddle .fluid .core .save_op_compatible_info (_test_program .desc )
167+ with open (os .path .join (ckpt_dir , 'model' ) + ".pdmodel" , "wb" ) as f :
168+ f .write (_test_program .desc .serialize_to_string ())
169+
170+
161171def update_best_model (ckpt_dir ):
162172 best_model_dir = os .path .join (cfg .TRAIN .MODEL_SAVE_DIR , 'best_model' )
163173 if os .path .exists (best_model_dir ):
@@ -173,6 +183,7 @@ def print_info(*msg):
173183def train (cfg ):
174184 startup_prog = fluid .Program ()
175185 train_prog = fluid .Program ()
186+ test_prog = fluid .Program ()
176187 if args .enable_ce :
177188 startup_prog .random_seed = 1000
178189 train_prog .random_seed = 1000
@@ -224,6 +235,7 @@ def data_generator():
224235
225236 data_loader , avg_loss , lr , pred , grts , masks = build_model (
226237 train_prog , startup_prog , phase = ModelPhase .TRAIN )
238+ build_model (test_prog , fluid .Program (), phase = ModelPhase .EVAL )
227239 data_loader .set_sample_generator (
228240 data_generator , batch_size = batch_size_per_dev , drop_last = drop_last )
229241
@@ -387,6 +399,7 @@ def data_generator():
387399 if (epoch % cfg .TRAIN .SNAPSHOT_EPOCH == 0
388400 or epoch == cfg .SOLVER .NUM_EPOCHS ) and cfg .TRAINER_ID == 0 :
389401 ckpt_dir = save_checkpoint (train_prog , epoch )
402+ save_infer_program (test_prog , ckpt_dir )
390403
391404 if args .do_eval :
392405 print ("Evaluation start" )
@@ -419,7 +432,8 @@ def data_generator():
419432
420433 # save final model
421434 if cfg .TRAINER_ID == 0 :
422- save_checkpoint (train_prog , 'final' )
435+ ckpt_dir = save_checkpoint (train_prog , 'final' )
436+ save_infer_program (test_prog , ckpt_dir )
423437
424438
425439def main (args ):
0 commit comments