From d1ce4926dec33636ec6b94c2b61bd994a5a6550b Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 4 May 2022 19:20:34 -0400 Subject: [PATCH 1/7] Add support for CoLA dataset + unit tests --- test/datasets/test_cola.py | 77 ++++++++++++++++++++++++++++++++++ torchtext/datasets/__init__.py | 2 + torchtext/datasets/cola.py | 72 +++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+) create mode 100644 test/datasets/test_cola.py create mode 100644 torchtext/datasets/cola.py diff --git a/test/datasets/test_cola.py b/test/datasets/test_cola.py new file mode 100644 index 0000000000..6d0e4c96fe --- /dev/null +++ b/test/datasets/test_cola.py @@ -0,0 +1,77 @@ +import os +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.cola import CoLA + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "CoLA") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + for _ in range(5): + label = seed % 2 + rand_string = get_random_unicode(seed) + dataset_line = (rand_string, label, rand_string) + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + f.write(f'"{rand_string}"\t"{label}"\t"{rand_string}"\n') + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "cola_public_1.1.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("cola_public", "raw", file_name)) + + return mocked_data + + +class TestCoLA(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test", "dev"]) + def test_cola(self, split): + dataset = CoLA(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test", "dev"]) + def test_cola_split_argument(self, split): + dataset1 = CoLA(root=self.root_dir, split=split) + (dataset2,) = CoLA(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index d7d33298ad..29dcc5f165 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -4,6 +4,7 @@ from .amazonreviewfull import AmazonReviewFull from .amazonreviewpolarity import AmazonReviewPolarity from .cc100 import CC100 +from .cola import CoLA from .conll2000chunking import CoNLL2000Chunking from .dbpedia import DBpedia from .enwik9 import EnWik9 @@ -28,6 +29,7 @@ "AmazonReviewFull": AmazonReviewFull, "AmazonReviewPolarity": AmazonReviewPolarity, "CC100": CC100, + "CoLA": CoLA, "CoNLL2000Chunking": CoNLL2000Chunking, "DBpedia": DBpedia, "EnWik9": EnWik9, diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py new file mode 100644 index 0000000000..f29f27c617 --- /dev/null +++ b/torchtext/datasets/cola.py @@ -0,0 +1,72 @@ +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + from torchtext._download_hooks import HttpReader + +URL = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip" + +MD5 = "9f6d88c3558ec424cd9d66ea03589aba" + +_PATH = "cola_public_1.1.zip" + +NUM_LINES = {"train": 8551, "dev": 527, "test": 516} + +_EXTRACTED_FILES = { + "train": os.path.join("cola_public", "raw", "in_domain_train.tsv"), + "dev": os.path.join("cola_public", "raw", "in_domain_dev.tsv"), + "test": os.path.join("cola_public", "raw", "out_of_domain_dev.tsv"), +} + +DATASET_NAME = "CoLA" + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def CoLA(root: str, split: Union[Tuple[str], str]): + """CoLA dataset + + For additional details refer to https://nyu-mll.github.io/CoLA/ + + Number of lines per split: + - train: 8551 + - dev: 527 + - test: 516 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) + + + :returns: DataPipe that yields rows from CoLA dataset (source (str), label (int), sentence (str)) + :rtype: str + """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + # some context stored at top of the file needs to be removed + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").filter(lambda x: len(x) == 4).map(lambda t: (t[0], int(t[1]), t[3])) + return parsed_data From a43a1a08f744e37cbe250f7baa36591e8bced47a Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 11 May 2022 12:11:47 -0400 Subject: [PATCH 2/7] Better test with differentiated rand_string --- test/datasets/test_cola.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/datasets/test_cola.py b/test/datasets/test_cola.py index 6d0e4c96fe..4b84b1f5df 100644 --- a/test/datasets/test_cola.py +++ b/test/datasets/test_cola.py @@ -25,11 +25,12 @@ def _get_mock_dataset(root_dir): with open(txt_file, "w", encoding="utf-8") as f: for _ in range(5): label = seed % 2 - rand_string = get_random_unicode(seed) - dataset_line = (rand_string, label, rand_string) + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed+1) + dataset_line = (rand_string_1, label, rand_string_2) # append line to correct dataset split mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) - f.write(f'"{rand_string}"\t"{label}"\t"{rand_string}"\n') + f.write(f'"{rand_string_1}"\t"{label}"\t"{rand_string_2}"\n') seed += 1 compressed_dataset_path = os.path.join(base_dir, "cola_public_1.1.zip") From 321d42b2315a85bf86755a62ca8503b74c97cf50 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 12 May 2022 10:49:17 -0400 Subject: [PATCH 3/7] Remove lambda functions --- torchtext/datasets/cola.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py index f29f27c617..297c2b5ab7 100644 --- a/torchtext/datasets/cola.py +++ b/torchtext/datasets/cola.py @@ -1,4 +1,5 @@ import os +import csv from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument @@ -50,23 +51,38 @@ def CoLA(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return (t[0], int(t[1]), t[3]) + + def _filter_res(x): + return len(x) == 4 + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + filepath_fn=_extracted_filepath_fn ) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # some context stored at top of the file needs to be removed - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").filter(lambda x: len(x) == 4).map(lambda t: (t[0], int(t[1]), t[3])) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) return parsed_data From 60e8ee1f4daf4273aeeaff3d4832bef0aac09d85 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 16 May 2022 09:52:04 -0400 Subject: [PATCH 4/7] Fix lint --- test/datasets/test_cola.py | 2 +- torchtext/datasets/cola.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/test/datasets/test_cola.py b/test/datasets/test_cola.py index 4b84b1f5df..d55eb31fbf 100644 --- a/test/datasets/test_cola.py +++ b/test/datasets/test_cola.py @@ -26,7 +26,7 @@ def _get_mock_dataset(root_dir): for _ in range(5): label = seed % 2 rand_string_1 = get_random_unicode(seed) - rand_string_2 = get_random_unicode(seed+1) + rand_string_2 = get_random_unicode(seed + 1) dataset_line = (rand_string_1, label, rand_string_2) # append line to correct dataset split mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py index 297c2b5ab7..4b4e710b53 100644 --- a/torchtext/datasets/cola.py +++ b/torchtext/datasets/cola.py @@ -1,9 +1,9 @@ -import os import csv +import os +from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument -from typing import Union, Tuple if is_module_available("torchdata"): from torchdata.datapipes.iter import FileOpener, IterableWrapper @@ -33,8 +33,8 @@ def CoLA(root: str, split: Union[Tuple[str], str]): For additional details refer to https://nyu-mll.github.io/CoLA/ - Number of lines per split: - - train: 8551 + Number of lines per split: + - train: 8551 - dev: 527 - test: 516 @@ -74,15 +74,13 @@ def _filter_res(x): ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=_extracted_filepath_fn - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # some context stored at top of the file needs to be removed - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) + parsed_data = ( + data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) + ) return parsed_data From 3917c12fd0f98e6172c698674c766c3c48d4e1fc Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 16 May 2022 09:53:45 -0400 Subject: [PATCH 5/7] Fix docstring --- torchtext/datasets/cola.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py index 4b4e710b53..0cb4df038b 100644 --- a/torchtext/datasets/cola.py +++ b/torchtext/datasets/cola.py @@ -44,7 +44,7 @@ def CoLA(root: str, split: Union[Tuple[str], str]): :returns: DataPipe that yields rows from CoLA dataset (source (str), label (int), sentence (str)) - :rtype: str + :rtype: (str, int, str) """ if not is_module_available("torchdata"): raise ModuleNotFoundError( From f2bff4180e5532a2b55eb52c1b1773e25db5358c Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Tue, 17 May 2022 11:52:51 -0400 Subject: [PATCH 6/7] Add dataset documentation --- docs/source/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 33eb44b21d..d9c32f201c 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -42,6 +42,11 @@ AmazonReviewPolarity .. autofunction:: AmazonReviewPolarity +CoLA +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: CoLA + DBpedia ~~~~~~~ From 23b6826b97cf59509fb03633ba3ffd0ab5e9849d Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 14:53:50 -0400 Subject: [PATCH 7/7] Add shuffle and sharding --- torchtext/datasets/cola.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py index 0cb4df038b..d52cb0be66 100644 --- a/torchtext/datasets/cola.py +++ b/torchtext/datasets/cola.py @@ -83,4 +83,4 @@ def _filter_res(x): parsed_data = ( data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) ) - return parsed_data + return parsed_data.shuffle().set_shuffle(False).sharding_filter()