Skip to content

Commit

Permalink
Merge pull request #813 from microsoft/abhiram-requests-fix
Browse files Browse the repository at this point in the history
Change url download util to use requests
  • Loading branch information
gramhagen authored Jun 6, 2019
2 parents beab2f7 + 1d8eb08 commit a229c94
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions reco_utils/dataset/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,45 @@
# Licensed under the MIT License.

import os
from urllib.request import urlretrieve
import logging
import requests
import math
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from tqdm import tqdm


log = logging.getLogger(__name__)


class TqdmUpTo(tqdm):
"""Wrapper class for the progress bar tqdm to get `update_to(n)` functionality"""

def update_to(self, b=1, bsize=1, tsize=None):
"""A progress bar showing how much is left to finish the operation
Args:
b (int): Number of blocks transferred so far.
bsize (int): Size of each block (in tqdm units).
tsize (int): Total size (in tqdm units).
"""
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n) # will also set self.n = b * bsize


def maybe_download(url, filename=None, work_directory=".", expected_bytes=None):
"""Download a file if it is not already downloaded.
Args:
filename (str): File name.
work_directory (str): Working directory.
url (str): URL of the file to download.
expected_bytes (int): Expected file size in bytes.
Returns:
str: File path of the file downloaded.
"""
if filename is None:
filename = url.split("/")[-1]
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
with TqdmUpTo(unit="B", unit_scale=True) as t:
filepath, _ = urlretrieve(url, filepath, reporthook=t.update_to)

r = requests.get(url, stream=True)
total_size = int(r.headers.get("content-length", 0))
block_size = 1024
num_iterables = math.ceil(total_size / block_size)

with open(filepath, "wb") as file:
for data in tqdm(
r.iter_content(block_size),
total=num_iterables,
unit="KB",
unit_scale=True,
):
file.write(data)
else:
log.debug("File {} already downloaded".format(filepath))
if expected_bytes is not None:
Expand Down Expand Up @@ -82,5 +77,3 @@ def download_path(path=None):
else:
path = os.path.realpath(path)
yield path


0 comments on commit a229c94

Please sign in to comment.