diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 386a8bfc43..a1a41631ae 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -1,10 +1,15 @@ -from torchtext.utils import download_from_url, extract_archive -from torchtext.data.datasets_utils import _RawTextIterableDataset -from torchtext.data.datasets_utils import _wrap_split_argument +import os +from pathlib import Path +from typing import Union, Tuple + +from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import _add_docstring_header from torchtext.data.datasets_utils import _create_dataset_directory -import io -from pathlib import Path +from torchtext.data.datasets_utils import _wrap_split_argument + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper, HttpReader + URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' @@ -23,16 +28,41 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'test')) -def IMDB(root, split): - def generate_imdb_data(key, extracted_files): - for fname in extracted_files: - *_, split, label, file = Path(fname).parts - - if key == split and (label in ['pos', 'neg']): - with io.open(fname, encoding="utf8") as f: - yield label, f.read() - dataset_tar = download_from_url(URL, root=root, - hash_value=MD5, hash_type='md5') - extracted_files = extract_archive(dataset_tar) - iterator = generate_imdb_data(split, extracted_files) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], iterator) +def IMDB(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL]) + + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + labels = {"neg", "pos"} + decompressed_folder = "aclImdb_v1" + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: [os.path.join(root, decompressed_folder, split, label) for label in labels] + ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") + cache_decompressed_dp = cache_decompressed_dp.read_from_tar() + + def filter_imdb_data(key, fname): + # eg. fname = "aclImdb/train/neg/12416_3.txt" + *_, split, label, file = Path(fname).parts + return key == split and label in labels + + cache_decompressed_dp = cache_decompressed_dp.filter(lambda t: filter_imdb_data(split, t[0])) + + # eg. "aclImdb/train/neg/12416_3.txt" -> "neg" + cache_decompressed_dp = cache_decompressed_dp.map(lambda t: (Path(t[0]).parts[-2], t[1])) + cache_decompressed_dp = cache_decompressed_dp.readlines(decode=True) + cache_decompressed_dp = cache_decompressed_dp.lines_to_paragraphs() # group by label in cache file + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wt", filepath_fn=lambda x: os.path.join(root, decompressed_folder, split, x) + ) + + data_dp = FileOpener(cache_decompressed_dp, mode="t") + # get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg" + return data_dp.readlines().map(lambda t: (Path(t[0]).parts[-1], t[1]))