-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TUTORIAL] Resnet-18 end to end tutorial example (#55)
- Loading branch information
Showing
1 changed file
with
326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,326 @@ | ||
""" | ||
ResNet Inference Example | ||
======================== | ||
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_ | ||
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference | ||
onto the VTA accelerator design to perform ImageNet classification tasks. | ||
""" | ||
|
||
|
||
###################################################################### | ||
# Import Libraries | ||
# ---------------- | ||
# We start by importing the tvm, vta, nnvm libraries to run this example. | ||
|
||
from __future__ import absolute_import, print_function | ||
|
||
import os | ||
import sys | ||
import nnvm | ||
import nnvm.compiler | ||
import tvm | ||
import vta | ||
import vta.testing | ||
import numpy as np | ||
import json | ||
import requests | ||
import time | ||
|
||
from nnvm.compiler import graph_attr | ||
from tvm.contrib import graph_runtime, rpc, util | ||
from tvm.contrib.download import download | ||
from vta.testing import simulator | ||
|
||
from io import BytesIO | ||
from matplotlib import pyplot as plt | ||
from PIL import Image | ||
|
||
# Load VTA parameters from the config.json file | ||
env = vta.get_env() | ||
|
||
# Helper to crop an image to a square (224, 224) | ||
# Takes in an Image object, returns an Image object | ||
def thumbnailify(image, pad=15): | ||
w, h = image.size | ||
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) | ||
image = image.crop(crop) | ||
image = image.resize((224, 224)) | ||
return image | ||
|
||
# Helper function to read in image | ||
# Takes in Image object, returns an ND array | ||
def process_image(image): | ||
# Convert to neural network input format | ||
image = np.array(image) - np.array([123., 117., 104.]) | ||
image /= np.array([58.395, 57.12, 57.375]) | ||
image = image.transpose((2, 0, 1)) | ||
image = image[np.newaxis, :] | ||
|
||
return tvm.nd.array(image.astype("float32")) | ||
|
||
# Classification helper function | ||
# Takes in the graph runtime, and an image, and returns top result and time | ||
def classify(m, image): | ||
m.set_input('data', image) | ||
timer = m.module.time_evaluator("run", ctx, number=1) | ||
tcost = timer() | ||
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0))) | ||
top = np.argmax(tvm_output.asnumpy()) | ||
tcost = "t={0:.2f}s".format(tcost.mean) | ||
return tcost + " {}".format(synset[top]) | ||
|
||
# Helper function to compile the NNVM graph | ||
# Takes in a path to a graph file, params file, and device target | ||
# Returns the NNVM graph object, a compiled library object, and the params dict | ||
def generate_graph(graph_fn, params_fn, device="vta"): | ||
|
||
# Measure build start time | ||
build_start = time.time() | ||
|
||
# Derive the TVM target | ||
target = tvm.target.create("llvm -device={}".format(device)) | ||
|
||
# Derive the LLVM compiler flags | ||
# When targetting the Pynq, cross-compile to ARMv7 ISA | ||
if env.TARGET == "sim": | ||
target_host = "llvm" | ||
elif env.TARGET == "pynq": | ||
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" | ||
|
||
# Load the ResNet-18 graph and parameters | ||
sym = nnvm.graph.load_json(open(graph_fn).read()) | ||
params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read()) | ||
|
||
# Populate the shape and data type dictionary | ||
shape_dict = {"data": (1, 3, 224, 224)} | ||
dtype_dict = {"data": 'float32'} | ||
shape_dict.update({k: v.shape for k, v in params.items()}) | ||
dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) | ||
|
||
# Create NNVM graph | ||
graph = nnvm.graph.create(sym) | ||
graph_attr.set_shape_inputs(sym, shape_dict) | ||
graph_attr.set_dtype_inputs(sym, dtype_dict) | ||
graph = graph.apply("InferShape").apply("InferType") | ||
|
||
# Apply NNVM graph optimization passes | ||
sym = vta.graph.clean_cast(sym) | ||
sym = vta.graph.clean_conv_fuse(sym) | ||
if target.device_name == "vta": | ||
assert env.BLOCK_IN == env.BLOCK_OUT | ||
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) | ||
|
||
# Compile NNVM graph | ||
with nnvm.compiler.build_config(opt_level=3): | ||
if target.device_name != "vta": | ||
graph, lib, params = nnvm.compiler.build( | ||
sym, target_host, shape_dict, dtype_dict, | ||
params=params) | ||
else: | ||
with vta.build_config(): | ||
graph, lib, params = nnvm.compiler.build( | ||
sym, target, shape_dict, dtype_dict, | ||
params=params, target_host=target_host) | ||
|
||
# Save the compiled inference graph library | ||
assert tvm.module.enabled("rpc") | ||
temp = util.tempdir() | ||
lib.save(temp.relpath("graphlib.o")) | ||
|
||
# Send the inference library over to the remote RPC server | ||
remote.upload(temp.relpath("graphlib.o")) | ||
lib = remote.load_module("graphlib.o") | ||
|
||
# Measure build time | ||
build_time = time.time() - build_start | ||
print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time)) | ||
|
||
return graph, lib, params | ||
|
||
|
||
###################################################################### | ||
# Download ResNet Model | ||
# -------------------------------------------- | ||
# Download the necessary files to run ResNet-18. | ||
# | ||
|
||
# Obtain ResNet model and download them into _data dir | ||
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" | ||
categ_fn = 'synset.txt' | ||
graph_fn = 'resnet18_qt8.json' | ||
params_fn = 'resnet18_qt8.params' | ||
|
||
# Create data dir | ||
data_dir = "_data/" | ||
if not os.path.exists(data_dir): | ||
os.makedirs(data_dir) | ||
|
||
# Download files | ||
for file in [categ_fn, graph_fn, params_fn]: | ||
if not os.path.isfile(file): | ||
download(os.path.join(url, file), os.path.join(data_dir, file)) | ||
|
||
# Read in ImageNet Categories | ||
synset = eval(open(os.path.join(data_dir, categ_fn)).read()) | ||
|
||
|
||
###################################################################### | ||
# Setup the Pynq Board's RPC Server | ||
# --------------------------------- | ||
# Build the RPC server's VTA runtime and program the Pynq FPGA. | ||
|
||
# Measure build start time | ||
reconfig_start = time.time() | ||
|
||
# We read the Pynq RPC host IP address and port number from the OS environment | ||
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") | ||
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) | ||
|
||
# We configure both the bitstream and the runtime system on the Pynq | ||
# to match the VTA configuration specified by the config.json file. | ||
if env.TARGET == "pynq": | ||
|
||
# Make sure that TVM was compiled with RPC=1 | ||
assert tvm.module.enabled("rpc") | ||
remote = rpc.connect(host, port) | ||
|
||
# Reconfigure the JIT runtime | ||
vta.reconfig_runtime(remote) | ||
|
||
# Program the FPGA with a pre-compiled VTA bitstream. | ||
# You can program the FPGA with your own custom bitstream | ||
# by passing the path to the bitstream file instead of None. | ||
vta.program_fpga(remote, bitstream=None) | ||
|
||
# Report on reconfiguration time | ||
reconfig_time = time.time() - reconfig_start | ||
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) | ||
|
||
# In simulation mode, host the RPC server locally. | ||
elif env.TARGET == "sim": | ||
remote = rpc.LocalSession() | ||
|
||
|
||
###################################################################### | ||
# Build the ResNet Runtime | ||
# ------------------------ | ||
# Build the ResNet graph runtime, and configure the parameters. | ||
|
||
# Set ``device=cpu`` to run inference on the CPU, | ||
# or ``device=vtacpu`` to run inference on the FPGA. | ||
device = "vta" | ||
|
||
# Device context | ||
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) | ||
|
||
# Build the graph runtime | ||
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn), | ||
os.path.join(data_dir, params_fn), | ||
device) | ||
m = graph_runtime.create(graph, lib, ctx) | ||
|
||
# Set the parameters | ||
m.set_input(**params) | ||
|
||
|
||
###################################################################### | ||
# Run ResNet-18 inference on a sample image | ||
# ----------------------------------------- | ||
# Perform image classification on test image. | ||
# You can change the test image URL to any image of your choosing. | ||
|
||
# Read in test image | ||
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' | ||
# Read in test image | ||
response = requests.get(image_url) | ||
image = Image.open(BytesIO(response.content)).resize((224, 224)) | ||
# Show Image | ||
plt.imshow(image) | ||
plt.show() | ||
# Set the input | ||
image = process_image(image) | ||
m.set_input('data', image) | ||
|
||
# Perform inference | ||
timer = m.module.time_evaluator("run", ctx, number=1) | ||
tcost = timer() | ||
|
||
# Get classification results | ||
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0))) | ||
top_categories = np.argsort(tvm_output.asnumpy()) | ||
|
||
# Report top-5 classification results | ||
print("ResNet-18 Prediction #1:", synset[top_categories[-1]]) | ||
print(" #2:", synset[top_categories[-2]]) | ||
print(" #3:", synset[top_categories[-3]]) | ||
print(" #4:", synset[top_categories[-4]]) | ||
print(" #5:", synset[top_categories[-5]]) | ||
print("Performed inference in {0:.2f}s".format(tcost.mean)) | ||
|
||
|
||
###################################################################### | ||
# Run a Youtube Video Image Classifier | ||
# ------------------------------------ | ||
# Perform image classification on test stream on 1 frame every 48 frames. | ||
# Comment the `if False:` out to run the demo | ||
|
||
# Early exit - remove for Demo | ||
if False: | ||
|
||
import cv2 | ||
import pafy | ||
from IPython.display import clear_output | ||
|
||
# Helper to crop an image to a square (224, 224) | ||
# Takes in an Image object, returns an Image object | ||
def thumbnailify(image, pad=15): | ||
w, h = image.size | ||
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) | ||
image = image.crop(crop) | ||
image = image.resize((224, 224)) | ||
return image | ||
|
||
# 16:16 inches | ||
plt.rcParams['figure.figsize'] = [16, 16] | ||
|
||
# Stream the video in | ||
url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" | ||
video = pafy.new(url) | ||
best = video.getbest(preftype="mp4") | ||
cap = cv2.VideoCapture(best.url) | ||
|
||
# Process one frame out of every 48 for variety | ||
count = 0 | ||
guess = "" | ||
while(count<2400): | ||
|
||
# Capture frame-by-frame | ||
ret, frame = cap.read() | ||
|
||
# Process one every 48 frames | ||
if count % 48 == 1: | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
frame = Image.fromarray(frame) | ||
# Crop and resize | ||
thumb = np.array(thumbnailify(frame)) | ||
image = process_image(thumb) | ||
guess = classify(m, image) | ||
|
||
# Insert guess in frame | ||
frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) | ||
cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) | ||
|
||
plt.imshow(thumb) | ||
plt.axis('off') | ||
plt.show() | ||
if cv2.waitKey(1) & 0xFF == ord('q'): | ||
break | ||
clear_output(wait=True) | ||
|
||
count += 1 | ||
|
||
# When everything done, release the capture | ||
cap.release() | ||
cv2.destroyAllWindows() |