Skip to content

Update download #922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 57 additions & 58 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,61 +203,60 @@ def test_fast_text(self):
self.assertEqual(vectors_obj[word][:3], expected_fasttext_simple_en[word])
self.assertEqual(jit_vectors_obj[word][:3], expected_fasttext_simple_en[word])

# TODO: reenable test once the GloVe dataset url starts working
# def test_glove(self):
# # copy the asset file into the expected download location
# # note that this is just a zip file with the first 100 entries of the GloVe 840B dataset
# asset_name = 'glove.840B.300d.zip'
# asset_path = get_asset_path(asset_name)

# with tempfile.TemporaryDirectory() as dir_name:
# data_path = os.path.join(dir_name, asset_name)
# shutil.copy(asset_path, data_path)
# vectors_obj = GloVe(root=dir_name, validate_file=False)
# jit_vectors_obj = torch.jit.script(vectors_obj)

# # The first 3 entries in each vector.
# expected_glove = {
# 'the': [0.27204, -0.06203, -0.1884],
# 'people': [-0.19686, 0.11579, -0.41091],
# }

# for word in expected_glove.keys():
# self.assertEqual(vectors_obj[word][:3], expected_glove[word])
# self.assertEqual(jit_vectors_obj[word][:3], expected_glove[word])

# def test_glove_different_dims(self):
# # copy the asset file into the expected download location
# # note that this is just a zip file with 1 line txt files used to test that the
# # correct files are being loaded
# asset_name = 'glove.6B.zip'
# asset_path = get_asset_path(asset_name)

# with tempfile.TemporaryDirectory() as dir_name:
# data_path = os.path.join(dir_name, asset_name)
# shutil.copy(asset_path, data_path)

# glove_50d = GloVe(name='6B', dim=50, root=dir_name, validate_file=False)
# glove_100d = GloVe(name='6B', dim=100, root=dir_name, validate_file=False)
# glove_200d = GloVe(name='6B', dim=200, root=dir_name, validate_file=False)
# glove_300d = GloVe(name='6B', dim=300, root=dir_name, validate_file=False)
# vectors_objects = [glove_50d, glove_100d, glove_200d, glove_300d]

# # The first 3 entries in each vector.
# expected_glove_50d = {
# 'the': [0.418, 0.24968, -0.41242],
# }
# expected_glove_100d = {
# 'the': [-0.038194, -0.24487, 0.72812],
# }
# expected_glove_200d = {
# 'the': [-0.071549, 0.093459, 0.023738],
# }
# expected_glove_300d = {
# 'the': [0.04656, 0.21318, -0.0074364],
# }
# expected_gloves = [expected_glove_50d, expected_glove_100d, expected_glove_200d, expected_glove_300d]

# for vectors_obj, expected_glove in zip(vectors_objects, expected_gloves):
# for word in expected_glove.keys():
# self.assertEqual(vectors_obj[word][:3], expected_glove[word])
def test_glove(self):
# copy the asset file into the expected download location
# note that this is just a zip file with the first 100 entries of the GloVe 840B dataset
asset_name = 'glove.840B.300d.zip'
asset_path = get_asset_path(asset_name)

with tempfile.TemporaryDirectory() as dir_name:
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vectors_obj = GloVe(root=dir_name, validate_file=False)
jit_vectors_obj = torch.jit.script(vectors_obj)

# The first 3 entries in each vector.
expected_glove = {
'the': [0.27204, -0.06203, -0.1884],
'people': [-0.19686, 0.11579, -0.41091],
}

for word in expected_glove.keys():
self.assertEqual(vectors_obj[word][:3], expected_glove[word])
self.assertEqual(jit_vectors_obj[word][:3], expected_glove[word])

def test_glove_different_dims(self):
# copy the asset file into the expected download location
# note that this is just a zip file with 1 line txt files used to test that the
# correct files are being loaded
asset_name = 'glove.6B.zip'
asset_path = get_asset_path(asset_name)

with tempfile.TemporaryDirectory() as dir_name:
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)

glove_50d = GloVe(name='6B', dim=50, root=dir_name, validate_file=False)
glove_100d = GloVe(name='6B', dim=100, root=dir_name, validate_file=False)
glove_200d = GloVe(name='6B', dim=200, root=dir_name, validate_file=False)
glove_300d = GloVe(name='6B', dim=300, root=dir_name, validate_file=False)
vectors_objects = [glove_50d, glove_100d, glove_200d, glove_300d]

# The first 3 entries in each vector.
expected_glove_50d = {
'the': [0.418, 0.24968, -0.41242],
}
expected_glove_100d = {
'the': [-0.038194, -0.24487, 0.72812],
}
expected_glove_200d = {
'the': [-0.071549, 0.093459, 0.023738],
}
expected_glove_300d = {
'the': [0.04656, 0.21318, -0.0074364],
}
expected_gloves = [expected_glove_50d, expected_glove_100d, expected_glove_200d, expected_glove_300d]

for vectors_obj, expected_glove in zip(vectors_objects, expected_gloves):
for word in expected_glove.keys():
self.assertEqual(vectors_obj[word][:3], expected_glove[word])
14 changes: 14 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from torchtext import utils
from .common.torchtext_test_case import TorchtextTestCase
from test.common.assets import get_asset_path
import shutil


def conditional_remove(f):
Expand Down Expand Up @@ -104,6 +106,18 @@ def test_download_extract_zip(self):
os.rmdir(os.path.join(root, 'en-ud-v2'))
conditional_remove(archive_path)

def test_no_download(self):
asset_name = 'glove.840B.300d.zip'
asset_path = get_asset_path(asset_name)
root = '.data'
if not os.path.exists(root):
os.makedirs(root)
data_path = os.path.join('.data', asset_name)
shutil.copy(asset_path, data_path)
file_path = utils.download_from_url('fakedownload/glove.840B.300d.zip')
self.assertEqual(file_path, data_path)
conditional_remove(data_path)

def test_download_extract_to_path(self):
# create root directory for downloading data
root = '.data'
Expand Down
9 changes: 9 additions & 0 deletions torchtext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def _process_response(r, root, filename):
print("Can't create the download directory {}.".format(root))
raise

if filename is not None:
path = os.path.join(root, filename)
# skip requests.get if path exists and not overwrite.
if os.path.exists(path):
logging.info('File %s already exists.' % path)
if not overwrite:
_check_hash(path)
return path

if 'drive.google.com' not in url:
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
return _process_response(response, root, filename)
Expand Down