diff --git a/torchtext/utils.py b/torchtext/utils.py index 21f90c414b..83ea6dad10 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -1,6 +1,7 @@ import six import requests import csv +from tqdm import tqdm def reporthook(t): @@ -25,11 +26,22 @@ def inner(b=1, bsize=1, tsize=None): def download_from_url(url, path): """Download file, with logic (from tensor2tensor) for Google Drive""" - if 'drive.google.com' not in url: - r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) + def process_response(r): + chunk_size = 16 * 1024 + total_size = int(r.headers.get('Content-length', 0)) with open(path, "wb") as file: - file.write(r.content) + with tqdm(total=total_size, unit='B', + unit_scale=1, desc=path.split('/')[-1]) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + file.write(chunk) + t.update(len(chunk)) + + if 'drive.google.com' not in url: + response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) + process_response(response) return + print('downloading from Google Drive; may take a few minutes') confirm_token = None session = requests.Session() @@ -42,11 +54,7 @@ def download_from_url(url, path): url = url + "&confirm=" + confirm_token response = session.get(url, stream=True) - chunk_size = 16 * 1024 - with open(path, "wb") as f: - for chunk in response.iter_content(chunk_size): - if chunk: - f.write(chunk) + process_response(response) def unicode_csv_reader(unicode_csv_data, **kwargs):