3535
3636
3737def create_model (is_training , input_ids , input_mask , segment_ids , labels ,
38- num_labels ):
38+ num_labels , bert_hub_module_handle ):
3939 """Creates a classification model."""
4040 tags = set ()
4141 if is_training :
4242 tags .add ("train" )
43- bert_module = hub .Module (
44- FLAGS .bert_hub_module_handle ,
45- tags = tags ,
46- trainable = True )
43+ bert_module = hub .Module (bert_hub_module_handle , tags = tags , trainable = True )
4744 bert_inputs = dict (
4845 input_ids = input_ids ,
4946 input_mask = input_mask ,
@@ -87,7 +84,7 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,
8784
8885
8986def model_fn_builder (num_labels , learning_rate , num_train_steps ,
90- num_warmup_steps , use_tpu ):
87+ num_warmup_steps , use_tpu , bert_hub_module_handle ):
9188 """Returns `model_fn` closure for TPUEstimator."""
9289
9390 def model_fn (features , labels , mode , params ): # pylint: disable=unused-argument
@@ -105,7 +102,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
105102 is_training = (mode == tf .estimator .ModeKeys .TRAIN )
106103
107104 (total_loss , per_example_loss , logits ) = create_model (
108- is_training , input_ids , input_mask , segment_ids , label_ids , num_labels )
105+ is_training , input_ids , input_mask , segment_ids , label_ids , num_labels ,
106+ bert_hub_module_handle )
109107
110108 output_spec = None
111109 if mode == tf .estimator .ModeKeys .TRAIN :
@@ -140,10 +138,10 @@ def metric_fn(per_example_loss, label_ids, logits):
140138 return model_fn
141139
142140
143- def create_tokenizer_from_hub_module ():
141+ def create_tokenizer_from_hub_module (bert_hub_module_handle ):
144142 """Get the vocab file and casing info from the Hub module."""
145143 with tf .Graph ().as_default ():
146- bert_module = hub .Module (FLAGS . bert_hub_module_handle )
144+ bert_module = hub .Module (bert_hub_module_handle )
147145 tokenization_info = bert_module (signature = "tokenization_info" , as_dict = True )
148146 with tf .Session () as sess :
149147 vocab_file , do_lower_case = sess .run ([tokenization_info ["vocab_file" ],
@@ -175,7 +173,7 @@ def main(_):
175173
176174 label_list = processor .get_labels ()
177175
178- tokenizer = create_tokenizer_from_hub_module ()
176+ tokenizer = create_tokenizer_from_hub_module (FLAGS . bert_hub_module_handle )
179177
180178 tpu_cluster_resolver = None
181179 if FLAGS .use_tpu and FLAGS .tpu_name :
@@ -207,7 +205,8 @@ def main(_):
207205 learning_rate = FLAGS .learning_rate ,
208206 num_train_steps = num_train_steps ,
209207 num_warmup_steps = num_warmup_steps ,
210- use_tpu = FLAGS .use_tpu )
208+ use_tpu = FLAGS .use_tpu ,
209+ bert_hub_module_handle = FLAGS .bert_hub_module_handle )
211210
212211 # If TPU is not available, this will fall back to normal Estimator on CPU
213212 # or GPU.
0 commit comments