diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index d652977924ca..70594678c08e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -548,6 +548,34 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target): OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_compute("nn.contrib_conv2d_NCHWc_int8") +def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target): + """Compute definition of conv2d NCHWc""" + # pylint: disable=assignment-from-no-return + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + data_layout = attrs.get_str("data_layout") + out_layout = attrs.get_str("out_layout") + out_dtype = attrs.get_str("out_dtype") + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + + out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation, + data_layout, out_layout, out_dtype) + return [out] + + +@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8") +def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target): + """Schedule definition of contrib_conv2d_NCHWc_int8""" + with target: + return topi.generic.schedule_conv2d_NCHWc_int8(outs) + + +reg.register_pattern("nn.contrib_conv2d_NCHWc_int8", + OpPattern.OUT_ELEMWISE_FUSABLE) + + @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): """Compute definition of depthwise conv2d NCHWc""" diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 19c50d6dc700..6d48f678f134 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1340,6 +1340,72 @@ def contrib_depthwise_conv2d_nchwc(data, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) +def contrib_conv2d_nchwc_int8(data, + kernel, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW8c", + kernel_layout="OIHW", + out_layout="", + out_dtype=""): + r"""Variant of 2D convolution. It deals with only int8 inputs. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output, following a specialized + NCHWc data layout. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + kernel : tvm.relay.Expr + The kernel expressions. + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + def contrib_conv2d_winograd_weight_transform(weight, tile_size): r"""Weight Transformation part for 2D convolution with winograd algorithm. diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 2f59fb9db19c..270451dc31c4 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -570,6 +570,54 @@ weight transformation in advance. .set_support_level(10) .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); +// Positional relay function to create conv2d NCHWc operator +// used by frontend FFI. +Expr MakeConv2DNCHWcInt8(Expr data, + Expr kernel, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = channels; + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc_int8"); + return CallNode::make(op, {data, kernel}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc_int8") +.set_body_typed(MakeConv2DNCHWcInt8); + + +RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8") +.describe(R"code(Compute conv2d with NCHWc data layout with int8 inputs. +- **data**: Input is 5D packed tensor. +- **weight**: 7D packed tensor. + +- **out**: Output is 5D packed tensor +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.Conv2D") +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(10) +.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel) +.set_attr("FInferCorrectLayout", + Conv2DInferCorrectLayout); // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 8fbedec3fef1..c2cb2b27c5f1 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -107,6 +107,25 @@ def schedule_conv2d_NCHWc(outs): return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv2d_NCHWc_int8(outs): + """Schedule for conv2d_NCHW[x]c_int8 + + Parameters + ---------- + outs : Array of Tensor + The computation graph description of conv2d_NCHWc_int8 + in the format of an array of tensors. + The number of filter, i.e., the output channel. + + Returns + ------- + sch : Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_conv2d_winograd_weight_transform(outs): """Schedule for weight transformation of winograd diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 05580c894b53..e2a4720f8b7e 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -398,27 +398,75 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou output : tvm.Tensor 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] """ - # search platform specific declaration first - # default declaration + + return conv2d_NCHWc_compute(data, + kernel, + stride, + padding, + dilation, + layout, + out_layout, + out_dtype) + + +def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): + """Conv2D operator compute for nChw[x]c layout. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + in_channel_block, num_filter_block] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload - pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, - (dilated_kernel_h, - dilated_kernel_w)) - HPAD = pad_top + pad_down - WPAD = pad_left + pad_right - HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) - dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) - assert (dh, dw) == (1, 1), "Does not support dilation" + HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ + else (dilation, dilation) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + target = tvm.target.current_target(allow_none=False) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ + get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 # output shape - out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 - out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1 + out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) # DOPAD @@ -433,13 +481,194 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou kw = tvm.reduce_axis((0, kernel_width), name='kw') return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, - ic%ic_bn].astype(out_dtype) * - kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], + tvm.sum(data_pad[n, + ic // ic_bn, + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + ic % ic_bn].astype(out_dtype) + * kernel[oc_chunk, + ic // ic_bn, + kh, + kw, + ic % ic_bn, + oc_block], axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc") +@tvm.target.generic_func +def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layout, + out_dtype='int32'): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.Tensor + 7-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4, + num_filter_block, 4] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + return conv2d_NCHWc_int8_compute(data, + kernel, + strides, + padding, + dilation, + layout, + out_layout, + out_dtype) + + +def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, out_layout, + out_dtype='int32'): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.Tensor + 7-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4, + num_filter_block, 4] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ + else (dilation, dilation) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + target = tvm.target.current_target(allow_none=False) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ + get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group + + # Since the weight is 7-D and the last element size is 4, we have to + # check ic_bn should be a multiple of 4. + # Similary, oc_bn has to be a multiple of 4. + + assert ic_bn % 4 == 0 + assert oc_bn % 16 == 0 + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + # output shape + out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + + # DOPAD + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + + if groups == 1: + n_elems = 4 + ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, + ic_outer, + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) + * kernel[oc_chunk, + ic_outer, + kh, + kw, + ic_f_inner, + oc_block, + ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # for int8 group conv support + n_elems = 4 + ic_chunk = in_channel//ic_bn + ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: + tvm.sum(data_pad[n, + (occ * oc_bn // (oc_chunk * oc_bn // groups)) + * (ic_chunk // groups) + ic_outer, + oh * HSTR + kh, + ow * WSTR + kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) + * kernel[occ, + ic_outer, + kh, + kw, + ic_f_inner, + oc_block, + ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + + def conv2d_winograd_weight_transform(kernel, tile_size): """Weight transformation for winograd diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 85798270ba7f..8594ac3431c9 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -6,6 +6,7 @@ from .binarize_pack import schedule_binarize_pack from .binary_dense import schedule_binary_dense from .nn import * +from .conv2d_int8 import * from .injective import * from .pooling import schedule_pool, schedule_adaptive_pool from .bitserial_conv2d import schedule_bitserial_conv2d diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index a5b9cc99c2ec..766fc34d3883 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -263,58 +263,6 @@ def traverse(op): traverse(outs[0].op) return s - -@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct']) -def schedule_conv2d_nhwc_pack(cfg, outs): - """Create schedule for tensors""" - s = tvm.create_schedule([x.op for x in outs]) - output_op = outs[0].op - scheduled_ops = [] - - def traverse(op): - """Traverse operators from computation graph""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - else: # inject custom schedule - if len(op.axis) == 4: # schedule bias + bn + relu - n, h, w, c = op.axis - fused = s[op].fuse(n, h, w) - s[op].parallel(fused) - s[op].vectorize(c) - for tensor in op.input_tensors: - if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: - traverse(tensor.op) - - if 'conv2d_nhwc_pack_int8' in op.tag: - conv_out = op.output(0) - kernel = conv_out.op.input_tensors[1] - data_vec = conv_out.op.input_tensors[0] - data = data_vec.op.input_tensors[0] \ - if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ - else data_vec - if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: - data_pad = data - data = data_pad.op.input_tensors[0] - - args = [s, cfg, data_vec, conv_out, outs[0]] - if data.dtype == 'uint8': - kh, kw, _, _, _ = get_const_tuple(kernel.shape) - if kh == 1 and kw == 1: - conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) - else: - raise ValueError("Only support 1x1 kernel with " - "schedule_conv2d_nhwc_pack.") - else: - raise ValueError("Not support this data type {} with " - "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype)) - - scheduled_ops.append(op) - traverse(output_op) - return s - - @generic.schedule_conv2d_nhwc.register("cpu") def schedule_conv2d_nhwc(outs): """Create schedule for tensors""" @@ -482,51 +430,53 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) dispatch_ctx.update(target, new_workload, cfg) - else: - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - # Convert kernel data layout from 4D to 7D - n_elems = 4 - out_channel, _, kh, kw = get_const_tuple(kernel.shape) - data_expr, kernel_expr = inputs - kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0)) - kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) - kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) - kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn)) - kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn//n_elems, n_elems)) - kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) - copy_inputs = [data_expr, kernel_OIHWioe] - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn, - in_channel//ic_bn, ic_bn//n_elems, - n_elems)) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, - new_attrs[layout_name], new_attrs['out_layout'], out_dtype], - conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) - else: - out_channel, _, kh, kw = get_const_tuple(kernel.shape) - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], - new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - dispatch_ctx.update(target, new_workload, cfg) - - if is_depthwise: if F.__name__ == 'nnvm.symbol': logging.warning("Use native layout for depthwise convolution on NNVM.") return None return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs) - else: + + if _is_int8_hw_support(data.dtype, kernel.dtype, target): + # Convert kernel data layout from 4D to 7D + n_elems = 4 + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + data_expr, kernel_expr = inputs + kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0)) + kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) + kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) + kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn)) + kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, n_elems)) + kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) + copy_inputs = [data_expr, kernel_OIHWioe] + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, + n_elems)) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, + new_attrs[layout_name], new_attrs['out_layout'], out_dtype], + conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) if F.__name__ == 'nnvm.symbol': - return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + logging.warning("Use native layout for int8 convolution on NNVM.") + return None + return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs) + + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + + if F.__name__ == 'nnvm.symbol': + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) @conv2d_infer_layout.register("cpu") @@ -549,95 +499,27 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype): # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload - HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) - HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ - else (dilation, dilation) - n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - target = tvm.target.current_target(allow_none=False) - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ - get_const_tuple(kernel.shape) - else: - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn - groups = ic_chunk // ic_chunk_group - - dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + # If no config was set, we can fallback to NCHW config. if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), dtype=kernel.dtype), strides, padding, out_dtype) - # output shape - out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 - out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 - oshape = (n, oc_chunk, out_height, out_width, oc_bn) - - # DOPAD - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - else: - data_pad = data - - ic = tvm.reduce_axis((0, in_channel), name='ic') - kh = tvm.reduce_axis((0, kernel_height), name='kh') - kw = tvm.reduce_axis((0, kernel_width), name='kw') - - if _is_int8_hw_support(data.dtype, kernel.dtype, target) and groups == 1: - assert out_dtype == "int32", \ - "INT8 convolution requires input dtype = uint8 and output dtype=int32" - # Intel performs dot product of 2 "4" Int8 values - # Current implementation requires ic_bn to be a multiple of 4 - n_elems = 4 - assert ic_bn % n_elems == 0 - - ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') - ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') - ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') - return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh*dilation_h, - ow*WSTR+kw*dilation_w, - ic_f_inner * n_elems + ic_s_inner] - .astype(out_dtype) * - kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, - oc_block, ic_s_inner].astype(out_dtype), - axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), - name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - # for int8 group conv support - n_elems = 4 - ic_chunk = in_channel//ic_bn - ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') - ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') - ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') - oshape = (n, oc_chunk, out_height, out_width, oc_bn) - return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: - tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\ - (ic_chunk//groups)+ic_outer, - oh*HSTR+kh, ow*WSTR+kw, - ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * - kernel[occ, ic_outer, kh, kw, ic_f_inner, - oc_block, ic_s_inner].astype(out_dtype), - axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), - name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - - # else: fp implementation - return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh*dilation_h, - ow*WSTR+kw*dilation_w, - ic%ic_bn].astype(out_dtype) * - kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], - axis=[ic, kh, kw]), - name='conv2d_NCHWc', tag="conv2d_NCHWc") + return nn.conv2d_NCHWc_compute(data, + kernel, + strides, + padding, + dilation, + layout, + out_layout, + out_dtype) @autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct']) @@ -669,19 +551,11 @@ def traverse(op): args = [s, cfg, data_vec, conv_out, outs[0]] target = tvm.target.current_target(allow_none=False) - if _is_int8_hw_support(data.dtype, kernel.dtype, target): - # int8 conv kernel is 7-dim - _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) - if kh == 1 and kw == 1: - conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args) - else: - conv2d_avx_common._schedule_conv_NCHWc_int8(*args) + _, _, kh, kw, _, _, = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_NCHWc(*args) else: - _, _, kh, kw, _, _, = get_const_tuple(kernel.shape) - if kh == 1 and kw == 1: - conv2d_avx_1x1._schedule_conv_NCHWc(*args) - else: - conv2d_avx_common._schedule_conv_NCHWc(*args) + conv2d_avx_common._schedule_conv_NCHWc(*args) scheduled_ops.append(op) diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py new file mode 100644 index 000000000000..d02ddaef389a --- /dev/null +++ b/topi/python/topi/x86/conv2d_int8.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on x86""" + +import tvm +from tvm import autotvm +from .. import generic, tag +from ..util import get_const_tuple +from ..nn.conv2d import conv2d_NCHWc_int8 +from .. import nn +from .conv2d import _get_default_config +from . import conv2d_avx_1x1, conv2d_avx_common + +@autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct') +def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides, + padding, dilation, layout, out_layout, out_dtype): + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \ + get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + + # If config is not set, we can reuse the default config for NCHW. + if cfg.is_fallback: + _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), + dtype=kernel.dtype), + strides, padding, out_dtype) + return nn.conv2d_NCHWc_int8_compute(data, + kernel, + strides, + padding, + dilation, + layout, + out_layout, + out_dtype) + + +@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, 'cpu', ['direct']) +def _schedule_conv2d_NCHWc_int8(cfg, outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv2d_NCHWc_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + target = tvm.target.current_target(allow_none=False) + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args) + else: + conv2d_avx_common._schedule_conv_NCHWc_int8(*args) + + scheduled_ops.append(op) + + traverse(outs[0].op) + return s + +@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct']) +def schedule_conv2d_nhwc_pack(cfg, outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 4: # schedule bias + bn + relu + n, h, w, c = op.axis + fused = s[op].fuse(n, h, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv2d_nhwc_pack_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + if data.dtype == 'uint8': + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) + else: + raise ValueError("Only support 1x1 kernel with " + "schedule_conv2d_nhwc_pack.") + else: + raise ValueError("Not support this data type {} with " + "schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype)) + + scheduled_ops.append(op) + traverse(output_op) + return s