Skip to content

Commit

Permalink
[VTA] Bringing group convolution support (apache#4421)
Browse files Browse the repository at this point in the history
* group conv operator support for VTA

* autotvm tuning script for group conv2d

* lint fix

* lint fix

* lint fix

* addressing comments
  • Loading branch information
tmoreau89 authored and Xingyu Zhou committed Dec 13, 2019
1 parent 76f741f commit 5c0b608
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 0 deletions.
1 change: 1 addition & 0 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from . import op
from . import vta_conv2d
from . import vta_conv2d_transpose
from . import vta_group_conv2d
from . import vta_dense
from . import util
199 changes: 199 additions & 0 deletions vta/python/vta/top/vta_group_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Group conv2D operator declaration and schedule registration for VTA."""

import numpy as np

import tvm
from tvm import autotvm
import topi

from ..environment import get_env

@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct')
def packed_group_conv2d(cfg,
data,
kernel,
strides,
padding,
dilation,
group,
out_dtype):
""" Packed group conv2d nchw function."""
assert dilation == (1, 1)

if padding[0]:
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data")
else:
pad_data = data
assert len(data.shape) == 6
assert len(kernel.shape) == 6
assert data.dtype == "int8", data.dtype
assert kernel.dtype == "int8", kernel.dtype
assert out_dtype == "int32", out_dtype

oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4])

ishape = topi.util.get_const_tuple(data.shape)
kshape = topi.util.get_const_tuple(kernel.shape)
assert group * kshape[1] == ishape[1]
assert kshape[0] % group == 0
d_i = tvm.reduce_axis((0, kshape[2]), name='d_i')
d_j = tvm.reduce_axis((0, kshape[3]), name='d_j')
k_o = tvm.reduce_axis((0, kshape[1]), name='k_o')
k_i = tvm.reduce_axis((0, kshape[-1]), name='k_i')
hstride, wstride = strides
out = tvm.compute(
oshape,
lambda b_o, c_o, i, j, b_i, c_i: tvm.sum(
pad_data[b_o, c_o // (kshape[0] // group) * kshape[1] + k_o, i * hstride + d_i,
j * wstride + d_j, b_i, k_i].astype(out_dtype) *
kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
axis=[k_o, d_i, d_j, k_i]),
name="res", tag="packed_group_conv2d")

cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
kshape[2] * kshape[3] * ishape[1] * kshape[-1])

return out


@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct')
def schedule_packed_group_conv2d(cfg, outs):
""" Schedule the packed conv2d.
"""
assert len(outs) == 1
output = outs[0]
const_ops = []
ewise_inputs = []
ewise_ops = []
conv2d_res = []
assert output.dtype == "int8"
assert output.op.input_tensors[0].dtype == "int32"

def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
if not op.axis:
const_ops.append(op)
else:
ewise_ops.append(op)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
else:
_traverse(tensor.op)
else:
assert op.tag == "packed_group_conv2d"
conv2d_res.append(op)

_traverse(output.op)
assert len(conv2d_res) == 1
conv2d_stage = conv2d_res[0].output(0)
s = tvm.create_schedule(output.op)

##### space definition begin #####
b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis
c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
cfg.define_split('tile_b', b, num_outputs=2)
cfg.define_split('tile_h', x_i, num_outputs=2)
cfg.define_split('tile_w', x_j, num_outputs=2)
cfg.define_split('tile_ci', c_i, num_outputs=2)
cfg.define_split('tile_co', c_o, num_outputs=2)
cfg.define_knob('oc_nthread', [1, 2])
cfg.define_knob('h_nthread', [1, 2])
###### space definition end ######

data, kernel = conv2d_stage.op.input_tensors
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
temp = data.op.input_tensors[0]
pad_data = data
data = temp
else:
pad_data = None

env = get_env()

# setup pad
if pad_data is not None:
cdata = pad_data
s[pad_data].set_scope(env.inp_scope)
else:
cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
s[conv2d_stage].set_scope(env.acc_scope)

# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, env.acc_scope, [consumer]))

# set ewise scope
for op in ewise_ops:
s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], env.alu)

for op in const_ops:
s[op].compute_inline()

# tile
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
store_pt = x_j0

# set all compute scopes
s[conv2d_stage].compute_at(s[output], store_pt)
for op in ewise_ops:
s[op].compute_at(s[output], store_pt)

for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)

# virtual threading along output channel axes
if cfg['oc_nthread'].val > 1:
_, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))

# virtual threading along spatial rows
if cfg['h_nthread'].val > 1:
_, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))

x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)

k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
s[cdata].compute_at(s[conv2d_stage], k_o)
s[ckernel].compute_at(s[conv2d_stage], k_o)

# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy)
s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy)
s[conv2d_stage].tensorize(x_bi, env.gemm)
s[output].pragma(x_co1, env.dma_copy)

return s
155 changes: 155 additions & 0 deletions vta/scripts/tune_group_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Tuning a single group conv2d operator"""

from collections import namedtuple
import logging
import os

import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing

env = vta.get_env()

Workload = namedtuple("GroupConv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter', 'groups',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

# Mobilenet (grouped variant) workloads
mobilenet_wkls = [
('mobilenet.D1', Workload(env.BATCH, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)),
('mobilenet.D2', Workload(env.BATCH, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)),
('mobilenet.D3', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)),
('mobilenet.D4', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)),
('mobilenet.D5', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)),
('mobilenet.D6', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)),
('mobilenet.D7', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)),
('mobilenet.D8', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)),
('mobilenet.D9', Workload(env.BATCH, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)),
]

@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x

def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group):

CI_G = CI // groups
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI_G//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)

data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)

with tvm.target.vta():
res = topi.nn.group_conv2d_nchw(
data,
kernel,
strides,
padding,
dilation,
groups,
env.acc_dtype)
res = topi.right_shift(res, env.WGT_WIDTH)
res = topi.add(res, bias)
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype)

if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_group_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])

return s, [data, kernel, bias, res]

if __name__ == '__main__':

# Logging config (for printing tuning log to the screen)
logging.basicConfig()

# Tuning log files
log_file = "%s.group_conv2d.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)

# Get tracker info from env
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()

for idx, (wl_name, wl) in enumerate(mobilenet_wkls):
prefix = "[Task %2d/%2d] " % (idx, len(mobilenet_wkls))

# Read in workload parameters
N = wl.batch
CI = wl.in_filter
H = wl.height
W = wl.width
CO = wl.out_filter
KH = wl.hkernel
KW = wl.wkernel
strides = (wl.hstride, wl.wstride)
padding = (wl.hpad, wl.wpad)
dilation = (1, 1)
groups = wl.groups

# Create task
task = autotvm.task.create(
group_conv2d,
args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, groups),
target=tvm.target.vta(),
target_host=env.target_host,
template_key='direct')
print(task.config_space)

# Tune
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET, host=tracker_host, port=int(tracker_port),
number=5, timeout=60,
check_correctness=True))

# Run Tuner
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])

# Pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)
Loading

0 comments on commit 5c0b608

Please sign in to comment.