diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py
index bcfe5c25ed..ed46dfeea9 100644
--- a/test/data/test_builtin_datasets.py
+++ b/test/data/test_builtin_datasets.py
@@ -55,9 +55,9 @@ def test_text_classification(self):
datadir = os.path.join(self.project_root, ".data")
if not os.path.exists(datadir):
os.mkdir(datadir)
- ag_news_cls = AG_NEWS(root=datadir, ngrams=3)
- self.assertEqual(len(ag_news_cls.train_examples), 120000)
- self.assertEqual(len(ag_news_cls.test_examples), 7600)
+ ag_news_train, ag_news_test = AG_NEWS(root=datadir, ngrams=3)
+ self.assertEqual(len(ag_news_train), 120000)
+ self.assertEqual(len(ag_news_test), 7600)
# Delete the dataset after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
diff --git a/test/data/test_utils.py b/test/data/test_utils.py
index 40bcb01633..7e4ed423cc 100644
--- a/test/data/test_utils.py
+++ b/test/data/test_utils.py
@@ -1,35 +1,42 @@
import six
import torchtext.data as data
-
+import pytest
from ..common.torchtext_test_case import TorchtextTestCase
class TestUtils(TorchtextTestCase):
- def test_get_tokenizer(self):
+ TEST_STR = "A string, particularly one with slightly complex punctuation."
+
+ def test_get_tokenizer_split(self):
# Test the default case with str.split
assert data.get_tokenizer(str.split) == str.split
- test_str = "A string, particularly one with slightly complex punctuation."
- assert data.get_tokenizer(str.split)(test_str) == str.split(test_str)
+ assert data.get_tokenizer(str.split)(self.TEST_STR) == str.split(self.TEST_STR)
+ def test_get_tokenizer_spacy(self):
# Test SpaCy option, and verify it properly handles punctuation.
- assert data.get_tokenizer("spacy")(six.text_type(test_str)) == [
+ assert data.get_tokenizer("spacy")(six.text_type(self.TEST_STR)) == [
"A", "string", ",", "particularly", "one", "with", "slightly",
"complex", "punctuation", "."]
+ # TODO: Remove this once issue was been resolved.
+ @pytest.mark.skip(reason=("Impractically slow! "
+ "https://github.com/alvations/sacremoses/issues/61"))
+ def test_get_tokenizer_moses(self):
# Test Moses option.
# Note that internally, MosesTokenizer converts to unicode if applicable
moses_tokenizer = data.get_tokenizer("moses")
- assert moses_tokenizer(test_str) == [
+ assert moses_tokenizer(self.TEST_STR) == [
"A", "string", ",", "particularly", "one", "with", "slightly",
"complex", "punctuation", "."]
# Nonbreaking prefixes should tokenize the final period.
assert moses_tokenizer(six.text_type("abc def.")) == ["abc", "def", "."]
+ def test_get_tokenizer_toktokt(self):
# Test Toktok option. Test strings taken from NLTK doctests.
# Note that internally, MosesTokenizer converts to unicode if applicable
toktok_tokenizer = data.get_tokenizer("toktok")
- assert toktok_tokenizer(test_str) == [
+ assert toktok_tokenizer(self.TEST_STR) == [
"A", "string", ",", "particularly", "one", "with", "slightly",
"complex", "punctuation", "."]
diff --git a/torchtext/datasets/text_classification.py b/torchtext/datasets/text_classification.py
index 36b72b8cb5..6d6bff0354 100644
--- a/torchtext/datasets/text_classification.py
+++ b/torchtext/datasets/text_classification.py
@@ -5,6 +5,7 @@
import io
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
from torchtext.data.utils import generate_ngrams
+from tqdm import tqdm
from collections import Counter
from collections import OrderedDict
@@ -28,6 +29,7 @@
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}
+
# TODO: Replicate below
# tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \
# sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \
@@ -55,10 +57,12 @@ def _build_dictionary_from_path(data_path, ngrams):
dictionary = Counter()
with io.open(data_path, encoding="utf8") as f:
reader = unicode_csv_reader(f)
- for row in reader:
- tokens = text_normalize(row[1])
- tokens = generate_ngrams(tokens, ngrams)
- dictionary.update(tokens)
+ with tqdm(unit_scale=0, unit='lines') as t:
+ for row in reader:
+ tokens = text_normalize(row[1])
+ tokens = generate_ngrams(tokens, ngrams)
+ dictionary.update(tokens)
+ t.update(1)
word_dictionary = OrderedDict()
for (token, frequency) in dictionary.most_common():
word_dictionary[token] = len(word_dictionary)
@@ -76,9 +80,9 @@ def _create_data(dictionary, data_path):
tokens = generate_ngrams(tokens, 2)
tokens = torch.tensor(
[dictionary.get(entry, dictionary['']) for entry in tokens])
+ data.append((cls, tokens))
labels.append(cls)
- data.append(tokens)
- return data, labels
+ return data, set(labels)
def _extract_data(root, dataset_name):
@@ -115,7 +119,7 @@ class TextClassificationDataset(torch.utils.data.Dataset):
"""
- def __init__(self, root, ngrams):
+ def __init__(self, dictionary, data, labels):
"""Initiate text-classification dataset.
Arguments:
@@ -125,66 +129,72 @@ def __init__(self, root, ngrams):
"""
super(TextClassificationDataset, self).__init__()
- train_csv_path, test_csv_path = _extract_data(root, self.__class__.__name__)
-
- # TODO: Clean up and use Vocab object
- # Standardized on torchtext.Vocab
- UNK = ''
- dictionary = _build_dictionary_from_path(train_csv_path, ngrams)
- dictionary[UNK] = len(dictionary)
- self.dictionary = dictionary
- self.train_data, self.train_labels = _create_data(dictionary, train_csv_path)
- self.test_data, self.test_labels = _create_data(dictionary, test_csv_path)
- self.train_examples = []
- for data, label in zip(self.train_data, self.train_labels):
- self.train_examples.append((label, data))
- self.test_examples = []
- for data, label in zip(self.test_data, self.test_labels):
- self.test_examples.append((label, data))
- self.data = self.train_data + self.test_data
- self.labels = self.train_labels + self.test_labels
- self._entries = zip(self.data, self.labels)
+ self._data = data
+ self._labels = labels
+ self._dictionary = dictionary
def __getitem__(self, i):
- return self._entries[i]
+ return self._data[i]
def __len__(self):
- try:
- return len(self._entries)
- except TypeError:
- return 2**32
+ return len(self._data)
def __iter__(self):
- for x in self._entries:
+ for x in self._data:
yield x
+ def get_labels(self):
+ return self._labels
+
+ def get_dictionary(self):
+ return self._dictionary
+
+
+def _setup_datasets(root, ngrams, dataset_name):
+ train_csv_path, test_csv_path = _extract_data(root, dataset_name)
+
+ logging.info('Building dictionary based on {}'.format(train_csv_path))
+ # TODO: Clean up and use Vocab object
+ # Standardized on torchtext.Vocab
+ UNK = ''
+ dictionary = _build_dictionary_from_path(train_csv_path, ngrams)
+ dictionary[UNK] = len(dictionary)
+ logging.info('Creating training data')
+ train_data, train_labels = _create_data(dictionary, train_csv_path)
+ logging.info('Creating testing data')
+ test_data, test_labels = _create_data(dictionary, test_csv_path)
+ if len(train_labels ^ test_labels) > 0:
+ raise ValueError("Training and test labels don't match")
+ return (TextClassificationDataset(dictionary, train_data, train_labels),
+ TextClassificationDataset(dictionary, test_data, test_labels))
-class AG_NEWS(TextClassificationDataset):
+
+def AG_NEWS(root='.data', ngrams=1):
""" Defines AG_NEWS datasets.
The labels includes:
- 1 : World
- 2 : Sports
- 3 : Business
- 4 : Sci/Tech
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: AG_NEWS
+ Create supervised learning dataset: AG_NEWS
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.AG_NEWS(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)
+
+ """
- super(AG_NEWS, self).__init__(root, ngrams)
+ return _setup_datasets(root, ngrams, "AG_NEWS")
-class SogouNews(TextClassificationDataset):
+def SogouNews(root='.data', ngrams=1):
""" Defines SogouNews datasets.
The labels includes:
- 1 : Sports
@@ -192,25 +202,25 @@ class SogouNews(TextClassificationDataset):
- 3 : Entertainment
- 4 : Automobile
- 5 : Technology
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: SogouNews
+ Create supervised learning dataset: SogouNews
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.SogouNews(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.SogouNews(ngrams=3)
- super(SogouNews, self).__init__(root, ngrams)
+ """
+ return _setup_datasets(root, ngrams, "SogouNews")
-class DBpedia(TextClassificationDataset):
+
+def DBpedia(root='.data', ngrams=1):
""" Defines DBpedia datasets.
The labels includes:
- 1 : Company
@@ -227,70 +237,70 @@ class DBpedia(TextClassificationDataset):
- 12 : Album
- 13 : Film
- 14 : WrittenWork
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: DBpedia
+ Create supervised learning dataset: DBpedia
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.DBpedia(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.DBpedia(ngrams=3)
+
+ """
- super(DBpedia, self).__init__(root, ngrams)
+ return _setup_datasets(root, ngrams, "DBpedia")
-class YelpReviewPolarity(TextClassificationDataset):
+def YelpReviewPolarity(root='.data', ngrams=1):
""" Defines YelpReviewPolarity datasets.
The labels includes:
- 1 : Negative polarity.
- 2 : Positive polarity.
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: YelpReviewPolarity
+ Create supervised learning dataset: YelpReviewPolarity
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.YelpReviewPolarity(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewPolarity(ngrams=3)
- super(YelpReviewPolarity, self).__init__(root, ngrams)
+ """
+
+ return _setup_datasets(root, ngrams, "YelpReviewPolarity")
-class YelpReviewFull(TextClassificationDataset):
+def YelpReviewFull(root='.data', ngrams=1):
""" Defines YelpReviewFull datasets.
The labels includes:
1 - 5 : rating classes (5 is highly recommended).
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: YelpReviewFull
+ Create supervised learning dataset: YelpReviewFull
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.YelpReviewFull(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewFull(ngrams=3)
+
+ """
- super(YelpReviewFull, self).__init__(root, ngrams)
+ return _setup_datasets(root, ngrams, "YelpReviewFull")
-class YahooAnswers(TextClassificationDataset):
+def YahooAnswers(root='.data', ngrams=1):
""" Defines YahooAnswers datasets.
The labels includes:
- 1 : Society & Culture
@@ -303,64 +313,76 @@ class YahooAnswers(TextClassificationDataset):
- 8 : Entertainment & Music
- 9 : Family & Relationships
- 10 : Politics & Government
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: YahooAnswers
+ Create supervised learning dataset: YahooAnswers
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.YahooAnswers(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.YahooAnswers(ngrams=3)
- super(YahooAnswers, self).__init__(root, ngrams)
+ """
+
+ return _setup_datasets(root, ngrams, "YahooAnswers")
-class AmazonReviewPolarity(TextClassificationDataset):
+def AmazonReviewPolarity(root='.data', ngrams=1):
""" Defines AmazonReviewPolarity datasets.
The labels includes:
- 1 : Negative polarity
- 2 : Positive polarity
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: AmazonReviewPolarity
+ Create supervised learning dataset: AmazonReviewPolarity
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.AmazonReviewPolarity(ngrams=3)
+ Arguments:
+ root: Directory where the datasets are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewPolarity(ngrams=3)
+
+ """
- super(AmazonReviewPolarity, self).__init__(root, ngrams)
+ return _setup_datasets(root, ngrams, "AmazonReviewPolarity")
-class AmazonReviewFull(TextClassificationDataset):
+def AmazonReviewFull(root='.data', ngrams=1):
""" Defines AmazonReviewFull datasets.
The labels includes:
1 - 5 : rating classes (5 is highly recommended)
- """
- def __init__(self, root='.data', ngrams=1):
- """Create supervised learning dataset: AmazonReviewFull
+ Create supervised learning dataset: AmazonReviewFull
- Arguments:
- root: Directory where the dataset are saved. Default: ".data"
- ngrams: a contiguous sequence of n items from s string text.
- Default: 1
+ Separately returns the training and test dataset
- Examples:
- >>> text_cls = torchtext.datasets.AmazonReviewFull(ngrams=3)
+ Arguments:
+ root: Directory where the dataset are saved. Default: ".data"
+ ngrams: a contiguous sequence of n items from s string text.
+ Default: 1
- """
+ Examples:
+ >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewFull(ngrams=3)
+
+ """
- super(AmazonReviewFull, self).__init__(root, ngrams)
+ return _setup_datasets(root, ngrams, "AmazonReviewFull")
+
+
+DATASETS = {
+ 'AG_NEWS': AG_NEWS,
+ 'SogouNews': SogouNews,
+ 'DBpedia': DBpedia,
+ 'YelpReviewPolarity': YelpReviewPolarity,
+ 'YelpReviewFull': YelpReviewFull,
+ 'YahooAnswers': YahooAnswers,
+ 'AmazonReviewPolarity': AmazonReviewPolarity,
+ 'AmazonReviewFull': AmazonReviewFull
+}
diff --git a/torchtext/utils.py b/torchtext/utils.py
index cb9160a44b..8b1b663710 100644
--- a/torchtext/utils.py
+++ b/torchtext/utils.py
@@ -121,16 +121,19 @@ def extract_archive(from_path, to_path=None, overwrite=False, archive='tar'):
if archive != 'tar':
raise NotImplementedError("We currently only support tar achives.")
- with tarfile.open(from_path, 'r:gz') as tar:
+ logging.info('Opening tar file {}.'.format(from_path))
+ with tarfile.open(from_path, 'r') as tar:
files = []
- for file_ in tar:
- if file_.isfile():
- if os.path.exists(file_.name):
+ for file_ in tar.getnames():
+ file_path = os.path.join(to_path, file_)
+ files.append(file_path)
+ if os.path.isfile(file_path):
+ if os.path.exists(file_path):
+ logging.info('{} already extracted.'.format(file_path))
if overwrite:
tar.extract(file_, to_path)
else:
tar.extract(file_, to_path)
- files.append(os.path.join(to_path, file_.name))
else:
tar.extract(file_, to_path)
return files