Skip to content

Commit 2860152

Browse files
committed
use distributed strategy in PREDICTION mode
1 parent e29e77d commit 2860152

File tree

1 file changed

+66
-61
lines changed

1 file changed

+66
-61
lines changed

run_custom_classifier.py

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,9 @@ def tpu_scaffold():
729729
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
730730
init_string)
731731

732+
is_multi_gpu = use_gpu and int(num_gpu_cores) >= 2
732733
if mode == tf.estimator.ModeKeys.TRAIN:
733-
if use_gpu and int(num_gpu_cores) >= 2:
734+
if is_multi_gpu:
734735
train_op = custom_optimization.create_optimizer(
735736
total_loss, learning_rate, num_train_steps, num_warmup_steps, fp16=fp16)
736737
output_spec = tf.estimator.EstimatorSpec(
@@ -777,11 +778,15 @@ def metric_fn(per_example_loss, label_ids, logits, is_real_example):
777778
eval_metrics=eval_metrics,
778779
scaffold_fn=scaffold_fn)
779780
else:
780-
# predict on single-gpu only
781-
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
782-
mode=mode,
783-
predictions={"probabilities": probabilities},
784-
scaffold_fn=scaffold_fn)
781+
if is_multi_gpu:
782+
output_spec = tf.estimator.EstimatorSpec(
783+
mode=mode,
784+
predictions={"probabilities": probabilities})
785+
else:
786+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
787+
mode=mode,
788+
predictions={"probabilities": probabilities},
789+
scaffold_fn=scaffold_fn)
785790

786791
return output_spec
787792

@@ -947,45 +952,47 @@ def main(_):
947952
num_shards=FLAGS.num_tpu_cores,
948953
per_host_input_for_training=is_per_host))
949954

950-
num_train_steps = 0
951-
num_warmup_steps = 0
952-
init_checkpoint = FLAGS.init_checkpoint
953-
is_multi_gpu = FLAGS.use_gpu and int(FLAGS.num_gpu_cores) >= 2
955+
train_examples = None
956+
num_train_steps = None
957+
num_warmup_steps = None
954958
if FLAGS.do_train:
955959
train_examples = processor.get_train_examples(FLAGS.data_dir)
956960
num_train_steps = int(
957961
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
958962
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
959963

960-
model_fn = model_fn_builder(
961-
bert_config=bert_config,
962-
num_labels=len(label_list),
963-
init_checkpoint=init_checkpoint,
964-
learning_rate=FLAGS.learning_rate,
965-
num_train_steps=num_train_steps,
966-
num_warmup_steps=num_warmup_steps,
967-
use_tpu=FLAGS.use_tpu,
968-
use_one_hot_embeddings=FLAGS.use_tpu,
969-
use_gpu=FLAGS.use_gpu,
970-
num_gpu_cores=FLAGS.num_gpu_cores,
971-
fp16=FLAGS.use_fp16)
964+
init_checkpoint = FLAGS.init_checkpoint
965+
is_multi_gpu = FLAGS.use_gpu and int(FLAGS.num_gpu_cores) >= 2
966+
model_fn = model_fn_builder(
967+
bert_config=bert_config,
968+
num_labels=len(label_list),
969+
init_checkpoint=init_checkpoint,
970+
learning_rate=FLAGS.learning_rate,
971+
num_train_steps=num_train_steps,
972+
num_warmup_steps=num_warmup_steps,
973+
use_tpu=FLAGS.use_tpu,
974+
use_one_hot_embeddings=FLAGS.use_tpu,
975+
use_gpu=FLAGS.use_gpu,
976+
num_gpu_cores=FLAGS.num_gpu_cores,
977+
fp16=FLAGS.use_fp16)
972978

973-
# If TPU is not available, this will fall back to normal Estimator on CPU
974-
# or GPU.
975-
if is_multi_gpu:
976-
estimator = Estimator(
977-
model_fn=model_fn,
978-
params={},
979-
config=dist_run_config)
980-
else:
981-
estimator = tf.contrib.tpu.TPUEstimator(
982-
use_tpu=FLAGS.use_tpu,
983-
model_fn=model_fn,
984-
config=tpu_run_config,
985-
train_batch_size=FLAGS.train_batch_size,
986-
eval_batch_size=FLAGS.eval_batch_size,
987-
predict_batch_size=FLAGS.predict_batch_size)
979+
# If TPU is not available, this will fall back to normal Estimator on CPU
980+
# or GPU.
981+
if is_multi_gpu:
982+
estimator = Estimator(
983+
model_fn=model_fn,
984+
params={},
985+
config=dist_run_config)
986+
else:
987+
estimator = tf.contrib.tpu.TPUEstimator(
988+
use_tpu=FLAGS.use_tpu,
989+
model_fn=model_fn,
990+
config=tpu_run_config,
991+
train_batch_size=FLAGS.train_batch_size,
992+
eval_batch_size=FLAGS.eval_batch_size,
993+
predict_batch_size=FLAGS.predict_batch_size)
988994

995+
if FLAGS.do_train:
989996
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
990997
file_based_convert_examples_to_features(
991998
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
@@ -1011,32 +1018,30 @@ def main(_):
10111018
if filename.startswith('model.ckpt-'):
10121019
max_idx = max(int(filename.split('.')[1].split('-')[1]), max_idx)
10131020
init_checkpoint = os.path.join(FLAGS.output_dir, f'model.ckpt-{max_idx}')
1021+
tf.logging.info(f'Current checkpoint: {init_checkpoint}')
10141022

1015-
if not FLAGS.do_eval and not FLAGS.do_predict:
1016-
return
1017-
1018-
model_fn = model_fn_builder(
1019-
bert_config=bert_config,
1020-
num_labels=len(label_list),
1021-
init_checkpoint=init_checkpoint,
1022-
learning_rate=FLAGS.learning_rate,
1023-
num_train_steps=num_train_steps,
1024-
num_warmup_steps=num_warmup_steps,
1025-
use_tpu=FLAGS.use_tpu,
1026-
use_one_hot_embeddings=FLAGS.use_tpu,
1027-
use_gpu=FLAGS.use_gpu,
1028-
num_gpu_cores=FLAGS.num_gpu_cores,
1029-
fp16=FLAGS.use_fp16)
1023+
if FLAGS.do_eval:
1024+
model_fn = model_fn_builder(
1025+
bert_config=bert_config,
1026+
num_labels=len(label_list),
1027+
init_checkpoint=init_checkpoint,
1028+
learning_rate=FLAGS.learning_rate,
1029+
num_train_steps=num_train_steps,
1030+
num_warmup_steps=num_warmup_steps,
1031+
use_tpu=FLAGS.use_tpu,
1032+
use_one_hot_embeddings=FLAGS.use_tpu,
1033+
use_gpu=FLAGS.use_gpu,
1034+
num_gpu_cores=FLAGS.num_gpu_cores,
1035+
fp16=FLAGS.use_fp16)
10301036

1031-
estimator = tf.contrib.tpu.TPUEstimator(
1032-
use_tpu=FLAGS.use_tpu,
1033-
model_fn=model_fn,
1034-
config=tpu_run_config,
1035-
train_batch_size=FLAGS.train_batch_size,
1036-
eval_batch_size=FLAGS.eval_batch_size,
1037-
predict_batch_size=FLAGS.predict_batch_size)
1037+
eval_estimator = tf.contrib.tpu.TPUEstimator(
1038+
use_tpu=FLAGS.use_tpu,
1039+
model_fn=model_fn,
1040+
config=tpu_run_config,
1041+
train_batch_size=FLAGS.train_batch_size,
1042+
eval_batch_size=FLAGS.eval_batch_size,
1043+
predict_batch_size=FLAGS.predict_batch_size)
10381044

1039-
if FLAGS.do_eval:
10401045
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
10411046
num_actual_eval_examples = len(eval_examples)
10421047
if FLAGS.use_tpu:
@@ -1074,7 +1079,7 @@ def main(_):
10741079
drop_remainder=eval_drop_remainder,
10751080
batch_size=FLAGS.eval_batch_size)
10761081

1077-
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
1082+
result = eval_estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
10781083

10791084
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
10801085
with tf.gfile.GFile(output_eval_file, "w") as writer:

0 commit comments

Comments
 (0)