Skip to content

Commit

Permalink
AVX schedule for conv_NCHW[x]c (apache#1143)
Browse files Browse the repository at this point in the history
* add conv2d_NCHWc compute template

* add conv_NCHWc compute decl and schedules

* allow default avx schedule

* fix lint

* remove unused schedule

* remove import nnvm.reg

* pass schedule object to compute and schedule
  • Loading branch information
yzhliu authored and sergei-mironov committed Aug 8, 2018
1 parent 6baccb0 commit e0ccb78
Show file tree
Hide file tree
Showing 4 changed files with 451 additions and 42 deletions.
55 changes: 55 additions & 0 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
raise ValueError("not support this layout {} yet".format(layout))


@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos):
"""Change Conv2D layout.
Parameters
----------
attrs : nnvm.top.AttrDict
Attributes of current convolution
inputs : nnvm.symbol
Grouped input symbols
tinfos : list
Input shape and dtype
"""
# not to change by default
return None


def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
Expand Down Expand Up @@ -425,6 +442,44 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
name="Conv2dOutput", tag="conv2d_nhwc")
return Output

@tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block]
num_filter : int
number of filters, i.e., output channel size
kernel_size : tuple of two ints
[kernel_height, kernel_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# search platform specific declaration first
# default declaration
raise ValueError("missing register for topi.nn.conv2d_NCHWc")

# map from schedule type to declaration function
_SCH_TO_DECL_FUNC = {
SpatialPack: _spatial_pack,
Expand Down
221 changes: 181 additions & 40 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,112 @@
from .. import generic, tag
from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, _get_workload, _get_schedule, _WORKLOADS
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, Workload

from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd

_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}

_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}

@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
_WORKLOADS_AVX = [
# workloads of resnet18_v1 on imagenet
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('float32', 'float32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]

fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16

_SCHEDULES_AVX_NCHW = [
# float32 resnet-18
_SCHEDULES_AVX = [
# workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 14, False),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConvCommonFwd(16, fp32_vec_len, 14, True),
AVXConvCommonFwd(16, 32, 7, True),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConvCommonFwd(16, fp32_vec_len, 7, True),
# float32 mobilenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]

sch = _SCHEDULES_AVX_NCHW[idx]
if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch


@conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}
out_dtype = data.dtype if out_dtype is None else out_dtype
target = tvm.target.current_target(allow_none=False)
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW':
if 'avx' in str(target) and layout == 'NCHW':
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW':
Expand All @@ -81,9 +122,63 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
raise ValueError("not support this layout {} yet".format(layout))


@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfos):
import nnvm.symbol as sym
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
# only optimize for NCHW, groups=1 conv
if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1:
return None

data = tinfos[0]
kernel = tinfos[1]

import ast
padding = ast.literal_eval(attrs['padding'])
stride = ast.literal_eval(attrs['strides'])

wkl = _get_workload(data, kernel, stride, padding, data.dtype)
sch = _get_schedule_conv(wkl)
is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd)
ic_bn, oc_bn = sch.ic_bn, sch.oc_bn

new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn

if is_kernel_1x1:
# (oc, ic, h, w) -> (OC, IC, ic, oc, h, w)
new_attrs['kernel_layout'] = 'OI%di%doHW' % (ic_bn, oc_bn)
else:
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)

return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)


@conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
}
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
kh, kw = kernel_size
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)


@generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)

Expand Down Expand Up @@ -213,3 +308,49 @@ def traverse(op):

traverse(output_op)
return s


@generic.schedule_conv2d_NCHWc.register(["cpu"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
s = tvm.create_schedule([x.op for x in outs])

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)

if 'conv2d_NCHWc' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]

n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype)

kh, kw = kernel_size
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)

wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0])

traverse(outs[0].op)
return s
Loading

0 comments on commit e0ccb78

Please sign in to comment.