Skip to content

Commit

Permalink
[VTA][Relay] Extending Vision model coverage compilation for VTA (apa…
Browse files Browse the repository at this point in the history
…che#3740)

* adding support for graphpack over multiply op

* increasing resnet model coverage

* fix indentation

* lint

* moving recursion limit fix into graphpack pass

* moving recursionlimit to relay init

* pooling on NCHWnc format

* adding more models

* deploy_resnet_on_vta.py

* trailing line

* generalizing to vision models

* merge conflicts

* fix, apply quantization to VTA only

* improving comments

* trimming models that have runtime issues for the moment

* lint

* lint

* lint
  • Loading branch information
tmoreau89 authored and wweic committed Sep 16, 2019
1 parent c119257 commit 5b831ab
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 38 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from sys import setrecursionlimit
from ..api import register_func
from . import base
from . import ty
Expand Down Expand Up @@ -59,6 +60,9 @@

from .scope_builder import ScopeBuilder

# Required to traverse large programs
setrecursionlimit(10000)

# Span
Span = base.Span

Expand Down
7 changes: 5 additions & 2 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";

CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
CHECK(inputs[0].ndim() == 4U ||
inputs[0].ndim() == 5U ||
inputs[0].ndim() == 6U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
<< " or 5-D input (e.g. NCHWc on for vector instructions)"
<< " or 6-D input (e.g. NCHWnc for tensor accelerators)";

if (param->padding.size() == 1) {
padding.push_back(padding[0]);
Expand Down
43 changes: 28 additions & 15 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor):
return data


def _pack_bias(data, dshape, dtype, bfactor, cfactor):
"""Pack the bias parameter.
def _pack_const(data, dshape, dtype, bfactor, cfactor):
"""Pack a constant parameter.
"""
dshape = _to_shape(dshape)
assert len(dshape) == 3
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self, bfactor, cfactor, weight_bits):
self.conv2d = op.op.get("nn.conv2d")
self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
self.add = op.op.get("add")
self.multiply = op.op.get("multiply")
self.bias_add = op.op.get("nn.bias_add")
self.number_of_conv2d = 0
super().__init__()
Expand Down Expand Up @@ -203,23 +204,35 @@ def visit_call(self, call):
output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype)
return conv2d
elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
elif call.op == self.add and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.add and len(input_types[1].shape) == 3:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
data, const = args
const = _pack_const(const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, const])
elif call.op == self.multiply and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.multiply and len(input_types[1].shape) == 3:
data, const = args
const = _pack_const(const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.multiply, [data, const])
elif self.start_pack and call.op == self.bias_add:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
bias = _pack_const(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
elif self.start_pack and call.op == op.op.get('cast') and \
input_types[0].dtype == 'int32':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""
Deploy Pretrained ResNet Model from MxNet on VTA
Deploy Pretrained Vision Model from MxNet on VTA
================================================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
This tutorial provides an end-to-end demo, on how to run ImageNet classification
inference onto the VTA accelerator design to perform ImageNet classification tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target.
Expand All @@ -40,7 +40,7 @@

from __future__ import absolute_import, print_function

import argparse, json, os, requests, time
import argparse, json, os, requests, sys, time
from io import BytesIO
from os.path import join, isfile
from PIL import Image
Expand All @@ -53,6 +53,7 @@
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_runtime, util, download
from tvm.contrib.debugger import debug_runtime
from tvm.relay import transform

import vta
from vta.testing import simulator
Expand All @@ -61,7 +62,6 @@
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")


######################################################################
# Define the platform and model targets
# -------------------------------------
Expand All @@ -75,13 +75,22 @@
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu

# Dictionary lookup for when to start/end bit packing
pack_dict = {
"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
}

# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
model = "resnet18_v1"
start_pack="nn.max_pool2d"
stop_pack="nn.global_avg_pool2d"
assert model in pack_dict

######################################################################
# Obtain an execution remote
Expand Down Expand Up @@ -125,7 +134,7 @@
######################################################################
# Build the inference graph runtime
# ---------------------------------
# Grab ResNet-18 model from Gluon model zoo and compile with Relay.
# Grab vision model from Gluon model zoo and compile with Relay.
# The compilation steps are:
# 1) Front end translation from MxNet into Relay module.
# 2) Apply 8-bit quantization: here we skip the first conv layer,
Expand All @@ -140,7 +149,7 @@
# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):

# Populate the shape and data type dictionary for ResNet input
# Populate the shape and data type dictionary for ImageNet classifier input
dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}

Expand All @@ -157,21 +166,22 @@
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)

# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target
assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack(
relay_prog,
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=start_pack,
stop_name=stop_pack)
start_name=pack_dict[model][0],
stop_name=pack_dict[model][1])
else:
relay_prog = mod["main"]

# Compile Relay program with AlterOpLayout disabled
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
Expand Down Expand Up @@ -199,8 +209,8 @@
m = graph_runtime.create(graph, lib, ctx)

######################################################################
# Perform ResNet-18 inference
# ---------------------------
# Perform image classification inference
# --------------------------------------
# We run classification on an image sample from ImageNet
# We just need to download the categories files, `synset.txt`
# and an input test image.
Expand Down Expand Up @@ -256,15 +266,13 @@
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.asnumpy()[b])

# Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])

# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
Expand Down

0 comments on commit 5b831ab

Please sign in to comment.