diff --git a/test/datasets/common.py b/test/datasets/common.py index 81edc565bf..77df60fa77 100644 --- a/test/datasets/common.py +++ b/test/datasets/common.py @@ -1,3 +1,5 @@ +import pickle + from parameterized import parameterized from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes @@ -7,6 +9,19 @@ from ..common.torchtext_test_case import TorchtextTestCase +class TestDatasetPickling(TorchtextTestCase): + @parameterized.expand([(f,) for f in DATASETS.values()]) + def test_pickling(self, dataset_fn): + dp = dataset_fn() + if type(dp) == tuple: + dp = list(dp) + else: + dp = [dp] + + for dp_split in dp: + pickle.loads(pickle.dumps(dp_split)) + + class TestShuffleShardDatasetWrapper(TorchtextTestCase): # Note that for order i.e shuffle before sharding, TorchData will provide linter warning # Modify this test when linter warning is available diff --git a/torchtext/datasets/cola.py b/torchtext/datasets/cola.py index d52cb0be66..97b4bd6f77 100644 --- a/torchtext/datasets/cola.py +++ b/torchtext/datasets/cola.py @@ -1,5 +1,6 @@ import csv import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -26,6 +27,26 @@ DATASET_NAME = "CoLA" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return (t[0], int(t[1]), t[3]) + + +def _filter_res(x): + return len(x) == 4 + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def CoLA(root: str, split: Union[Tuple[str], str]): @@ -51,31 +72,18 @@ def CoLA(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return (t[0], int(t[1]), t[3]) - - def _filter_res(x): - return len(x) == 4 - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index ec64e6a437..c684261272 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -53,6 +53,13 @@ def _modify_res(t): return Path(t[0]).parts[-1], t[1] +def filter_imdb_data(key, fname): + labels = {"neg", "pos"} + # eg. fname = "aclImdb/train/neg/12416_3.txt" + *_, split, label, file = Path(fname).parts + return key == split and label in labels + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def IMDB(root: str, split: Union[Tuple[str], str]): @@ -92,12 +99,6 @@ def IMDB(root: str, split: Union[Tuple[str], str]): ) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - - def filter_imdb_data(key, fname): - # eg. fname = "aclImdb/train/neg/12416_3.txt" - *_, split, label, file = Path(fname).parts - return key == split and label in labels - cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, filter_imdb_data, split)) # eg. "aclImdb/train/neg/12416_3.txt" -> "neg" diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py index 43bcfe7d9f..9c2a3f71ad 100644 --- a/torchtext/datasets/mnli.py +++ b/torchtext/datasets/mnli.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import csv import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -39,6 +40,26 @@ LABEL_TO_INT = {"entailment": 0, "neutral": 1, "contradiction": 2} +def _filepath_fn(root, x=None): + return os.path.join(root, os.path.basename(x)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _filter_res(x): + return x[0] in LABEL_TO_INT + + +def _modify_res(x): + return (LABEL_TO_INT[x[0]], x[5], x[6]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev_matched", "dev_mismatched")) def MNLI(root, split): @@ -64,31 +85,18 @@ def MNLI(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(x=None): - return os.path.join(root, os.path.basename(x)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _filter_res(x): - return x[0] in LABEL_TO_INT - - def _modify_res(x): - return (LABEL_TO_INT[x[0]], x[5], x[6]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(URL): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root, URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/mrpc.py b/torchtext/datasets/mrpc.py index c5077a195f..85ac94d1ba 100644 --- a/torchtext/datasets/mrpc.py +++ b/torchtext/datasets/mrpc.py @@ -1,5 +1,6 @@ import csv import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -31,6 +32,14 @@ DATASET_NAME = "MRPC" +def _filepath_fn(root, x): + return os.path.join(root, os.path.basename(x)) + + +def _modify_res(x): + return (int(x[0]), x[3], x[4]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def MRPC(root: str, split: Union[Tuple[str], str]): @@ -54,17 +63,11 @@ def MRPC(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(x): - return os.path.join(root, os.path.basename(x)) - - def _modify_res(x): - return (int(x[0]), x[3], x[4]) - url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(URL[split]): MD5[split]}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root, URL[split]): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py index 387cbffaa5..86abea4343 100644 --- a/torchtext/datasets/qqp.py +++ b/torchtext/datasets/qqp.py @@ -1,4 +1,5 @@ import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import _create_dataset_directory @@ -18,6 +19,14 @@ DATASET_NAME = "QQP" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _modify_res(x): + return (int(x[-1]), x[3], x[4]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) def QQP(root: str): """QQP dataset @@ -34,16 +43,10 @@ def QQP(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _modify_res(x): - return (int(x[-1]), x[3], x[4]) - url_dp = IterableWrapper([URL]) cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index c239de4465..7c46fa3ae9 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -1,5 +1,6 @@ import csv import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -36,6 +37,22 @@ } +def _filepath_fn(root, x=_PATH): + return os.path.join(root, os.path.basename(x)) + + +def _extracted_filepath_fn(root, split, _=None): + return _filepath_fn(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(x): + return (int(x[3]), float(x[4]), x[5], x[6]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def STSB(root, split): @@ -61,28 +78,18 @@ def STSB(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(x=_PATH): - return os.path.join(root, os.path.basename(x)) - - def _extracted_filepath_fn(_=None): - return _filepath_fn(_EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(x): - return (int(x[3]), float(x[4]), x[5], x[6]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(URL): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root, URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")