From 58a2b72a931172e39b21256671d6854e2439c7d4 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 9 Sep 2017 09:12:23 +0000 Subject: [PATCH 1/5] custom vec paths; vocab extend; rm str wrapper for vecs --- torchtext/vocab.py | 80 ++++++++++++++-------------------------------- 1 file changed, 24 insertions(+), 56 deletions(-) diff --git a/torchtext/vocab.py b/torchtext/vocab.py index 19186f1147..dbfdfa491e 100644 --- a/torchtext/vocab.py +++ b/torchtext/vocab.py @@ -29,7 +29,7 @@ class Vocab(object): itos: A list of token strings indexed by their numerical identifiers. """ def __init__(self, counter, max_size=None, min_freq=1, specials=[''], - vectors=None, unk_init=torch.Tensor.zero_, expand_vocab=False): + vectors=None): """Create a Vocab object from a collections.Counter. Arguments: @@ -42,15 +42,9 @@ def __init__(self, counter, max_size=None, min_freq=1, specials=[''], specials: The list of special tokens (e.g., padding or eos) that will be prepended to the vocabulary in addition to an token. Default: [''] - vectors: one of the available pretrained vectors or a list with each - element one of the available pretrained vectors - (see Vocab.load_vectors). Default: None - unk_init (callback): by default, initialize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size. Default: torch.Tensor.zero_ - expand_vocab (bool): If True, expand vocabulary to include all - words for which the specified pretrained word vectors are - available. Default: False + vectors: one of either the available pretrained vectors + or custom pretrained vectors (see Vocab.load_vectors); + or a list of aforementioned vectors """ self.freqs = counter.copy() min_freq = max(min_freq, 1) @@ -76,7 +70,7 @@ def __init__(self, counter, max_size=None, min_freq=1, specials=[''], self.vectors = None if vectors is not None: - self.load_vectors(vectors, unk_init=unk_init, expand_vocab=expand_vocab) + self.load_vectors(vectors) def __eq__(self, other): if self.freqs != other.freqs: @@ -92,50 +86,19 @@ def __eq__(self, other): def __len__(self): return len(self.itos) - def load_vectors(self, vectors, unk_init=torch.Tensor.zero_, expand_vocab=False): + def extend(self, v, sort=True): + words = sorted(v.itos) if sort else v.itos + for w in words: + self.itos.append(w) + self.stoi[w] = len(self.itos) - 1 + + def load_vectors(self, vectors): """Arguments: - vectors: one of the available pretrained vectors or a list with each - element one of the available pretrained vectors: - charngram.100d - fasttext.en.300d - fasttext.simple.300d - glove.42B.300d - glove.840B.300d - glove.twitter.27B.25d - glove.twitter.27B.50d - glove.twitter.27B.100d - glove.twitter.27B.200d - glove.6B.50d - glove.6B.100d - glove.6B.200d - glove.6B.300d - unk_init (callback): by default, initialize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size. Default: torch.Tensor.zero_ - expand_vocab (bool): expand vocabulary to include all words for which - the specified pretrained word vectors are available + vectors: one of or a list containing instantiations of the + GloVe, CharNGram, or Vectors classes """ if not isinstance(vectors, list): vectors = [vectors] - vecs = [] - tot_dim = 0 - for v in vectors: - wv_type, _, rest = v.partition('.') - rest, _, wv_dim = rest.rpartition('.') - wv_dim = int(wv_dim[:-1]) - if wv_type == 'glove': - wv_name = rest - vecs.append(GloVe(name=wv_name, dim=wv_dim, unk_init=unk_init)) - if expand_vocab: - for w in sorted(vecs[-1].stoi.keys()): - self.itos.append(w) - self.stoi[w] = len(self.itos) - 1 - elif 'charngram' in v: - vecs.append(CharNGram(unk_init=unk_init)) - elif 'fasttext' in v: - wv_language = rest - vecs.append(FastText(language=wv_language, unk_init=unk_init)) - tot_dim += wv_dim self.vectors = torch.Tensor(len(self), tot_dim) for i, token in enumerate(self.itos): @@ -174,6 +137,11 @@ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): class Vectors(object): def __init__(self, unk_init=torch.Tensor.zero_): + """Arguments: + unk_init (callback): by default, initalize out-of-vocabulary word vectors + to zero vectors; can be any function that takes in a Tensor and + returns a Tensor of the same size + """ self.unk_init = unk_init def __getitem__(self, token): @@ -182,7 +150,7 @@ def __getitem__(self, token): else: return self.unk_init(torch.Tensor(1, self.dim)) - def vector_cache(self, url, root, fname): + def vector_cache(self, fname, root='.vector_cache', url=None): desc = fname fname = os.path.join(root, fname) fname_pt = fname + '.pt' @@ -190,11 +158,11 @@ def vector_cache(self, url, root, fname): desc = os.path.basename(fname) if not os.path.isfile(fname_pt): - dest = os.path.join(root, os.path.basename(url)) - if not os.path.isfile(fname_txt): + if not os.path.isfile(fname_txt) and url: logger.info('Downloading vectors from {}'.format(url)) if not os.path.exists(root): os.makedirs(root) + dest = os.path.join(root, os.path.basename(url)) with tqdm(unit='B', unit_scale=True, miniters=1, desc=desc) as t: urlretrieve(url, dest, reporthook=reporthook(t)) logger.info('Extracting vectors into {}'.format(root)) @@ -211,7 +179,7 @@ def vector_cache(self, url, root, fname): else: raise RuntimeError('unsupported compression format {}'.format(ext)) if not os.path.isfile(fname_txt): - raise RuntimeError('no vectors found') + raise RuntimeError('no vectors found at {}'.format(fname_txt)) # str call is necessary for Python 2/3 compatibility, since # argument must be Python 2 str (Python 3 bytes) or @@ -224,7 +192,7 @@ def vector_cache(self, url, root, fname): with io.open(fname_txt, encoding="utf8") as f: lines = [line for line in f] # If there are malformed lines, read in binary mode - # and manually decode each word form utf-8 + # and manually decode each word from utf-8 except: logger.warning("Could not read {} as UTF8 file, " "reading file as bytes and skipping " From 4d63e0a2c4c2b957339db2eb8c3dd5020b7dce5b Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 9 Sep 2017 10:46:48 +0000 Subject: [PATCH 2/5] simplifying Vector inheritance --- test/test_vocab.py | 7 ++-- torchtext/vocab.py | 84 +++++++++++++++++++++------------------------- 2 files changed, 43 insertions(+), 48 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index a7e3c498c2..67436a4a3a 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -9,6 +9,7 @@ from numpy.testing import assert_allclose import torch from torchtext import vocab +from vocab import Vectors, FastText, GloVe, CharNGram from .common.test_markers import slow from .common.torchtext_test_case import TorchtextTestCase @@ -44,7 +45,7 @@ def test_vocab_download_fasttext_vectors(self): # Build a vocab and get vectors twice to test caching. for i in range(2): v = vocab.Vocab(c, min_freq=3, specials=['', ''], - vectors='fasttext.simple.300d') + vectors=FastText(language='simple') expected_itos = ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] @@ -76,7 +77,7 @@ def test_vocab_download_glove_vectors(self): # Build a vocab and get vectors twice to test caching. for i in range(2): v = vocab.Vocab(c, min_freq=3, specials=['', ''], - vectors='glove.twitter.27B.25d') + vectors=GloVe(name='twitter.27B', dim='25')) expected_itos = ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] @@ -114,7 +115,7 @@ def test_vocab_download_charngram_vectors(self): # Build a vocab and get vectors twice to test caching. for i in range(2): v = vocab.Vocab(c, min_freq=3, specials=['', ''], - vectors='charngram.100d') + vectors=CharNGram()) expected_itos = ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} diff --git a/torchtext/vocab.py b/torchtext/vocab.py index dbfdfa491e..42aa547a4e 100644 --- a/torchtext/vocab.py +++ b/torchtext/vocab.py @@ -136,13 +136,17 @@ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): class Vectors(object): - def __init__(self, unk_init=torch.Tensor.zero_): + def __init__(self, name, cache='.vector_cache', url=None, unk_init=torch.Tensor.zero_): """Arguments: + name: name of the file that contains the vectors + cache: directory for cached vectors + url: url for download if vectors not found in cache unk_init (callback): by default, initalize out-of-vocabulary word vectors to zero vectors; can be any function that takes in a Tensor and returns a Tensor of the same size """ self.unk_init = unk_init + self.cache(name, cache, url=url) def __getitem__(self, token): if token in self.stoi: @@ -150,36 +154,30 @@ def __getitem__(self, token): else: return self.unk_init(torch.Tensor(1, self.dim)) - def vector_cache(self, fname, root='.vector_cache', url=None): - desc = fname - fname = os.path.join(root, fname) - fname_pt = fname + '.pt' - fname_txt = fname + '.txt' - desc = os.path.basename(fname) + def cache(self, name, cache, url=None): + path = os.join([cache, name]) + path_pt = path + '.pt' - if not os.path.isfile(fname_pt): - if not os.path.isfile(fname_txt) and url: + if not os.path.isfile(path_pt): + if not os.path.isfile(path) and url: logger.info('Downloading vectors from {}'.format(url)) - if not os.path.exists(root): - os.makedirs(root) - dest = os.path.join(root, os.path.basename(url)) + if not os.path.exists(cache): + os.makedirs(cache) + dest = os.path.join(cache, os.path.basename(url)) with tqdm(unit='B', unit_scale=True, miniters=1, desc=desc) as t: urlretrieve(url, dest, reporthook=reporthook(t)) - logger.info('Extracting vectors into {}'.format(root)) + logger.info('Extracting vectors into {}'.format(cache)) ext = os.path.splitext(dest)[1][1:] if ext == 'zip': with zipfile.ZipFile(dest, "r") as zf: - zf.extractall(root) + zf.extractall(cache) elif ext == 'gz': with tarfile.open(dest, 'r:gz') as tar: - tar.extractall(path=root) - elif ext == 'vec' or ext == 'txt': - if dest != fname_txt: - shutil.copy(dest, fname_txt) - else: - raise RuntimeError('unsupported compression format {}'.format(ext)) - if not os.path.isfile(fname_txt): - raise RuntimeError('no vectors found at {}'.format(fname_txt)) + tar.extractall(path=cache) + elif dest != path: + shutil.copy(dest, path) + if not os.path.isfile(path): + raise RuntimeError('no vectors found at {}'.format(path)) # str call is necessary for Python 2/3 compatibility, since # argument must be Python 2 str (Python 3 bytes) or @@ -189,19 +187,19 @@ def vector_cache(self, fname, root='.vector_cache', url=None): # Try to read the whole file with utf-8 encoding. binary_lines = False try: - with io.open(fname_txt, encoding="utf8") as f: + with io.open(path, encoding="utf8") as f: lines = [line for line in f] # If there are malformed lines, read in binary mode # and manually decode each word from utf-8 except: logger.warning("Could not read {} as UTF8 file, " "reading file as bytes and skipping " - "words with malformed UTF8.".format(fname_txt)) - with open(fname_txt, 'rb') as f: + "words with malformed UTF8.".format(path)) + with open(path, 'rb') as f: lines = [line for line in f] binary_lines = True - logger.info("Loading vectors from {}".format(fname_txt)) + logger.info("Loading vectors from {}".format(path)) for line in tqdm(lines, total=len(lines)): # Explicitly splitting on " " is important, so we don't # get rid of Unicode non-breaking spaces in the vectors. @@ -232,15 +230,14 @@ def vector_cache(self, fname, root='.vector_cache', url=None): self.stoi = {word: i for i, word in enumerate(itos)} self.vectors = torch.Tensor(vectors).view(-1, dim) self.dim = dim - logger.info('Saving vectors to {}'.format(fname_pt)) - torch.save((self.stoi, self.vectors, self.dim), fname_pt) + logger.info('Saving vectors to {}'.format(path_pt)) + torch.save((self.stoi, self.vectors, self.dim), path_pt) else: - logger.info('Loading vectors from {}'.format(fname_pt)) - self.stoi, self.vectors, self.dim = torch.load(fname_pt) + logger.info('Loading vectors from {}'.format(path_pt)) + self.stoi, self.vectors, self.dim = torch.load(path_pt) class GloVe(Vectors): - url = { 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', @@ -248,13 +245,12 @@ class GloVe(Vectors): 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', } - def __init__(self, root='.vector_cache', name='840B', dim=300, **kwargs): - super(GloVe, self).__init__(**kwargs) - dim = str(dim) + 'd' - name = '.'.join(['glove', name]) - fname = name + '.' + dim - self.vector_cache(self.url[name], root, fname) - + def __init__(self, name='840B', dim=300, **kwargs): + name = 'glove.{}'.format(name) + url = self.url[name] + name = '{}.{}'.format(name, str(dim)+'d') + super(GloVe, self).__init__(name, url=url, **kwargs) + class FastText(Vectors): url = { @@ -264,21 +260,19 @@ class FastText(Vectors): 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec' } - def __init__(self, root='.vector_cache', language="en", **kwargs): - super(FastText, self).__init__(**kwargs) + def __init__(self, language="en", **kwargs): name = "fasttext.{}.300d".format(language) - self.vector_cache(self.url[name], root, name) + super(FastText, self).__init__(url=self.url[name], **kwargs) class CharNGram(Vectors): + name = 'charNgram.jmt.100d' url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 'jmt_pre-trained_embeddings.tar.gz') - filename = 'charNgram' - def __init__(self, root='.vector_cache', **kwargs): - super(CharNGram, self).__init__(**kwargs) - self.vector_cache(self.url, root, self.filename) + def __init__(self, **kwargs): + super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) def __getitem__(self, token): vector = torch.Tensor(1, self.dim).zero_() From 423238e0d3d117f3802f9ce7361227ac906f9435 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 9 Sep 2017 15:25:25 +0000 Subject: [PATCH 3/5] more simplification and testing --- test/imdb.py | 4 +-- test/sst.py | 23 ++++++++++++++-- test/test_vocab.py | 68 +++++++++++++++++++++++++++++++++++++++++++--- test/trec.py | 6 ++-- torchtext/vocab.py | 63 ++++++++++++++++++++---------------------- 5 files changed, 119 insertions(+), 45 deletions(-) diff --git a/test/imdb.py b/test/imdb.py index 595c0de488..dc7f2c3589 100644 --- a/test/imdb.py +++ b/test/imdb.py @@ -1,5 +1,6 @@ from torchtext import data from torchtext import datasets +from torchtext.vocab import GloVe, CharNGram, Vectors # Approach 1: @@ -17,8 +18,7 @@ print('vars(train[0])', vars(train[0])) # build the vocabulary - -TEXT.build_vocab(train, vectors='glove.6B.300d') +TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) LABEL.build_vocab(train) # print vocab information diff --git a/test/sst.py b/test/sst.py index 827fe59d90..aabdad78fc 100644 --- a/test/sst.py +++ b/test/sst.py @@ -1,5 +1,6 @@ from torchtext import data from torchtext import datasets +from torchtext.vocab import Vectors, GloVe, CharNGram, FastText # Approach 1: @@ -18,8 +19,8 @@ print('vars(train[0])', vars(train[0])) # build the vocabulary - -TEXT.build_vocab(train, vectors='glove.6B.300d') +TEXT.build_vocab(train, vectors=Vectors('wiki.simple.vec', + url='https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec')) LABEL.build_vocab(train) # print vocab information @@ -36,13 +37,29 @@ print(batch.label) # Approach 2: -TEXT.build_vocab(train, vectors=['glove.840B.300d', 'charngram.100d']) +TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram(), FastText()]) LABEL.build_vocab(train) # print vocab information print('len(TEXT.vocab)', len(TEXT.vocab)) print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) +train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4) + +# print batch information +batch = next(iter(train_iter)) +print(batch.text) +print(batch.label) + +# Approach 3: +f = FastText() +TEXT.build_vocab(train, vectors=f) +TEXT.vocab.extend(f) +LABEL.build_vocab(train) + +# print vocab information +print('len(TEXT.vocab)', len(TEXT.vocab)) +print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4) diff --git a/test/test_vocab.py b/test/test_vocab.py index 67436a4a3a..4731a58fb3 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -9,7 +9,7 @@ from numpy.testing import assert_allclose import torch from torchtext import vocab -from vocab import Vectors, FastText, GloVe, CharNGram +from torchtext.vocab import Vectors, FastText, GloVe, CharNGram from .common.test_markers import slow from .common.torchtext_test_case import TorchtextTestCase @@ -45,7 +45,7 @@ def test_vocab_download_fasttext_vectors(self): # Build a vocab and get vectors twice to test caching. for i in range(2): v = vocab.Vocab(c, min_freq=3, specials=['', ''], - vectors=FastText(language='simple') + vectors=FastText(language='simple')) expected_itos = ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] @@ -71,6 +71,66 @@ def test_vocab_download_fasttext_vectors(self): os.remove(os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec")) + def test_vocab_extend(self): + c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) + # Build a vocab and get vectors twice to test caching. + for i in range(2): + f = FastText(language='simple') + v = vocab.Vocab(c, min_freq=3, specials=['', ''], + vectors=f) + n_vocab = len(v) + v.extend(f) # extend the vocab with the words contained in f.itos + self.assertGreater(len(v), n_vocab) + + self.assertEqual(v.itos[:6], ['', '', '', + 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) + vectors = v.vectors.numpy() + + # The first 5 entries in each vector. + expected_fasttext_simple_en = { + 'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645], + 'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165], + } + + for word in expected_fasttext_simple_en: + assert_allclose(vectors[v.stoi[word], :5], + expected_fasttext_simple_en[word]) + + assert_allclose(vectors[v.stoi['']], np.zeros(300)) + # Delete the vectors after we're done to save disk space on CI + if os.environ["TRAVIS"] == "true": + os.remove(os.path.join(self.project_root, ".vector_cache", + "wiki.simple.vec")) + + def test_vocab_download_custom_vectors(self): + c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) + # Build a vocab and get vectors twice to test caching. + for i in range(2): + v = vocab.Vocab(c, min_freq=3, specials=['', ''], + vectors=Vectors('wiki.simple.vec', + url=FastText.url_base.format('simple'))) + + self.assertEqual(v.itos, ['', '', '', + 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) + vectors = v.vectors.numpy() + + # The first 5 entries in each vector. + expected_fasttext_simple_en = { + 'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645], + 'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165], + } + + for word in expected_fasttext_simple_en: + assert_allclose(vectors[v.stoi[word], :5], + expected_fasttext_simple_en[word]) + + assert_allclose(vectors[v.stoi['']], np.zeros(300)) + # Delete the vectors after we're done to save disk space on CI + if os.environ["TRAVIS"] == "true": + os.remove(os.path.join(self.project_root, ".vector_cache", + "wiki.simple.vec")) + + @slow def test_vocab_download_glove_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -105,7 +165,7 @@ def test_vocab_download_glove_vectors(self): "glove.twitter.27B.zip")) for dim in ["25", "50", "100", "200"]: os.remove(os.path.join(self.project_root, ".vector_cache", - "glove.twitter.27B.{}d.txt".format(dim))) + "glove.twitter.27B.{}d".format(dim))) os.remove(os.path.join(self.project_root, ".vector_cache", "glove.twitter.27B.25d.pt")) @@ -139,7 +199,7 @@ def test_vocab_download_charngram_vectors(self): # Delete the vectors after we're done to save disk space on CI if os.environ.get("TRAVIS") == "true": os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.txt")) - os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.pt")) + os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.txt.pt")) os.remove(os.path.join(self.project_root, ".vector_cache", "jmt_pre-trained_embeddings.tar.gz")) diff --git a/test/trec.py b/test/trec.py index 68a3b8a023..d31dfbfd0f 100644 --- a/test/trec.py +++ b/test/trec.py @@ -1,5 +1,6 @@ from torchtext import data from torchtext import datasets +from torchtext.vocab import GloVe, CharNGram # Approach 1: @@ -17,8 +18,7 @@ print('vars(train[0])', vars(train[0])) # build the vocabulary - -TEXT.build_vocab(train, vectors='glove.6B.300d') +TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) LABEL.build_vocab(train) # print vocab information @@ -35,7 +35,7 @@ print(batch.label) # Approach 2: -TEXT.build_vocab(train, vectors=['glove.840B.300d', 'charngram.100d']) +TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram()]) LABEL.build_vocab(train) train_iter, test_iter = datasets.TREC.iters(batch_size=4) diff --git a/torchtext/vocab.py b/torchtext/vocab.py index 42aa547a4e..3e646d4814 100644 --- a/torchtext/vocab.py +++ b/torchtext/vocab.py @@ -100,12 +100,13 @@ def load_vectors(self, vectors): if not isinstance(vectors, list): vectors = [vectors] + tot_dim = sum(v.dim for v in vectors) self.vectors = torch.Tensor(len(self), tot_dim) for i, token in enumerate(self.itos): start_dim = 0 - for j, v in enumerate(vectors): - end_dim = start_dim + vecs[j].dim - self.vectors[i][start_dim:end_dim] = vecs[j][token] + for v in vectors: + end_dim = start_dim + v.dim + self.vectors[i][start_dim:end_dim] = v[token] start_dim = end_dim assert(start_dim == tot_dim) @@ -137,14 +138,14 @@ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): class Vectors(object): def __init__(self, name, cache='.vector_cache', url=None, unk_init=torch.Tensor.zero_): - """Arguments: - name: name of the file that contains the vectors - cache: directory for cached vectors - url: url for download if vectors not found in cache - unk_init (callback): by default, initalize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size - """ + """Arguments: + name: name of the file that contains the vectors + cache: directory for cached vectors + url: url for download if vectors not found in cache + unk_init (callback): by default, initalize out-of-vocabulary word vectors + to zero vectors; can be any function that takes in a Tensor and + returns a Tensor of the same size + """ self.unk_init = unk_init self.cache(name, cache, url=url) @@ -155,7 +156,7 @@ def __getitem__(self, token): return self.unk_init(torch.Tensor(1, self.dim)) def cache(self, name, cache, url=None): - path = os.join([cache, name]) + path = os.path.join(cache, name) path_pt = path + '.pt' if not os.path.isfile(path_pt): @@ -164,8 +165,9 @@ def cache(self, name, cache, url=None): if not os.path.exists(cache): os.makedirs(cache) dest = os.path.join(cache, os.path.basename(url)) - with tqdm(unit='B', unit_scale=True, miniters=1, desc=desc) as t: - urlretrieve(url, dest, reporthook=reporthook(t)) + if not os.path.isfile(dest): + with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: + urlretrieve(url, dest, reporthook=reporthook(t)) logger.info('Extracting vectors into {}'.format(cache)) ext = os.path.splitext(dest)[1][1:] if ext == 'zip': @@ -174,8 +176,6 @@ def cache(self, name, cache, url=None): elif ext == 'gz': with tarfile.open(dest, 'r:gz') as tar: tar.extractall(path=cache) - elif dest != path: - shutil.copy(dest, path) if not os.path.isfile(path): raise RuntimeError('no vectors found at {}'.format(path)) @@ -227,47 +227,44 @@ def cache(self, name, cache, url=None): vectors.extend(float(x) for x in entries) itos.append(word) + self.itos = itos self.stoi = {word: i for i, word in enumerate(itos)} self.vectors = torch.Tensor(vectors).view(-1, dim) self.dim = dim logger.info('Saving vectors to {}'.format(path_pt)) - torch.save((self.stoi, self.vectors, self.dim), path_pt) + torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) else: logger.info('Loading vectors from {}'.format(path_pt)) - self.stoi, self.vectors, self.dim = torch.load(path_pt) + self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) class GloVe(Vectors): url = { - 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', - 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', - 'glove.twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', - 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', + '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', + '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', + 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', + '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', } def __init__(self, name='840B', dim=300, **kwargs): - name = 'glove.{}'.format(name) url = self.url[name] - name = '{}.{}'.format(name, str(dim)+'d') + name = 'glove.{}.{}d.txt'.format(name, str(dim)) super(GloVe, self).__init__(name, url=url, **kwargs) class FastText(Vectors): - url = { - 'fasttext.en.300d': - 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec', - 'fasttext.simple.300d': - 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec' - } + + url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' def __init__(self, language="en", **kwargs): - name = "fasttext.{}.300d".format(language) - super(FastText, self).__init__(url=self.url[name], **kwargs) + url = self.url_base.format(language) + name = os.path.basename(url) + super(FastText, self).__init__(name, url=url, **kwargs) class CharNGram(Vectors): - name = 'charNgram.jmt.100d' + name = 'charNgram.txt' url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 'jmt_pre-trained_embeddings.tar.gz') From cfb3d7d5d92d80368fd29c5df220e481d8ab745d Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 9 Sep 2017 18:54:11 +0000 Subject: [PATCH 4/5] bug in extend --- test/test_vocab.py | 2 +- torchtext/vocab.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index 4731a58fb3..6ab86b3735 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -165,7 +165,7 @@ def test_vocab_download_glove_vectors(self): "glove.twitter.27B.zip")) for dim in ["25", "50", "100", "200"]: os.remove(os.path.join(self.project_root, ".vector_cache", - "glove.twitter.27B.{}d".format(dim))) + "glove.twitter.27B.{}d.txt".format(dim))) os.remove(os.path.join(self.project_root, ".vector_cache", "glove.twitter.27B.25d.pt")) diff --git a/torchtext/vocab.py b/torchtext/vocab.py index 3e646d4814..2db2c13a5f 100644 --- a/torchtext/vocab.py +++ b/torchtext/vocab.py @@ -86,11 +86,12 @@ def __eq__(self, other): def __len__(self): return len(self.itos) - def extend(self, v, sort=True): + def extend(self, v, sort=False): words = sorted(v.itos) if sort else v.itos for w in words: - self.itos.append(w) - self.stoi[w] = len(self.itos) - 1 + if w not in self.stoi: + self.itos.append(w) + self.stoi[w] = len(self.itos) - 1 def load_vectors(self, vectors): """Arguments: From eea3d33499f16088ea3c9449f1950405673cbb3e Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 9 Sep 2017 19:43:10 +0000 Subject: [PATCH 5/5] flake8 and tests --- test/imdb.py | 2 +- test/sst.py | 4 ++-- test/test_vocab.py | 48 +++++++++++++++++++++++++--------------------- torchtext/vocab.py | 6 +++--- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/test/imdb.py b/test/imdb.py index dc7f2c3589..87b74d8e17 100644 --- a/test/imdb.py +++ b/test/imdb.py @@ -1,6 +1,6 @@ from torchtext import data from torchtext import datasets -from torchtext.vocab import GloVe, CharNGram, Vectors +from torchtext.vocab import GloVe # Approach 1: diff --git a/test/sst.py b/test/sst.py index aabdad78fc..33c369e558 100644 --- a/test/sst.py +++ b/test/sst.py @@ -19,8 +19,8 @@ print('vars(train[0])', vars(train[0])) # build the vocabulary -TEXT.build_vocab(train, vectors=Vectors('wiki.simple.vec', - url='https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec')) +url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec' +TEXT.build_vocab(train, vectors=Vectors('wiki.simple.vec', url=url)) LABEL.build_vocab(train) # print vocab information diff --git a/test/test_vocab.py b/test/test_vocab.py index 6ab86b3735..0976dcc77b 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -15,6 +15,11 @@ from .common.torchtext_test_case import TorchtextTestCase +def conditional_remove(f): + if os.path.isfile(f): + os.remove(f) + + class TestVocab(TorchtextTestCase): def test_vocab_basic(self): @@ -68,8 +73,8 @@ def test_vocab_download_fasttext_vectors(self): assert_allclose(vectors[v.stoi['OOV token']], np.zeros(300)) # Delete the vectors after we're done to save disk space on CI if os.environ.get("TRAVIS") == "true": - os.remove(os.path.join(self.project_root, ".vector_cache", - "wiki.simple.vec")) + vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") + conditional_remove(vec_file) def test_vocab_extend(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -79,11 +84,11 @@ def test_vocab_extend(self): v = vocab.Vocab(c, min_freq=3, specials=['', ''], vectors=f) n_vocab = len(v) - v.extend(f) # extend the vocab with the words contained in f.itos + v.extend(f) # extend the vocab with the words contained in f.itos self.assertGreater(len(v), n_vocab) self.assertEqual(v.itos[:6], ['', '', '', - 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) + 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) vectors = v.vectors.numpy() # The first 5 entries in each vector. @@ -98,16 +103,16 @@ def test_vocab_extend(self): assert_allclose(vectors[v.stoi['']], np.zeros(300)) # Delete the vectors after we're done to save disk space on CI - if os.environ["TRAVIS"] == "true": - os.remove(os.path.join(self.project_root, ".vector_cache", - "wiki.simple.vec")) + if os.environ.get("TRAVIS") == "true": + vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") + conditional_remove(vec_file) def test_vocab_download_custom_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) # Build a vocab and get vectors twice to test caching. for i in range(2): v = vocab.Vocab(c, min_freq=3, specials=['', ''], - vectors=Vectors('wiki.simple.vec', + vectors=Vectors('wiki.simple.vec', url=FastText.url_base.format('simple'))) self.assertEqual(v.itos, ['', '', '', @@ -126,10 +131,9 @@ def test_vocab_download_custom_vectors(self): assert_allclose(vectors[v.stoi['']], np.zeros(300)) # Delete the vectors after we're done to save disk space on CI - if os.environ["TRAVIS"] == "true": - os.remove(os.path.join(self.project_root, ".vector_cache", - "wiki.simple.vec")) - + if os.environ.get("TRAVIS") == "true": + vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") + conditional_remove(vec_file) @slow def test_vocab_download_glove_vectors(self): @@ -161,13 +165,12 @@ def test_vocab_download_glove_vectors(self): assert_allclose(vectors[v.stoi['OOV token']], np.zeros(25)) # Delete the vectors after we're done to save disk space on CI if os.environ.get("TRAVIS") == "true": - os.remove(os.path.join(self.project_root, ".vector_cache", - "glove.twitter.27B.zip")) + zip_file = os.path.join(self.project_root, ".vector_cache", + "glove.twitter.27B.zip") + conditional_remove(zip_file) for dim in ["25", "50", "100", "200"]: - os.remove(os.path.join(self.project_root, ".vector_cache", - "glove.twitter.27B.{}d.txt".format(dim))) - os.remove(os.path.join(self.project_root, ".vector_cache", - "glove.twitter.27B.25d.pt")) + conditional_remove(os.path.join(self.project_root, ".vector_cache", + "glove.twitter.27B.{}d.txt".format(dim))) @slow def test_vocab_download_charngram_vectors(self): @@ -198,10 +201,11 @@ def test_vocab_download_charngram_vectors(self): assert_allclose(vectors[v.stoi['OOV token']], np.zeros(100)) # Delete the vectors after we're done to save disk space on CI if os.environ.get("TRAVIS") == "true": - os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.txt")) - os.remove(os.path.join(self.project_root, ".vector_cache", "charNgram.txt.pt")) - os.remove(os.path.join(self.project_root, ".vector_cache", - "jmt_pre-trained_embeddings.tar.gz")) + conditional_remove( + os.path.join(self.project_root, ".vector_cache", "charNgram.txt")) + conditional_remove( + os.path.join(self.project_root, ".vector_cache", + "jmt_pre-trained_embeddings.tar.gz")) def test_serialization(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) diff --git a/torchtext/vocab.py b/torchtext/vocab.py index 2db2c13a5f..e3e5b1eca0 100644 --- a/torchtext/vocab.py +++ b/torchtext/vocab.py @@ -4,7 +4,6 @@ import io import logging import os -import shutil import zipfile import six @@ -138,7 +137,8 @@ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): class Vectors(object): - def __init__(self, name, cache='.vector_cache', url=None, unk_init=torch.Tensor.zero_): + def __init__(self, name, cache='.vector_cache', + url=None, unk_init=torch.Tensor.zero_): """Arguments: name: name of the file that contains the vectors cache: directory for cached vectors @@ -251,7 +251,7 @@ def __init__(self, name='840B', dim=300, **kwargs): url = self.url[name] name = 'glove.{}.{}d.txt'.format(name, str(dim)) super(GloVe, self).__init__(name, url=url, **kwargs) - + class FastText(Vectors):