Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add packing for int8 1x1 convolution and support the int8 group convolution on X86 #2991

Merged
merged 15 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def schedule_conv2d_nchw(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_nhwc_pack(outs):
"""Schedule for conv2d_nhwc_pack

Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nhwc_pack
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_nhwc(outs):
"""Schedule for conv2d_nhwc
Expand Down
26 changes: 20 additions & 6 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

# workload description of conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

@tvm.target.generic_func
def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
Expand Down Expand Up @@ -95,19 +95,33 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
return None


def _get_workload(data, kernel, stride, padding, out_dtype):
def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
if data_layout == 'NCHW':
_, CI, IH, IW = [x.value for x in data.shape]
elif data_layout == 'NHWC':
_, IH, IW, CI = [x.value for x in data.shape]
elif data_layout == 'HWCN':
IH, IW, CI, _ = [x.value for x in data.shape]
else:
raise ValueError("not support this layout {} yet".format(data_layout))


if data_layout == 'NHWC':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use kernel_layout instead? as data_layout might not be necessarily binded to kernal_layout.
I'm actually a bit confused with the int8 conv layout, for NHWC data, what kernal layout is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was mainly following the data layout and kernel layout corresponding relationship here: https://github.com/dmlc/tvm/blob/147ea3b0ca147b527086228d524a2f68f872112d/topi/python/topi/nn/conv2d.py#L284

KH, KW, CO, CIG = [x.value for x in kernel.shape]
else:
CO, CIG, KH, KW = [x.value for x in kernel.shape]

HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
GRPS = CI // CIG
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)


def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
Expand Down
99 changes: 92 additions & 7 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

logger = logging.getLogger('topi')

def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
layout='NCHW'):
"""
Get default schedule config for the workload
"""
Expand All @@ -46,7 +47,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
from .depthwise_conv2d import _fallback_schedule
_fallback_schedule(cfg, wkl)
else:
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, wkl)
Expand All @@ -62,6 +63,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
if layout == 'NCHW':
n, ic, h, w = dshape
oc, _, kh, kw = kshape
elif layout == 'NHWC':
n, h, w, ic = dshape
oc, _, kh, kw = kshape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kh, kw, oc, _ ?

elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
if data.dtype == 'uint8':
Expand Down Expand Up @@ -93,12 +97,14 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
cfg.define_knob("unroll_kw", [True, False])


@autotvm.register_topi_compute(conv2d, 'cpu', 'direct')
@autotvm.register_topi_compute(conv2d, 'cpu', ['direct'])
def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)

_, _, kh, kw = get_const_tuple(kernel.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, for data=NHWC & fp32, kernel=HWIO, while for data=NHWC & int8, kernel=OIHW?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will spend some time this week unifying them.

if layout == 'NCHW':
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
if cfg.is_fallback:
Expand All @@ -107,7 +113,13 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
padding, dilation, layout, out_dtype)
if layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
if layout == 'NHWC':
elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8":
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout)
# specialize for INT8 1X1 conv on X86
return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
padding, dilation, out_dtype)
elif layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))

Expand Down Expand Up @@ -226,6 +238,58 @@ def traverse(op):
return s


@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct'])
def schedule_conv2d_nhwc_pack(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
n, h, w, c = op.axis
fused = s[op].fuse(n, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_nhwc_pack_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]

args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
# int8 conv kernel is 7-dim
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
else:
raise ValueError("Only support 1x1 kernel with "
"schedule template.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make the fatal msg more detailed other than just "schedule template"

else:
raise ValueError("Not support this data type {} with "
"schedule template.".format(data.dtype))

scheduled_ops.append(op)
traverse(output_op)
return s


@generic.schedule_conv2d_nhwc.register("cpu")
def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors"""
Expand Down Expand Up @@ -422,10 +486,13 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
else:
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group

if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
Expand All @@ -449,7 +516,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')

if data.dtype == 'uint8':
if data.dtype == 'uint8' and groups == 1:
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
Expand All @@ -468,6 +535,24 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
if data.dtype == 'uint8':
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\
(ic_chunk//groups)+ic_outer,
oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[occ, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")

# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
Expand Down
106 changes: 104 additions & 2 deletions topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import tvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from ..nn.util import infer_pad
from ..util import get_const_tuple
from ..nn.pad import pad
from ..nn.util import infer_pad, get_pad_tuple
from ..util import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len
Expand Down Expand Up @@ -251,3 +252,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
s[O].parallel(parallel_axis)

return s


def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
# more assertion for the shapes
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation

batch, in_height, in_width, in_channel = Input.shape
kernel_h, kernel_w, num_filter, channel = Filter.shape

# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# todo: padding filter to accomodate the intrinsic

# packing the Filter to let memory access be consecutive for AVX512 intrinsic
# Done in pre-compute stage
packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4)
PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e],
name="packed_filter")

rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
return Output


def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
"""
Defines the schedule for the int8 nhwc layout. For 1x1 conv, it
is a matrix-multiply operation by using nhwc layout. We will do
packing of weight to make the address access be friendly to int8
intrinsic
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1

# assertion to fail the unhandled case
_, _, _, ic_num = get_const_tuple(data.shape)
_, _, _, oc_num = get_const_tuple(conv_out.shape)
assert ic_num % 4 == 0
assert oc_num % 16 == 0

ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ih, iw, ic = s[A].op.axis
d_ic_chunk, d_ic_block = s[A].split(ic, factor=4)
s[A].vectorize(d_ic_block)

C, O = conv_out, last

batch, oh, ow, oc = s[C].op.axis
kh, kw, ic = s[C].op.reduce_axis
# match the x86 intrinsic
ic_outer, ic_inner = s[C].split(ic, factor=4)
oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes)

ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor)
s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner)

pc = dot_16x1x16_int8_int8_int32()
s[C].tensorize(oc_inner, pc)

if C != O:
batch, last_oh, last_ow, last_oc = s[O].op.axis
oc_chunk, oc_block = s[O].split(ochannel, 16)
# not saw perf improvement to split oh/ow here
s[O].vectorize(oc_block)

return s
Loading