Skip to content

Commit d28a139

Browse files
ejguanfacebook-github-bot
authored andcommitted
Fix ArchiveReader to keep archive path (#73)
Summary: Pull Request resolved: #73 Previous implementation of `ArchiveReader` has a bug. Please take a reference from pytorch/pytorch#65424 (comment) Reviewed By: NivekT Differential Revision: D31797765 fbshipit-source-id: 494e1a49b43d5a846de971a67586089e6d7ebafc
1 parent 2894636 commit d28a139

File tree

5 files changed

+38
-73
lines changed

5 files changed

+38
-73
lines changed

examples/text/amazonreviewpolarity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
_PATH = "amazon_review_polarity_csv.tar.gz"
2525

2626
_EXTRACTED_FILES = {
27-
"train": f"{os.sep}".join(["amazon_review_polarity_csv", "train.csv"]),
28-
"test": f"{os.sep}".join(["amazon_review_polarity_csv", "test.csv"]),
27+
"train": f"{os.sep}".join([_PATH, "amazon_review_polarity_csv", "train.csv"]),
28+
"test": f"{os.sep}".join([_PATH, "amazon_review_polarity_csv", "test.csv"]),
2929
}
3030

3131
_EXTRACTED_FILES_MD5 = {

examples/text/sst2.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

test/test_examples.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
import sys
44
import unittest
55

6+
from torch.testing._internal.common_utils import slowTest
67

78
current = os.path.dirname(os.path.realpath(__file__))
89
ROOT = os.path.dirname(current)
910
sys.path.append(ROOT)
1011

12+
from examples.text.ag_news import AG_NEWS
13+
from examples.text.amazonreviewpolarity import AmazonReviewPolarity
14+
from examples.text.imdb import IMDB
15+
from examples.text.squad1 import SQuAD1
16+
from examples.text.squad2 import SQuAD2
1117
from examples.vision.caltech101 import Caltech101
1218
from examples.vision.caltech256 import Caltech256
1319

@@ -42,5 +48,33 @@ def test_Caltech256(self) -> None:
4248
self.assertEqual(6, len(samples))
4349

4450

51+
# TODO: Replace the following tests with the corresponding tests in TorchText
52+
class TestTextExamples(unittest.TestCase):
53+
def _test_helper(self, fn):
54+
dp = fn()
55+
for stage_dp in dp:
56+
_ = list(stage_dp)
57+
58+
@slowTest
59+
def test_AG_NEWS(self) -> None:
60+
self._test_helper(AG_NEWS)
61+
62+
@slowTest
63+
def test_AmazonReviewPolarity(self) -> None:
64+
self._test_helper(AmazonReviewPolarity)
65+
66+
@slowTest
67+
def test_IMDB(self) -> None:
68+
self._test_helper(IMDB)
69+
70+
@slowTest
71+
def test_SQuAD1(self) -> None:
72+
self._test_helper(SQuAD1)
73+
74+
@slowTest
75+
def test_SQuAD2(self) -> None:
76+
self._test_helper(SQuAD2)
77+
78+
4579
if __name__ == "__main__":
4680
unittest.main()

torchdata/datapipes/iter/util/tararchivereader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
4040
for data in self.datapipe:
4141
validate_pathname_binary_tuple(data)
4242
pathname, data_stream = data
43-
folder_name = os.path.dirname(pathname)
4443
try:
4544
# typing.cast is used here to silence mypy's type checker
4645
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
@@ -51,7 +50,7 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5150
if extracted_fobj is None:
5251
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
5352
raise tarfile.ExtractError
54-
inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
53+
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
5554
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
5655
except Exception as e:
5756
warnings.warn(

torchdata/datapipes/iter/util/ziparchivereader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
3939
for data in self.datapipe:
4040
validate_pathname_binary_tuple(data)
4141
pathname, data_stream = data
42-
folder_name = os.path.dirname(pathname)
4342
try:
4443
# typing.cast is used here to silence mypy's type checker
4544
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
@@ -51,7 +50,7 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5150
elif zipinfo.filename.endswith("/"):
5251
continue
5352
extracted_fobj = zips.open(zipinfo)
54-
inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
53+
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
5554
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
5655
except Exception as e:
5756
warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")

0 commit comments

Comments
 (0)