Skip to content

Custom vectors and refactor #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/imdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe


# Approach 1:
Expand All @@ -17,8 +18,7 @@
print('vars(train[0])', vars(train[0]))

# build the vocabulary

TEXT.build_vocab(train, vectors='glove.6B.300d')
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
LABEL.build_vocab(train)

# print vocab information
Expand Down
23 changes: 20 additions & 3 deletions test/sst.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe, CharNGram, FastText


# Approach 1:
Expand All @@ -18,8 +19,8 @@
print('vars(train[0])', vars(train[0]))

# build the vocabulary

TEXT.build_vocab(train, vectors='glove.6B.300d')
url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec'
TEXT.build_vocab(train, vectors=Vectors('wiki.simple.vec', url=url))
LABEL.build_vocab(train)

# print vocab information
Expand All @@ -36,13 +37,29 @@
print(batch.label)

# Approach 2:
TEXT.build_vocab(train, vectors=['glove.840B.300d', 'charngram.100d'])
TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram(), FastText()])
LABEL.build_vocab(train)

# print vocab information
print('len(TEXT.vocab)', len(TEXT.vocab))
print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size())

train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4)

# print batch information
batch = next(iter(train_iter))
print(batch.text)
print(batch.label)

# Approach 3:
f = FastText()
TEXT.build_vocab(train, vectors=f)
TEXT.vocab.extend(f)
LABEL.build_vocab(train)

# print vocab information
print('len(TEXT.vocab)', len(TEXT.vocab))
print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size())

train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4)

Expand Down
95 changes: 80 additions & 15 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from numpy.testing import assert_allclose
import torch
from torchtext import vocab
from torchtext.vocab import Vectors, FastText, GloVe, CharNGram

from .common.test_markers import slow
from .common.torchtext_test_case import TorchtextTestCase


def conditional_remove(f):
if os.path.isfile(f):
os.remove(f)


class TestVocab(TorchtextTestCase):

def test_vocab_basic(self):
Expand Down Expand Up @@ -44,7 +50,7 @@ def test_vocab_download_fasttext_vectors(self):
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors='fasttext.simple.300d')
vectors=FastText(language='simple'))

expected_itos = ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
Expand All @@ -67,16 +73,75 @@ def test_vocab_download_fasttext_vectors(self):
assert_allclose(vectors[v.stoi['OOV token']], np.zeros(300))
# Delete the vectors after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
os.remove(os.path.join(self.project_root, ".vector_cache",
"wiki.simple.vec"))
vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec")
conditional_remove(vec_file)

def test_vocab_extend(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
f = FastText(language='simple')
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors=f)
n_vocab = len(v)
v.extend(f) # extend the vocab with the words contained in f.itos
self.assertGreater(len(v), n_vocab)

self.assertEqual(v.itos[:6], ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
vectors = v.vectors.numpy()

# The first 5 entries in each vector.
expected_fasttext_simple_en = {
'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645],
'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165],
}

for word in expected_fasttext_simple_en:
assert_allclose(vectors[v.stoi[word], :5],
expected_fasttext_simple_en[word])

assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
# Delete the vectors after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec")
conditional_remove(vec_file)

def test_vocab_download_custom_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors=Vectors('wiki.simple.vec',
url=FastText.url_base.format('simple')))

self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
vectors = v.vectors.numpy()

# The first 5 entries in each vector.
expected_fasttext_simple_en = {
'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645],
'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165],
}

for word in expected_fasttext_simple_en:
assert_allclose(vectors[v.stoi[word], :5],
expected_fasttext_simple_en[word])

assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
# Delete the vectors after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec")
conditional_remove(vec_file)

@slow
def test_vocab_download_glove_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors='glove.twitter.27B.25d')
vectors=GloVe(name='twitter.27B', dim='25'))

expected_itos = ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
Expand All @@ -100,21 +165,20 @@ def test_vocab_download_glove_vectors(self):
assert_allclose(vectors[v.stoi['OOV token']], np.zeros(25))
# Delete the vectors after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
os.remove(os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.zip"))
zip_file = os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.zip")
conditional_remove(zip_file)
for dim in ["25", "50", "100", "200"]:
os.remove(os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.{}d.txt".format(dim)))
os.remove(os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.25d.pt"))
conditional_remove(os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.{}d.txt".format(dim)))

@slow
def test_vocab_download_charngram_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors='charngram.100d')
vectors=CharNGram())
expected_itos = ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
Expand All @@ -137,10 +201,11 @@ def test_vocab_download_charngram_vectors(self):
assert_allclose(vectors[v.stoi['OOV token']], np.zeros(100))
# Delete the vectors after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.txt"))
os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.pt"))
os.remove(os.path.join(self.project_root, ".vector_cache",
"jmt_pre-trained_embeddings.tar.gz"))
conditional_remove(
os.path.join(self.project_root, ".vector_cache", "charNgram.txt"))
conditional_remove(
os.path.join(self.project_root, ".vector_cache",
"jmt_pre-trained_embeddings.tar.gz"))

def test_serialization(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
Expand Down
6 changes: 3 additions & 3 deletions test/trec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe, CharNGram


# Approach 1:
Expand All @@ -17,8 +18,7 @@
print('vars(train[0])', vars(train[0]))

# build the vocabulary

TEXT.build_vocab(train, vectors='glove.6B.300d')
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
LABEL.build_vocab(train)

# print vocab information
Expand All @@ -35,7 +35,7 @@
print(batch.label)

# Approach 2:
TEXT.build_vocab(train, vectors=['glove.840B.300d', 'charngram.100d'])
TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram()])
LABEL.build_vocab(train)

train_iter, test_iter = datasets.TREC.iters(batch_size=4)
Expand Down
Loading