Skip to content

Adding support for batch input in BERT Tokenizer with perf benchmark #1745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions benchmark/benchmark_bert_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from argparse import ArgumentParser

from benchmark.utils import Timer
from tokenizers import Tokenizer as hf_tokenizer_lib
from torchtext.datasets import EnWik9
from torchtext.transforms import BERTTokenizer as tt_bert_tokenizer
from transformers import BertTokenizer as hf_bert_tokenizer_slow


VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"


def benchmark_bert_tokenizer(args):
tt_tokenizer = tt_bert_tokenizer(VOCAB_FILE, return_tokens=True)
hf_tokenizer_slow = hf_bert_tokenizer_slow.from_pretrained("bert-base-uncased")
hf_tokenizer_fast = hf_tokenizer_lib.from_pretrained("bert-base-uncased")
dp = EnWik9().header(args.num_samples)
samples = list(dp)

with Timer("Running TorchText BERT Tokenizer on non-batched input"):
for s in samples:
tt_tokenizer(s)

with Timer("Running HF BERT Tokenizer (slow) on non-batched input"):
for s in samples:
hf_tokenizer_slow.tokenize(s)

with Timer("Running HF BERT Tokenizer (fast) on non-batched input"):
for s in samples:
hf_tokenizer_fast.encode(s)

with Timer("Running TorchText BERT Tokenizer on batched input"):
tt_tokenizer(samples)

with Timer("Running HF BERT Tokenizer (fast) on batched input"):
hf_tokenizer_fast.encode_batch(samples)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--num-samples", default=1000, type=int)
benchmark_bert_tokenizer(parser.parse_args())
18 changes: 18 additions & 0 deletions torchtext/csrc/bert_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,24 @@ std::vector<int64_t> BERTEncoder::Encode(std::string text) {
return indices;
}

std::vector<std::vector<std::string>> BERTEncoder::BatchTokenize(
std::vector<std::string> text) {
std::vector<std::vector<std::string>> output;
for (const auto& t : text) {
output.push_back(Tokenize(t));
}
return output;
}

std::vector<std::vector<int64_t>> BERTEncoder::BatchEncode(
std::vector<std::string> text) {
std::vector<std::vector<int64_t>> output;
for (const auto& t : text) {
output.push_back(Encode(t));
}
return output;
}

BERTEncoderStates _serialize_bert_encoder(
const c10::intrusive_ptr<BERTEncoder>& self) {
return std::make_tuple(
Expand Down
4 changes: 4 additions & 0 deletions torchtext/csrc/bert_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ struct BERTEncoder : torch::CustomClassHolder {
c10::optional<bool> strip_accents);
std::vector<std::string> Tokenize(std::string text);
std::vector<int64_t> Encode(std::string text);
std::vector<std::vector<std::string>> BatchTokenize(
std::vector<std::string> text);
std::vector<std::vector<int64_t>> BatchEncode(std::vector<std::string> text);

Vocab vocab_;
bool do_lower_case_;
c10::optional<bool> strip_accents_ = {};
Expand Down
24 changes: 24 additions & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ PYBIND11_MODULE(_torchtext, m) {
.def(py::init<const std::string, bool, c10::optional<bool>>())
.def("encode", &BERTEncoder::Encode)
.def("tokenize", &BERTEncoder::Tokenize)
.def(
"batch_encode",
[](const c10::intrusive_ptr<BERTEncoder>& self,
const py::list& items) {
std::vector<std::string> input;
for (const auto& item : items) {
Py_ssize_t length;
const char* buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
input.push_back(std::string(buffer));
}
return self->BatchEncode(input);
})
.def(
"batch_tokenize",
[](const c10::intrusive_ptr<BERTEncoder>& self,
const py::list& items) {
std::vector<std::string> input;
for (const auto& item : items) {
Py_ssize_t length;
const char* buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
input.push_back(std::string(buffer));
}
return self->BatchTokenize(input);
})
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<BERTEncoder>& self) -> BERTEncoderStates {
Expand Down
12 changes: 12 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,18 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
.def(torch::init<const std::string, bool, c10::optional<bool>>())
.def("encode", &BERTEncoder::Encode)
.def("tokenize", &BERTEncoder::Tokenize)
.def(
"batch_encode",
[](const c10::intrusive_ptr<BERTEncoder>& self,
const std::vector<std::string>& items) {
return self->BatchEncode(items);
})
.def(
"batch_tokenize",
[](const c10::intrusive_ptr<BERTEncoder>& self,
const std::vector<std::string>& items) {
return self->BatchTokenize(items);
})
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<BERTEncoder>& self) -> BERTEncoderStates {
Expand Down
21 changes: 16 additions & 5 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,13 @@ def _encode(self, text: str) -> List[str]:
tokens_ids_str: List[str] = [str(token_id) for token_id in token_ids]
return tokens_ids_str

@torch.jit.export
def _batch_encode(self, text: List[str]) -> List[List[str]]:
"""Batch version of _encode i.e operate on list of str"""
token_ids: List[List[int]] = self.bert_model.batch_encode([t.strip() for t in text])
tokens_ids_str: List[List[str]] = [[str(t) for t in token_id] for token_id in token_ids]
return tokens_ids_str

@torch.jit.export
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text into a list of tokens
Expand All @@ -612,6 +619,11 @@ def _tokenize(self, text: str) -> List[str]:
"""
return self.bert_model.tokenize(text.strip())

@torch.jit.export
def _batch_tokenize(self, text: List[str]) -> List[List[str]]:
"""Batch version of _tokenize i.e operate on list of str"""
return self.bert_model.batch_tokenize([t.strip() for t in text])

def forward(self, input: Any) -> Any:
"""
:param input: Input sentence or list of sentences on which to apply tokenizer.
Expand All @@ -621,11 +633,10 @@ def forward(self, input: Any) -> Any:
"""
if torch.jit.isinstance(input, List[str]):
tokens: List[List[str]] = []
for text in input:
if self._return_tokens:
tokens.append(self._tokenize(text))
else:
tokens.append(self._encode(text))
if self._return_tokens:
tokens = self._batch_tokenize(input)
else:
tokens = self._batch_encode(input)
return tokens
elif torch.jit.isinstance(input, str):
if self._return_tokens:
Expand Down