Skip to content

Commit

Permalink
Review comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Jul 19, 2018
1 parent 6e5ba73 commit 56b1418
Showing 1 changed file with 4 additions and 60 deletions.
64 changes: 4 additions & 60 deletions tutorials/nnvm/nlp/from_darknet_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
by the script.
"""
import random
import os
import sys
import time
import urllib
import requests
import numpy as np
import urllib.request as urllib2
from mxnet.gluon.utils import download
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.darknet import __darknetffi__
Expand All @@ -43,51 +38,13 @@
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true'

def _dl_progress(count, block_size, total_size):
"""Show the download progress."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()

def _download(url, path, overwrite=False, sizecompare=False):
"""Downloads the file from the internet.
"""
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
url_file_size = int(res_get.headers['Content-Length'])
if url_file_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=_dl_progress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=_dl_progress)

_download(CFG_URL, CFG_NAME)
_download(WEIGHTS_URL, WEIGHTS_NAME)
download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)

# Download and Load darknet library
DARKNET_LIB = 'libdarknet.so'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
_download(DARKNET_URL, DARKNET_LIB)
download(DARKNET_URL, DARKNET_LIB)
DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
weights = "./" + str(WEIGHTS_NAME)
Expand All @@ -110,19 +67,6 @@ def _download(url, path, overwrite=False, sizecompare=False):
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params)

# Save the json
def _save_lib():
'''Save the graph, params and .so to the current directory'''
print("Saving the compiled output...")
path_name = 'nnvm_darknet_' + MODEL_NAME
path_lib = path_name + '_deploy_lib.so'
lib.export_library(path_lib)
with open(path_name + "deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open(path_name + "deploy_param.params", "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
#_save_lib()

# Execute on TVM
ctx = tvm.cpu(0)

Expand Down

0 comments on commit 56b1418

Please sign in to comment.