From dcc019892002a6467e2847efbfe60d1a1c4dbd22 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Tue, 3 Jul 2018 09:47:31 -0700 Subject: [PATCH] [TUTORIAL] Resnet-18 end to end tutorial example (#55) --- vta/tutorials/resnet.py | 326 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 vta/tutorials/resnet.py diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py new file mode 100644 index 000000000000..72ed5d7d2184 --- /dev/null +++ b/vta/tutorials/resnet.py @@ -0,0 +1,326 @@ +""" +ResNet Inference Example +======================== +**Author**: `Thierry Moreau `_ + +This tutorial provides an end-to-end demo, on how to run ResNet-18 inference +onto the VTA accelerator design to perform ImageNet classification tasks. + +""" + + +###################################################################### +# Import Libraries +# ---------------- +# We start by importing the tvm, vta, nnvm libraries to run this example. + +from __future__ import absolute_import, print_function + +import os +import sys +import nnvm +import nnvm.compiler +import tvm +import vta +import vta.testing +import numpy as np +import json +import requests +import time + +from nnvm.compiler import graph_attr +from tvm.contrib import graph_runtime, rpc, util +from tvm.contrib.download import download +from vta.testing import simulator + +from io import BytesIO +from matplotlib import pyplot as plt +from PIL import Image + +# Load VTA parameters from the config.json file +env = vta.get_env() + +# Helper to crop an image to a square (224, 224) +# Takes in an Image object, returns an Image object +def thumbnailify(image, pad=15): + w, h = image.size + crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) + image = image.crop(crop) + image = image.resize((224, 224)) + return image + +# Helper function to read in image +# Takes in Image object, returns an ND array +def process_image(image): + # Convert to neural network input format + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + + return tvm.nd.array(image.astype("float32")) + +# Classification helper function +# Takes in the graph runtime, and an image, and returns top result and time +def classify(m, image): + m.set_input('data', image) + timer = m.module.time_evaluator("run", ctx, number=1) + tcost = timer() + tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0))) + top = np.argmax(tvm_output.asnumpy()) + tcost = "t={0:.2f}s".format(tcost.mean) + return tcost + " {}".format(synset[top]) + +# Helper function to compile the NNVM graph +# Takes in a path to a graph file, params file, and device target +# Returns the NNVM graph object, a compiled library object, and the params dict +def generate_graph(graph_fn, params_fn, device="vta"): + + # Measure build start time + build_start = time.time() + + # Derive the TVM target + target = tvm.target.create("llvm -device={}".format(device)) + + # Derive the LLVM compiler flags + # When targetting the Pynq, cross-compile to ARMv7 ISA + if env.TARGET == "sim": + target_host = "llvm" + elif env.TARGET == "pynq": + target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" + + # Load the ResNet-18 graph and parameters + sym = nnvm.graph.load_json(open(graph_fn).read()) + params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read()) + + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Create NNVM graph + graph = nnvm.graph.create(sym) + graph_attr.set_shape_inputs(sym, shape_dict) + graph_attr.set_dtype_inputs(sym, dtype_dict) + graph = graph.apply("InferShape").apply("InferType") + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + # Compile NNVM graph + with nnvm.compiler.build_config(opt_level=3): + if target.device_name != "vta": + graph, lib, params = nnvm.compiler.build( + sym, target_host, shape_dict, dtype_dict, + params=params) + else: + with vta.build_config(): + graph, lib, params = nnvm.compiler.build( + sym, target, shape_dict, dtype_dict, + params=params, target_host=target_host) + + # Save the compiled inference graph library + assert tvm.module.enabled("rpc") + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) + + # Send the inference library over to the remote RPC server + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # Measure build time + build_time = time.time() - build_start + print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time)) + + return graph, lib, params + + +###################################################################### +# Download ResNet Model +# -------------------------------------------- +# Download the necessary files to run ResNet-18. +# + +# Obtain ResNet model and download them into _data dir +url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" +categ_fn = 'synset.txt' +graph_fn = 'resnet18_qt8.json' +params_fn = 'resnet18_qt8.params' + +# Create data dir +data_dir = "_data/" +if not os.path.exists(data_dir): + os.makedirs(data_dir) + +# Download files +for file in [categ_fn, graph_fn, params_fn]: + if not os.path.isfile(file): + download(os.path.join(url, file), os.path.join(data_dir, file)) + +# Read in ImageNet Categories +synset = eval(open(os.path.join(data_dir, categ_fn)).read()) + + +###################################################################### +# Setup the Pynq Board's RPC Server +# --------------------------------- +# Build the RPC server's VTA runtime and program the Pynq FPGA. + +# Measure build start time +reconfig_start = time.time() + +# We read the Pynq RPC host IP address and port number from the OS environment +host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") +port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + +# We configure both the bitstream and the runtime system on the Pynq +# to match the VTA configuration specified by the config.json file. +if env.TARGET == "pynq": + + # Make sure that TVM was compiled with RPC=1 + assert tvm.module.enabled("rpc") + remote = rpc.connect(host, port) + + # Reconfigure the JIT runtime + vta.reconfig_runtime(remote) + + # Program the FPGA with a pre-compiled VTA bitstream. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.program_fpga(remote, bitstream=None) + + # Report on reconfiguration time + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + +# In simulation mode, host the RPC server locally. +elif env.TARGET == "sim": + remote = rpc.LocalSession() + + +###################################################################### +# Build the ResNet Runtime +# ------------------------ +# Build the ResNet graph runtime, and configure the parameters. + +# Set ``device=cpu`` to run inference on the CPU, +# or ``device=vtacpu`` to run inference on the FPGA. +device = "vta" + +# Device context +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) + +# Build the graph runtime +graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn), + os.path.join(data_dir, params_fn), + device) +m = graph_runtime.create(graph, lib, ctx) + +# Set the parameters +m.set_input(**params) + + +###################################################################### +# Run ResNet-18 inference on a sample image +# ----------------------------------------- +# Perform image classification on test image. +# You can change the test image URL to any image of your choosing. + +# Read in test image +image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' +# Read in test image +response = requests.get(image_url) +image = Image.open(BytesIO(response.content)).resize((224, 224)) +# Show Image +plt.imshow(image) +plt.show() +# Set the input +image = process_image(image) +m.set_input('data', image) + +# Perform inference +timer = m.module.time_evaluator("run", ctx, number=1) +tcost = timer() + +# Get classification results +tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0))) +top_categories = np.argsort(tvm_output.asnumpy()) + +# Report top-5 classification results +print("ResNet-18 Prediction #1:", synset[top_categories[-1]]) +print(" #2:", synset[top_categories[-2]]) +print(" #3:", synset[top_categories[-3]]) +print(" #4:", synset[top_categories[-4]]) +print(" #5:", synset[top_categories[-5]]) +print("Performed inference in {0:.2f}s".format(tcost.mean)) + + +###################################################################### +# Run a Youtube Video Image Classifier +# ------------------------------------ +# Perform image classification on test stream on 1 frame every 48 frames. +# Comment the `if False:` out to run the demo + +# Early exit - remove for Demo +if False: + + import cv2 + import pafy + from IPython.display import clear_output + + # Helper to crop an image to a square (224, 224) + # Takes in an Image object, returns an Image object + def thumbnailify(image, pad=15): + w, h = image.size + crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) + image = image.crop(crop) + image = image.resize((224, 224)) + return image + + # 16:16 inches + plt.rcParams['figure.figsize'] = [16, 16] + + # Stream the video in + url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" + video = pafy.new(url) + best = video.getbest(preftype="mp4") + cap = cv2.VideoCapture(best.url) + + # Process one frame out of every 48 for variety + count = 0 + guess = "" + while(count<2400): + + # Capture frame-by-frame + ret, frame = cap.read() + + # Process one every 48 frames + if count % 48 == 1: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + # Crop and resize + thumb = np.array(thumbnailify(frame)) + image = process_image(thumb) + guess = classify(m, image) + + # Insert guess in frame + frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) + cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) + + plt.imshow(thumb) + plt.axis('off') + plt.show() + if cv2.waitKey(1) & 0xFF == ord('q'): + break + clear_output(wait=True) + + count += 1 + + # When everything done, release the capture + cap.release() + cv2.destroyAllWindows() \ No newline at end of file