diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 2c6599ce497..cb4a0e8ba8f 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -36,6 +36,18 @@ def test_check_integrity(self): self.assertTrue(utils.check_integrity(existing_fpath)) self.assertFalse(utils.check_integrity(nonexisting_fpath)) + def test_get_redirect_url(self): + url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" + expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" + + actual = utils._get_redirect_url(url) + assert actual == expected + + def test_get_redirect_url_max_hops_exceeded(self): + url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" + with self.assertRaises(RecursionError): + utils._get_redirect_url(url, max_hops=0) + def test_download_url(self): with get_tmp_dir() as temp_dir: url = "http://github.com/pytorch/vision/archive/master.zip" diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 6a32cc07fd5..9490f8e972b 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -42,7 +42,23 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: return check_md5(fpath, md5) -def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: +def _get_redirect_url(url: str, max_hops: int = 10) -> str: + import requests + + for hop in range(max_hops + 1): + response = requests.get(url) + + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError(f"Too many redirects: {max_hops + 1})") + + +def download_url( + url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 +) -> None: """Download a file from a url and place it in root. Args: @@ -50,6 +66,7 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed """ import urllib @@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio # check if file is already present locally if check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) - else: # download the file - try: - print('Downloading ' + url + ' to ' + fpath) + return + + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # download the file + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) - except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] - if url[:5] == 'https': - url = url.replace('https:', 'http:') - print('Failed download. Trying https -> http instead.' - ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve( - url, fpath, - reporthook=gen_bar_updater() - ) - else: - raise e - # check integrity of downloaded file - if not check_integrity(fpath, md5): - raise RuntimeError("File not found or corrupted.") + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") def list_dir(root: str, prefix: bool = False) -> List[str]: