Skip to content

clean-up stale code #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ torchtext.utils

.. autofunction:: download_from_url

:hidden:`unicode_csv_reader`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: unicode_csv_reader

:hidden:`extract_archive`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
26 changes: 0 additions & 26 deletions test/data/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import io

from torchtext.data import get_tokenizer
from torchtext.utils import unicode_csv_reader

from ..common.assets import get_asset_path
from ..common.torchtext_test_case import TorchtextTestCase


Expand Down Expand Up @@ -37,25 +33,3 @@ def test_get_tokenizer_toktokt(self):
get_tokenizer(1)
with self.assertRaises(ValueError):
get_tokenizer("some other string")

def test_text_nomalize_function(self):
# Test text_nomalize function in torchtext.datasets.text_classification
ref_lines = []
test_lines = []

tokenizer = get_tokenizer("basic_english")
data_path = get_asset_path("text_normalization_ag_news_test.csv")
with io.open(data_path, encoding="utf8") as f:
reader = unicode_csv_reader(f)
for row in reader:
test_lines.append(tokenizer(" , ".join(row)))

data_path = get_asset_path("text_normalization_ag_news_ref_results.test")
with io.open(data_path, encoding="utf8") as ref_data:
for line in ref_data:
line = line.split()
self.assertEqual(line[0][:9], "__label__")
line[0] = line[0][9:] # remove '__label__'
ref_lines.append(line)

self.assertEqual(ref_lines, test_lines)
129 changes: 0 additions & 129 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import codecs
import functools
import inspect
import io
import os

import torch
from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import StreamWrapper
from torchtext.utils import download_from_url, extract_archive, validate_file

try:
import defusedxml.ElementTree as ET
Expand Down Expand Up @@ -100,12 +97,6 @@ def _clean_files(outfile, fname, stream):
return _rewrite_text_file(outfile, stream)


def _read_text_iterator(path):
with io.open(path, encoding="utf8") as f:
for row in f:
yield row


def _check_default_set(split, target_select, dataset_name):
# Check whether given object split is either a tuple of strings or string
# and represents a valid selection of options given by the tuple of strings
Expand Down Expand Up @@ -135,65 +126,6 @@ def _wrap_datasets(datasets, split):
return datasets


def _dataset_docstring_header(fn, num_lines=None, num_classes=None):
"""
Returns docstring for a dataset based on function arguments.

Assumes function signature of form (root='.data', split=<some tuple of strings>, **kwargs)
"""
argspec = inspect.getfullargspec(fn)
if not (argspec.args[0] == "root" and argspec.args[1] == "split"):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
default_split = argspec.defaults[1]

if not (isinstance(default_split, tuple) or isinstance(default_split, str)):
raise ValueError("default_split type expected to be of string or tuple but got {}".format(type(default_split)))

header_s = fn.__name__ + " dataset\n"

if isinstance(default_split, tuple):
header_s += "\nSeparately returns the {} split".format("/".join(default_split))

if isinstance(default_split, str):
header_s += "\nOnly returns the {} split".format(default_split)

if num_lines is not None:
header_s += "\n\nNumber of lines per split:"
for k, v in num_lines.items():
header_s += "\n {}: {}\n".format(k, v)

if num_classes is not None:
header_s += "\n\nNumber of classes"
header_s += "\n {}\n".format(num_classes)

args_s = "\nArgs:"
args_s += "\n root: Directory where the datasets are saved."
args_s += "\n Default: .data"

if isinstance(default_split, tuple):
args_s += "\n split: split or splits to be returned. Can be a string or tuple of strings."
args_s += "\n Default: {}" "".format(str(default_split))

if isinstance(default_split, str):
args_s += "\n split: Only {default_split} is available."
args_s += "\n Default: {default_split}.format(default_split=default_split)"

return "\n".join([header_s, args_s]) + "\n"


def _add_docstring_header(docstring=None, num_lines=None, num_classes=None):
def docstring_decorator(fn):
old_doc = fn.__doc__
fn.__doc__ = _dataset_docstring_header(fn, num_lines, num_classes)
if docstring is not None:
fn.__doc__ += docstring
if old_doc is not None:
fn.__doc__ += old_doc
return fn

return docstring_decorator


def _wrap_split_argument_with_fn(fn, splits):
"""
Wraps given function of specific signature to extend behavior of split
Expand Down Expand Up @@ -265,67 +197,6 @@ def wrapper(root=_CACHE_DIR, *args, **kwargs):
return decorator


def _download_extract_validate(
root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5, hash_type="sha256"
):
root = os.path.abspath(root)
downloaded_file = os.path.abspath(downloaded_file)
extracted_file = os.path.abspath(extracted_file)
if os.path.exists(extracted_file):
with open(os.path.join(root, extracted_file), "rb") as f:
if validate_file(f, extracted_file_md5, hash_type):
return extracted_file

dataset_tar = download_from_url(
url, path=os.path.join(root, downloaded_file), hash_value=url_md5, hash_type=hash_type
)
extracted_files = extract_archive(dataset_tar)
assert os.path.exists(extracted_file), "extracted_file [{}] was not found in the archive [{}]".format(
extracted_file, extracted_files
)

return extracted_file


class _RawTextIterableDataset(torch.utils.data.IterableDataset):
"""Defines an abstraction for raw text iterable datasets."""

def __init__(self, description, full_num_lines, iterator):
"""Initiate the dataset abstraction."""
super(_RawTextIterableDataset, self).__init__()
self.description = description
self.full_num_lines = full_num_lines
self._iterator = iterator
self.num_lines = full_num_lines
self.current_pos = None

def __iter__(self):
return self

def __next__(self):
if self.current_pos == self.num_lines - 1:
raise StopIteration
item = next(self._iterator)
if self.current_pos is None:
self.current_pos = 0
else:
self.current_pos += 1
return item

def __len__(self):
return self.num_lines

def pos(self):
"""
Returns current position of the iterator. This returns None
if the iterator hasn't been used yet.
"""
return self.current_pos

def __str__(self):
return self.description


def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, valid_set, test_set):
train_filenames = (
"train.{}-{}.{}".format(src_language, tgt_language, src_language),
Expand Down
50 changes: 0 additions & 50 deletions torchtext/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import csv
import gzip
import hashlib
import logging
import os
import sys
import tarfile
import zipfile

Expand Down Expand Up @@ -38,14 +36,12 @@ def inner(b=1, bsize=1, tsize=None):

def validate_file(file_obj, hash_value, hash_type="sha256"):
"""Validate a given file object with its hash.

Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.

"""

if hash_type == "sha256":
Expand Down Expand Up @@ -76,22 +72,19 @@ def _check_hash(path, hash_value, hash_type):
def download_from_url(url, path=None, root=".data", overwrite=False, hash_value=None, hash_type="sha256"):
"""Download file, with logic (from tensor2tensor) for Google Drive. Returns
the path to the downloaded file.

Args:
url: the url of the file from URL header. (None)
path: path where file will be saved
root: download folder used to store the file in (.data)
overwrite: overwrite existing files (False)
hash_value (str, optional): hash for url (Default: ``None``).
hash_type (str, optional): hash type, among "sha256" and "md5" (Default: ``"sha256"``).

Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> torchtext.utils.download_from_url(url)
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> torchtext.utils.download_from_url(url)
>>> '.data/validation.tar.gz'

"""
# figure out filename and root
if path is None:
Expand Down Expand Up @@ -130,54 +123,14 @@ def download_from_url(url, path=None, root=".data", overwrite=False, hash_value=
return path


def unicode_csv_reader(unicode_csv_data, **kwargs):
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples

Args:
unicode_csv_data: unicode csv data (see example below)

Examples:
>>> from torchtext.utils import unicode_csv_reader
>>> import io
>>> with io.open(data_path, encoding="utf8") as f:
>>> reader = unicode_csv_reader(f)

"""

# Fix field larger than field limit error
maxInt = sys.maxsize
while True:
# decrease the maxInt value by factor 10
# as long as the OverflowError occurs.
try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt / 10)
csv.field_size_limit(maxInt)

for line in csv.reader(unicode_csv_data, **kwargs):
yield line


def utf_8_encoder(unicode_csv_data):
for line in unicode_csv_data:
yield line.encode("utf-8")


def extract_archive(from_path, to_path=None, overwrite=False):
"""Extract archive.

Args:
from_path: the path of the archive.
to_path: the root path of the extracted files (directory of from_path)
overwrite: overwrite existing files (False)

Returns:
List of paths to extracted files even if not overwritten.

Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
Expand All @@ -188,7 +141,6 @@ def extract_archive(from_path, to_path=None, overwrite=False):
>>> torchtext.utils.download_from_url(url, from_path)
>>> torchtext.utils.extract_archive(from_path, to_path)
>>> ['.data/val.de', '.data/val.en']

"""

if to_path is None:
Expand Down Expand Up @@ -256,12 +208,10 @@ def _log_class_usage(klass):

def get_asset_local_path(asset_path: str) -> str:
"""Get local path for assets. Download if path does not exost locally

Args:
asset_path: Local path to asset or remote URL
Returns:
bool: local path of the asset after downloading or reading from cache

Examples:
>>> url = 'http://<HOST>/file.txt'
>>> torchtext.utils.get_asset_local_path(url)
Expand Down