|
1 |
| -from torchtext.utils import ( |
2 |
| - download_from_url, |
3 |
| - extract_archive, |
4 |
| -) |
| 1 | +from torchtext._internal.module_utils import is_module_available |
| 2 | +from typing import Union, Tuple |
| 3 | + |
| 4 | +if is_module_available("torchdata"): |
| 5 | + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper |
| 6 | + |
5 | 7 | from torchtext.data.datasets_utils import (
|
6 |
| - _RawTextIterableDataset, |
7 | 8 | _wrap_split_argument,
|
8 | 9 | _add_docstring_header,
|
9 |
| - _find_match, |
10 | 10 | _create_dataset_directory,
|
11 |
| - _create_data_from_csv, |
12 | 11 | )
|
| 12 | + |
13 | 13 | import os
|
14 | 14 |
|
15 | 15 | URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0'
|
|
25 | 25 |
|
26 | 26 | DATASET_NAME = "YelpReviewFull"
|
27 | 27 |
|
| 28 | +_EXTRACTED_FILES = { |
| 29 | + 'train': os.path.join('yelp_review_full_csv', 'train.csv'), |
| 30 | + 'test': os.path.join('yelp_review_full_csv', 'test.csv'), |
| 31 | +} |
| 32 | + |
28 | 33 |
|
29 | 34 | @_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
|
30 | 35 | @_create_dataset_directory(dataset_name=DATASET_NAME)
|
31 | 36 | @_wrap_split_argument(('train', 'test'))
|
32 |
| -def YelpReviewFull(root, split): |
33 |
| - dataset_tar = download_from_url(URL, root=root, |
34 |
| - path=os.path.join(root, _PATH), |
35 |
| - hash_value=MD5, hash_type='md5') |
36 |
| - extracted_files = extract_archive(dataset_tar) |
37 |
| - |
38 |
| - path = _find_match(split + '.csv', extracted_files) |
39 |
| - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], |
40 |
| - _create_data_from_csv(path)) |
| 37 | +def YelpReviewFull(root: str, split: Union[Tuple[str], str]): |
| 38 | + if not is_module_available("torchdata"): |
| 39 | + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") |
| 40 | + |
| 41 | + url_dp = IterableWrapper([URL]) |
| 42 | + |
| 43 | + cache_dp = url_dp.on_disk_cache( |
| 44 | + filepath_fn=lambda x: os.path.join(root, _PATH), |
| 45 | + hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" |
| 46 | + ) |
| 47 | + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) |
| 48 | + cache_dp = FileOpener(cache_dp, mode="b") |
| 49 | + |
| 50 | + extracted_files = cache_dp.read_from_tar() |
| 51 | + |
| 52 | + filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) |
| 53 | + |
| 54 | + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) |
0 commit comments