diff --git a/docs/source/utils.rst b/docs/source/utils.rst index e45c977bca..93ef7fe163 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -17,11 +17,6 @@ torchtext.utils .. autofunction:: download_from_url -:hidden:`unicode_csv_reader` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: unicode_csv_reader - :hidden:`extract_archive` ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/data/test_utils.py b/test/data/test_utils.py index 349e8ca25a..7d4344ab50 100644 --- a/test/data/test_utils.py +++ b/test/data/test_utils.py @@ -1,9 +1,5 @@ -import io - from torchtext.data import get_tokenizer -from torchtext.utils import unicode_csv_reader -from ..common.assets import get_asset_path from ..common.torchtext_test_case import TorchtextTestCase @@ -37,25 +33,3 @@ def test_get_tokenizer_toktokt(self): get_tokenizer(1) with self.assertRaises(ValueError): get_tokenizer("some other string") - - def test_text_nomalize_function(self): - # Test text_nomalize function in torchtext.datasets.text_classification - ref_lines = [] - test_lines = [] - - tokenizer = get_tokenizer("basic_english") - data_path = get_asset_path("text_normalization_ag_news_test.csv") - with io.open(data_path, encoding="utf8") as f: - reader = unicode_csv_reader(f) - for row in reader: - test_lines.append(tokenizer(" , ".join(row))) - - data_path = get_asset_path("text_normalization_ag_news_ref_results.test") - with io.open(data_path, encoding="utf8") as ref_data: - for line in ref_data: - line = line.split() - self.assertEqual(line[0][:9], "__label__") - line[0] = line[0][9:] # remove '__label__' - ref_lines.append(line) - - self.assertEqual(ref_lines, test_lines) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 425eafcea6..725218d357 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -1,13 +1,10 @@ import codecs import functools import inspect -import io import os -import torch from torch.utils.data import functional_datapipe, IterDataPipe from torch.utils.data.datapipes.utils.common import StreamWrapper -from torchtext.utils import download_from_url, extract_archive, validate_file try: import defusedxml.ElementTree as ET @@ -100,12 +97,6 @@ def _clean_files(outfile, fname, stream): return _rewrite_text_file(outfile, stream) -def _read_text_iterator(path): - with io.open(path, encoding="utf8") as f: - for row in f: - yield row - - def _check_default_set(split, target_select, dataset_name): # Check whether given object split is either a tuple of strings or string # and represents a valid selection of options given by the tuple of strings @@ -135,65 +126,6 @@ def _wrap_datasets(datasets, split): return datasets -def _dataset_docstring_header(fn, num_lines=None, num_classes=None): - """ - Returns docstring for a dataset based on function arguments. - - Assumes function signature of form (root='.data', split=, **kwargs) - """ - argspec = inspect.getfullargspec(fn) - if not (argspec.args[0] == "root" and argspec.args[1] == "split"): - raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) - default_split = argspec.defaults[1] - - if not (isinstance(default_split, tuple) or isinstance(default_split, str)): - raise ValueError("default_split type expected to be of string or tuple but got {}".format(type(default_split))) - - header_s = fn.__name__ + " dataset\n" - - if isinstance(default_split, tuple): - header_s += "\nSeparately returns the {} split".format("/".join(default_split)) - - if isinstance(default_split, str): - header_s += "\nOnly returns the {} split".format(default_split) - - if num_lines is not None: - header_s += "\n\nNumber of lines per split:" - for k, v in num_lines.items(): - header_s += "\n {}: {}\n".format(k, v) - - if num_classes is not None: - header_s += "\n\nNumber of classes" - header_s += "\n {}\n".format(num_classes) - - args_s = "\nArgs:" - args_s += "\n root: Directory where the datasets are saved." - args_s += "\n Default: .data" - - if isinstance(default_split, tuple): - args_s += "\n split: split or splits to be returned. Can be a string or tuple of strings." - args_s += "\n Default: {}" "".format(str(default_split)) - - if isinstance(default_split, str): - args_s += "\n split: Only {default_split} is available." - args_s += "\n Default: {default_split}.format(default_split=default_split)" - - return "\n".join([header_s, args_s]) + "\n" - - -def _add_docstring_header(docstring=None, num_lines=None, num_classes=None): - def docstring_decorator(fn): - old_doc = fn.__doc__ - fn.__doc__ = _dataset_docstring_header(fn, num_lines, num_classes) - if docstring is not None: - fn.__doc__ += docstring - if old_doc is not None: - fn.__doc__ += old_doc - return fn - - return docstring_decorator - - def _wrap_split_argument_with_fn(fn, splits): """ Wraps given function of specific signature to extend behavior of split @@ -265,67 +197,6 @@ def wrapper(root=_CACHE_DIR, *args, **kwargs): return decorator -def _download_extract_validate( - root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5, hash_type="sha256" -): - root = os.path.abspath(root) - downloaded_file = os.path.abspath(downloaded_file) - extracted_file = os.path.abspath(extracted_file) - if os.path.exists(extracted_file): - with open(os.path.join(root, extracted_file), "rb") as f: - if validate_file(f, extracted_file_md5, hash_type): - return extracted_file - - dataset_tar = download_from_url( - url, path=os.path.join(root, downloaded_file), hash_value=url_md5, hash_type=hash_type - ) - extracted_files = extract_archive(dataset_tar) - assert os.path.exists(extracted_file), "extracted_file [{}] was not found in the archive [{}]".format( - extracted_file, extracted_files - ) - - return extracted_file - - -class _RawTextIterableDataset(torch.utils.data.IterableDataset): - """Defines an abstraction for raw text iterable datasets.""" - - def __init__(self, description, full_num_lines, iterator): - """Initiate the dataset abstraction.""" - super(_RawTextIterableDataset, self).__init__() - self.description = description - self.full_num_lines = full_num_lines - self._iterator = iterator - self.num_lines = full_num_lines - self.current_pos = None - - def __iter__(self): - return self - - def __next__(self): - if self.current_pos == self.num_lines - 1: - raise StopIteration - item = next(self._iterator) - if self.current_pos is None: - self.current_pos = 0 - else: - self.current_pos += 1 - return item - - def __len__(self): - return self.num_lines - - def pos(self): - """ - Returns current position of the iterator. This returns None - if the iterator hasn't been used yet. - """ - return self.current_pos - - def __str__(self): - return self.description - - def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, valid_set, test_set): train_filenames = ( "train.{}-{}.{}".format(src_language, tgt_language, src_language), diff --git a/torchtext/utils.py b/torchtext/utils.py index b6e50b9632..e448faae01 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -1,9 +1,7 @@ -import csv import gzip import hashlib import logging import os -import sys import tarfile import zipfile @@ -38,14 +36,12 @@ def inner(b=1, bsize=1, tsize=None): def validate_file(file_obj, hash_value, hash_type="sha256"): """Validate a given file object with its hash. - Args: file_obj: File object to read from. hash_value (str): Hash for url. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). Returns: bool: return True if its a valid file, else False. - """ if hash_type == "sha256": @@ -76,7 +72,6 @@ def _check_hash(path, hash_value, hash_type): def download_from_url(url, path=None, root=".data", overwrite=False, hash_value=None, hash_type="sha256"): """Download file, with logic (from tensor2tensor) for Google Drive. Returns the path to the downloaded file. - Args: url: the url of the file from URL header. (None) path: path where file will be saved @@ -84,14 +79,12 @@ def download_from_url(url, path=None, root=".data", overwrite=False, hash_value= overwrite: overwrite existing files (False) hash_value (str, optional): hash for url (Default: ``None``). hash_type (str, optional): hash type, among "sha256" and "md5" (Default: ``"sha256"``). - Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> torchtext.utils.download_from_url(url) >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> torchtext.utils.download_from_url(url) >>> '.data/validation.tar.gz' - """ # figure out filename and root if path is None: @@ -130,54 +123,14 @@ def download_from_url(url, path=None, root=".data", overwrite=False, hash_value= return path -def unicode_csv_reader(unicode_csv_data, **kwargs): - r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. - Borrowed and slightly modified from the Python docs: - https://docs.python.org/2/library/csv.html#csv-examples - - Args: - unicode_csv_data: unicode csv data (see example below) - - Examples: - >>> from torchtext.utils import unicode_csv_reader - >>> import io - >>> with io.open(data_path, encoding="utf8") as f: - >>> reader = unicode_csv_reader(f) - - """ - - # Fix field larger than field limit error - maxInt = sys.maxsize - while True: - # decrease the maxInt value by factor 10 - # as long as the OverflowError occurs. - try: - csv.field_size_limit(maxInt) - break - except OverflowError: - maxInt = int(maxInt / 10) - csv.field_size_limit(maxInt) - - for line in csv.reader(unicode_csv_data, **kwargs): - yield line - - -def utf_8_encoder(unicode_csv_data): - for line in unicode_csv_data: - yield line.encode("utf-8") - - def extract_archive(from_path, to_path=None, overwrite=False): """Extract archive. - Args: from_path: the path of the archive. to_path: the root path of the extracted files (directory of from_path) overwrite: overwrite existing files (False) - Returns: List of paths to extracted files even if not overwritten. - Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> from_path = './validation.tar.gz' @@ -188,7 +141,6 @@ def extract_archive(from_path, to_path=None, overwrite=False): >>> torchtext.utils.download_from_url(url, from_path) >>> torchtext.utils.extract_archive(from_path, to_path) >>> ['.data/val.de', '.data/val.en'] - """ if to_path is None: @@ -256,12 +208,10 @@ def _log_class_usage(klass): def get_asset_local_path(asset_path: str) -> str: """Get local path for assets. Download if path does not exost locally - Args: asset_path: Local path to asset or remote URL Returns: bool: local path of the asset after downloading or reading from cache - Examples: >>> url = 'http:///file.txt' >>> torchtext.utils.get_asset_local_path(url)