@@ -331,31 +331,45 @@ def main():
331331 parser .add_argument ("--val_data" , default = "build/data/dev-v1.1.json" , help = "Path to validation data" )
332332 parser .add_argument ("--log_file" , default = "build/logs/mlperf_log_accuracy.json" , help = "Path to loadge accuracy log" )
333333 parser .add_argument ("--out_file" , default = "build/result/predictions.json" , help = "Path to output prediction file" )
334+ parser .add_argument ("--features_cache_file" , default = "eval_features.pickle" , help = "Path to features' cache file" )
334335 parser .add_argument ("--output_transposed" , action = "store_true" , help = "Transpose the output" )
335336 args = parser .parse_args ()
336337
337- print ("Creating tokenizer..." )
338- tokenizer = BertTokenizer (args .vocab_file )
339-
340- print ("Reading examples..." )
341- eval_examples = read_squad_examples (
342- input_file = args .val_data , is_training = False ,
343- version_2_with_negative = False )
344-
345- print ("Converting examples to features..." )
346338 eval_features = []
347- def append_feature (feature ):
348- eval_features .append (feature )
349-
350- convert_examples_to_features (
351- examples = eval_examples ,
352- tokenizer = tokenizer ,
353- max_seq_length = max_seq_length ,
354- doc_stride = doc_stride ,
355- max_query_length = max_query_length ,
356- is_training = False ,
357- output_fn = append_feature ,
358- verbose_logging = False )
339+ # Load features if cached, convert from examples otherwise.
340+ cache_path = args .features_cache_file
341+ if os .path .exists (cache_path ):
342+ print ("Loading cached features from '%s'..." % cache_path )
343+ with open (cache_path , 'rb' ) as cache_file :
344+ eval_features = pickle .load (cache_file )
345+ else :
346+ print ("No cached features at '%s'... converting from examples..." % cache_path )
347+
348+ print ("Creating tokenizer..." )
349+ tokenizer = BertTokenizer (args .vocab_file )
350+
351+ print ("Reading examples..." )
352+ eval_examples = read_squad_examples (
353+ input_file = args .val_data , is_training = False ,
354+ version_2_with_negative = False )
355+
356+ print ("Converting examples to features..." )
357+ def append_feature (feature ):
358+ eval_features .append (feature )
359+
360+ convert_examples_to_features (
361+ examples = eval_examples ,
362+ tokenizer = tokenizer ,
363+ max_seq_length = max_seq_length ,
364+ doc_stride = doc_stride ,
365+ max_query_length = max_query_length ,
366+ is_training = False ,
367+ output_fn = append_feature ,
368+ verbose_logging = False )
369+
370+ print ("Caching features at '%s'..." % cache_path )
371+ with open (cache_path , 'wb' ) as cache_file :
372+ pickle .dump (eval_features , cache_file )
359373
360374 print ("Loading loadgen logs..." )
361375 results = load_loadgen_log (args .log_file , eval_features , args .output_transposed )
0 commit comments