-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add packing for int8 1x1 convolution and support the int8 group convolution on X86 #2991
Changes from 11 commits
8313af9
dbebc13
3768887
4adc7c9
13f809d
5887943
5a2eea7
d270413
a8f75f7
c6ca741
e336121
813443f
96c87e2
b1fa040
6656a43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,8 @@ | |
|
||
logger = logging.getLogger('topi') | ||
|
||
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): | ||
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, | ||
layout='NCHW'): | ||
""" | ||
Get default schedule config for the workload | ||
""" | ||
|
@@ -46,7 +47,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth | |
from .depthwise_conv2d import _fallback_schedule | ||
_fallback_schedule(cfg, wkl) | ||
else: | ||
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) | ||
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) | ||
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 | ||
if is_kernel_1x1: | ||
conv2d_avx_1x1._fallback_schedule(cfg, wkl) | ||
|
@@ -62,6 +63,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): | |
if layout == 'NCHW': | ||
n, ic, h, w = dshape | ||
oc, _, kh, kw = kshape | ||
elif layout == 'NHWC': | ||
n, h, w, ic = dshape | ||
oc, _, kh, kw = kshape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
elif pat.match(layout) is not None: | ||
n, ic_chunk, h, w, ic_bn = dshape | ||
if data.dtype == 'uint8': | ||
|
@@ -93,12 +97,14 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): | |
cfg.define_knob("unroll_kw", [True, False]) | ||
|
||
|
||
@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') | ||
@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) | ||
def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): | ||
out_dtype = data.dtype if out_dtype is None else out_dtype | ||
padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) | ||
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) | ||
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) | ||
|
||
_, _, kh, kw = get_const_tuple(kernel.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so, for data=NHWC & fp32, kernel=HWIO, while for data=NHWC & int8, kernel=OIHW? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I will spend some time this week unifying them. |
||
if layout == 'NCHW': | ||
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) | ||
if cfg.is_fallback: | ||
|
@@ -107,7 +113,13 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out | |
padding, dilation, layout, out_dtype) | ||
if layout == 'HWCN': | ||
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) | ||
if layout == 'NHWC': | ||
elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": | ||
if cfg.is_fallback: | ||
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) | ||
# specialize for INT8 1X1 conv on X86 | ||
return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, | ||
padding, dilation, out_dtype) | ||
elif layout == 'NHWC': | ||
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) | ||
raise ValueError("not support this layout {} yet".format(layout)) | ||
|
||
|
@@ -226,6 +238,58 @@ def traverse(op): | |
return s | ||
|
||
|
||
@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct']) | ||
def schedule_conv2d_nhwc_pack(cfg, outs): | ||
"""Create schedule for tensors""" | ||
s = tvm.create_schedule([x.op for x in outs]) | ||
output_op = outs[0].op | ||
scheduled_ops = [] | ||
|
||
def traverse(op): | ||
"""Traverse operators from computation graph""" | ||
# inline all one-to-one-mapping operators except the last stage (output) | ||
if tag.is_broadcast(op.tag): | ||
if op not in s.outputs: | ||
s[op].compute_inline() | ||
else: # inject custom schedule | ||
if len(op.axis) == 4: # schedule bias + bn + relu | ||
n, h, w, c = op.axis | ||
fused = s[op].fuse(n, h, w) | ||
s[op].parallel(fused) | ||
s[op].vectorize(c) | ||
for tensor in op.input_tensors: | ||
if tensor.op.input_tensors and tensor.op not in scheduled_ops: | ||
traverse(tensor.op) | ||
|
||
if 'conv2d_nhwc_pack_int8' in op.tag: | ||
conv_out = op.output(0) | ||
kernel = conv_out.op.input_tensors[1] | ||
data_vec = conv_out.op.input_tensors[0] | ||
data = data_vec.op.input_tensors[0] \ | ||
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ | ||
else data_vec | ||
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: | ||
data_pad = data | ||
data = data_pad.op.input_tensors[0] | ||
|
||
args = [s, cfg, data_vec, conv_out, outs[0]] | ||
if data.dtype == 'uint8': | ||
# int8 conv kernel is 7-dim | ||
kh, kw, _, _, _ = get_const_tuple(kernel.shape) | ||
if kh == 1 and kw == 1: | ||
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) | ||
else: | ||
raise ValueError("Only support 1x1 kernel with " | ||
"schedule template.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make the fatal msg more detailed other than just "schedule template" |
||
else: | ||
raise ValueError("Not support this data type {} with " | ||
"schedule template.".format(data.dtype)) | ||
|
||
scheduled_ops.append(op) | ||
traverse(output_op) | ||
return s | ||
|
||
|
||
@generic.schedule_conv2d_nhwc.register("cpu") | ||
def schedule_conv2d_nhwc(outs): | ||
"""Create schedule for tensors""" | ||
|
@@ -422,10 +486,13 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, | |
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) | ||
in_channel = ic_chunk * ic_bn | ||
if data.dtype == 'uint8': | ||
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) | ||
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ | ||
get_const_tuple(kernel.shape) | ||
else: | ||
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) | ||
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ | ||
get_const_tuple(kernel.shape) | ||
num_filter = oc_chunk * oc_bn | ||
groups = ic_chunk // ic_chunk_group | ||
|
||
if cfg.is_fallback: | ||
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), | ||
|
@@ -449,7 +516,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, | |
kh = tvm.reduce_axis((0, kernel_height), name='kh') | ||
kw = tvm.reduce_axis((0, kernel_width), name='kw') | ||
|
||
if data.dtype == 'uint8': | ||
if data.dtype == 'uint8' and groups == 1: | ||
assert out_dtype == "int32", \ | ||
"INT8 convolution requires input dtype = uint8 and output dtype=int32" | ||
# Intel performs dot product of 2 "4" Int8 values | ||
|
@@ -468,6 +535,24 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, | |
oc_block, ic_s_inner].astype(out_dtype), | ||
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), | ||
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") | ||
if data.dtype == 'uint8': | ||
# for int8 group conv support | ||
n_elems = 4 | ||
ic_chunk = in_channel//ic_bn | ||
ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') | ||
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') | ||
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') | ||
oshape = (n, oc_chunk, out_height, out_width, oc_bn) | ||
return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: | ||
tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\ | ||
(ic_chunk//groups)+ic_outer, | ||
oh*HSTR+kh, ow*WSTR+kw, | ||
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * | ||
kernel[occ, ic_outer, kh, kw, ic_f_inner, | ||
oc_block, ic_s_inner].astype(out_dtype), | ||
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), | ||
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") | ||
|
||
# else: fp implementation | ||
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: | ||
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use
kernel_layout
instead? asdata_layout
might not be necessarily binded tokernal_layout
.I'm actually a bit confused with the int8 conv layout, for NHWC data, what kernal layout is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was mainly following the data layout and kernel layout corresponding relationship here: https://github.com/dmlc/tvm/blob/147ea3b0ca147b527086228d524a2f68f872112d/topi/python/topi/nn/conv2d.py#L284