Skip to content

Commit 7c1a4bf

Browse files
Replace the usage of FLAGS.bert_hub_module_handle with function argument to faciliate code reuse in colabs.
1 parent ffbda2a commit 7c1a4bf

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

run_classifier_with_tfhub.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@
3535

3636

3737
def create_model(is_training, input_ids, input_mask, segment_ids, labels,
38-
num_labels):
38+
num_labels, bert_hub_module_handle):
3939
"""Creates a classification model."""
4040
tags = set()
4141
if is_training:
4242
tags.add("train")
43-
bert_module = hub.Module(
44-
FLAGS.bert_hub_module_handle,
45-
tags=tags,
46-
trainable=True)
43+
bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True)
4744
bert_inputs = dict(
4845
input_ids=input_ids,
4946
input_mask=input_mask,
@@ -87,7 +84,7 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,
8784

8885

8986
def model_fn_builder(num_labels, learning_rate, num_train_steps,
90-
num_warmup_steps, use_tpu):
87+
num_warmup_steps, use_tpu, bert_hub_module_handle):
9188
"""Returns `model_fn` closure for TPUEstimator."""
9289

9390
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
@@ -105,7 +102,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
105102
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
106103

107104
(total_loss, per_example_loss, logits) = create_model(
108-
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels)
105+
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
106+
bert_hub_module_handle)
109107

110108
output_spec = None
111109
if mode == tf.estimator.ModeKeys.TRAIN:
@@ -140,10 +138,10 @@ def metric_fn(per_example_loss, label_ids, logits):
140138
return model_fn
141139

142140

143-
def create_tokenizer_from_hub_module():
141+
def create_tokenizer_from_hub_module(bert_hub_module_handle):
144142
"""Get the vocab file and casing info from the Hub module."""
145143
with tf.Graph().as_default():
146-
bert_module = hub.Module(FLAGS.bert_hub_module_handle)
144+
bert_module = hub.Module(bert_hub_module_handle)
147145
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
148146
with tf.Session() as sess:
149147
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
@@ -175,7 +173,7 @@ def main(_):
175173

176174
label_list = processor.get_labels()
177175

178-
tokenizer = create_tokenizer_from_hub_module()
176+
tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle)
179177

180178
tpu_cluster_resolver = None
181179
if FLAGS.use_tpu and FLAGS.tpu_name:
@@ -207,7 +205,8 @@ def main(_):
207205
learning_rate=FLAGS.learning_rate,
208206
num_train_steps=num_train_steps,
209207
num_warmup_steps=num_warmup_steps,
210-
use_tpu=FLAGS.use_tpu)
208+
use_tpu=FLAGS.use_tpu,
209+
bert_hub_module_handle=FLAGS.bert_hub_module_handle)
211210

212211
# If TPU is not available, this will fall back to normal Estimator on CPU
213212
# or GPU.

0 commit comments

Comments
 (0)