Skip to content

Adding parameterized dataset pickling tests #1732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/datasets/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from parameterized import parameterized
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
Expand All @@ -7,6 +9,19 @@
from ..common.torchtext_test_case import TorchtextTestCase


class TestDatasetPickling(TorchtextTestCase):
@parameterized.expand([(f,) for f in DATASETS.values()])
def test_pickling(self, dataset_fn):
dp = dataset_fn()
if type(dp) == tuple:
dp = list(dp)
else:
dp = [dp]

for dp_split in dp:
pickle.loads(pickle.dumps(dp_split))


class TestShuffleShardDatasetWrapper(TorchtextTestCase):
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning
# Modify this test when linter warning is available
Expand Down
46 changes: 27 additions & 19 deletions torchtext/datasets/cola.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand All @@ -26,6 +27,26 @@
DATASET_NAME = "CoLA"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return (t[0], int(t[1]), t[3])


def _filter_res(x):
return len(x) == 4


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def CoLA(root: str, split: Union[Tuple[str], str]):
Expand All @@ -51,31 +72,18 @@ def CoLA(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return (t[0], int(t[1]), t[3])

def _filter_res(x):
return len(x) == 4

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
13 changes: 7 additions & 6 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def _modify_res(t):
return Path(t[0]).parts[-1], t[1]


def filter_imdb_data(key, fname):
labels = {"neg", "pos"}
# eg. fname = "aclImdb/train/neg/12416_3.txt"
*_, split, label, file = Path(fname).parts
return key == split and label in labels


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def IMDB(root: str, split: Union[Tuple[str], str]):
Expand Down Expand Up @@ -92,12 +99,6 @@ def IMDB(root: str, split: Union[Tuple[str], str]):
)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b")
cache_decompressed_dp = cache_decompressed_dp.load_from_tar()

def filter_imdb_data(key, fname):
# eg. fname = "aclImdb/train/neg/12416_3.txt"
*_, split, label, file = Path(fname).parts
return key == split and label in labels

cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, filter_imdb_data, split))

# eg. "aclImdb/train/neg/12416_3.txt" -> "neg"
Expand Down
46 changes: 27 additions & 19 deletions torchtext/datasets/mnli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import csv
import os
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
Expand Down Expand Up @@ -39,6 +40,26 @@
LABEL_TO_INT = {"entailment": 0, "neutral": 1, "contradiction": 2}


def _filepath_fn(root, x=None):
return os.path.join(root, os.path.basename(x))


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _filter_res(x):
return x[0] in LABEL_TO_INT


def _modify_res(x):
return (LABEL_TO_INT[x[0]], x[5], x[6])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev_matched", "dev_mismatched"))
def MNLI(root, split):
Expand All @@ -64,31 +85,18 @@ def MNLI(root, split):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(x=None):
return os.path.join(root, os.path.basename(x))

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _filter_res(x):
return x[0] in LABEL_TO_INT

def _modify_res(x):
return (LABEL_TO_INT[x[0]], x[5], x[6])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(URL): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root, URL): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
19 changes: 11 additions & 8 deletions torchtext/datasets/mrpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -31,6 +32,14 @@
DATASET_NAME = "MRPC"


def _filepath_fn(root, x):
return os.path.join(root, os.path.basename(x))


def _modify_res(x):
return (int(x[0]), x[3], x[4])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def MRPC(root: str, split: Union[Tuple[str], str]):
Expand All @@ -54,17 +63,11 @@ def MRPC(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(x):
return os.path.join(root, os.path.basename(x))

def _modify_res(x):
return (int(x[0]), x[3], x[4])

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(URL[split]): MD5[split]},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root, URL[split]): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand Down
19 changes: 11 additions & 8 deletions torchtext/datasets/qqp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory
Expand All @@ -18,6 +19,14 @@
DATASET_NAME = "QQP"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _modify_res(x):
return (int(x[-1]), x[3], x[4])


@_create_dataset_directory(dataset_name=DATASET_NAME)
def QQP(root: str):
"""QQP dataset
Expand All @@ -34,16 +43,10 @@ def QQP(root: str):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _modify_res(x):
return (int(x[-1]), x[3], x[4])

url_dp = IterableWrapper([URL])
cache_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/stsb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import os
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
Expand Down Expand Up @@ -36,6 +37,22 @@
}


def _filepath_fn(root, x=_PATH):
return os.path.join(root, os.path.basename(x))


def _extracted_filepath_fn(root, split, _=None):
return _filepath_fn(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(x):
return (int(x[3]), float(x[4]), x[5], x[6])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def STSB(root, split):
Expand All @@ -61,28 +78,18 @@ def STSB(root, split):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(x=_PATH):
return os.path.join(root, os.path.basename(x))

def _extracted_filepath_fn(_=None):
return _filepath_fn(_EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(x):
return (int(x[3]), float(x[4]), x[5], x[6])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(URL): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root, URL): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down