Skip to content

Add never_split feature to BERTTokenizer #1898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 108 additions & 30 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import OrderedDict
from typing import List, Optional
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -586,7 +587,9 @@ def test_clip_tokenizer_save_load_torchscript(self) -> None:


class TestBERTTokenizer(TorchtextTestCase):
def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_tokens: bool):
def _load_tokenizer(
self, test_scripting: bool, do_lower_case: bool, return_tokens: bool, never_split: Optional[List[str]] = None
):
if do_lower_case:
vocab_file = "bert_base_uncased_vocab.txt"
else:
Expand All @@ -596,46 +599,117 @@ def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_toke
vocab_path=get_asset_path(vocab_file),
do_lower_case=do_lower_case,
return_tokens=return_tokens,
never_split=never_split,
)
if test_scripting:
tokenizer = torch.jit.script(tokenizer)
return tokenizer

def _bert_tokenizer(self, tokenizer, do_lower_case):
def _bert_tokenizer(self, tokenizer, do_lower_case, never_split: Optional[List[str]] = None):
sample_texts = [
"Hello World!, how are you?",
"Hélló WoŕlḊ¿",
"Respublica superiorem",
"Avdija Vršajević în",
" \tHeLLo!how \n Are yoU? [UNK]",
"hi world [UNK] [CLS]",
"testing, [UNK] words! [SEP]",
]

if do_lower_case:
expected_tokens = [
["hello", "world", "!", ",", "how", "are", "you", "?"],
["hello", "world", "¿"],
["res", "##pu", "##bl", "##ica", "superior", "##em"],
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
]
expected_token_ids = [
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
["7592", "2088", "1094"],
["24501", "14289", "16558", "5555", "6020", "6633"],
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
]
if not never_split:
if do_lower_case:
expected_tokens = [
["hello", "world", "!", ",", "how", "are", "you", "?"],
["hello", "world", "¿"],
["res", "##pu", "##bl", "##ica", "superior", "##em"],
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
["hello", "!", "how", "are", "you", "?", "[", "un", "##k", "]"],
["hi", "world", "[", "un", "##k", "]", "[", "cl", "##s", "]"],
["testing", ",", "[", "un", "##k", "]", "words", "!", "[", "sep", "]"],
]
expected_token_ids = [
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
["7592", "2088", "1094"],
["24501", "14289", "16558", "5555", "6020", "6633"],
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
["7592", "999", "2129", "2024", "2017", "1029", "1031", "4895", "2243", "1033"],
["7632", "2088", "1031", "4895", "2243", "1033", "1031", "18856", "2015", "1033"],
["5604", "1010", "1031", "4895", "2243", "1033", "2616", "999", "1031", "19802", "1033"],
]

else:
expected_tokens = [
["Hello", "World", "!", ",", "how", "are", "you", "?"],
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[", "UN", "##K", "]"],
["hi", "world", "[", "UN", "##K", "]", "[", "C", "##LS", "]"],
["testing", ",", "[", "UN", "##K", "]", "words", "!", "[", "SE", "##P", "]"],
]
expected_token_ids = [
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
["145", "2744", "2339", "7774", "100", "225"],
["11336", "20080", "10354", "9538", "7298", "5521"],
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
[
"1124",
"23955",
"1186",
"106",
"1293",
"2372",
"26063",
"2591",
"136",
"164",
"7414",
"2428",
"166",
],
["20844", "1362", "164", "7414", "2428", "166", "164", "140", "15928", "166"],
["5193", "117", "164", "7414", "2428", "166", "1734", "106", "164", "12342", "2101", "166"],
]
else:
expected_tokens = [
["Hello", "World", "!", ",", "how", "are", "you", "?"],
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
]
expected_token_ids = [
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
["145", "2744", "2339", "7774", "100", "225"],
["11336", "20080", "10354", "9538", "7298", "5521"],
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
]
if do_lower_case:
expected_tokens = [
["hello", "world", "!", ",", "how", "are", "you", "?"],
["hello", "world", "¿"],
["res", "##pu", "##bl", "##ica", "superior", "##em"],
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
["hello", "!", "how", "are", "you", "?", "[UNK]"],
["hi", "world", "[UNK]", "[CLS]"],
["testing", ",", "[UNK]", "words", "!", "[", "sep", "]"],
]
expected_token_ids = [
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
["7592", "2088", "1094"],
["24501", "14289", "16558", "5555", "6020", "6633"],
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
["7592", "999", "2129", "2024", "2017", "1029", "100"],
["7632", "2088", "100", "101"],
["5604", "1010", "100", "2616", "999", "1031", "19802", "1033"],
]

else:
expected_tokens = [
["Hello", "World", "!", ",", "how", "are", "you", "?"],
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[UNK]"],
["hi", "world", "[UNK]", "[CLS]"],
["testing", ",", "[UNK]", "words", "!", "[", "SE", "##P", "]"],
]
expected_token_ids = [
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
["145", "2744", "2339", "7774", "100", "225"],
["11336", "20080", "10354", "9538", "7298", "5521"],
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
["1124", "23955", "1186", "106", "1293", "2372", "26063", "2591", "136", "100"],
["20844", "1362", "100", "101"],
["5193", "117", "100", "1734", "106", "164", "12342", "2101", "166"],
]

# test batch of sentences
if tokenizer._return_tokens:
Expand All @@ -650,14 +724,18 @@ def _bert_tokenizer(self, tokenizer, do_lower_case):
else:
self.assertEqual(tokenizer(txt), expected_token_ids[idx])

@nested_params([True, False], [True, False], [True, False])
def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens):
@nested_params([True, False], [True, False], [True, False], [[], None, ["[UNK]", "[CLS]"]])
def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens, never_split):
"""test tokenization on single sentence input as well as batch on sentences"""
self._bert_tokenizer(
self._load_tokenizer(
test_scripting=test_scripting, do_lower_case=do_lower_case, return_tokens=return_tokens
test_scripting=test_scripting,
do_lower_case=do_lower_case,
return_tokens=return_tokens,
never_split=never_split,
),
do_lower_case=do_lower_case,
never_split=never_split,
)

@nested_params([True, False], [True, False], [True, False])
Expand Down
121 changes: 73 additions & 48 deletions torchtext/csrc/bert_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,52 @@ static std::string _convert_from_unicode(const UString& text) {
return ret;
}

static void to_lower(UString& text) {
for (size_t i = 0; i < text.size(); i++) {
text[i] = utf8proc_tolower(text[i]);
static void to_lower(UString& token) {
for (size_t i = 0; i < token.size(); i++) {
token[i] = utf8proc_tolower(token[i]);
}
}

BERTEncoder::BERTEncoder(
const std::string& vocab_file,
bool do_lower_case,
c10::optional<bool> strip_accents)
c10::optional<bool> strip_accents,
std::vector<std::string> never_split)
: vocab_{_read_vocab(vocab_file)},
do_lower_case_{do_lower_case},
strip_accents_{strip_accents} {}
strip_accents_{strip_accents},
never_split_{never_split} {
never_split_set_.insert(never_split_.begin(), never_split_.end());
}

BERTEncoder::BERTEncoder(
Vocab vocab,
bool do_lower_case,
c10::optional<bool> strip_accents)
c10::optional<bool> strip_accents,
std::vector<std::string> never_split)
: vocab_{vocab},
do_lower_case_{do_lower_case},
strip_accents_{strip_accents} {}
strip_accents_{strip_accents},
never_split_{never_split} {
never_split_set_.insert(never_split_.begin(), never_split_.end());
}

UString BERTEncoder::_clean(const UString& text, bool strip_accents) {
UString BERTEncoder::_clean(
const UString& token,
bool strip_accents,
bool is_never_split_token) {
/* This function combines:
* cleaning
* strip accents
*/
size_t len = text.size();
size_t len = token.size();
UString ret;
for (size_t i = 0; i < len; i++) {
uint32_t c = text[i];
if (c == 0 || c == 0xFFFD || _is_control(c) ||
uint32_t c = token[i];
if (c == 0 || c == 0xFFFD || _is_control(c)) {
continue;
}
if ((!is_never_split_token) &&
(utf8proc_category(c) == UTF8PROC_CATEGORY_MN && strip_accents)) {
continue;
}
Expand Down Expand Up @@ -221,18 +235,20 @@ void BERTEncoder::_max_seg(
}
}

UString BERTEncoder::_basic_tokenize(const UString& text) {
UString BERTEncoder::_basic_tokenize(
const UString& token,
bool is_never_split_token) {
/*
This function enables white space based tokenization for following:
* chinese character
* punctuation
*/

UString ret;
size_t len = text.size();
size_t len = token.size();
for (size_t i = 0; i < len; i++) {
uint32_t c = text[i];
if (_is_chinese_char(c) || _is_punct_char(c)) {
uint32_t c = token[i];
if (_is_chinese_char(c) || (_is_punct_char(c) && !is_never_split_token)) {
if (!ret.empty() && ret.back() != ' ') {
ret.append(1, ' ');
}
Expand All @@ -254,51 +270,56 @@ UString BERTEncoder::_basic_tokenize(const UString& text) {

std::vector<std::string> BERTEncoder::Tokenize(std::string text) {
std::vector<std::string> results;
std::vector<std::string> interim_results;
std::vector<std::string> tokens;

// normalize
// split based on whitespace
split_(text, tokens);

bool strip_accents = do_lower_case_;
for (auto& token : tokens) {
bool is_never_split_token =
never_split_set_.find(token) != never_split_set_.end();

if (strip_accents_.has_value()) {
strip_accents = strip_accents_.has_value();
}
// normalize

if (strip_accents) {
char* nfkcstr = reinterpret_cast<char*>(
utf8proc_NFD(reinterpret_cast<const unsigned char*>(text.c_str())));
if (nfkcstr == nullptr) {
return {};
}
bool strip_accents = do_lower_case_;

text.assign(nfkcstr, strlen(nfkcstr));
if (strip_accents_.has_value()) {
strip_accents = strip_accents_.has_value();
}

free(nfkcstr);
}
if (strip_accents) {
char* nfkcstr = reinterpret_cast<char*>(
utf8proc_NFD(reinterpret_cast<const unsigned char*>(token.c_str())));
if (nfkcstr == nullptr) {
return {};
}

// convert to unicode codepoints
UString unicodes = _convert_to_unicode(text);
token.assign(nfkcstr, strlen(nfkcstr));

// clean -> invalid character removal, whitespce cleanup, strip accents
unicodes = _clean(unicodes, strip_accents);
free(nfkcstr);
}

// Add whitespace in front/back of tokens to enable splitting based on
// white-space Enables tokenization on chinese characters, Punctuations
unicodes = _basic_tokenize(unicodes);
// convert to unicode codepoints
UString unicodes = _convert_to_unicode(token);

// Convert text to lower-case
if (do_lower_case_)
to_lower(unicodes);
// clean -> invalid character removal, whitespce cleanup, strip accents
unicodes = _clean(unicodes, strip_accents, is_never_split_token);

// Convert back to string from code-points
std::string newtext = _convert_from_unicode(unicodes);
// Add whitespace in front/back of tokens to enable splitting based on
// white-space Enables tokenization on chinese characters, Punctuations
unicodes = _basic_tokenize(unicodes, is_never_split_token);

std::vector<std::string> tokens;
// Convert token to lower-case
if (do_lower_case_ && !is_never_split_token)
to_lower(unicodes);

// split based on whitespace
split_(newtext, tokens);
// Convert back to string from code-points
split_(_convert_from_unicode(unicodes), interim_results);
}

// Perform WORDPIECE tokenization
for (auto s : tokens) {
for (auto s : interim_results) {
if (s.size() > kMaxCharsPerWords) {
results.push_back(kUnkToken);
} else {
Expand Down Expand Up @@ -338,16 +359,20 @@ std::vector<std::vector<int64_t>> BERTEncoder::BatchEncode(
BERTEncoderStates _serialize_bert_encoder(
const c10::intrusive_ptr<BERTEncoder>& self) {
return std::make_tuple(
self->do_lower_case_, self->strip_accents_, self->vocab_.itos_);
self->do_lower_case_,
self->strip_accents_,
self->never_split_,
self->vocab_.itos_);
}

c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(
BERTEncoderStates states) {
auto do_lower_case = std::get<0>(states);
auto strip_accents = std::get<1>(states);
auto strings = std::get<2>(states);
auto never_split = std::get<2>(states);
auto strings = std::get<3>(states);
return c10::make_intrusive<BERTEncoder>(
Vocab(std::move(strings)), do_lower_case, strip_accents);
Vocab(std::move(strings)), do_lower_case, strip_accents, never_split);
}

} // namespace torchtext
Loading