Skip to content

Commit 11db102

Browse files
committed
Support feature cache in squad_eval.py too.
1 parent d914775 commit 11db102

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

v0.7/language/bert/squad_eval.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -331,31 +331,45 @@ def main():
331331
parser.add_argument("--val_data", default="build/data/dev-v1.1.json", help="Path to validation data")
332332
parser.add_argument("--log_file", default="build/logs/mlperf_log_accuracy.json", help="Path to loadge accuracy log")
333333
parser.add_argument("--out_file", default="build/result/predictions.json", help="Path to output prediction file")
334+
parser.add_argument("--features_cache_file", default="eval_features.pickle", help="Path to features' cache file")
334335
parser.add_argument("--output_transposed", action="store_true", help="Transpose the output")
335336
args = parser.parse_args()
336337

337-
print("Creating tokenizer...")
338-
tokenizer = BertTokenizer(args.vocab_file)
339-
340-
print("Reading examples...")
341-
eval_examples = read_squad_examples(
342-
input_file=args.val_data, is_training=False,
343-
version_2_with_negative=False)
344-
345-
print("Converting examples to features...")
346338
eval_features = []
347-
def append_feature(feature):
348-
eval_features.append(feature)
349-
350-
convert_examples_to_features(
351-
examples=eval_examples,
352-
tokenizer=tokenizer,
353-
max_seq_length=max_seq_length,
354-
doc_stride=doc_stride,
355-
max_query_length=max_query_length,
356-
is_training=False,
357-
output_fn=append_feature,
358-
verbose_logging=False)
339+
# Load features if cached, convert from examples otherwise.
340+
cache_path = args.features_cache_file
341+
if os.path.exists(cache_path):
342+
print("Loading cached features from '%s'..." % cache_path)
343+
with open(cache_path, 'rb') as cache_file:
344+
eval_features = pickle.load(cache_file)
345+
else:
346+
print("No cached features at '%s'... converting from examples..." % cache_path)
347+
348+
print("Creating tokenizer...")
349+
tokenizer = BertTokenizer(args.vocab_file)
350+
351+
print("Reading examples...")
352+
eval_examples = read_squad_examples(
353+
input_file=args.val_data, is_training=False,
354+
version_2_with_negative=False)
355+
356+
print("Converting examples to features...")
357+
def append_feature(feature):
358+
eval_features.append(feature)
359+
360+
convert_examples_to_features(
361+
examples=eval_examples,
362+
tokenizer=tokenizer,
363+
max_seq_length=max_seq_length,
364+
doc_stride=doc_stride,
365+
max_query_length=max_query_length,
366+
is_training=False,
367+
output_fn=append_feature,
368+
verbose_logging=False)
369+
370+
print("Caching features at '%s'..." % cache_path)
371+
with open(cache_path, 'wb') as cache_file:
372+
pickle.dump(eval_features, cache_file)
359373

360374
print("Loading loadgen logs...")
361375
results = load_loadgen_log(args.log_file, eval_features, args.output_transposed)

0 commit comments

Comments
 (0)