Skip to content

Commit c56bfbd

Browse files
mthrokfacebook-github-bot
authored andcommitted
Import torchtext #1325 57a1df3
Reviewed By: NicolasHug Differential Revision: D28994054 fbshipit-source-id: 4c679f56ef37b18f6d2acaaaed8518facbeaa41c
1 parent e9d7593 commit c56bfbd

File tree

8 files changed

+47
-52
lines changed

8 files changed

+47
-52
lines changed

README.rst

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,22 @@ This repository consists of:
1515
* `torchtext.datasets <https://github.com/pytorch/text/tree/master/torchtext/datasets>`_: The raw text iterators for common NLP datasets
1616
* `torchtext.data <https://github.com/pytorch/text/tree/master/torchtext/data>`_: Some basic NLP building blocks (tokenizers, metrics, functionals etc.)
1717
* `torchtext.nn <https://github.com/pytorch/text/tree/master/torchtext/nn>`_: NLP related modules
18+
* `torchtext.vocab <https://github.com/pytorch/text/tree/master/torchtext/vocab.py>`_: Vocab and Vectors related classes and factory functions
1819
* `examples <https://github.com/pytorch/text/tree/master/examples>`_: Example NLP workflows with PyTorch and torchtext library.
1920

20-
Note: the legacy code discussed in `torchtext v0.7.0 release note <https://github.com/pytorch/text/releases/tag/v0.7.0-rc3>`_ has been retired to `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_ folder. Those legacy code will not be maintained by the development team, and we plan to fully remove them in the future release. See `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_ folder for more details.
21+
Note: The legacy code discussed in `torchtext v0.7.0 release note <https://github.com/pytorch/text/releases/tag/v0.7.0-rc3>`_ has been retired to `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_ folder. Those legacy code will not be maintained by the development team, and we plan to fully remove them in the future release. See `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_ folder for more details.
2122

2223
Installation
2324
============
2425

25-
We recommend Anaconda as Python package management system. Please refer to `pytorch.org <https://pytorch.org/>`_ for the detail of PyTorch installation. The following is the corresponding ``torchtext`` versions and supported Python versions.
26+
We recommend Anaconda as a Python package management system. Please refer to `pytorch.org <https://pytorch.org/>`_ for the details of PyTorch installation. The following are the corresponding ``torchtext`` versions and supported Python versions.
2627

2728
.. csv-table:: Version Compatibility
2829
:header: "PyTorch version", "torchtext version", "Supported Python version"
2930
:widths: 10, 10, 10
3031

3132
nightly build, master, 3.6+
33+
1.9, 0.10, 3.6+
3234
1.8, 0.9, 3.6+
3335
1.7, 0.8, 3.6+
3436
1.6, 0.7, 3.6+
@@ -93,7 +95,7 @@ Datasets
9395
The datasets module currently contains:
9496

9597
* Language modeling: WikiText2, WikiText103, PennTreebank, EnWik9
96-
* Machine translation: IWSLT2016, IWSLT2017
98+
* Machine translation: IWSLT2016, IWSLT2017, Multi30k
9799
* Sequence tagging (e.g. POS/NER): UDPOS, CoNLL2000Chunking
98100
* Question answering: SQuAD1, SQuAD2
99101
* Text classification: AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, YelpReviewFull, YahooAnswers, AmazonReviewPolarity, AmazonReviewFull, IMDB
@@ -113,15 +115,22 @@ For example, to access the raw text from the AG_NEWS dataset:
113115
>>> train_iter = AG_NEWS(split='train')
114116
>>> dataloader = DataLoader(train_iter, batch_size=8, shuffle=False)
115117
116-
A tutorial for the end-to-end text classification workflow can be found in `PyTorch tutorial <https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html>`_
118+
Tutorials
119+
=========
120+
121+
To get started with torchtext, users may refer to the following tutorials available on PyTorch website.
122+
123+
* `Text classification with AG_NEWS dataset <https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html>`_
124+
* `Translation trained with Multi30k dataset using transformers and torchtext <https://pytorch.org/tutorials/beginner/translation_transformer.html>`_
125+
* `Language modeling using transforms and torchtext <https://pytorch.org/tutorials/beginner/transformer_tutorial.html>`_
126+
117127

118128
[Prototype] Experimental Code
119129
=============================
120130

121131
We have re-written several building blocks under ``torchtext.experimental``:
122132

123133
* `Transforms <https://github.com/pytorch/text/blob/master/torchtext/experimental/transforms.py>`_: some basic data processing building blocks
124-
* `Vocabulary <https://github.com/pytorch/text/blob/master/torchtext/experimental/vocab.py>`_: a vocabulary to numericalize tokens
125134
* `Vectors <https://github.com/pytorch/text/blob/master/torchtext/experimental/vectors.py>`_: the vectors to convert tokens into tensors.
126135

127136
These prototype building blocks in the experimental folder are available in the nightly release only. The nightly packages are accessible via Pip and Conda for Windows, Mac, and Linux. For example, Linux users can install the nightly wheels with the following command::
@@ -133,7 +142,7 @@ For more detailed instructions, please refer to `Install PyTorch <https://pytorc
133142
[BC Breaking] Legacy
134143
====================
135144

136-
In v0.9.0 release, we move the following legacy code to `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_. This is part of the work to revamp the torchtext library and the motivation has been discussed in `Issue #664 <https://github.com/pytorch/text/issues/664>`_:
145+
In the v0.9.0 release, we moved the following legacy code to `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_. This is part of the work to revamp the torchtext library and the motivation has been discussed in `Issue #664 <https://github.com/pytorch/text/issues/664>`_:
137146

138147
* ``torchtext.legacy.data.field``
139148
* ``torchtext.legacy.data.batch``
@@ -144,6 +153,8 @@ In v0.9.0 release, we move the following legacy code to `torchtext.legacy <https
144153

145154
We have a `migration tutorial <https://colab.research.google.com/github/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb>`_ to help users switch to the torchtext datasets in ``v0.9.0`` release. For the users who still want the legacy components, they can add ``legacy`` to the import path.
146155

156+
In the v0.10.0 release, we retire the Vocab class to `torchtext.legacy <https://github.com/pytorch/text/tree/master/torchtext/legacy>`_. Users can still access the legacy Vocab via ``torchtext.legacy.vocab``. This class has been replaced by a Vocab module that is backed by efficient C++ implementation and provides common functional APIs for NLP workflows.
157+
147158
Disclaimer on Datasets
148159
======================
149160

test/data/test_builtin_datasets.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -207,23 +207,14 @@ def test_next_method_dataset(self):
207207

208208
def test_imdb(self):
209209
from torchtext.experimental.datasets import IMDB
210-
from torchtext.legacy.vocab import Vocab
211210
# smoke test to ensure imdb works properly
212211
train_dataset, test_dataset = IMDB()
213212
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
214213
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
215214
self._helper_test_func(len(test_dataset), 25000, test_dataset[0][1][:10],
216215
[13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20])
217216

218-
# Test API with a vocab input object
219-
old_vocab = train_dataset.get_vocab()
220-
new_vocab = Vocab(counter=old_vocab.freqs, max_size=2500)
221-
new_train_data, new_test_data = IMDB(vocab=new_vocab)
222-
223217
# Add test for the subset of the standard datasets
224-
train_dataset = IMDB(split='train')
225-
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
226-
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
227218
train_iter, test_iter = torchtext.datasets.IMDB()
228219
self._helper_test_func(len(train_iter), 25000, next(train_iter)[1][:25], 'I rented I AM CURIOUS-YEL')
229220
self._helper_test_func(len(test_iter), 25000, next(test_iter)[1][:25], 'I love sci-fi and am will')
@@ -241,8 +232,8 @@ def test_iwslt2017(self):
241232
de_vocab, en_vocab = train_dataset.get_vocab()
242233

243234
def assert_nth_pair_is_equal(n, expected_sentence_pair):
244-
de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]]
245-
en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]]
235+
de_sentence = [de_vocab.lookup_token(index) for index in train_dataset[n][0]]
236+
en_sentence = [en_vocab.lookup_token(index) for index in train_dataset[n][1]]
246237

247238
expected_de_sentence, expected_en_sentence = expected_sentence_pair
248239

@@ -267,8 +258,8 @@ def test_iwslt2016(self):
267258
de_vocab, en_vocab = train_dataset.get_vocab()
268259

269260
def assert_nth_pair_is_equal(n, expected_sentence_pair):
270-
de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]]
271-
en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]]
261+
de_sentence = [de_vocab.lookup_token(index) for index in train_dataset[n][0]]
262+
en_sentence = [en_vocab.lookup_token(index) for index in train_dataset[n][1]]
272263
expected_de_sentence, expected_en_sentence = expected_sentence_pair
273264

274265
self.assertEqual(de_sentence, expected_de_sentence)
@@ -462,7 +453,6 @@ def test_conll_sequence_tagging(self):
462453

463454
def test_squad1(self):
464455
from torchtext.experimental.datasets import SQuAD1
465-
from torchtext.legacy.vocab import Vocab
466456
# smoke test to ensure imdb works properly
467457
train_dataset, dev_dataset = SQuAD1()
468458
context, question, answers, ans_pos = train_dataset[100]
@@ -472,16 +462,8 @@ def test_squad1(self):
472462
self._helper_test_func(len(dev_dataset), 10570, (question, ans_pos[0]),
473463
([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16], [45, 48]))
474464

475-
# Test API with a vocab input object
476-
old_vocab = train_dataset.get_vocab()
477-
new_vocab = Vocab(counter=old_vocab.freqs, max_size=2500)
478-
new_train_data, new_test_data = SQuAD1(vocab=new_vocab)
479-
480465
# Add test for the subset of the standard datasets
481466
train_dataset = SQuAD1(split='train')
482-
context, question, answers, ans_pos = train_dataset[100]
483-
self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]),
484-
([7, 24, 86, 52, 2], [72, 72]))
485467
train_iter, dev_iter = torchtext.datasets.SQuAD1()
486468
self._helper_test_func(len(train_iter), 87599, next(train_iter)[0][:50],
487469
'Architecturally, the school has a Catholic charact')
@@ -491,7 +473,6 @@ def test_squad1(self):
491473

492474
def test_squad2(self):
493475
from torchtext.experimental.datasets import SQuAD2
494-
from torchtext.legacy.vocab import Vocab
495476
# smoke test to ensure imdb works properly
496477
train_dataset, dev_dataset = SQuAD2()
497478
context, question, answers, ans_pos = train_dataset[200]
@@ -501,16 +482,8 @@ def test_squad2(self):
501482
self._helper_test_func(len(dev_dataset), 11873, (question, ans_pos[0]),
502483
([41, 29, 2, 66, 17016, 30, 0, 1955, 16], [40, 46]))
503484

504-
# Test API with a vocab input object
505-
old_vocab = train_dataset.get_vocab()
506-
new_vocab = Vocab(counter=old_vocab.freqs, max_size=2500)
507-
new_train_data, new_test_data = SQuAD2(vocab=new_vocab)
508-
509485
# Add test for the subset of the standard datasets
510486
train_dataset = SQuAD2(split='train')
511-
context, question, answers, ans_pos = train_dataset[200]
512-
self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]),
513-
([84, 50, 1421, 12, 5439], [9, 9]))
514487
train_iter, dev_iter = torchtext.datasets.SQuAD2()
515488
self._helper_test_func(len(train_iter), 130319, next(train_iter)[0][:50],
516489
'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-Y')

torchtext/experimental/datasets/language_modeling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import logging
33
from torchtext.data.utils import get_tokenizer
4-
from torchtext.legacy.vocab import build_vocab_from_iterator
4+
from torchtext.vocab import build_vocab_from_iterator
55
from torchtext import datasets as raw
66
from torchtext.experimental.datasets import raw as experimental_raw
77
from torchtext.data.datasets_utils import _check_default_set
@@ -15,7 +15,9 @@ def apply_transforms(data):
1515
for line in data:
1616
tokens = transforms(line)
1717
yield tokens
18-
return build_vocab_from_iterator(apply_transforms(data), len(data))
18+
vocab = build_vocab_from_iterator(apply_transforms(data), specials=['<unk>', '<pad>'])
19+
vocab.set_default_index(vocab['<unk>'])
20+
return vocab
1921

2022

2123
class LanguageModelingDataset(torch.utils.data.Dataset):

torchtext/experimental/datasets/question_answer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import logging
33
from torchtext.data.utils import get_tokenizer
4-
from torchtext.legacy.vocab import build_vocab_from_iterator
4+
from torchtext.vocab import build_vocab_from_iterator
55
from torchtext import datasets as raw
66
from torchtext.data.datasets_utils import _check_default_set
77
from torchtext.data.datasets_utils import _wrap_datasets
@@ -81,7 +81,8 @@ def apply_transform(data):
8181
tok_ans += text_transform(item)
8282
yield text_transform(_context) + text_transform(_question) + tok_ans
8383
logger_.info('Building Vocab based on train data')
84-
vocab = build_vocab_from_iterator(apply_transform(raw_data['train']), len(raw_data['train']))
84+
vocab = build_vocab_from_iterator(apply_transform(raw_data['train']), specials=['<unk>', '<pad>'])
85+
vocab.set_default_index(vocab['<unk>'])
8586
logger_.info('Vocab has %d entries', len(vocab))
8687
text_transform = sequential_transforms(text_transform, vocab_func(vocab), totensor(dtype=torch.long))
8788
transforms = {'context': text_transform, 'question': text_transform,

torchtext/experimental/datasets/sequence_tagging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torchtext.data.datasets_utils import _check_default_set
44
from torchtext.data.datasets_utils import _wrap_datasets
55
from torchtext import datasets as raw
6-
from torchtext.legacy.vocab import build_vocab_from_iterator
6+
from torchtext.vocab import build_vocab_from_iterator
77
from torchtext.experimental.functional import (
88
vocab_func,
99
totensor,
@@ -22,7 +22,9 @@ def build_vocab(data):
2222
for idx, col in enumerate(line):
2323
data_list[idx].append(col)
2424
for it in data_list:
25-
vocabs.append(build_vocab_from_iterator(it, len(it)))
25+
vocab = build_vocab_from_iterator(it, specials=['<unk>', '<pad>'])
26+
vocab.set_default_index(vocab['<unk>'])
27+
vocabs.append(vocab)
2628

2729
return vocabs
2830

torchtext/experimental/datasets/text_classification.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import logging
33
from torchtext.data.utils import get_tokenizer
4-
from torchtext.legacy.vocab import build_vocab_from_iterator
4+
from torchtext.vocab import build_vocab_from_iterator
55
from torchtext import datasets as raw
66
from torchtext.data.datasets_utils import _check_default_set
77
from torchtext.data.datasets_utils import _wrap_datasets
@@ -19,7 +19,9 @@ def build_vocab(data, transforms):
1919
def apply_transforms(data):
2020
for _, line in data:
2121
yield transforms(line)
22-
return build_vocab_from_iterator(apply_transforms(data), len(data))
22+
vocab = build_vocab_from_iterator(apply_transforms(data), specials=['<unk>', '<pad>'])
23+
vocab.set_default_index(vocab['<unk>'])
24+
return vocab
2325

2426

2527
class TextClassificationDataset(torch.utils.data.Dataset):

torchtext/experimental/datasets/translation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torchtext.data.datasets_utils import _wrap_datasets
55
from torchtext import datasets as raw
66
from torchtext.experimental.datasets import raw as experimental_raw
7-
from torchtext.legacy.vocab import Vocab, build_vocab_from_iterator
7+
from torchtext.vocab import Vocab, build_vocab_from_iterator
88
from torchtext.data.utils import get_tokenizer
99
from ..functional import vocab_func, totensor, sequential_transforms
1010

@@ -15,7 +15,9 @@ def build_vocab(data, transforms, index):
1515
def apply_transforms(data):
1616
for line in data:
1717
yield transforms(line[index])
18-
return build_vocab_from_iterator(apply_transforms(data), len(data))
18+
vocab = build_vocab_from_iterator(apply_transforms(data), specials=['<unk>', '<pad>'])
19+
vocab.set_default_index(vocab['<unk>'])
20+
return vocab
1921

2022

2123
def _setup_datasets(dataset_name,

torchtext/vocab.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,16 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
258258
counter = Counter()
259259
for tokens in iterator:
260260
counter.update(tokens)
261-
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
262-
ordered_dict = OrderedDict(sorted_by_freq_tuples)
263261

264262
if specials is not None:
265-
for symbol in specials:
266-
if symbol in ordered_dict:
267-
del ordered_dict[symbol]
263+
for tok in specials:
264+
del counter[tok]
268265

266+
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0])
267+
sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True)
268+
ordered_dict = OrderedDict(sorted_by_freq_tuples)
269+
270+
if specials is not None:
269271
if special_first:
270272
specials = specials[::-1]
271273
for symbol in specials:

0 commit comments

Comments
 (0)