From 8f681140080674e3fb21b3a87aff990adc4aaefc Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Tue, 10 May 2022 18:22:57 -0400 Subject: [PATCH] Updating dataset code to avoid creating multiple iterators from a DataPipe --- torchtext/datasets/iwslt2016.py | 14 ++++++-------- torchtext/datasets/iwslt2017.py | 9 +++++---- torchtext/datasets/multi30k.py | 6 ++++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index a79c189613..9d3443a5bb 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -125,7 +125,7 @@ # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) - cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").load_from_tar() + cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( lambda x: os.path.basename(uncleaned_filename) in x[0] ) @@ -261,12 +261,10 @@ def IWSLT2016( ) cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b") - .load_from_tar() - .filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) - ) + cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() + cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) src_filename = file_path_by_lang_and_split[src_language][split] uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] @@ -276,7 +274,7 @@ def IWSLT2016( full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) cache_inner_src_decompressed_dp = _filter_clean_cache( - cache_decompressed_dp, full_src_filepath, uncleaned_src_filename + cache_decompressed_dp_1, full_src_filepath, uncleaned_src_filename ) tgt_filename = file_path_by_lang_and_split[tgt_language][split] @@ -287,7 +285,7 @@ def IWSLT2016( full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) cache_inner_tgt_decompressed_dp = _filter_clean_cache( - cache_decompressed_dp, full_tgt_filepath, uncleaned_tgt_filename + cache_decompressed_dp_2, full_tgt_filepath, uncleaned_tgt_filename ) tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index e97ce9fbf5..0fb865d4e0 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -104,7 +104,7 @@ # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) - cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").load_from_tar() + cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( lambda x: os.path.basename(uncleaned_filename) in x[0] ) @@ -208,8 +208,9 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de ) cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar() + cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) src_filename = file_path_by_lang_and_split[src_language][split] uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] @@ -224,7 +225,7 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de ) cache_inner_src_decompressed_dp = _filter_clean_cache( - cache_decompressed_dp, full_src_filepath, uncleaned_src_filename + cache_decompressed_dp_1, full_src_filepath, uncleaned_src_filename ) tgt_filename = file_path_by_lang_and_split[tgt_language][split] @@ -240,7 +241,7 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de ) cache_inner_tgt_decompressed_dp = _filter_clean_cache( - cache_decompressed_dp, full_tgt_filepath, uncleaned_tgt_filename + cache_decompressed_dp_2, full_tgt_filepath, uncleaned_tgt_filename ) tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 8382fdc57c..26390379ba 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -79,7 +79,9 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - src_cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + cache_compressed_dp_1, cache_compressed_dp_2 = cache_compressed_dp.fork(num_instances=2) + + src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache( filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}") ) src_cache_decompressed_dp = ( @@ -89,7 +91,7 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] ) src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache( filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}") ) tgt_cache_decompressed_dp = (