From 0d0334f50bbce18f61d482c3e9192bb5c98f6144 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 10 Jun 2019 14:34:12 -0700 Subject: [PATCH] updated tutorial to use Relay --- vta/tutorials/resnet.py | 324 ++++++++++++++-------------------------- 1 file changed, 116 insertions(+), 208 deletions(-) diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py index 13161586480ea..d3ed0cebe79d8 100644 --- a/vta/tutorials/resnet.py +++ b/vta/tutorials/resnet.py @@ -24,292 +24,200 @@ """ + ###################################################################### # Import Libraries # ---------------- -# We start by importing the tvm, vta, nnvm libraries to run this example. +# We start by importing libraries to run this example. from __future__ import absolute_import, print_function -import os -import time +import argparse, json, os, requests, time from io import BytesIO +from os.path import join, isfile +from PIL import Image +from mxnet.gluon.model_zoo import vision import numpy as np -import requests from matplotlib import pyplot as plt -from PIL import Image import tvm -from tvm import rpc, autotvm -from tvm.contrib import graph_runtime, util -from tvm.contrib.download import download -import nnvm.compiler -import vta -import vta.testing +from tvm import rpc, autotvm, relay +from tvm.contrib import graph_runtime, util, download +from tvm.contrib.debugger import debug_runtime -# Load VTA parameters from the vta/config/vta_config.json file -env = vta.get_env() +import vta +from vta.testing import simulator +from vta.top import graph_pack -# Helper to crop an image to a square (224, 224) -# Takes in an Image object, returns an Image object -def thumbnailify(image, pad=15): - w, h = image.size - crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) - image = image.crop(crop) - image = image.resize((224, 224)) - return image - -# Helper function to read in image -# Takes in Image object, returns an ND array -def process_image(image): - # Convert to neural network input format - image = np.array(image) - np.array([123., 117., 104.]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :] - - return tvm.nd.array(image.astype("float32")) - -# Classification helper function -# Takes in the graph runtime, and an image, and returns top result and time -def classify(m, image): - m.set_input('data', image) - timer = m.module.time_evaluator("run", ctx, number=1) - tcost = timer() - tvm_output = m.get_output(0) - top = np.argmax(tvm_output.asnumpy()[0]) - tcost = "t={0:.2f}s".format(tcost.mean) - return tcost + " {}".format(synset[top]) +# Make sure that TVM was compiled with RPC=1 +assert tvm.module.enabled("rpc") ###################################################################### -# Download ResNet Model -# -------------------------------------------- -# Download the necessary files to run ResNet-18. -# - -# Obtain ResNet model and download them into _data dir -url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" -categ_fn = 'synset.txt' -graph_fn = 'resnet18_qt8.json' -params_fn = 'resnet18_qt8.params' +# Define the platform and model targets +# ---------------- +# Execute on CPU vs. VTA, and define the model. -# Create data dir -data_dir = "_data/" -if not os.path.exists(data_dir): - os.makedirs(data_dir) +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() -# Download files -for file in [categ_fn, graph_fn, params_fn]: - download(os.path.join(url, file), os.path.join(data_dir, file)) +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu -# Read in ImageNet Categories -synset = eval(open(os.path.join(data_dir, categ_fn)).read()) +# Name of Gluon model to compile +model = "resnet18_v1" +start_pack="nn.max_pool2d" +stop_pack="nn.global_avg_pool2d" ###################################################################### -# Setup the Pynq Board's RPC Server +# Obtain an execution remote # --------------------------------- -# Build the RPC server's VTA runtime and program the Pynq FPGA. - -# Measure build start time -reconfig_start = time.time() - -# We read the Pynq RPC host IP address and port number from the OS environment -host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") -port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) - -# We configure both the bitstream and the runtime system on the Pynq -# to match the VTA configuration specified by the vta_config.json file. -if env.TARGET == "pynq": - # Make sure that TVM was compiled with RPC=1 - assert tvm.module.enabled("rpc") - remote = rpc.connect(host, port) - - # Reconfigure the JIT runtime - vta.reconfig_runtime(remote) - - # Program the FPGA with a pre-compiled VTA bitstream. +# When target is 'pynq', reconfigure FPGA and runtime. +# Otherwise, if target is 'sim', execute locally. + +if env.TARGET != "sim": + + # Get remote from fleet node if environment variable is set + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") + device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + if not tracker_host or not tracker_port: + remote = rpc.connect(device_host, device_port) + else: + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + + # Reconfigure the JIT runtime and FPGA. # You can program the FPGA with your own custom bitstream # by passing the path to the bitstream file instead of None. + reconfig_start = time.time() + vta.reconfig_runtime(remote) vta.program_fpga(remote, bitstream=None) - - # Report on reconfiguration time reconfig_time = time.time() - reconfig_start print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) # In simulation mode, host the RPC server locally. -elif env.TARGET == "sim": +else: remote = rpc.LocalSession() +# Get execution context from remote +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) ###################################################################### -# Build the ResNet Runtime +# Build the inference runtime # ------------------------ -# Build the ResNet graph runtime, and configure the parameters. - -# Set ``device=vtacpu`` to run inference on the CPU -# or ``device=vta`` to run inference on the FPGA. -device = "vta" - -# TVM target and context -target = tvm.target.create("llvm -device={}".format(device)) -ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) - -# TVM module -m = None +# Build ResNet from Gluon with Relay. +# Load pre-configured AutoTVM schedules with autotvm.tophub.context(target): - graph_fn = os.path.join(data_dir, graph_fn) - params_fn= os.path.join(data_dir, params_fn) + # Populate the shape and data type dictionary for ResNet input + dtype_dict = {"data": 'float32'} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(model, pretrained=True) # Measure build start time build_start = time.time() - # Load the ResNet-18 graph and parameters - sym = nnvm.graph.load_json(open(graph_fn).read()) - params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read()) + # Start front end compilation + relay_prog, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + print(relay_prog) + # exit() - # Populate the shape and data type dictionary - shape_dict = {"data": (1, 3, 224, 224)} - dtype_dict = {"data": 'float32'} + # Update shape and type dictionary shape_dict.update({k: v.shape for k, v in params.items()}) dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - # Apply NNVM graph optimization passes - sym = vta.graph.clean_cast(sym) - sym = vta.graph.clean_conv_fuse(sym) + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, skip_k_conv=1): + relay_prog = relay.quantize.quantize(relay_prog, params=params) + + # Perform graph packing and constant folding for VTA target if target.device_name == "vta": assert env.BLOCK_IN == env.BLOCK_OUT - sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) - - # Compile NNVM graph - with nnvm.compiler.build_config(opt_level=3): + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=start_pack, + stop_name=stop_pack) + relay_prog = relay.ir_pass.fold_constant(relay_prog) + + # Compile Relay program with AlterOpLayout disabled + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): if target.device_name != "vta": - graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, + graph, lib, params = relay.build( + relay_prog, target=target, params=params, target_host=env.target_host) else: with vta.build_config(): - graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, + graph, lib, params = relay.build( + relay_prog, target=target, params=params, target_host=env.target_host) - # Save the compiled inference graph library - assert tvm.module.enabled("rpc") - temp = util.tempdir() - lib.save(temp.relpath("graphlib.o")) + # Measure Relay build time + build_time = time.time() - build_start + print(model + " inference graph built in {0:.2f}s!".format(build_time)) # Send the inference library over to the remote RPC server + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) remote.upload(temp.relpath("graphlib.o")) lib = remote.load_module("graphlib.o") - # Measure build time - build_time = time.time() - build_start - print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time)) - + # Graph runtime m = graph_runtime.create(graph, lib, ctx) - # Set the parameters - m.set_input(**params) - ###################################################################### -# Run ResNet-18 inference on a sample image -# ----------------------------------------- -# Perform image classification on test image. -# You can change the test image URL to any image of your choosing. +# Perform ResNet-18 inference +# ------------------------ +# We run classification on an image sample from ImageNet + +# Download ImageNet categories +categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/synset.txt" +categ_fn = "synset.txt" +download.download(join(categ_url, categ_fn), categ_fn) +synset = eval(open(categ_fn).read()) -# Read in test image +# Download test image image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' -# Read in test image response = requests.get(image_url) + +# Prepare test image for inference image = Image.open(BytesIO(response.content)).resize((224, 224)) -# Show Image plt.imshow(image) plt.show() -# Set the input -image = process_image(image) +image = np.array(image) - np.array([123., 117., 104.]) +image /= np.array([58.395, 57.12, 57.375]) +image = image.transpose((2, 0, 1)) +image = image[np.newaxis, :] +image = np.repeat(image, env.BATCH, axis=0) + +# Set the network parameters and inputs +m.set_input(**params) m.set_input('data', image) # Perform inference -timer = m.module.time_evaluator("run", ctx, number=1) +timer = m.module.time_evaluator("run", ctx, number=4, repeat=3) tcost = timer() # Get classification results -tvm_output = m.get_output(0) +tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) top_categories = np.argsort(tvm_output.asnumpy()[0]) # Report top-5 classification results -print("ResNet-18 Prediction #1:", synset[top_categories[-1]]) +std = np.std(tcost.results) * 1000 / env.BATCH +mean = tcost.mean * 1000 / env.BATCH +print("%s prediction" % model) +print(" #1:", synset[top_categories[-1]]) print(" #2:", synset[top_categories[-2]]) print(" #3:", synset[top_categories[-3]]) print(" #4:", synset[top_categories[-4]]) print(" #5:", synset[top_categories[-5]]) -print("Performed inference in {0:.2f}s".format(tcost.mean)) - - -###################################################################### -# Run a Youtube Video Image Classifier -# ------------------------------------ -# Perform image classification on test stream on 1 frame every 48 frames. -# Comment the `if False:` out to run the demo - -# Early exit - remove for Demo -if False: - - import cv2 - import pafy - from IPython.display import clear_output - - # Helper to crop an image to a square (224, 224) - # Takes in an Image object, returns an Image object - def thumbnailify(image, pad=15): - w, h = image.size - crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) - image = image.crop(crop) - image = image.resize((224, 224)) - return image - - # 16:16 inches - plt.rcParams['figure.figsize'] = [16, 16] - - # Stream the video in - url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" - video = pafy.new(url) - best = video.getbest(preftype="mp4") - cap = cv2.VideoCapture(best.url) - - # Process one frame out of every 48 for variety - count = 0 - guess = "" - while(count<2400): - - # Capture frame-by-frame - ret, frame = cap.read() - - # Process one every 48 frames - if count % 48 == 1: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - # Crop and resize - thumb = np.array(thumbnailify(frame)) - image = process_image(thumb) - guess = classify(m, image) - - # Insert guess in frame - frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) - cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) - - plt.imshow(thumb) - plt.axis('off') - plt.show() - if cv2.waitKey(1) & 0xFF == ord('q'): - break - clear_output(wait=True) - - count += 1 - - # When everything done, release the capture - cap.release() - cv2.destroyAllWindows() +print("Performed inference in %.2fms/sample (std = %.2f)" % (mean, std))