diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index bcfe5c25ed..ed46dfeea9 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -55,9 +55,9 @@ def test_text_classification(self): datadir = os.path.join(self.project_root, ".data") if not os.path.exists(datadir): os.mkdir(datadir) - ag_news_cls = AG_NEWS(root=datadir, ngrams=3) - self.assertEqual(len(ag_news_cls.train_examples), 120000) - self.assertEqual(len(ag_news_cls.test_examples), 7600) + ag_news_train, ag_news_test = AG_NEWS(root=datadir, ngrams=3) + self.assertEqual(len(ag_news_train), 120000) + self.assertEqual(len(ag_news_test), 7600) # Delete the dataset after we're done to save disk space on CI if os.environ.get("TRAVIS") == "true": diff --git a/test/data/test_utils.py b/test/data/test_utils.py index 40bcb01633..7e4ed423cc 100644 --- a/test/data/test_utils.py +++ b/test/data/test_utils.py @@ -1,35 +1,42 @@ import six import torchtext.data as data - +import pytest from ..common.torchtext_test_case import TorchtextTestCase class TestUtils(TorchtextTestCase): - def test_get_tokenizer(self): + TEST_STR = "A string, particularly one with slightly complex punctuation." + + def test_get_tokenizer_split(self): # Test the default case with str.split assert data.get_tokenizer(str.split) == str.split - test_str = "A string, particularly one with slightly complex punctuation." - assert data.get_tokenizer(str.split)(test_str) == str.split(test_str) + assert data.get_tokenizer(str.split)(self.TEST_STR) == str.split(self.TEST_STR) + def test_get_tokenizer_spacy(self): # Test SpaCy option, and verify it properly handles punctuation. - assert data.get_tokenizer("spacy")(six.text_type(test_str)) == [ + assert data.get_tokenizer("spacy")(six.text_type(self.TEST_STR)) == [ "A", "string", ",", "particularly", "one", "with", "slightly", "complex", "punctuation", "."] + # TODO: Remove this once issue was been resolved. + @pytest.mark.skip(reason=("Impractically slow! " + "https://github.com/alvations/sacremoses/issues/61")) + def test_get_tokenizer_moses(self): # Test Moses option. # Note that internally, MosesTokenizer converts to unicode if applicable moses_tokenizer = data.get_tokenizer("moses") - assert moses_tokenizer(test_str) == [ + assert moses_tokenizer(self.TEST_STR) == [ "A", "string", ",", "particularly", "one", "with", "slightly", "complex", "punctuation", "."] # Nonbreaking prefixes should tokenize the final period. assert moses_tokenizer(six.text_type("abc def.")) == ["abc", "def", "."] + def test_get_tokenizer_toktokt(self): # Test Toktok option. Test strings taken from NLTK doctests. # Note that internally, MosesTokenizer converts to unicode if applicable toktok_tokenizer = data.get_tokenizer("toktok") - assert toktok_tokenizer(test_str) == [ + assert toktok_tokenizer(self.TEST_STR) == [ "A", "string", ",", "particularly", "one", "with", "slightly", "complex", "punctuation", "."] diff --git a/torchtext/datasets/text_classification.py b/torchtext/datasets/text_classification.py index 36b72b8cb5..6d6bff0354 100644 --- a/torchtext/datasets/text_classification.py +++ b/torchtext/datasets/text_classification.py @@ -5,6 +5,7 @@ import io from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader from torchtext.data.utils import generate_ngrams +from tqdm import tqdm from collections import Counter from collections import OrderedDict @@ -28,6 +29,7 @@ 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA' } + # TODO: Replicate below # tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \ # sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \ @@ -55,10 +57,12 @@ def _build_dictionary_from_path(data_path, ngrams): dictionary = Counter() with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) - for row in reader: - tokens = text_normalize(row[1]) - tokens = generate_ngrams(tokens, ngrams) - dictionary.update(tokens) + with tqdm(unit_scale=0, unit='lines') as t: + for row in reader: + tokens = text_normalize(row[1]) + tokens = generate_ngrams(tokens, ngrams) + dictionary.update(tokens) + t.update(1) word_dictionary = OrderedDict() for (token, frequency) in dictionary.most_common(): word_dictionary[token] = len(word_dictionary) @@ -76,9 +80,9 @@ def _create_data(dictionary, data_path): tokens = generate_ngrams(tokens, 2) tokens = torch.tensor( [dictionary.get(entry, dictionary['']) for entry in tokens]) + data.append((cls, tokens)) labels.append(cls) - data.append(tokens) - return data, labels + return data, set(labels) def _extract_data(root, dataset_name): @@ -115,7 +119,7 @@ class TextClassificationDataset(torch.utils.data.Dataset): """ - def __init__(self, root, ngrams): + def __init__(self, dictionary, data, labels): """Initiate text-classification dataset. Arguments: @@ -125,66 +129,72 @@ def __init__(self, root, ngrams): """ super(TextClassificationDataset, self).__init__() - train_csv_path, test_csv_path = _extract_data(root, self.__class__.__name__) - - # TODO: Clean up and use Vocab object - # Standardized on torchtext.Vocab - UNK = '' - dictionary = _build_dictionary_from_path(train_csv_path, ngrams) - dictionary[UNK] = len(dictionary) - self.dictionary = dictionary - self.train_data, self.train_labels = _create_data(dictionary, train_csv_path) - self.test_data, self.test_labels = _create_data(dictionary, test_csv_path) - self.train_examples = [] - for data, label in zip(self.train_data, self.train_labels): - self.train_examples.append((label, data)) - self.test_examples = [] - for data, label in zip(self.test_data, self.test_labels): - self.test_examples.append((label, data)) - self.data = self.train_data + self.test_data - self.labels = self.train_labels + self.test_labels - self._entries = zip(self.data, self.labels) + self._data = data + self._labels = labels + self._dictionary = dictionary def __getitem__(self, i): - return self._entries[i] + return self._data[i] def __len__(self): - try: - return len(self._entries) - except TypeError: - return 2**32 + return len(self._data) def __iter__(self): - for x in self._entries: + for x in self._data: yield x + def get_labels(self): + return self._labels + + def get_dictionary(self): + return self._dictionary + + +def _setup_datasets(root, ngrams, dataset_name): + train_csv_path, test_csv_path = _extract_data(root, dataset_name) + + logging.info('Building dictionary based on {}'.format(train_csv_path)) + # TODO: Clean up and use Vocab object + # Standardized on torchtext.Vocab + UNK = '' + dictionary = _build_dictionary_from_path(train_csv_path, ngrams) + dictionary[UNK] = len(dictionary) + logging.info('Creating training data') + train_data, train_labels = _create_data(dictionary, train_csv_path) + logging.info('Creating testing data') + test_data, test_labels = _create_data(dictionary, test_csv_path) + if len(train_labels ^ test_labels) > 0: + raise ValueError("Training and test labels don't match") + return (TextClassificationDataset(dictionary, train_data, train_labels), + TextClassificationDataset(dictionary, test_data, test_labels)) -class AG_NEWS(TextClassificationDataset): + +def AG_NEWS(root='.data', ngrams=1): """ Defines AG_NEWS datasets. The labels includes: - 1 : World - 2 : Sports - 3 : Business - 4 : Sci/Tech - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: AG_NEWS + Create supervised learning dataset: AG_NEWS - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.AG_NEWS(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3) + + """ - super(AG_NEWS, self).__init__(root, ngrams) + return _setup_datasets(root, ngrams, "AG_NEWS") -class SogouNews(TextClassificationDataset): +def SogouNews(root='.data', ngrams=1): """ Defines SogouNews datasets. The labels includes: - 1 : Sports @@ -192,25 +202,25 @@ class SogouNews(TextClassificationDataset): - 3 : Entertainment - 4 : Automobile - 5 : Technology - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: SogouNews + Create supervised learning dataset: SogouNews - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.SogouNews(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.SogouNews(ngrams=3) - super(SogouNews, self).__init__(root, ngrams) + """ + return _setup_datasets(root, ngrams, "SogouNews") -class DBpedia(TextClassificationDataset): + +def DBpedia(root='.data', ngrams=1): """ Defines DBpedia datasets. The labels includes: - 1 : Company @@ -227,70 +237,70 @@ class DBpedia(TextClassificationDataset): - 12 : Album - 13 : Film - 14 : WrittenWork - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: DBpedia + Create supervised learning dataset: DBpedia - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.DBpedia(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.DBpedia(ngrams=3) + + """ - super(DBpedia, self).__init__(root, ngrams) + return _setup_datasets(root, ngrams, "DBpedia") -class YelpReviewPolarity(TextClassificationDataset): +def YelpReviewPolarity(root='.data', ngrams=1): """ Defines YelpReviewPolarity datasets. The labels includes: - 1 : Negative polarity. - 2 : Positive polarity. - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: YelpReviewPolarity + Create supervised learning dataset: YelpReviewPolarity - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.YelpReviewPolarity(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewPolarity(ngrams=3) - super(YelpReviewPolarity, self).__init__(root, ngrams) + """ + + return _setup_datasets(root, ngrams, "YelpReviewPolarity") -class YelpReviewFull(TextClassificationDataset): +def YelpReviewFull(root='.data', ngrams=1): """ Defines YelpReviewFull datasets. The labels includes: 1 - 5 : rating classes (5 is highly recommended). - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: YelpReviewFull + Create supervised learning dataset: YelpReviewFull - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.YelpReviewFull(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewFull(ngrams=3) + + """ - super(YelpReviewFull, self).__init__(root, ngrams) + return _setup_datasets(root, ngrams, "YelpReviewFull") -class YahooAnswers(TextClassificationDataset): +def YahooAnswers(root='.data', ngrams=1): """ Defines YahooAnswers datasets. The labels includes: - 1 : Society & Culture @@ -303,64 +313,76 @@ class YahooAnswers(TextClassificationDataset): - 8 : Entertainment & Music - 9 : Family & Relationships - 10 : Politics & Government - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: YahooAnswers + Create supervised learning dataset: YahooAnswers - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.YahooAnswers(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.YahooAnswers(ngrams=3) - super(YahooAnswers, self).__init__(root, ngrams) + """ + + return _setup_datasets(root, ngrams, "YahooAnswers") -class AmazonReviewPolarity(TextClassificationDataset): +def AmazonReviewPolarity(root='.data', ngrams=1): """ Defines AmazonReviewPolarity datasets. The labels includes: - 1 : Negative polarity - 2 : Positive polarity - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: AmazonReviewPolarity + Create supervised learning dataset: AmazonReviewPolarity - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.AmazonReviewPolarity(ngrams=3) + Arguments: + root: Directory where the datasets are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewPolarity(ngrams=3) + + """ - super(AmazonReviewPolarity, self).__init__(root, ngrams) + return _setup_datasets(root, ngrams, "AmazonReviewPolarity") -class AmazonReviewFull(TextClassificationDataset): +def AmazonReviewFull(root='.data', ngrams=1): """ Defines AmazonReviewFull datasets. The labels includes: 1 - 5 : rating classes (5 is highly recommended) - """ - def __init__(self, root='.data', ngrams=1): - """Create supervised learning dataset: AmazonReviewFull + Create supervised learning dataset: AmazonReviewFull - Arguments: - root: Directory where the dataset are saved. Default: ".data" - ngrams: a contiguous sequence of n items from s string text. - Default: 1 + Separately returns the training and test dataset - Examples: - >>> text_cls = torchtext.datasets.AmazonReviewFull(ngrams=3) + Arguments: + root: Directory where the dataset are saved. Default: ".data" + ngrams: a contiguous sequence of n items from s string text. + Default: 1 - """ + Examples: + >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewFull(ngrams=3) + + """ - super(AmazonReviewFull, self).__init__(root, ngrams) + return _setup_datasets(root, ngrams, "AmazonReviewFull") + + +DATASETS = { + 'AG_NEWS': AG_NEWS, + 'SogouNews': SogouNews, + 'DBpedia': DBpedia, + 'YelpReviewPolarity': YelpReviewPolarity, + 'YelpReviewFull': YelpReviewFull, + 'YahooAnswers': YahooAnswers, + 'AmazonReviewPolarity': AmazonReviewPolarity, + 'AmazonReviewFull': AmazonReviewFull +} diff --git a/torchtext/utils.py b/torchtext/utils.py index cb9160a44b..8b1b663710 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -121,16 +121,19 @@ def extract_archive(from_path, to_path=None, overwrite=False, archive='tar'): if archive != 'tar': raise NotImplementedError("We currently only support tar achives.") - with tarfile.open(from_path, 'r:gz') as tar: + logging.info('Opening tar file {}.'.format(from_path)) + with tarfile.open(from_path, 'r') as tar: files = [] - for file_ in tar: - if file_.isfile(): - if os.path.exists(file_.name): + for file_ in tar.getnames(): + file_path = os.path.join(to_path, file_) + files.append(file_path) + if os.path.isfile(file_path): + if os.path.exists(file_path): + logging.info('{} already extracted.'.format(file_path)) if overwrite: tar.extract(file_, to_path) else: tar.extract(file_, to_path) - files.append(os.path.join(to_path, file_.name)) else: tar.extract(file_, to_path) return files