2020
2121import collections
2222import csv
23+ import json
2324import os
25+
26+ import numpy as np
27+ import tensorflow as tf
28+ from tensorflow .python .distribute .cross_device_ops import AllReduceCrossDeviceOps
29+ from tensorflow .python .estimator .estimator import Estimator
30+ from tensorflow .python .estimator .run_config import RunConfig
31+
32+ import custom_optimization
2433import modeling
2534import optimization
26- import custom_optimization
2735import tokenization
28- from tensorflow .python .distribute .cross_device_ops import AllReduceCrossDeviceOps
29- import tensorflow as tf
30- from tensorflow .python .estimator .run_config import RunConfig
31- from tensorflow .python .estimator .estimator import Estimator
3236
3337flags = tf .flags
3438
@@ -749,35 +753,35 @@ def metric_fn(per_example_loss, label_ids, logits, is_real_example):
749753 predictions = tf .argmax (logits , axis = - 1 , output_type = tf .int32 )
750754 accuracy = tf .metrics .accuracy (
751755 labels = label_ids , predictions = predictions , weights = is_real_example )
756+ # add more metrics
757+ pr , pr_op = tf .metrics .precision (
758+ labels = label_ids , predictions = predictions , weights = is_real_example )
759+ re , re_op = tf .metrics .recall (
760+ labels = label_ids , predictions = predictions , weights = is_real_example )
761+ f1 = (2 * pr * re ) / (pr + re ) # f1-score for binary classification
752762 loss = tf .metrics .mean (values = per_example_loss , weights = is_real_example )
753763 return {
754764 "eval_accuracy" : accuracy ,
755- "eval_loss" : loss ,
765+ "eval_precision" : (pr , pr_op ),
766+ "eval_recall" : (re , re_op ),
767+ "eval_f1score" : (f1 , tf .identity (f1 )),
768+ "eval_loss" : loss
756769 }
757770
758771 eval_metrics = (metric_fn ,
759772 [per_example_loss , label_ids , logits , is_real_example ])
760- if use_gpu and int (num_gpu_cores ) >= 2 :
761- output_spec = tf .estimator .EstimatorSpec (
762- mode = mode ,
763- loss = total_loss ,
764- eval_metric_ops = eval_metrics [0 ](* eval_metrics [1 ]))
765- else :
766- output_spec = tf .contrib .tpu .TPUEstimatorSpec (
767- mode = mode ,
768- loss = total_loss ,
769- eval_metrics = eval_metrics ,
770- scaffold_fn = scaffold_fn )
773+ # eval on single-gpu only
774+ output_spec = tf .contrib .tpu .TPUEstimatorSpec (
775+ mode = mode ,
776+ loss = total_loss ,
777+ eval_metrics = eval_metrics ,
778+ scaffold_fn = scaffold_fn )
771779 else :
772- if use_gpu and int (num_gpu_cores ) >= 2 :
773- output_spec = tf .estimator .EstimatorSpec (
774- mode = mode ,
775- predictions = {"probabilities" : probabilities })
776- else :
777- output_spec = tf .contrib .tpu .TPUEstimatorSpec (
778- mode = mode ,
779- predictions = {"probabilities" : probabilities },
780- scaffold_fn = scaffold_fn )
780+ # predict on single-gpu only
781+ output_spec = tf .contrib .tpu .TPUEstimatorSpec (
782+ mode = mode ,
783+ predictions = {"probabilities" : probabilities },
784+ scaffold_fn = scaffold_fn )
781785
782786 return output_spec
783787
@@ -918,76 +922,70 @@ def main(_):
918922 FLAGS .tpu_name , zone = FLAGS .tpu_zone , project = FLAGS .gcp_project )
919923
920924 is_per_host = tf .contrib .tpu .InputPipelineConfig .PER_HOST_V2
921- if FLAGS .use_gpu and int (FLAGS .num_gpu_cores ) >= 2 :
922- tf .logging .info ("Use normal RunConfig" )
923- # https://github.com/tensorflow/tensorflow/issues/21470#issuecomment-422506263
924- dist_strategy = tf .contrib .distribute .MirroredStrategy (
925- num_gpus = FLAGS .num_gpu_cores ,
926- cross_device_ops = AllReduceCrossDeviceOps ('nccl' , num_packs = FLAGS .num_gpu_cores ),
927- # cross_device_ops=AllReduceCrossDeviceOps('hierarchical_copy'),
928- )
929- log_every_n_steps = 8
930- run_config = RunConfig (
931- train_distribute = dist_strategy ,
932- eval_distribute = dist_strategy ,
933- log_step_count_steps = log_every_n_steps ,
934- model_dir = FLAGS .output_dir ,
935- save_checkpoints_steps = FLAGS .save_checkpoints_steps )
936- else :
937- tf .logging .info ("Use TPURunConfig" )
938- run_config = tf .contrib .tpu .RunConfig (
939- cluster = tpu_cluster_resolver ,
940- master = FLAGS .master ,
941- model_dir = FLAGS .output_dir ,
942- save_checkpoints_steps = FLAGS .save_checkpoints_steps ,
943- tpu_config = tf .contrib .tpu .TPUConfig (
944- iterations_per_loop = FLAGS .iterations_per_loop ,
945- num_shards = FLAGS .num_tpu_cores ,
946- per_host_input_for_training = is_per_host ))
947-
948- train_examples = None
949- num_train_steps = None
950- num_warmup_steps = None
925+
926+ # https://github.com/tensorflow/tensorflow/issues/21470#issuecomment-422506263
927+ dist_strategy = tf .contrib .distribute .MirroredStrategy (
928+ num_gpus = FLAGS .num_gpu_cores ,
929+ cross_device_ops = AllReduceCrossDeviceOps ('nccl' , num_packs = FLAGS .num_gpu_cores ),
930+ # cross_device_ops=AllReduceCrossDeviceOps('hierarchical_copy'),
931+ )
932+ log_every_n_steps = 8
933+ dist_run_config = RunConfig (
934+ train_distribute = dist_strategy ,
935+ eval_distribute = dist_strategy ,
936+ log_step_count_steps = log_every_n_steps ,
937+ model_dir = FLAGS .output_dir ,
938+ save_checkpoints_steps = FLAGS .save_checkpoints_steps )
939+
940+ tpu_run_config = tf .contrib .tpu .RunConfig (
941+ cluster = tpu_cluster_resolver ,
942+ master = FLAGS .master ,
943+ model_dir = FLAGS .output_dir ,
944+ save_checkpoints_steps = FLAGS .save_checkpoints_steps ,
945+ tpu_config = tf .contrib .tpu .TPUConfig (
946+ iterations_per_loop = FLAGS .iterations_per_loop ,
947+ num_shards = FLAGS .num_tpu_cores ,
948+ per_host_input_for_training = is_per_host ))
949+
950+ num_train_steps = 0
951+ num_warmup_steps = 0
952+ init_checkpoint = FLAGS .init_checkpoint
953+ is_multi_gpu = FLAGS .use_gpu and int (FLAGS .num_gpu_cores ) >= 2
951954 if FLAGS .do_train :
952955 train_examples = processor .get_train_examples (FLAGS .data_dir )
953956 num_train_steps = int (
954957 len (train_examples ) / FLAGS .train_batch_size * FLAGS .num_train_epochs )
955958 num_warmup_steps = int (num_train_steps * FLAGS .warmup_proportion )
956959
957- init_checkpoint = FLAGS .init_checkpoint
958-
959- model_fn = model_fn_builder (
960- bert_config = bert_config ,
961- num_labels = len (label_list ),
962- init_checkpoint = init_checkpoint ,
963- learning_rate = FLAGS .learning_rate ,
964- num_train_steps = num_train_steps ,
965- num_warmup_steps = num_warmup_steps ,
966- use_tpu = FLAGS .use_tpu ,
967- use_one_hot_embeddings = FLAGS .use_tpu ,
968- use_gpu = FLAGS .use_gpu ,
969- num_gpu_cores = FLAGS .num_gpu_cores ,
970- fp16 = FLAGS .use_fp16 )
971-
972- # If TPU is not available, this will fall back to normal Estimator on CPU
973- # or GPU.
974- if FLAGS .use_gpu and int (FLAGS .num_gpu_cores ) >= 2 :
975- tf .logging .info ("Use normal Estimator" )
976- estimator = Estimator (
977- model_fn = model_fn ,
978- params = {},
979- config = run_config )
980- else :
981- tf .logging .info ("Use TPUEstimator" )
982- estimator = tf .contrib .tpu .TPUEstimator (
960+ model_fn = model_fn_builder (
961+ bert_config = bert_config ,
962+ num_labels = len (label_list ),
963+ init_checkpoint = init_checkpoint ,
964+ learning_rate = FLAGS .learning_rate ,
965+ num_train_steps = num_train_steps ,
966+ num_warmup_steps = num_warmup_steps ,
983967 use_tpu = FLAGS .use_tpu ,
984- model_fn = model_fn ,
985- config = run_config ,
986- train_batch_size = FLAGS .train_batch_size ,
987- eval_batch_size = FLAGS .eval_batch_size ,
988- predict_batch_size = FLAGS .predict_batch_size )
968+ use_one_hot_embeddings = FLAGS .use_tpu ,
969+ use_gpu = FLAGS .use_gpu ,
970+ num_gpu_cores = FLAGS .num_gpu_cores ,
971+ fp16 = FLAGS .use_fp16 )
972+
973+ # If TPU is not available, this will fall back to normal Estimator on CPU
974+ # or GPU.
975+ if is_multi_gpu :
976+ estimator = Estimator (
977+ model_fn = model_fn ,
978+ params = {},
979+ config = dist_run_config )
980+ else :
981+ estimator = tf .contrib .tpu .TPUEstimator (
982+ use_tpu = FLAGS .use_tpu ,
983+ model_fn = model_fn ,
984+ config = tpu_run_config ,
985+ train_batch_size = FLAGS .train_batch_size ,
986+ eval_batch_size = FLAGS .eval_batch_size ,
987+ predict_batch_size = FLAGS .predict_batch_size )
989988
990- if FLAGS .do_train :
991989 train_file = os .path .join (FLAGS .output_dir , "train.tf_record" )
992990 file_based_convert_examples_to_features (
993991 train_examples , label_list , FLAGS .max_seq_length , tokenizer , train_file )
@@ -1002,6 +1000,41 @@ def main(_):
10021000 drop_remainder = True ,
10031001 batch_size = FLAGS .train_batch_size )
10041002 estimator .train (input_fn = train_input_fn , max_steps = num_train_steps )
1003+ # TF Serving
1004+ if FLAGS .save_for_serving :
1005+ serving_dir = os .path .join (FLAGS .output_dir , 'serving' )
1006+ save_for_serving (estimator , serving_dir , FLAGS .max_seq_length , not is_multi_gpu )
1007+
1008+ # Find the latest checkpoint
1009+ max_idx = 0
1010+ for filename in os .listdir (FLAGS .output_dir ):
1011+ if filename .startswith ('model.ckpt-' ):
1012+ max_idx = max (int (filename .split ('.' )[1 ].split ('-' )[1 ]), max_idx )
1013+ init_checkpoint = os .path .join (FLAGS .output_dir , f'model.ckpt-{ max_idx } ' )
1014+
1015+ if not FLAGS .do_eval and not FLAGS .do_predict :
1016+ return
1017+
1018+ model_fn = model_fn_builder (
1019+ bert_config = bert_config ,
1020+ num_labels = len (label_list ),
1021+ init_checkpoint = init_checkpoint ,
1022+ learning_rate = FLAGS .learning_rate ,
1023+ num_train_steps = num_train_steps ,
1024+ num_warmup_steps = num_warmup_steps ,
1025+ use_tpu = FLAGS .use_tpu ,
1026+ use_one_hot_embeddings = FLAGS .use_tpu ,
1027+ use_gpu = FLAGS .use_gpu ,
1028+ num_gpu_cores = FLAGS .num_gpu_cores ,
1029+ fp16 = FLAGS .use_fp16 )
1030+
1031+ estimator = tf .contrib .tpu .TPUEstimator (
1032+ use_tpu = FLAGS .use_tpu ,
1033+ model_fn = model_fn ,
1034+ config = tpu_run_config ,
1035+ train_batch_size = FLAGS .train_batch_size ,
1036+ eval_batch_size = FLAGS .eval_batch_size ,
1037+ predict_batch_size = FLAGS .predict_batch_size )
10051038
10061039 if FLAGS .do_eval :
10071040 eval_examples = processor .get_dev_examples (FLAGS .data_dir )
@@ -1050,6 +1083,22 @@ def main(_):
10501083 tf .logging .info (" %s = %s" , key , str (result [key ]))
10511084 writer .write ("%s = %s\n " % (key , str (result [key ])))
10521085
1086+ # dump result as json file (easy parsing for other tasks)
1087+ class ExtEncoder (json .JSONEncoder ):
1088+ def default (self , obj ):
1089+ if isinstance (obj , np .integer ):
1090+ return int (obj )
1091+ if isinstance (obj , np .floating ):
1092+ return float (obj )
1093+ if isinstance (obj , np .ndarray ):
1094+ return obj .tolist ()
1095+ else :
1096+ return super (ExtEncoder , self ).default (obj )
1097+
1098+ output_eval_file2 = os .path .join (FLAGS .output_dir , "eval_results.json" )
1099+ with tf .gfile .GFile (output_eval_file2 , "w" ) as writer :
1100+ json .dump (result , writer , indent = 4 , cls = ExtEncoder )
1101+
10531102 if FLAGS .do_predict :
10541103 predict_examples = processor .get_test_examples (FLAGS .data_dir )
10551104 num_actual_predict_examples = len (predict_examples )
@@ -1097,11 +1146,6 @@ def main(_):
10971146 num_written_lines += 1
10981147 assert num_written_lines == num_actual_predict_examples
10991148
1100- if FLAGS .do_train and FLAGS .save_for_serving :
1101- serving_dir = os .path .join (FLAGS .output_dir , 'serving' )
1102- is_tpu_estimator = not FLAGS .use_gpu or int (FLAGS .num_gpu_cores ) < 2
1103- save_for_serving (estimator , serving_dir , FLAGS .max_seq_length , is_tpu_estimator )
1104-
11051149
11061150if __name__ == "__main__" :
11071151 flags .mark_flag_as_required ("data_dir" )
0 commit comments