From 56b1418885470bb4f891f67d1dd8c60695985602 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 19 Jul 2018 08:54:50 +0530 Subject: [PATCH] Review comments addressed --- tutorials/nnvm/nlp/from_darknet_rnn.py | 64 ++------------------------ 1 file changed, 4 insertions(+), 60 deletions(-) diff --git a/tutorials/nnvm/nlp/from_darknet_rnn.py b/tutorials/nnvm/nlp/from_darknet_rnn.py index 1d9ecb73ea265..118295f23186a 100644 --- a/tutorials/nnvm/nlp/from_darknet_rnn.py +++ b/tutorials/nnvm/nlp/from_darknet_rnn.py @@ -19,13 +19,8 @@ by the script. """ import random -import os -import sys -import time -import urllib -import requests import numpy as np -import urllib.request as urllib2 +from mxnet.gluon.utils import download import tvm from tvm.contrib import graph_runtime from nnvm.testing.darknet import __darknetffi__ @@ -43,51 +38,13 @@ CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true' WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true' -def _dl_progress(count, block_size, total_size): - """Show the download progress.""" - global start_time - if count == 0: - start_time = time.time() - return - duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) - percent = int(count * block_size * 100 / total_size) - sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % - (percent, progress_size / (1024 * 1024), speed, duration)) - sys.stdout.flush() - -def _download(url, path, overwrite=False, sizecompare=False): - """Downloads the file from the internet. - """ - if os.path.isfile(path) and not overwrite: - if sizecompare: - file_size = os.path.getsize(path) - res_head = requests.head(url) - res_get = requests.get(url, stream=True) - if 'Content-Length' not in res_head.headers: - res_get = urllib2.urlopen(url) - url_file_size = int(res_get.headers['Content-Length']) - if url_file_size != file_size: - print("exist file got corrupted, downloading", path, " file freshly") - _download(url, path, True, False) - return - print('File {} exists, skip.'.format(path)) - return - print('Downloading from url {} to {}'.format(url, path)) - try: - urllib.request.urlretrieve(url, path, reporthook=_dl_progress) - print('') - except: - urllib.urlretrieve(url, path, reporthook=_dl_progress) - -_download(CFG_URL, CFG_NAME) -_download(WEIGHTS_URL, WEIGHTS_NAME) +download(CFG_URL, CFG_NAME) +download(WEIGHTS_URL, WEIGHTS_NAME) # Download and Load darknet library DARKNET_LIB = 'libdarknet.so' DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' -_download(DARKNET_URL, DARKNET_LIB) +download(DARKNET_URL, DARKNET_LIB) DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB) cfg = "./" + str(CFG_NAME) weights = "./" + str(WEIGHTS_NAME) @@ -110,19 +67,6 @@ def _download(url, path, overwrite=False, sizecompare=False): with nnvm.compiler.build_config(opt_level=2): graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params) -# Save the json -def _save_lib(): - '''Save the graph, params and .so to the current directory''' - print("Saving the compiled output...") - path_name = 'nnvm_darknet_' + MODEL_NAME - path_lib = path_name + '_deploy_lib.so' - lib.export_library(path_lib) - with open(path_name + "deploy_graph.json", "w") as fo: - fo.write(graph.json()) - with open(path_name + "deploy_param.params", "wb") as fo: - fo.write(nnvm.compiler.save_param_dict(params)) -#_save_lib() - # Execute on TVM ctx = tvm.cpu(0)