@@ -729,8 +729,9 @@ def tpu_scaffold():
729729 tf .logging .info (" name = %s, shape = %s%s" , var .name , var .shape ,
730730 init_string )
731731
732+ is_multi_gpu = use_gpu and int (num_gpu_cores ) >= 2
732733 if mode == tf .estimator .ModeKeys .TRAIN :
733- if use_gpu and int ( num_gpu_cores ) >= 2 :
734+ if is_multi_gpu :
734735 train_op = custom_optimization .create_optimizer (
735736 total_loss , learning_rate , num_train_steps , num_warmup_steps , fp16 = fp16 )
736737 output_spec = tf .estimator .EstimatorSpec (
@@ -777,11 +778,15 @@ def metric_fn(per_example_loss, label_ids, logits, is_real_example):
777778 eval_metrics = eval_metrics ,
778779 scaffold_fn = scaffold_fn )
779780 else :
780- # predict on single-gpu only
781- output_spec = tf .contrib .tpu .TPUEstimatorSpec (
782- mode = mode ,
783- predictions = {"probabilities" : probabilities },
784- scaffold_fn = scaffold_fn )
781+ if is_multi_gpu :
782+ output_spec = tf .estimator .EstimatorSpec (
783+ mode = mode ,
784+ predictions = {"probabilities" : probabilities })
785+ else :
786+ output_spec = tf .contrib .tpu .TPUEstimatorSpec (
787+ mode = mode ,
788+ predictions = {"probabilities" : probabilities },
789+ scaffold_fn = scaffold_fn )
785790
786791 return output_spec
787792
@@ -947,45 +952,47 @@ def main(_):
947952 num_shards = FLAGS .num_tpu_cores ,
948953 per_host_input_for_training = is_per_host ))
949954
950- num_train_steps = 0
951- num_warmup_steps = 0
952- init_checkpoint = FLAGS .init_checkpoint
953- is_multi_gpu = FLAGS .use_gpu and int (FLAGS .num_gpu_cores ) >= 2
955+ train_examples = None
956+ num_train_steps = None
957+ num_warmup_steps = None
954958 if FLAGS .do_train :
955959 train_examples = processor .get_train_examples (FLAGS .data_dir )
956960 num_train_steps = int (
957961 len (train_examples ) / FLAGS .train_batch_size * FLAGS .num_train_epochs )
958962 num_warmup_steps = int (num_train_steps * FLAGS .warmup_proportion )
959963
960- model_fn = model_fn_builder (
961- bert_config = bert_config ,
962- num_labels = len (label_list ),
963- init_checkpoint = init_checkpoint ,
964- learning_rate = FLAGS .learning_rate ,
965- num_train_steps = num_train_steps ,
966- num_warmup_steps = num_warmup_steps ,
967- use_tpu = FLAGS .use_tpu ,
968- use_one_hot_embeddings = FLAGS .use_tpu ,
969- use_gpu = FLAGS .use_gpu ,
970- num_gpu_cores = FLAGS .num_gpu_cores ,
971- fp16 = FLAGS .use_fp16 )
964+ init_checkpoint = FLAGS .init_checkpoint
965+ is_multi_gpu = FLAGS .use_gpu and int (FLAGS .num_gpu_cores ) >= 2
966+ model_fn = model_fn_builder (
967+ bert_config = bert_config ,
968+ num_labels = len (label_list ),
969+ init_checkpoint = init_checkpoint ,
970+ learning_rate = FLAGS .learning_rate ,
971+ num_train_steps = num_train_steps ,
972+ num_warmup_steps = num_warmup_steps ,
973+ use_tpu = FLAGS .use_tpu ,
974+ use_one_hot_embeddings = FLAGS .use_tpu ,
975+ use_gpu = FLAGS .use_gpu ,
976+ num_gpu_cores = FLAGS .num_gpu_cores ,
977+ fp16 = FLAGS .use_fp16 )
972978
973- # If TPU is not available, this will fall back to normal Estimator on CPU
974- # or GPU.
975- if is_multi_gpu :
976- estimator = Estimator (
977- model_fn = model_fn ,
978- params = {},
979- config = dist_run_config )
980- else :
981- estimator = tf .contrib .tpu .TPUEstimator (
982- use_tpu = FLAGS .use_tpu ,
983- model_fn = model_fn ,
984- config = tpu_run_config ,
985- train_batch_size = FLAGS .train_batch_size ,
986- eval_batch_size = FLAGS .eval_batch_size ,
987- predict_batch_size = FLAGS .predict_batch_size )
979+ # If TPU is not available, this will fall back to normal Estimator on CPU
980+ # or GPU.
981+ if is_multi_gpu :
982+ estimator = Estimator (
983+ model_fn = model_fn ,
984+ params = {},
985+ config = dist_run_config )
986+ else :
987+ estimator = tf .contrib .tpu .TPUEstimator (
988+ use_tpu = FLAGS .use_tpu ,
989+ model_fn = model_fn ,
990+ config = tpu_run_config ,
991+ train_batch_size = FLAGS .train_batch_size ,
992+ eval_batch_size = FLAGS .eval_batch_size ,
993+ predict_batch_size = FLAGS .predict_batch_size )
988994
995+ if FLAGS .do_train :
989996 train_file = os .path .join (FLAGS .output_dir , "train.tf_record" )
990997 file_based_convert_examples_to_features (
991998 train_examples , label_list , FLAGS .max_seq_length , tokenizer , train_file )
@@ -1011,32 +1018,30 @@ def main(_):
10111018 if filename .startswith ('model.ckpt-' ):
10121019 max_idx = max (int (filename .split ('.' )[1 ].split ('-' )[1 ]), max_idx )
10131020 init_checkpoint = os .path .join (FLAGS .output_dir , f'model.ckpt-{ max_idx } ' )
1021+ tf .logging .info (f'Current checkpoint: { init_checkpoint } ' )
10141022
1015- if not FLAGS .do_eval and not FLAGS .do_predict :
1016- return
1017-
1018- model_fn = model_fn_builder (
1019- bert_config = bert_config ,
1020- num_labels = len (label_list ),
1021- init_checkpoint = init_checkpoint ,
1022- learning_rate = FLAGS .learning_rate ,
1023- num_train_steps = num_train_steps ,
1024- num_warmup_steps = num_warmup_steps ,
1025- use_tpu = FLAGS .use_tpu ,
1026- use_one_hot_embeddings = FLAGS .use_tpu ,
1027- use_gpu = FLAGS .use_gpu ,
1028- num_gpu_cores = FLAGS .num_gpu_cores ,
1029- fp16 = FLAGS .use_fp16 )
1023+ if FLAGS .do_eval :
1024+ model_fn = model_fn_builder (
1025+ bert_config = bert_config ,
1026+ num_labels = len (label_list ),
1027+ init_checkpoint = init_checkpoint ,
1028+ learning_rate = FLAGS .learning_rate ,
1029+ num_train_steps = num_train_steps ,
1030+ num_warmup_steps = num_warmup_steps ,
1031+ use_tpu = FLAGS .use_tpu ,
1032+ use_one_hot_embeddings = FLAGS .use_tpu ,
1033+ use_gpu = FLAGS .use_gpu ,
1034+ num_gpu_cores = FLAGS .num_gpu_cores ,
1035+ fp16 = FLAGS .use_fp16 )
10301036
1031- estimator = tf .contrib .tpu .TPUEstimator (
1032- use_tpu = FLAGS .use_tpu ,
1033- model_fn = model_fn ,
1034- config = tpu_run_config ,
1035- train_batch_size = FLAGS .train_batch_size ,
1036- eval_batch_size = FLAGS .eval_batch_size ,
1037- predict_batch_size = FLAGS .predict_batch_size )
1037+ eval_estimator = tf .contrib .tpu .TPUEstimator (
1038+ use_tpu = FLAGS .use_tpu ,
1039+ model_fn = model_fn ,
1040+ config = tpu_run_config ,
1041+ train_batch_size = FLAGS .train_batch_size ,
1042+ eval_batch_size = FLAGS .eval_batch_size ,
1043+ predict_batch_size = FLAGS .predict_batch_size )
10381044
1039- if FLAGS .do_eval :
10401045 eval_examples = processor .get_dev_examples (FLAGS .data_dir )
10411046 num_actual_eval_examples = len (eval_examples )
10421047 if FLAGS .use_tpu :
@@ -1074,7 +1079,7 @@ def main(_):
10741079 drop_remainder = eval_drop_remainder ,
10751080 batch_size = FLAGS .eval_batch_size )
10761081
1077- result = estimator .evaluate (input_fn = eval_input_fn , steps = eval_steps )
1082+ result = eval_estimator .evaluate (input_fn = eval_input_fn , steps = eval_steps )
10781083
10791084 output_eval_file = os .path .join (FLAGS .output_dir , "eval_results.txt" )
10801085 with tf .gfile .GFile (output_eval_file , "w" ) as writer :
0 commit comments