Skip to content

experimental.dataset WikiText2, WikiText103, PennTreeBank, WMTNewsCrawl #774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e4b094c
torchtext.experimental.raw: update __init__.py
rmz59 May 11, 2020
eb5409c
add language_modeling.py in raw dataset
rmz59 May 13, 2020
ebfd0a7
fix typo
rmz59 May 14, 2020
5dcf6e2
add new language_modeling dataset
rmz59 May 14, 2020
7be2bfe
add new language_modeling dataset
rmz59 May 14, 2020
5d71b3c
Merge remote-tracking branch 'origin/new_language_modeling' into new_…
rmz59 May 14, 2020
71baaf1
Revert "fix typo". Will submit another dedicated PR for typos
rmz59 May 15, 2020
7704105
remove duplicated functions.
rmz59 May 15, 2020
b9e4645
fix incorrect dataset orders
rmz59 May 15, 2020
37514b5
remove setup_iter
rmz59 May 15, 2020
0117041
explicitly select data
rmz59 May 15, 2020
ecafa7b
remove sys
rmz59 May 15, 2020
f56dd2f
Merge branch 'master' into new_language_modeling
rmz59 May 20, 2020
777c8f5
use functionals
rmz59 May 20, 2020
f433b40
restore the order of vocab/tokenizer
rmz59 May 20, 2020
645a749
Point language_modeling.DATASETS to local functions
rmz59 May 20, 2020
e27058f
get rid of _get_datafile_path
rmz59 May 21, 2020
f77c53c
really get rid of _get_datafile_path.
rmz59 May 21, 2020
fb36e7b
add WMTNewsCrawl
rmz59 May 24, 2020
321871d
restore setup_iter
rmz59 May 24, 2020
0409da8
add `single_line` option
rmz59 Jun 3, 2020
0a6f6ac
minor change
rmz59 Jun 3, 2020
66fe231
Update docs.
rmz59 Jun 4, 2020
5f8a63a
take care of `single_line` in WMTNewsCrawl
rmz59 Jun 4, 2020
1ed393e
add unit test for WMTNewsCrawl
rmz59 Jun 4, 2020
066bc35
Merge branch 'master' into new_language_modeling
rmz59 Jun 4, 2020
e55b60b
remove the unit test for WMTNewsCrawl because it takes too long time …
Jun 4, 2020
a2e7b02
raise an error for single_line
Jun 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_penntreebank_legacy(self):

def test_penntreebank(self):
from torchtext.experimental.datasets import PennTreebank
# smoke test to ensure wikitext2 works properly
# smoke test to ensure penn treebank works properly
train_dataset, test_dataset, valid_dataset = PennTreebank()
self.assertEqual(len(train_dataset), 924412)
self.assertEqual(len(test_dataset), 82114)
Expand Down
3 changes: 2 additions & 1 deletion torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA
from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank, WMTNewsCrawl # NOQA
from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \
YelpReviewFull, YahooAnswers, \
AmazonReviewPolarity, AmazonReviewFull, IMDB
Expand All @@ -7,6 +7,7 @@
'WikiText2',
'WikiText103',
'PennTreebank',
'WMTNewsCrawl',
'IMDB',
'AG_NEWS',
'SogouNews',
Expand Down
191 changes: 106 additions & 85 deletions torchtext/experimental/datasets/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import torch
import logging
import io
from torchtext.utils import download_from_url, extract_archive
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab
from torchtext.data.functional import numericalize_tokens_from_iterator

URLS = {
'WikiText2':
'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip',
'WikiText103':
'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip',
'PennTreebank':
['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt',
'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt',
'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt']
}
from torchtext.vocab import build_vocab_from_iterator
from torchtext.experimental.datasets.raw import language_modeling as raw
from torchtext.experimental.functional import vocab_func, totensor, sequential_transforms


def build_vocab(data, transforms):
tok_list = []
for txt in data:
tok_list.append(transforms(txt))
return build_vocab_from_iterator(tok_list)


class LanguageModelingDataset(torch.utils.data.Dataset):
Expand All @@ -26,33 +19,36 @@ class LanguageModelingDataset(torch.utils.data.Dataset):
- WikiText2
- WikiText103
- PennTreebank
- WMTNewsCrawl

"""

def __init__(self, data, vocab):
def __init__(self, data, vocab, transforms, single_line):
"""Initiate language modeling dataset.

Arguments:
data: a tensor of tokens. tokens are ids after
numericalizing the string tokens.
torch.tensor([token_id_1, token_id_2, token_id_3, token_id1]).long()
vocab: Vocabulary object used for dataset.

Examples:
>>> from torchtext.vocab import build_vocab_from_iterator
>>> data = torch.tensor([token_id_1, token_id_2,
token_id_3, token_id_1]).long()
>>> vocab = build_vocab_from_iterator([['language', 'modeling']])
>>> dataset = LanguageModelingDataset(data, vocab)
transforms: Text string transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and docs for single_line

"""

super(LanguageModelingDataset, self).__init__()
self.data = data
self.vocab = vocab
self.transforms = transforms
self.single_line = single_line
if single_line:
self.data = torch.cat(tuple(transforms(row) for row in data), axis=0)
else:
self.data = data

def __getitem__(self, i):
return self.data[i]
if self.single_line:
return self.data[i]
else:
return self.transforms(self.data[i])

def __len__(self):
return len(self.data)
Expand All @@ -65,63 +61,45 @@ def get_vocab(self):
return self.vocab


def _get_datafile_path(key, extracted_files):
for fname in extracted_files:
if key in fname:
return fname


def _setup_datasets(dataset_name, tokenizer=get_tokenizer("basic_english"),
root='.data', vocab=None, removed_tokens=[],
data_select=('train', 'test', 'valid')):
def _setup_datasets(dataset_name, tokenizer=None, root='.data', vocab=None,
data_select=('train', 'test', 'valid'), single_line=True):
if tokenizer is None:
tokenizer = get_tokenizer('basic_english')
text_transform = sequential_transforms(tokenizer)

if isinstance(data_select, str):
data_select = [data_select]
if not set(data_select).issubset(set(('train', 'test', 'valid'))):
raise TypeError('data_select is not supported!')

if dataset_name == 'PennTreebank':
extracted_files = []
select_to_index = {'train': 0, 'test': 1, 'valid': 2}
extracted_files = [download_from_url(URLS['PennTreebank'][select_to_index[key]],
root=root) for key in data_select]
if not set(data_select).issubset(set(('train', 'valid', 'test'))):
raise TypeError('Given data selection {} is not supported!'.format(data_select))

if not single_line and dataset_name != 'WikiText103':
raise TypeError('single_line must be True except for WikiText103')
if dataset_name == 'WMTNewsCrawl':
train, = raw.DATASETS[dataset_name](root=root, data_select=('train',))
if single_line:
raw_data = {'train': [" ".join([txt for txt in train]), ]}
else:
raw_data = {'train': [txt for txt in train]}
else:
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)

_path = {}
for item in data_select:
_path[item] = _get_datafile_path(item, extracted_files)
train, test, valid = raw.DATASETS[dataset_name](root=root, data_select=('train', 'test', 'valid'))
# Cache raw text iterable dataset
if single_line:
raw_data = {'train': [" ".join([txt for txt in train]), ],
'valid': [" ".join(txt for txt in valid), ],
'test': [" ".join(txt for txt in test), ]}
else:
raw_data = {'train': [txt for txt in train],
'valid': [txt for txt in valid],
'test': [txt for txt in test]}

if vocab is None:
if 'train' not in _path.keys():
if 'train' not in data_select:
raise TypeError("Must pass a vocab if train is not selected.")
logging.info('Building Vocab based on {}'.format(_path['train']))
txt_iter = iter(tokenizer(row) for row in io.open(_path['train'],
encoding="utf8"))
vocab = build_vocab_from_iterator(txt_iter)
logging.info('Vocab has {} entries'.format(len(vocab)))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")

data = {}
for item in _path.keys():
data[item] = []
logging.info('Creating {} data'.format(item))
txt_iter = iter(tokenizer(row) for row in io.open(_path[item],
encoding="utf8"))
_iter = numericalize_tokens_from_iterator(
vocab, txt_iter, removed_tokens)
for tokens in _iter:
data[item] += [token_id for token_id in tokens]

for key in data_select:
if data[key] == []:
raise TypeError('Dataset {} is empty!'.format(key))

return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab)
for d in data_select)
vocab = build_vocab(raw_data['train'], text_transform)
text_transform = sequential_transforms(text_transform, vocab_func(vocab),
totensor(dtype=torch.long))
return tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform, single_line)
for item in data_select)


def WikiText2(*args, **kwargs):
Expand All @@ -138,14 +116,17 @@ def WikiText2(*args, **kwargs):
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
removed_tokens: removed tokens from output dataset (Default: [])
data_select: a string or tupel for the returned datasets
(Default: ('train', 'test','valid'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
single_line: whether to return all tokens in a single line.
(Default: True)
By default, all lines in raw text file are concatenated into a single line.
Use `single_line = False` if one wants to get data line by line.

Examples:
>>> from torchtext.experimental.datasets import WikiText2
Expand Down Expand Up @@ -175,19 +156,17 @@ def WikiText103(*args, **kwargs):
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: the returned datasets (Default: ('train', 'test','valid'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test').
If 'train' is not in the tuple, an vocab object should be provided which will
be used to process valid and/or test data.
removed_tokens: removed tokens from output dataset (Default: [])
data_select: a string or tupel for the returned datasets
(Default: ('train', 'test','valid'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
single_line: whether to return all tokens in a single line.
(Default: True)
By default, all lines in raw text file are concatenated into a single line.
Use `single_line = False` if one wants to get data line by line.

Examples:
>>> from torchtext.experimental.datasets import WikiText103
Expand Down Expand Up @@ -217,14 +196,17 @@ def PennTreebank(*args, **kwargs):
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
removed_tokens: removed tokens from output dataset (Default: [])
data_select: a string or tupel for the returned datasets
(Default: ('train', 'test','valid'))
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
single_line: whether to return all tokens in a single line.
(Default: True)
By default, all lines in raw text file are concatenated into a single line.
Use `single_line = False` if one wants to get data line by line.

Examples:
>>> from torchtext.experimental.datasets import PennTreebank
Expand All @@ -238,3 +220,42 @@ def PennTreebank(*args, **kwargs):
"""

return _setup_datasets(*(("PennTreebank",) + args), **kwargs)


def WMTNewsCrawl(*args, **kwargs):
""" Defines WMTNewsCrawl datasets.

Create language modeling dataset: WMTNewsCrawl
returns the train set

Arguments:
tokenizer: the tokenizer used to preprocess raw text data.
The default one is basic_english tokenizer in fastText. spacy tokenizer
is supported as well (see example below). A custom tokenizer is callable
function with input of a string and output of a token list.
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tupel for the returned datasets
(Default: ('train',))
single_line: whether to return all tokens in a single line.
(Default: True)
By default, all lines in raw text file are concatenated into a single line.
Use `single_line = False` if one wants to get data line by line.
Examples:
>>> from torchtext.experimental.datasets import WMTNewsCrawl
>>> from torchtext.data.utils import get_tokenizer
>>> tokenizer = get_tokenizer("spacy")
>>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, data_select='train')

"""

return _setup_datasets(*(("WMTNewsCrawl",) + args), **kwargs)


DATASETS = {
'WikiText2': WikiText2,
'WikiText103': WikiText103,
'PennTreebank': PennTreebank,
'WMTNewsCrawl': WMTNewsCrawl
}
7 changes: 6 additions & 1 deletion torchtext/experimental/datasets/raw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \
YelpReviewFull, YahooAnswers, \
AmazonReviewPolarity, AmazonReviewFull, IMDB
from .language_modeling import WikiText2, WikiText103, PennTreebank, WMTNewsCrawl

__all__ = ['IMDB',
'AG_NEWS',
Expand All @@ -10,4 +11,8 @@
'YelpReviewFull',
'YahooAnswers',
'AmazonReviewPolarity',
'AmazonReviewFull']
'AmazonReviewFull',
'WikiText2',
'WikiText103',
'PennTreebank',
'WMTNewsCrawl']
Loading