Skip to content

Commit

Permalink
fixing resnet18 tutorial to work with TOPI
Browse files Browse the repository at this point in the history
  • Loading branch information
tmoreau89 committed May 10, 2019
1 parent 85c468d commit f0a5472
Showing 1 changed file with 64 additions and 71 deletions.
135 changes: 64 additions & 71 deletions vta/tutorials/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import numpy as np
import requests
#from matplotlib import pyplot as plt
from matplotlib import pyplot as plt
from PIL import Image

import tvm
Expand Down Expand Up @@ -107,10 +107,6 @@ def classify(m, image):
# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())

# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
# autotvm.tophub.check_backend('vta')


######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
Expand Down Expand Up @@ -152,16 +148,13 @@ def classify(m, image):
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.

# Set ``device=arm_cpu`` to run inference on the CPU
# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"

# Derive the TVM target
if device == "vta":
target = env.target
elif device == "arm_cpu":
target = env.target_vta_cpu
ctx = remote.context(str(target))
# TVM target and context
target = tvm.target.create("llvm -device={}".format(device))
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

# TVM module
m = None
Expand Down Expand Up @@ -233,8 +226,8 @@ def classify(m, image):
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
# Show Image
# plt.imshow(image)
# plt.show()
plt.imshow(image)
plt.show()
# Set the input
image = process_image(image)
m.set_input('data', image)
Expand Down Expand Up @@ -263,60 +256,60 @@ def classify(m, image):
# Comment the `if False:` out to run the demo

# Early exit - remove for Demo
# if False:

# import cv2
# import pafy
# from IPython.display import clear_output

# # Helper to crop an image to a square (224, 224)
# # Takes in an Image object, returns an Image object
# def thumbnailify(image, pad=15):
# w, h = image.size
# crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
# image = image.crop(crop)
# image = image.resize((224, 224))
# return image

# # 16:16 inches
# plt.rcParams['figure.figsize'] = [16, 16]

# # Stream the video in
# url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
# video = pafy.new(url)
# best = video.getbest(preftype="mp4")
# cap = cv2.VideoCapture(best.url)

# # Process one frame out of every 48 for variety
# count = 0
# guess = ""
# while(count<2400):

# # Capture frame-by-frame
# ret, frame = cap.read()

# # Process one every 48 frames
# if count % 48 == 1:
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# frame = Image.fromarray(frame)
# # Crop and resize
# thumb = np.array(thumbnailify(frame))
# image = process_image(thumb)
# guess = classify(m, image)

# # Insert guess in frame
# frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
# cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)

# plt.imshow(thumb)
# plt.axis('off')
# plt.show()
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# clear_output(wait=True)

# count += 1

# # When everything done, release the capture
# cap.release()
# cv2.destroyAllWindows()
if False:

import cv2
import pafy
from IPython.display import clear_output

# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
w, h = image.size
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
image = image.crop(crop)
image = image.resize((224, 224))
return image

# 16:16 inches
plt.rcParams['figure.figsize'] = [16, 16]

# Stream the video in
url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
video = pafy.new(url)
best = video.getbest(preftype="mp4")
cap = cv2.VideoCapture(best.url)

# Process one frame out of every 48 for variety
count = 0
guess = ""
while(count<2400):

# Capture frame-by-frame
ret, frame = cap.read()

# Process one every 48 frames
if count % 48 == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
# Crop and resize
thumb = np.array(thumbnailify(frame))
image = process_image(thumb)
guess = classify(m, image)

# Insert guess in frame
frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)

plt.imshow(thumb)
plt.axis('off')
plt.show()
if cv2.waitKey(1) & 0xFF == ord('q'):
break
clear_output(wait=True)

count += 1

# When everything done, release the capture
cap.release()
cv2.destroyAllWindows()

0 comments on commit f0a5472

Please sign in to comment.