diff --git a/reco_utils/dataset/download_utils.py b/reco_utils/dataset/download_utils.py index 21cd0dd909..3f0f08ffef 100644 --- a/reco_utils/dataset/download_utils.py +++ b/reco_utils/dataset/download_utils.py @@ -2,41 +2,24 @@ # Licensed under the MIT License. import os -from urllib.request import urlretrieve import logging +import requests +import math from contextlib import contextmanager from tempfile import TemporaryDirectory from tqdm import tqdm - log = logging.getLogger(__name__) -class TqdmUpTo(tqdm): - """Wrapper class for the progress bar tqdm to get `update_to(n)` functionality""" - - def update_to(self, b=1, bsize=1, tsize=None): - """A progress bar showing how much is left to finish the operation - - Args: - b (int): Number of blocks transferred so far. - bsize (int): Size of each block (in tqdm units). - tsize (int): Total size (in tqdm units). - """ - if tsize is not None: - self.total = tsize - self.update(b * bsize - self.n) # will also set self.n = b * bsize - - def maybe_download(url, filename=None, work_directory=".", expected_bytes=None): """Download a file if it is not already downloaded. - + Args: filename (str): File name. work_directory (str): Working directory. url (str): URL of the file to download. expected_bytes (int): Expected file size in bytes. - Returns: str: File path of the file downloaded. """ @@ -44,8 +27,20 @@ def maybe_download(url, filename=None, work_directory=".", expected_bytes=None): filename = url.split("/")[-1] filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): - with TqdmUpTo(unit="B", unit_scale=True) as t: - filepath, _ = urlretrieve(url, filepath, reporthook=t.update_to) + + r = requests.get(url, stream=True) + total_size = int(r.headers.get("content-length", 0)) + block_size = 1024 + num_iterables = math.ceil(total_size / block_size) + + with open(filepath, "wb") as file: + for data in tqdm( + r.iter_content(block_size), + total=num_iterables, + unit="KB", + unit_scale=True, + ): + file.write(data) else: log.debug("File {} already downloaded".format(filepath)) if expected_bytes is not None: @@ -82,5 +77,3 @@ def download_path(path=None): else: path = os.path.realpath(path) yield path - - \ No newline at end of file