diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index 76f84b66aa..995f3585d6 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -1,5 +1,6 @@ import os from collections import OrderedDict +from typing import List, Optional from unittest.mock import patch import torch @@ -586,7 +587,9 @@ def test_clip_tokenizer_save_load_torchscript(self) -> None: class TestBERTTokenizer(TorchtextTestCase): - def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_tokens: bool): + def _load_tokenizer( + self, test_scripting: bool, do_lower_case: bool, return_tokens: bool, never_split: Optional[List[str]] = None + ): if do_lower_case: vocab_file = "bert_base_uncased_vocab.txt" else: @@ -596,46 +599,117 @@ def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_toke vocab_path=get_asset_path(vocab_file), do_lower_case=do_lower_case, return_tokens=return_tokens, + never_split=never_split, ) if test_scripting: tokenizer = torch.jit.script(tokenizer) return tokenizer - def _bert_tokenizer(self, tokenizer, do_lower_case): + def _bert_tokenizer(self, tokenizer, do_lower_case, never_split: Optional[List[str]] = None): sample_texts = [ "Hello World!, how are you?", "Hélló WoŕlḊ¿", "Respublica superiorem", "Avdija Vršajević în", + " \tHeLLo!how \n Are yoU? [UNK]", + "hi world [UNK] [CLS]", + "testing, [UNK] words! [SEP]", ] - if do_lower_case: - expected_tokens = [ - ["hello", "world", "!", ",", "how", "are", "you", "?"], - ["hello", "world", "¿"], - ["res", "##pu", "##bl", "##ica", "superior", "##em"], - ["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"], - ] - expected_token_ids = [ - ["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"], - ["7592", "2088", "1094"], - ["24501", "14289", "16558", "5555", "6020", "6633"], - ["20704", "4305", "3900", "27830", "3736", "26782", "1999"], - ] + if not never_split: + if do_lower_case: + expected_tokens = [ + ["hello", "world", "!", ",", "how", "are", "you", "?"], + ["hello", "world", "¿"], + ["res", "##pu", "##bl", "##ica", "superior", "##em"], + ["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"], + ["hello", "!", "how", "are", "you", "?", "[", "un", "##k", "]"], + ["hi", "world", "[", "un", "##k", "]", "[", "cl", "##s", "]"], + ["testing", ",", "[", "un", "##k", "]", "words", "!", "[", "sep", "]"], + ] + expected_token_ids = [ + ["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"], + ["7592", "2088", "1094"], + ["24501", "14289", "16558", "5555", "6020", "6633"], + ["20704", "4305", "3900", "27830", "3736", "26782", "1999"], + ["7592", "999", "2129", "2024", "2017", "1029", "1031", "4895", "2243", "1033"], + ["7632", "2088", "1031", "4895", "2243", "1033", "1031", "18856", "2015", "1033"], + ["5604", "1010", "1031", "4895", "2243", "1033", "2616", "999", "1031", "19802", "1033"], + ] + else: + expected_tokens = [ + ["Hello", "World", "!", ",", "how", "are", "you", "?"], + ["H", "##é", "##ll", "##ó", "[UNK]", "¿"], + ["Re", "##sp", "##ub", "##lica", "superior", "##em"], + ["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"], + ["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[", "UN", "##K", "]"], + ["hi", "world", "[", "UN", "##K", "]", "[", "C", "##LS", "]"], + ["testing", ",", "[", "UN", "##K", "]", "words", "!", "[", "SE", "##P", "]"], + ] + expected_token_ids = [ + ["8667", "1291", "106", "117", "1293", "1132", "1128", "136"], + ["145", "2744", "2339", "7774", "100", "225"], + ["11336", "20080", "10354", "9538", "7298", "5521"], + ["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"], + [ + "1124", + "23955", + "1186", + "106", + "1293", + "2372", + "26063", + "2591", + "136", + "164", + "7414", + "2428", + "166", + ], + ["20844", "1362", "164", "7414", "2428", "166", "164", "140", "15928", "166"], + ["5193", "117", "164", "7414", "2428", "166", "1734", "106", "164", "12342", "2101", "166"], + ] else: - expected_tokens = [ - ["Hello", "World", "!", ",", "how", "are", "you", "?"], - ["H", "##é", "##ll", "##ó", "[UNK]", "¿"], - ["Re", "##sp", "##ub", "##lica", "superior", "##em"], - ["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"], - ] - expected_token_ids = [ - ["8667", "1291", "106", "117", "1293", "1132", "1128", "136"], - ["145", "2744", "2339", "7774", "100", "225"], - ["11336", "20080", "10354", "9538", "7298", "5521"], - ["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"], - ] + if do_lower_case: + expected_tokens = [ + ["hello", "world", "!", ",", "how", "are", "you", "?"], + ["hello", "world", "¿"], + ["res", "##pu", "##bl", "##ica", "superior", "##em"], + ["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"], + ["hello", "!", "how", "are", "you", "?", "[UNK]"], + ["hi", "world", "[UNK]", "[CLS]"], + ["testing", ",", "[UNK]", "words", "!", "[", "sep", "]"], + ] + expected_token_ids = [ + ["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"], + ["7592", "2088", "1094"], + ["24501", "14289", "16558", "5555", "6020", "6633"], + ["20704", "4305", "3900", "27830", "3736", "26782", "1999"], + ["7592", "999", "2129", "2024", "2017", "1029", "100"], + ["7632", "2088", "100", "101"], + ["5604", "1010", "100", "2616", "999", "1031", "19802", "1033"], + ] + + else: + expected_tokens = [ + ["Hello", "World", "!", ",", "how", "are", "you", "?"], + ["H", "##é", "##ll", "##ó", "[UNK]", "¿"], + ["Re", "##sp", "##ub", "##lica", "superior", "##em"], + ["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"], + ["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[UNK]"], + ["hi", "world", "[UNK]", "[CLS]"], + ["testing", ",", "[UNK]", "words", "!", "[", "SE", "##P", "]"], + ] + expected_token_ids = [ + ["8667", "1291", "106", "117", "1293", "1132", "1128", "136"], + ["145", "2744", "2339", "7774", "100", "225"], + ["11336", "20080", "10354", "9538", "7298", "5521"], + ["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"], + ["1124", "23955", "1186", "106", "1293", "2372", "26063", "2591", "136", "100"], + ["20844", "1362", "100", "101"], + ["5193", "117", "100", "1734", "106", "164", "12342", "2101", "166"], + ] # test batch of sentences if tokenizer._return_tokens: @@ -650,14 +724,18 @@ def _bert_tokenizer(self, tokenizer, do_lower_case): else: self.assertEqual(tokenizer(txt), expected_token_ids[idx]) - @nested_params([True, False], [True, False], [True, False]) - def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens): + @nested_params([True, False], [True, False], [True, False], [[], None, ["[UNK]", "[CLS]"]]) + def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens, never_split): """test tokenization on single sentence input as well as batch on sentences""" self._bert_tokenizer( self._load_tokenizer( - test_scripting=test_scripting, do_lower_case=do_lower_case, return_tokens=return_tokens + test_scripting=test_scripting, + do_lower_case=do_lower_case, + return_tokens=return_tokens, + never_split=never_split, ), do_lower_case=do_lower_case, + never_split=never_split, ) @nested_params([True, False], [True, False], [True, False]) diff --git a/torchtext/csrc/bert_tokenizer.cpp b/torchtext/csrc/bert_tokenizer.cpp index 7c75995c9b..2242588d1e 100644 --- a/torchtext/csrc/bert_tokenizer.cpp +++ b/torchtext/csrc/bert_tokenizer.cpp @@ -132,38 +132,52 @@ static std::string _convert_from_unicode(const UString& text) { return ret; } -static void to_lower(UString& text) { - for (size_t i = 0; i < text.size(); i++) { - text[i] = utf8proc_tolower(text[i]); +static void to_lower(UString& token) { + for (size_t i = 0; i < token.size(); i++) { + token[i] = utf8proc_tolower(token[i]); } } BERTEncoder::BERTEncoder( const std::string& vocab_file, bool do_lower_case, - c10::optional strip_accents) + c10::optional strip_accents, + std::vector never_split) : vocab_{_read_vocab(vocab_file)}, do_lower_case_{do_lower_case}, - strip_accents_{strip_accents} {} + strip_accents_{strip_accents}, + never_split_{never_split} { + never_split_set_.insert(never_split_.begin(), never_split_.end()); +} BERTEncoder::BERTEncoder( Vocab vocab, bool do_lower_case, - c10::optional strip_accents) + c10::optional strip_accents, + std::vector never_split) : vocab_{vocab}, do_lower_case_{do_lower_case}, - strip_accents_{strip_accents} {} + strip_accents_{strip_accents}, + never_split_{never_split} { + never_split_set_.insert(never_split_.begin(), never_split_.end()); +} -UString BERTEncoder::_clean(const UString& text, bool strip_accents) { +UString BERTEncoder::_clean( + const UString& token, + bool strip_accents, + bool is_never_split_token) { /* This function combines: * cleaning * strip accents */ - size_t len = text.size(); + size_t len = token.size(); UString ret; for (size_t i = 0; i < len; i++) { - uint32_t c = text[i]; - if (c == 0 || c == 0xFFFD || _is_control(c) || + uint32_t c = token[i]; + if (c == 0 || c == 0xFFFD || _is_control(c)) { + continue; + } + if ((!is_never_split_token) && (utf8proc_category(c) == UTF8PROC_CATEGORY_MN && strip_accents)) { continue; } @@ -221,7 +235,9 @@ void BERTEncoder::_max_seg( } } -UString BERTEncoder::_basic_tokenize(const UString& text) { +UString BERTEncoder::_basic_tokenize( + const UString& token, + bool is_never_split_token) { /* This function enables white space based tokenization for following: * chinese character @@ -229,10 +245,10 @@ UString BERTEncoder::_basic_tokenize(const UString& text) { */ UString ret; - size_t len = text.size(); + size_t len = token.size(); for (size_t i = 0; i < len; i++) { - uint32_t c = text[i]; - if (_is_chinese_char(c) || _is_punct_char(c)) { + uint32_t c = token[i]; + if (_is_chinese_char(c) || (_is_punct_char(c) && !is_never_split_token)) { if (!ret.empty() && ret.back() != ' ') { ret.append(1, ' '); } @@ -254,51 +270,56 @@ UString BERTEncoder::_basic_tokenize(const UString& text) { std::vector BERTEncoder::Tokenize(std::string text) { std::vector results; + std::vector interim_results; + std::vector tokens; - // normalize + // split based on whitespace + split_(text, tokens); - bool strip_accents = do_lower_case_; + for (auto& token : tokens) { + bool is_never_split_token = + never_split_set_.find(token) != never_split_set_.end(); - if (strip_accents_.has_value()) { - strip_accents = strip_accents_.has_value(); - } + // normalize - if (strip_accents) { - char* nfkcstr = reinterpret_cast( - utf8proc_NFD(reinterpret_cast(text.c_str()))); - if (nfkcstr == nullptr) { - return {}; - } + bool strip_accents = do_lower_case_; - text.assign(nfkcstr, strlen(nfkcstr)); + if (strip_accents_.has_value()) { + strip_accents = strip_accents_.has_value(); + } - free(nfkcstr); - } + if (strip_accents) { + char* nfkcstr = reinterpret_cast( + utf8proc_NFD(reinterpret_cast(token.c_str()))); + if (nfkcstr == nullptr) { + return {}; + } - // convert to unicode codepoints - UString unicodes = _convert_to_unicode(text); + token.assign(nfkcstr, strlen(nfkcstr)); - // clean -> invalid character removal, whitespce cleanup, strip accents - unicodes = _clean(unicodes, strip_accents); + free(nfkcstr); + } - // Add whitespace in front/back of tokens to enable splitting based on - // white-space Enables tokenization on chinese characters, Punctuations - unicodes = _basic_tokenize(unicodes); + // convert to unicode codepoints + UString unicodes = _convert_to_unicode(token); - // Convert text to lower-case - if (do_lower_case_) - to_lower(unicodes); + // clean -> invalid character removal, whitespce cleanup, strip accents + unicodes = _clean(unicodes, strip_accents, is_never_split_token); - // Convert back to string from code-points - std::string newtext = _convert_from_unicode(unicodes); + // Add whitespace in front/back of tokens to enable splitting based on + // white-space Enables tokenization on chinese characters, Punctuations + unicodes = _basic_tokenize(unicodes, is_never_split_token); - std::vector tokens; + // Convert token to lower-case + if (do_lower_case_ && !is_never_split_token) + to_lower(unicodes); - // split based on whitespace - split_(newtext, tokens); + // Convert back to string from code-points + split_(_convert_from_unicode(unicodes), interim_results); + } // Perform WORDPIECE tokenization - for (auto s : tokens) { + for (auto s : interim_results) { if (s.size() > kMaxCharsPerWords) { results.push_back(kUnkToken); } else { @@ -338,16 +359,20 @@ std::vector> BERTEncoder::BatchEncode( BERTEncoderStates _serialize_bert_encoder( const c10::intrusive_ptr& self) { return std::make_tuple( - self->do_lower_case_, self->strip_accents_, self->vocab_.itos_); + self->do_lower_case_, + self->strip_accents_, + self->never_split_, + self->vocab_.itos_); } c10::intrusive_ptr _deserialize_bert_encoder( BERTEncoderStates states) { auto do_lower_case = std::get<0>(states); auto strip_accents = std::get<1>(states); - auto strings = std::get<2>(states); + auto never_split = std::get<2>(states); + auto strings = std::get<3>(states); return c10::make_intrusive( - Vocab(std::move(strings)), do_lower_case, strip_accents); + Vocab(std::move(strings)), do_lower_case, strip_accents, never_split); } } // namespace torchtext diff --git a/torchtext/csrc/bert_tokenizer.h b/torchtext/csrc/bert_tokenizer.h index 66ad101419..87b0eb2cf7 100644 --- a/torchtext/csrc/bert_tokenizer.h +++ b/torchtext/csrc/bert_tokenizer.h @@ -9,19 +9,26 @@ typedef std::basic_string UString; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; -// stores (do_lower_case, strip_accents, list of tokens in vocabulary) -typedef std::tuple, std::vector> +// stores (do_lower_case, strip_accents, never_split, list of tokens in +// vocabulary) +typedef std::tuple< + bool, + c10::optional, + std::vector, + std::vector> BERTEncoderStates; struct BERTEncoder : torch::CustomClassHolder { TORCHTEXT_API BERTEncoder( const std::string& vocab_file, bool do_lower_case, - c10::optional strip_accents); + c10::optional strip_accents, + std::vector never_split); BERTEncoder( Vocab vocab, bool do_lower_case, - c10::optional strip_accents); + c10::optional strip_accents, + std::vector never_split); TORCHTEXT_API std::vector Tokenize(std::string text); TORCHTEXT_API std::vector Encode(std::string text); TORCHTEXT_API std::vector> BatchTokenize( @@ -32,11 +39,16 @@ struct BERTEncoder : torch::CustomClassHolder { Vocab vocab_; bool do_lower_case_; c10::optional strip_accents_ = {}; + std::vector never_split_; + std::set never_split_set_; protected: - UString _clean(const UString& text, bool strip_accents); + UString _clean( + const UString& text, + bool strip_accents, + bool is_never_split_token); void _max_seg(const std::string& s, std::vector& results); - UString _basic_tokenize(const UString& text); + UString _basic_tokenize(const UString& token, bool is_never_split_token); void split_( const std::string& str, std::vector& tokens, diff --git a/torchtext/csrc/register_pybindings.cpp b/torchtext/csrc/register_pybindings.cpp index ce8a1297a7..80f3591cf3 100644 --- a/torchtext/csrc/register_pybindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -217,7 +217,11 @@ PYBIND11_MODULE(_torchtext, m) { })); py::class_>(m, "BERTEncoder") - .def(py::init>()) + .def(py::init< + const std::string, + bool, + c10::optional, + std::vector>()) .def("encode", &BERTEncoder::Encode) .def("tokenize", &BERTEncoder::Tokenize) .def( diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index 6b03cfea53..64427f12e4 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -174,7 +174,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }); m.class_("BERTEncoder") - .def(torch::init>()) + .def(torch::init< + const std::string, + bool, + c10::optional, + std::vector>()) .def("encode", &BERTEncoder::Encode) .def("tokenize", &BERTEncoder::Tokenize) .def( diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 84e93aa3cc..b2d90c88ac 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -564,21 +564,31 @@ class BERTTokenizer(Module): :type strip_accents: Optional[bool] :param return_tokens: Indicate whether to return tokens. If false, returns corresponding token IDs as strings (default: False) :type return_tokens: bool + :param never_split: Collection of tokens which will not be split during tokenization. (default: None) + :type never_split: Optional[List[str]] """ __jit_unused_properties__ = ["is_jitable"] def __init__( - self, vocab_path: str, do_lower_case: bool = True, strip_accents: Optional[bool] = None, return_tokens=False + self, + vocab_path: str, + do_lower_case: bool = True, + strip_accents: Optional[bool] = None, + return_tokens=False, + never_split: Optional[List[str]] = None, ) -> None: super().__init__() + if never_split is None: + never_split = [] self.bert_model = BERTEncoderPyBind( - get_asset_local_path(vocab_path, overwite=True), do_lower_case, strip_accents + get_asset_local_path(vocab_path, overwite=True), do_lower_case, strip_accents, never_split ) self._return_tokens = return_tokens self._vocab_path = vocab_path self._do_lower_case = do_lower_case self._strip_accents = strip_accents + self._never_split = never_split @property def is_jitable(self): @@ -654,7 +664,7 @@ def __prepare_scriptable__(self): if not self.is_jitable: tokenizer_copy = deepcopy(self) tokenizer_copy.bert_model = torch.classes.torchtext.BERTEncoder( - self._vocab_path, self._do_lower_case, self._strip_accents + self._vocab_path, self._do_lower_case, self._strip_accents, self._never_split ) return tokenizer_copy