Skip to content

Commit d452612

Browse files
Putting nampledtuple in global scope
1 parent 68d5076 commit d452612

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

create_pretraining_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def create_instances_from_document(
332332
return instances
333333

334334

335+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
336+
["index", "label"])
337+
338+
335339
def create_masked_lm_predictions(tokens, masked_lm_prob,
336340
max_predictions_per_seq, vocab_words, rng):
337341
"""Creates the predictions for the masked LM objective."""
@@ -346,8 +350,6 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
346350

347351
output_tokens = list(tokens)
348352

349-
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
350-
351353
num_to_predict = min(max_predictions_per_seq,
352354
max(1, int(round(len(tokens) * masked_lm_prob))))
353355

@@ -374,7 +376,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
374376

375377
output_tokens[index] = masked_token
376378

377-
masked_lms.append(masked_lm(index=index, label=tokens[index]))
379+
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
378380

379381
masked_lms = sorted(masked_lms, key=lambda x: x.index)
380382

0 commit comments

Comments
 (0)