From 3f3005824a78979da3ef87d8e5d11d7ed7f334aa Mon Sep 17 00:00:00 2001 From: Moto Hira Date: Thu, 21 Jan 2021 08:10:20 -0800 Subject: [PATCH] Make download_url() follow redirects (#3235) (#3236) Summary: * Make download_url() follow redirects Fix bug related to the incorrect processing of redirects. Follow the redirect chain until the destination is reached or the number of redirects exceeds the max allowed value (by default 10). * Parametrize value of max allowed redirect number Make max number of hops a function argument and assign its default value to 10 * Propagate the max number of hops to download_url() Add the maximum number of redirect hops parameter to download_url() * check file existence before redirect * remove print * remove recursion * add tests * Reducing max_redirect_hops Reviewed By: datumbox Differential Revision: D25954556 fbshipit-source-id: 3b2c64592d5882b98e87acdb5efd95e9283d2862 Co-authored-by: Vasilis Vryniotis Co-authored-by: Philip Meier Co-authored-by: Vasilis Vryniotis --- test/test_datasets_utils.py | 12 ++++++++ torchvision/datasets/utils.py | 58 ++++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 2c6599ce497..cb4a0e8ba8f 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -36,6 +36,18 @@ def test_check_integrity(self): self.assertTrue(utils.check_integrity(existing_fpath)) self.assertFalse(utils.check_integrity(nonexisting_fpath)) + def test_get_redirect_url(self): + url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" + expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" + + actual = utils._get_redirect_url(url) + assert actual == expected + + def test_get_redirect_url_max_hops_exceeded(self): + url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" + with self.assertRaises(RecursionError): + utils._get_redirect_url(url, max_hops=0) + def test_download_url(self): with get_tmp_dir() as temp_dir: url = "http://github.com/pytorch/vision/archive/master.zip" diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 6a32cc07fd5..9490f8e972b 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -42,7 +42,23 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: return check_md5(fpath, md5) -def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: +def _get_redirect_url(url: str, max_hops: int = 10) -> str: + import requests + + for hop in range(max_hops + 1): + response = requests.get(url) + + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError(f"Too many redirects: {max_hops + 1})") + + +def download_url( + url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 +) -> None: """Download a file from a url and place it in root. Args: @@ -50,6 +66,7 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed """ import urllib @@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio # check if file is already present locally if check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) - else: # download the file - try: - print('Downloading ' + url + ' to ' + fpath) + return + + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # download the file + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) - except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] - if url[:5] == 'https': - url = url.replace('https:', 'http:') - print('Failed download. Trying https -> http instead.' - ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve( - url, fpath, - reporthook=gen_bar_updater() - ) - else: - raise e - # check integrity of downloaded file - if not check_integrity(fpath, md5): - raise RuntimeError("File not found or corrupted.") + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") def list_dir(root: str, prefix: bool = False) -> List[str]: