Skip to content

Commit

Permalink
[VTA] Support for batched inference (#3661)
Browse files Browse the repository at this point in the history
* fix in IR pass to support padding on 6-d tensors

* support for both N>1 and N==1 for padding

* batch size > 1 tuning and base config

* output formatting

* batch conv2d

* print all category results

* revert to single-batch config

* pick record best

* fix conv test

* improving reporting

* address batching bug in fast simulator

* fix
  • Loading branch information
tmoreau89 authored and jroesch committed Jul 30, 2019
1 parent 9b355fc commit 6c7f0c4
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 80 deletions.
37 changes: 22 additions & 15 deletions vta/python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,29 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
if pad_before:
assert pad_after
ndim = len(pad_before)
if ndim <= 2 or ndim > 4:
if ndim <= 2 or ndim > 5:
raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
if ndim > 2:
if not util.equal_const_int(pad_before[ndim - 1], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 1], 0):
raise ValueError("Do not support pad on the innermost block")
if ndim > 3:
if not util.equal_const_int(pad_before[ndim - 2], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 2], 0):
raise ValueError("Do not support pad on the innermost block")
y_pad_before = pad_before[0]
x_pad_before = pad_before[1]
y_pad_after = pad_after[0]
x_pad_after = pad_after[1]
if ndim == 5:
# This case occurs when batch size N > 1
y_pad_before = pad_before[1]
x_pad_before = pad_before[2]
y_pad_after = pad_after[1]
x_pad_after = pad_after[2]
for dim in range(3, ndim):
if not util.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
else:
y_pad_before = pad_before[0]
x_pad_before = pad_before[1]
y_pad_after = pad_after[0]
x_pad_after = pad_after[1]
for dim in range(2, ndim):
if not util.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
allow_fold = False
else:
x_pad_before = 0
Expand Down
69 changes: 43 additions & 26 deletions vta/scripts/tune_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@

resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C1', Workload(env.BATCH, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
('resnet-18.C3', Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C4', Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C5', Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C6', Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C7', Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C8', Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C9', Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C10', Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C11', Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]

@tvm.tag_scope(tag=topi.tag.ELEMWISE)
Expand Down Expand Up @@ -87,16 +86,25 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dt

# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# logging.getLogger('autotvm').setLevel(logging.DEBUG)

# Tuning log files
log_file = "%s.conv2d.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)

# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracket_host or not tracket_port:
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()

for wl_name, wl in resnet_wkls:
for idx, (wl_name, wl) in enumerate(resnet_wkls):

prefix = "[Task %2d/%2d] " % (idx, len(resnet_wkls))

# Workload parameters
N = wl.batch
Expand All @@ -116,15 +124,24 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dt
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)

# Tune
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, int(tracket_port), number=4, repeat=3, timeout=10000,
check_correctness=True))
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET, host=tracker_host, port=int(tracker_port),
number=5, timeout=60,
check_correctness=True))

# Run Tuner
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('conv2d.log')])

print("\nBest tuner config:")
print(tuner.best_config)
tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])

# Pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)
2 changes: 1 addition & 1 deletion vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ class Device {
src_index += y * op->src_factor_out + x * op->src_factor_in;
BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
for (int k = 0; k < VTA_BATCH * VTA_BLOCK_OUT; ++k) {
if (use_imm) {
dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
} else {
Expand Down
32 changes: 18 additions & 14 deletions vta/tests/python/integration/test_benchmark_topi_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,23 @@
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

# Get batch info from env
env = vta.get_env()

# ResNet18 workloads
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C1', Workload(env.BATCH, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
('resnet-18.C3', Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C4', Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C5', Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C6', Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C7', Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C8', Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C9', Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C10', Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C11', Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]

# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
Expand Down Expand Up @@ -143,7 +145,7 @@ def get_ref_data():
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_np = bias_np.reshape(
wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT,
wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
1, 1, env.BATCH, env.BLOCK_OUT)

# Build
Expand Down Expand Up @@ -201,8 +203,10 @@ def get_ref_data():
if data_pack:
res_orig = res_orig.transpose(
(0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
bias_np = bias_np.transpose(
(0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
res_ref = res_ref >> 8
res_ref += bias_np.reshape(wl.out_filter, 1, 1)
res_ref += bias_np
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
Expand Down
14 changes: 13 additions & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,19 @@ def tune_and_evaluate(tuning_opt):
assert len(tasks) == 10
print("Extracted {} conv2d tasks:".format(len(tasks)))
for tsk in tasks:
print("\t{}".format(tsk))
inp = tsk.args[0][1]
wgt = tsk.args[1][1]
batch = inp[0]*inp[4]
in_filter = inp[1]*inp[5]
out_filter = wgt[0]*wgt[4]
height, width = inp[2], inp[3]
hkernel, wkernel = wgt[2], wgt[3]
hstride, wstride = tsk.args[2][0], tsk.args[2][1]
hpad, wpad = tsk.args[3][0], tsk.args[3][1]
print("({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
batch, height, width, in_filter, out_filter,
hkernel, wkernel, hpad, wpad, hstride, wstride
))

# We do not run the tuning in our webpage server since it takes too long.
# Comment the following line to run it by yourself.
Expand Down
48 changes: 25 additions & 23 deletions vta/tutorials/frontend/deploy_resnet_on_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,29 +247,31 @@
print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
else:
tcost = timer()
std = np.std(tcost.results) * 1000 / env.BATCH
mean = tcost.mean * 1000 / env.BATCH
print("\nPerformed inference in %.2fms/sample (std = %.2f)" % (mean, std))
std = np.std(tcost.results) * 1000
mean = tcost.mean * 1000
print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
print("Average per sample inference time: %.2fms" % (mean/env.BATCH))

# Get classification results
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
top_categories = np.argsort(tvm_output.asnumpy()[0])

# Report top-5 classification results
print("\n%s prediction" % model)
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])

# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
if "cat" in synset[k]:
cat_detected = True
assert(cat_detected)
for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.asnumpy()[b])

# Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])

# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
if "cat" in synset[k]:
cat_detected = True
assert(cat_detected)

0 comments on commit 6c7f0c4

Please sign in to comment.