Skip to content

Commit

Permalink
[CONTRIB] TVM download utility based on urllib2/urlib.request (apache…
Browse files Browse the repository at this point in the history
…#1313)

moving nnvm/testing/download.py to python/tvm/contrib/download.py to be used as a general TVM download utility
  • Loading branch information
tmoreau89 authored and sergei-mironov committed Aug 8, 2018
1 parent 0c69be9 commit b0576ab
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,17 @@
# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except
"""Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs

import os
import sys
import time
import urllib
import requests

if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2

def _download_progress(count, block_size, total_size):
"""Show the download progress.
"""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()

def download(url, path, overwrite=False, size_compare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Expand Down Expand Up @@ -62,8 +45,29 @@ def download(url, path, overwrite=False, size_compare=False):
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=_download_progress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=_download_progress)

# Stateful start time
start_time = time.time()

def _download_progress(count, block_size, total_size):
#pylint: disable=unused-argument
"""Show the download progress.
"""
if count == 0:
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()

if sys.version_info >= (3,):
urllib2.urlretrieve(url, path, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(path, "wb") as code:
code.write(data)
2 changes: 1 addition & 1 deletion tutorials/nnvm/deploy_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.download import download
from tvm.contrib.download import download
from tvm.contrib import graph_runtime
from mxnet.model import load_checkpoint

Expand Down
2 changes: 1 addition & 1 deletion tutorials/nnvm/from_darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import os

from ctypes import *
from nnvm.testing.download import download
from tvm.contrib.download import download
from nnvm.testing.darknet import __darknetffi__

######################################################################
Expand Down

0 comments on commit b0576ab

Please sign in to comment.