Skip to content

unit testing for experimental vocab #1301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def test_vocab_insert_token(self):

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
with self.assertRaises(RuntimeError) as context:
v.insert_token('b', 0)

self.assertTrue("Token b already exists in the Vocab with index: 0" in str(context.exception))

def test_vocab_append_token(self):
c = OrderedDict({'a': 2})
Expand All @@ -88,6 +92,11 @@ def test_vocab_append_token(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

with self.assertRaises(RuntimeError) as context:
v.append_token('b')

self.assertTrue("Token b already exists in the Vocab with index: 2" in str(context.exception))

def test_vocab_len(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand Down
18 changes: 16 additions & 2 deletions torchtext/csrc/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ bool Vocab::__contains__(const c10::string_view &token) const {
return false;
}


int64_t Vocab::__getitem__(const c10::string_view &token) const {
int64_t id = _find(token);
if (stoi_[id] != -1) {
Expand All @@ -47,7 +46,22 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const {
return unk_index_;
}

void Vocab::append_token(const std::string &token) { _add(token); }
void Vocab::append_token(const std::string &token) {
// if item already in stoi we throw an error
auto token_position = _find(c10::string_view{token.data(), token.size()});
if (stoi_[token_position] != -1) {
#ifdef _MSC_VER
std::cerr << "[RuntimeError] Token " << token
<< " already exists in the Vocab with index: "
<< stoi_[token_position] << std::endl;
#endif
throw std::runtime_error("Token " + token +
" already exists in the Vocab with index: " +
std::to_string(stoi_[token_position]) + ".");
}

_add(token);
}

void Vocab::insert_token(const std::string &token, const int64_t &index) {
if (index < 0 || index > itos_.size()) {
Expand Down
2 changes: 1 addition & 1 deletion torchtext/csrc/vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct Vocab : torch::CustomClassHolder {
uint32_t _find(const c10::string_view &w) const {
uint32_t stoi_size = stoi_.size();
uint32_t id = _hash(w) % stoi_size;
while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) {
while (stoi_[id] != -1 && itos_[stoi_[id]] != w) {
id = (id + 1) % stoi_size;
}
return id;
Expand Down
3 changes: 3 additions & 0 deletions torchtext/experimental/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def append_token(self, token: str) -> None:
r"""
Args:
token (str): the token used to lookup the corresponding index.

Raises:
RuntimeError: if token already exists in the vocab
"""
self.vocab.append_token(token)

Expand Down