diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py index 7930bfe750c6e..13161586480ea 100644 --- a/vta/tutorials/resnet.py +++ b/vta/tutorials/resnet.py @@ -37,7 +37,7 @@ import numpy as np import requests -#from matplotlib import pyplot as plt +from matplotlib import pyplot as plt from PIL import Image import tvm @@ -107,10 +107,6 @@ def classify(m, image): # Read in ImageNet Categories synset = eval(open(os.path.join(data_dir, categ_fn)).read()) -# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA -# autotvm.tophub.check_backend('vta') - - ###################################################################### # Setup the Pynq Board's RPC Server # --------------------------------- @@ -152,16 +148,13 @@ def classify(m, image): # ------------------------ # Build the ResNet graph runtime, and configure the parameters. -# Set ``device=arm_cpu`` to run inference on the CPU +# Set ``device=vtacpu`` to run inference on the CPU # or ``device=vta`` to run inference on the FPGA. device = "vta" -# Derive the TVM target -if device == "vta": - target = env.target -elif device == "arm_cpu": - target = env.target_vta_cpu -ctx = remote.context(str(target)) +# TVM target and context +target = tvm.target.create("llvm -device={}".format(device)) +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) # TVM module m = None @@ -233,8 +226,8 @@ def classify(m, image): response = requests.get(image_url) image = Image.open(BytesIO(response.content)).resize((224, 224)) # Show Image -# plt.imshow(image) -# plt.show() +plt.imshow(image) +plt.show() # Set the input image = process_image(image) m.set_input('data', image) @@ -263,60 +256,60 @@ def classify(m, image): # Comment the `if False:` out to run the demo # Early exit - remove for Demo -# if False: - -# import cv2 -# import pafy -# from IPython.display import clear_output - -# # Helper to crop an image to a square (224, 224) -# # Takes in an Image object, returns an Image object -# def thumbnailify(image, pad=15): -# w, h = image.size -# crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) -# image = image.crop(crop) -# image = image.resize((224, 224)) -# return image - -# # 16:16 inches -# plt.rcParams['figure.figsize'] = [16, 16] - -# # Stream the video in -# url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" -# video = pafy.new(url) -# best = video.getbest(preftype="mp4") -# cap = cv2.VideoCapture(best.url) - -# # Process one frame out of every 48 for variety -# count = 0 -# guess = "" -# while(count<2400): - -# # Capture frame-by-frame -# ret, frame = cap.read() - -# # Process one every 48 frames -# if count % 48 == 1: -# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) -# frame = Image.fromarray(frame) -# # Crop and resize -# thumb = np.array(thumbnailify(frame)) -# image = process_image(thumb) -# guess = classify(m, image) - -# # Insert guess in frame -# frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) -# cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) - -# plt.imshow(thumb) -# plt.axis('off') -# plt.show() -# if cv2.waitKey(1) & 0xFF == ord('q'): -# break -# clear_output(wait=True) - -# count += 1 - -# # When everything done, release the capture -# cap.release() -# cv2.destroyAllWindows() +if False: + + import cv2 + import pafy + from IPython.display import clear_output + + # Helper to crop an image to a square (224, 224) + # Takes in an Image object, returns an Image object + def thumbnailify(image, pad=15): + w, h = image.size + crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) + image = image.crop(crop) + image = image.resize((224, 224)) + return image + + # 16:16 inches + plt.rcParams['figure.figsize'] = [16, 16] + + # Stream the video in + url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" + video = pafy.new(url) + best = video.getbest(preftype="mp4") + cap = cv2.VideoCapture(best.url) + + # Process one frame out of every 48 for variety + count = 0 + guess = "" + while(count<2400): + + # Capture frame-by-frame + ret, frame = cap.read() + + # Process one every 48 frames + if count % 48 == 1: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + # Crop and resize + thumb = np.array(thumbnailify(frame)) + image = process_image(thumb) + guess = classify(m, image) + + # Insert guess in frame + frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) + cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) + + plt.imshow(thumb) + plt.axis('off') + plt.show() + if cv2.waitKey(1) & 0xFF == ord('q'): + break + clear_output(wait=True) + + count += 1 + + # When everything done, release the capture + cap.release() + cv2.destroyAllWindows()