Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Setting up AutoTVM template for Intel Int8 conv2D #3955

Merged
merged 1 commit into from
Sep 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, allow_duplicate=False):
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
Expand All @@ -100,6 +101,7 @@ def __init__(self, allow_duplicate=False):
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
Expand All @@ -111,6 +113,7 @@ def __init__(self, allow_duplicate=False):
self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
topi.nn.conv2d_NCHWc_int8: lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")



def conv2d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for winograd

Expand Down
44 changes: 25 additions & 19 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple, get_shape
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
Expand Down Expand Up @@ -77,7 +77,6 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
else:
conv2d_avx_common._fallback_schedule(cfg, wkl)


def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
Expand All @@ -92,19 +91,15 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*k_ic_s
else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk
assert ic_bn == k_ic_bn
ic = ic_chunk*ic_bn
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk
assert ic_bn == k_ic_bn
ic = ic_chunk*ic_bn
oc = oc_chunk*oc_bn
else:
raise ValueError("Not support this layout {} with "
"schedule template.".format(layout))

is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
Expand Down Expand Up @@ -444,14 +439,25 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
in_channel//ic_bn, ic_bn//n_elems, n_elems))
kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
copy_inputs = [data_expr, kernel_OIHWioe]
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn//n_elems,
n_elems))
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation,
new_attrs[layout_name], new_attrs['out_layout'], out_dtype],
conv2d_NCHWc)

# Store altered operator's config. New kernel layout OIHWio4
new_kernel = tvm.placeholder((out_channel // oc_bn,
in_channel // ic_bn,
kh,
kw,
ic_bn // n_elems,
oc_bn,
n_elems), dtype=kernel.dtype)

new_workload = autotvm.task.args_to_workload([new_data,
new_kernel,
strides,
padding,
dilation,
new_attrs[layout_name],
new_attrs['out_layout'],
out_dtype],
conv2d_NCHWc_int8)
dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for int8 convolution on NNVM.")
Expand Down
104 changes: 91 additions & 13 deletions topi/python/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,108 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on x86"""

import re
import tvm
from tvm import autotvm
from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn
from .conv2d import _get_default_config
from . import conv2d_avx_1x1, conv2d_avx_common

def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
pat = re.compile(r'NCHW.+(\d+)c')
if layout == 'NCHW':
n, ic, h, w = dshape
oc, _, kh, kw = kshape
elif layout == 'NHWC':
n, h, w, ic = dshape
kh, kw, oc, _ = kshape
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False)
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk * ic_bn
assert ic == k_ic * k_ic_f * k_ic_s
oc = oc_chunk*oc_bn
else:
raise ValueError("Not support this layout {} with "
"schedule template.".format(layout))

is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (h - kh + 2 * ph) // sh + 1
ow = (w - kw + 2 * pw) // sw + 1

# Create schedule config
cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
cfg.define_split('tile_oc', oc, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
if is_kernel_1x1:
cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
else:
cfg.define_knob("unroll_kw", [True, False])


# Define template function for autotvm task
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@autotvm.task.register("topi_x86_conv2d_NCHWc_int8")
def _topi_nn_conv2d_NCHWc_int8(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)

if len(args) == 7:
data, kernel, strides, padding, dilation, origin_layout, dtype = args
else:
assert len(args) == 8
data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args

raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape)

# get config here
cfg = get_config()
_create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, origin_layout)

# change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1])

data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn

# Set up the new shape for data and kernel
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
n_elems = 4
new_kernel_shape = (raw_kernel_shape[0] // oc_bn,
raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2],
raw_kernel_shape[3],
ic_bn // n_elems,
oc_bn,
n_elems)

new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)

C = _declaration_conv_NCHWc_int8(cfg, new_data, new_kernel, strides, padding, dilation,
data_layout, out_layout, dtype)
s = _schedule_conv2d_NCHWc_int8(cfg, [C])
return s, [new_data, new_kernel, C]


@autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct')
def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn

# If config is not set, we can reuse the default config for NCHW.
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
return nn.conv2d_NCHWc_int8_compute(data,
kernel,
strides,
Expand Down