Skip to content

Commit

Permalink
Make download_url() follow redirects (#3235) (#3236)
Browse files Browse the repository at this point in the history
* Make download_url() follow redirects

Fix bug related to the incorrect processing of redirects.
Follow the redirect chain until the destination is reached or
the number of redirects exceeds the max allowed value (by default 10).

* Parametrize value of max allowed redirect number

Make max number of hops a function argument and assign its default value to 10

* Propagate the max number of hops to download_url()

Add the maximum number of redirect hops parameter to download_url()

* check file existence before redirect

* remove print

* remove recursion

* add tests

* Reducing max_redirect_hops

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Philip Meier <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
4 people authored Jan 15, 2021
1 parent 3b19d6f commit 0985533
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
12 changes: 12 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def test_check_integrity(self):
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))

def test_get_redirect_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

actual = utils._get_redirect_url(url)
assert actual == expected

def test_get_redirect_url_max_hops_exceeded(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0)

def test_download_url(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
Expand Down
58 changes: 40 additions & 18 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,31 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
return check_md5(fpath, md5)


def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
def _get_redirect_url(url: str, max_hops: int = 10) -> str:
import requests

for hop in range(max_hops + 1):
response = requests.get(url)

if response.url == url or response.url is None:
return url

url = response.url
else:
raise RecursionError(f"Too many redirects: {max_hops + 1})")


def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
import urllib

Expand All @@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio
# check if file is already present locally
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else: # download the file
try:
print('Downloading ' + url + ' to ' + fpath)
return

# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)

# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
Expand Down

0 comments on commit 0985533

Please sign in to comment.