diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index abe63e583ddc..5f1ee2f31cc5 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -187,6 +187,17 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { + int tile_rows; + int tile_cols; + + TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") { + TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm."); + TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm."); + } +}; + /*! \brief Attributes used in convolution operators with winograd algorithm */ struct Conv2DWinogradAttrs : public tvm::AttrsNode { int tile_size; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 1c76f57a6343..564d6f762b3f 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -446,6 +446,23 @@ def compute_mirror_pad(attrs, inputs, out_dtype): reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) +# conv2d_gemm related operators +reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform", + strategy.conv2d_gemm_without_weight_transform_strategy) +reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform") +def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype): + """Compute definition of contrib_conv2d_gemm_weight_transform""" + out = topi.nn.conv2d_gemm_weight_transform( + inputs[0], attrs.tile_rows, attrs.tile_cols) + return [out] + +reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform", + strategy.schedule_conv2d_gemm_weight_transform) +reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 34d07dce2863..3c47cf7919b5 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2046,6 +2046,74 @@ def contrib_conv2d_winograd_without_weight_transform(data, kernel_layout, out_layout, out_dtype) +def contrib_conv2d_gemm_without_weight_transform(data, + weight, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="", + out_dtype=""): + r"""2D convolution with gemm algorithm. + + The basic parameters are the same as the ones in vanilla conv2d. + It assumes the weight is pre-transformed by nn.contrib_conv2d_gemm_weight_transform + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # convert 2-way padding to 4-way padding + padding = get_pad_tuple2d(padding) + return _make.contrib_conv2d_gemm_without_weight_transform( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + def contrib_conv2d_nchwc(data, kernel, strides=(1, 1), @@ -2204,6 +2272,29 @@ def contrib_conv2d_winograd_weight_transform(weight, return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) +def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols): + r"""Weight Transformation part for 2D convolution with gemm algorithm. + + We separate this as a single op to enable pre-compute for inference. + Use this together with nn.contrib_conv2d_gemm_without_weight_transform + + Parameters + ---------- + weights : tvm.relay.Expr + The weight expressions. + tile_rows: int + Tile rows of the weight transformation for ConvGemm. + tile_cols: int + Tile columns of the weight transformation for ConvGemm. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols) + + def contrib_conv3d_winograd_weight_transform(weight, tile_size): r"""Weight Transformation part for 3D convolution with winograd algorithm. diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 6bdec67617e1..d682aad63bec 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -112,6 +112,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd), name='conv2d_direct_simd.micro_dev') elif kernel_layout == "HWIO": + is_aarch64 = "aarch64" in str(isa.target) + + if is_aarch64 and data.dtype in ["int8", "uint8"]: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized), + name="conv2d_NHWC_quantized.arm_cpu") + strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), @@ -246,6 +254,40 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out format(layout)) return strategy +def wrap_compute_conv2d_gemm(topi_compute): + """wrap topi compute for conv2d_gemm""" + + def _compute_conv2d_gemm(attrs, inputs, out_type): + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + out_dtype = attrs.get_str("out_dtype") + channels = attrs['channels'] + kernel_size = attrs['kernel_size'] + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + return [topi_compute(inputs[0], inputs[1], strides, padding, + dilation, out_dtype, kernel_size, channels)] + + return _compute_conv2d_gemm + +@conv2d_gemm_without_weight_transform_strategy.register("arm_cpu") +def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transfrom arm cpu strategy""" + layout = attrs.data_layout + data = inputs[0] + strategy = _op.OpStrategy() + + if layout == "NHWC" and data.dtype in ['int8', 'uint8']: + strategy.add_implementation( + wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized), + name="conv2d_NHWC_quantized_without_transform.arm_cpu") + else: + raise RuntimeError( + "Unsupported conv2d_gemm_without_weight_transform layout {0} with datatype {1}". + format(layout, data.dtype)) + return strategy + @conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d_transpose arm cpu strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index b1fb421c3e2e..a0dd6bfe7b15 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -266,6 +266,12 @@ def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, t """conv2d_winograd_without_weight_transfrom generic strategy""" raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform") +# conv2d_gemm_without_weight_transform +@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy") +def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target): + """conv2d_gemm_without_weight_transfrom generic strategy""" + raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform") + # conv2d_winograd_weight_transform @generic_func def schedule_conv2d_winograd_weight_transform(attrs, outs, target): @@ -280,6 +286,13 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): with target: return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) +# conv2d_gemm_weight_transform +@generic_func +def schedule_conv2d_gemm_weight_transform(attrs, outs, target): + """Schedule conv2d_gemm_weight_transform""" + with target: + return topi.generic.schedule_conv2d_gemm_weight_transform(outs) + # deformable_conv2d def wrap_compute_deformable_conv2d(topi_compute): """wrap deformable_conv2d topi compute""" diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index d3b0e44a1a13..72462141258c 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -237,6 +237,11 @@ def is_fast_int8_on_arm(): target = tvm.target.Target.current(allow_none=False) return '+v8.2a,+dotprod' in ' '.join(target.options) +def is_aarch64_arm(): + """ Checks whether we are compiling for an AArch64 target. """ + target = tvm.target.Target.current(allow_none=False) + return 'aarch64' in ' '.join(target.options) + ######################## # ARM CPU legalizations. ######################## @@ -244,10 +249,11 @@ def is_fast_int8_on_arm(): @qnn_conv2d_legalize.register('arm_cpu') def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): # ARM prefers the dtypes to be same. - if is_fast_int8_on_arm(): + if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm(): return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) + @qnn_dense_legalize.register('arm_cpu') def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): # ARM prefers the dtypes to be same. diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 6c6eb1ecb8b2..f63c48915f25 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -77,6 +77,26 @@ Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array st return Call(op, {data, weight}, Attrs(attrs), {}); } +template +Expr MakeConvGemm(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, std::string kernel_layout, + std::string out_layout, DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; @@ -84,6 +104,14 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_ return Call(op, {weight}, Attrs(attrs), {}); } +Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) { + auto attrs = make_object(); + attrs->tile_rows = tile_rows; + attrs->tile_cols = tile_cols; + const Op& op = Op::Get(op_name); + return Call(op, {weight}, Attrs(attrs), {}); +} + template Expr MakeConvTranspose(Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, @@ -504,6 +532,60 @@ weight transformation in advance. .set_support_level(10) .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); +// relay.nn.contrib_conv2d_gemm_without_weight_transform +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform") + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConvGemm( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform"); + }); + +RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform") + .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout. + This operator assumes the weight tensor is already pre-transformed by + nn.contrib_conv2d_gemm_weight_transform. + +- **data**: Input is 4D array of shape (batch_size, height, width, in_channels) +- **weight**: Any shape + We do not check the shape for this input tensor. Since different backend + has different layout strategy. + +- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DGemm", Conv2DGemmRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + +// relay.nn.contrib_conv2d_gemm_weight_transform + +TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); + +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform") + .set_body_typed([](Expr weights, int tile_rows, int tile_cols) { + return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols, + "nn.contrib_conv2d_gemm_weight_transform"); + }); + +RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform") + .describe(R"code(Weight transformation of GEMM convolution algorithm. + +Separate this into another operator in order to enable Precompute Pass to compute the +weight transformation in advance. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weights", "Tensor", "The weights tensor.") + .set_support_level(10) + .add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel); + // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 0c5b20a153cf..f53f4e0454a4 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -383,6 +383,65 @@ inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_i return true; } +// Gemm convolution shape relations +// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. +// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and +// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] +// matrix that we call W_interleaved_t. +// +// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed +// for tile_rows = 4 and tile_cols = 16 +// +// W[0,0,:,:] W_interleaved_t[0,0,:,:] +// +-------------------------------+ +----------------------------------- + +// |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]| +// |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]| +// |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]| +// | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]| +// | ... ... ... ... | +------------------------------------+ +// |W[15,0] W[15,1] W[15,2] W[15,3]| +// +-------------------------------+ +// +// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements +// at the time, we should set tile_cols = k. +// Tile rows is connected with the number of registers available for the given target. +// +inline bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* weight = types[0].as(); + if (weight == nullptr) return false; + + const ConvGemmWeightTransformAttrs* param = attrs.as(); + CHECK(param != nullptr); + int n = param->tile_rows; + int k = param->tile_cols; + + CHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout"; + + const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2]; + const auto N = weight->shape[3]; + + auto K_mod_k = indexmod(K, k); + auto N_mod_n = indexmod(N, n); + + auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32))); + auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32))); + + const auto N_padded = N + pad_N; + const auto K_padded = K + pad_K; + + Array oshape{ + indexdiv(N_padded, n), + indexdiv(K_padded, k), + n, + k, + }; + + reporter->Assign(types[1], TensorType(oshape, weight->dtype)); + return true; +} + inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); @@ -519,6 +578,78 @@ bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& at return true; } +template +bool Conv2DGemmRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + static const Layout kNHWC("NHWC"); + static const Layout kHWIO("HWIO"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NHWC." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from HWIO." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NHWC." + << " But got " << out_layout; + + Array dshape_nhwc = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + CHECK(param->kernel_size.defined() && param->channels.defined()) + << "The kernel size and channels of a Conv must be set or inferred by previous pass"; + + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + + channels = param->channels; + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + + // NOTE: Do not check weight shape here! + + // dilation + Array oshape({dshape_nhwc[0], 0, 0, channels}); + + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + if (!dshape_nhwc[2].as()) { + oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1); + } else { + oshape.Set(1, dshape_nhwc[1]); + } + if (!dshape_nhwc[3].as()) { + oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1); + } else { + oshape.Set(2, dshape_nhwc[2]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + template bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py index 3206168d51bd..99fdf21d5bc0 100644 --- a/topi/python/topi/arm_cpu/conv2d_alter_op.py +++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py @@ -59,10 +59,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - # We only perform layout alteration for NCHW data layout. - if data_layout == "NHWC": - return None - # Extract data types data_tensor, kernel_tensor = tinfos data_dtype = data_tensor.dtype @@ -70,6 +66,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): idxd = tvm.tir.indexdiv + # We don't perform layout alteration for NHWC layout with real data types + if data_layout == "NHWC" and data_dtype not in ['uint8', 'int8']: + return None + if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu": assert data_layout == "NCHW" and kernel_layout == "OIHW" N, CI, H, W = get_const_tuple(data.shape) @@ -88,21 +88,27 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.conv2d(*inputs, **new_attrs) if topi_tmpl == "conv2d_nhwc_spatial_pack.arm_cpu": + assert (data.dtype == 'int8' and kernel.dtype == 'int8' or + data.dtype == 'uint8' and kernel.dtype == 'uint8') + assert data_layout == "NHWC" and kernel_layout == "HWIO" - N, H, W, CI = get_const_tuple(data.shape) - KH, KW, _, CO = get_const_tuple(kernel.shape) - VC = cfg['tile_co'].size[-1] - new_attrs['kernel_layout'] = 'OHWI%do' % VC + data_expr, kernel_expr = inputs + + data_int16 = relay.cast(data_expr, dtype='int16') + kernel_int16 = relay.cast(kernel_expr, dtype='int16') + + new_attrs = {k : attrs[k] for k in attrs.keys()} + + new_data = te.placeholder(data.shape, 'int16') + new_kernel = te.placeholder(kernel.shape, 'int16') - new_data = data - new_kernel = te.placeholder((idxd(CO, VC), KH, KW, CI, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, out_dtype], - "conv2d_nhwc_spatial_pack.arm_cpu") + 'conv2d_nhwc_spatial_pack.arm_cpu') dispatch_ctx.update(target, new_workload, cfg) - return relay.nn.conv2d(*inputs, **new_attrs) + return relay.nn.conv2d(data_int16, kernel_int16, **new_attrs) if topi_tmpl == "conv2d_nchw_winograd.arm_cpu": assert data_layout == "NCHW" and kernel_layout == "OIHW" @@ -235,5 +241,40 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): new_attrs['out_layout'], out_dtype], topi_tmpl) dispatch_ctx.update(target, new_workload, cfg) return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs) + if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu": + assert (data.dtype == 'int8' and kernel.dtype == 'int8' or + data.dtype == 'uint8' and kernel.dtype == 'uint8') + assert data_layout == "NHWC" and kernel_layout == "HWIO" + CO, IC, KH, KW = get_const_tuple(kernel.shape) + K = KH * KW * IC + N = CO + + tile_rows = 4 + tile_cols = 16 + pad_K = 0 + pad_N = 0 + + if N % tile_rows != 0: + pad_N = tile_rows - (N % tile_rows) + if K % tile_cols != 0: + pad_k = tile_cols - (K % tile_cols) + + N_padded = N + pad_N + K_padded = K + pad_K + kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols) + new_kernel = te.placeholder((N_padded // tile_rows, + K_padded // tile_cols, + tile_rows, + tile_cols), kernel.dtype) + + new_workload = autotvm.task.args_to_workload([data, new_kernel, + strides, padding, dilation, + out_dtype, (KH, KW), CO], + "conv2d_NHWC_int8_without_tranform.arm_cpu") + dispatch_ctx.update(target, new_workload, cfg) + + return relay.nn.contrib_conv2d_gemm_without_weight_transform(inputs[0], + kernel_expr, + **new_attrs) return None diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py new file mode 100644 index 000000000000..2b6122919d85 --- /dev/null +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=unused-argument, redefined-builtin +"""GEMM Convolution schedule on ARM""" +import tvm +from tvm import te +from topi import nn +from ..util import get_const_tuple +from ..nn.util import get_pad_tuple +from .tensor_intrin import gemv_quantized, gemv_quantized_impl + + +# Compute function +def compute_conv2d_gemm_without_weight_transform(cfg, + data, B_interleaved_t, strides, padding, dilation, + out_dtype, kernel_size, output_channels): + """Compute conv2d by transforming the input, + executing GEMM and transforming the output back""" + batches, IH, IW, IC = get_const_tuple(data.shape) + + KH, KW = kernel_size + OC = output_channels + + K_AREA = KH * KW + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = \ + get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + if pad_top or pad_left: + data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], + name="data_pad") + else: + data_pad = data + + # --- Im2col + M = OH * OW + K = IC * K_AREA + N = OC + + A_shape = (batches, M, K) + if K_AREA == 1: + A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y], + name='data_flatten') + else: + A = te.compute(A_shape, lambda n, x, y: + data_pad[n, + HSTR * (x // OW) + dilation_h * (y // IC) // KW, + WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC], + name='data_im2col') + N_transformed = B_interleaved_t.shape[0] + + # --- Pad if necessary + idxm = tvm.tir.indexmod + + pad_m = 0 + pad_k = 0 + + if M % 4 != 0: + pad_m = 4 - (M % 4) + + if K % 16 != 0: + pad_k = 16 - (K % 16) + + M_padded = M + pad_m + K_padded = K + pad_k + + pad_before = (0, 0, 0) + pad_after = (0, pad_m, pad_k) + + if pad_m != 0 or pad_k != 0: + A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") + + # --- GEMM: A*B' + k = te.reduce_axis((0, K_padded), "k") + + A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16), + lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y], + name='A_interleaved') + + C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4), + lambda b, x, y, w, z: + te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)* + B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype), + axis=k), + name='C_interleaved') + + # --- Unpack C + C = te.compute((batches, M, N), + lambda b, x, y: + C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], + name="C", tag='injective') + + # --- Produce the conv output + out_shape = (batches, OH, OW, OC) + out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), + name='conv2d_gemm_output') + + return out + +# Schedules +def schedule_conv2d_gemm(cfg, s, out): + """Create schedule for tensors""" + C = out.op.input_tensors[0] + C_interleaved = C.op.input_tensors[0] + A_interleaved = C_interleaved.op.input_tensors[0] + + # Input transform + A_interleaved_input = A_interleaved.op.input_tensors[0] + if A_interleaved_input.op.name == "A_padded": + s[A_interleaved_input].compute_at(s[A_interleaved], A_interleaved.op.axis[3]) + s[A_interleaved_input].vectorize(A_interleaved_input.op.axis[2]) + s[A_interleaved_input].compute_inline() + data_im2col = A_interleaved_input.op.input_tensors[0] + else: + data_im2col = A_interleaved_input + + b, m, n = data_im2col.op.axis + if data_im2col.op.name == "data_im2col": + n_outer, n_inner = s[data_im2col].split(n, 16) + s[data_im2col].unroll(n_outer) + s[data_im2col].vectorize(n_inner) + else: + s[data_im2col].compute_inline() + + # Computation(through tensorize) + b, xo, yo, xi, yi = C_interleaved.op.axis + s[C_interleaved].reorder(xo, yo, yi, xi) + s[C_interleaved].parallel(xo) + s[A_interleaved].compute_at(s[C_interleaved], xo) + s[A_interleaved].vectorize(A_interleaved.op.axis[4]) + + in_type = A_interleaved.dtype + out_type = C.dtype + if out_type == 'int32': + K = A_interleaved_input.shape[2] + _, M, N = C.shape + assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported" + + gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type) + s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type)) + s[C_interleaved].tensorize(yi, gem_v_dotprod) + + # Output transform + N, OH, OW, OC = out.shape + s[C].split(C.op.axis[1], OW) + s[C].compute_at(s[out], out.op.axis[3]) + + return s diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py index 06412b656b4b..5a895c084c06 100644 --- a/topi/python/topi/arm_cpu/conv2d_int8.py +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -19,11 +19,12 @@ from tvm import te from tvm import autotvm from .. import tag -from ..util import get_const_tuple +from ..util import traverse_inline, get_const_tuple from ..generic import conv2d as conv2d_generic from .. import nn from ..nn.conv2d import _get_workload as _get_conv2d_workload from .tensor_intrin import dot_int8_int8_int32 +from .conv2d_gemm import compute_conv2d_gemm_without_weight_transform, schedule_conv2d_gemm def _get_default_config(cfg, data, kernel, strides, padding, out_dtype): @@ -109,3 +110,38 @@ def traverse(op): traverse(outs[0].op) return s + + +@autotvm.register_topi_compute("conv2d_NHWC_quantized.arm_cpu") +def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation, out_dtype): + N, IH, IW, IC = get_const_tuple(data.shape) + KH, KW, _, OC = get_const_tuple(kernel.shape) + tile_rows = 4 + tile_cols = 16 + kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols) + return compute_conv2d_gemm_without_weight_transform(cfg, + data, kernel, strides, padding, + dilation, out_dtype, (KH, KW), OC) + + +@autotvm.register_topi_compute("conv2d_NHWC_quantized_without_transform.arm_cpu") +def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, padding, + dilation, out_dtype, kernel_size=None, + output_channels=None): + return compute_conv2d_gemm_without_weight_transform(cfg, data, B, strides, padding, + dilation, out_dtype, kernel_size, + output_channels) + + +@autotvm.register_topi_schedule("conv2d_NHWC_quantized.arm_cpu") +def schedule_conv2d_NHWC_quantized(cfg, outs): + """Create schedule for tensors""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + """Traverse operators from computation graph""" + if op.name == "conv2d_gemm_output": + schedule_conv2d_gemm(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index bab91578e77e..cf56a06c326a 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -19,6 +19,345 @@ import tvm from tvm import te +from tvm.contrib import util, clang + +def gemv_quantized_impl(M, N, data_type='uint8'): + """ Assembly implementation of a blocked gemv. Given + a block a of shape (4, k) and a block b' of shape (4, k) + produces the output block c = a*b of shape (4,4) """ + + stepA = min(4, M) + stepB = min(4, N) + assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation' + + cc_code = """ + extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer, + unsigned char *a_buffer, + unsigned char *b_buffer, + int K, int m, int n) + """.format(data_type, stepA, stepB) + + cc_code += """ + { + unsigned char * a_ptr = a_buffer; + unsigned char * b_ptr = b_buffer; + int * c_ptr = c_buffer; + + int k = K / 16; + + __asm__ __volatile__ ( + "movi v16.4s, #0\\n" + "movi v17.4s, #0\\n" + "movi v18.4s, #0\\n" + "movi v19.4s, #0\\n" + "movi v20.4s, #0\\n" + "movi v21.4s, #0\\n" + "movi v22.4s, #0\\n" + "movi v23.4s, #0\\n" + "movi v24.4s, #0\\n" + "movi v25.4s, #0\\n" + "movi v26.4s, #0\\n" + "movi v27.4s, #0\\n" + "movi v28.4s, #0\\n" + "movi v29.4s, #0\\n" + "movi v30.4s, #0\\n" + "movi v31.4s, #0\\n" + "1:" + """ + + cc_code += ' "ldr q0, [%[a_ptr]]\\n" ' + + if M > 1: + cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" ' + else: + cc_code += ' "movi v1.4s, #0\\n" ' + + if M > 2: + cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" ' + else: + cc_code += ' "movi v2.4s, #0\\n" ' + + if M > 3: + cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" ' + else: + cc_code += ' "movi v3.4s, #0\\n" ' + + cc_code += ' "ldr q4, [%[b_ptr]]\\n" ' + + if N > 1: + cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" ' + + if N > 2: + cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" ' + + if N > 3: + cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" ' + + cc_code += """ + // First half + // Higher part of a0 * {b0,b1,b2,b3} + "umull v8.8h, v0.8b, v4.8b\\n" + "umull v9.8h, v0.8b, v5.8b\\n" + "umull v10.8h, v0.8b, v6.8b\\n" + "umull v11.8h, v0.8b, v7.8b\\n" + + // Higher part of a1 * {b0,b1,b2,b3} + "umull v12.8h, v1.8b, v4.8b\\n" + "umull v13.8h, v1.8b, v5.8b\\n" + "umull v14.8h, v1.8b, v6.8b\\n" + "umull v15.8h, v1.8b, v7.8b\\n" + + // Accumulate + "uadalp v16.4s, v8.8h\\n" + "uadalp v17.4s, v9.8h\\n" + "uadalp v18.4s, v10.8h\\n" + "uadalp v19.4s, v11.8h\\n" + "uadalp v20.4s, v12.8h\\n" + "uadalp v21.4s, v13.8h\\n" + "uadalp v22.4s, v14.8h\\n" + "uadalp v23.4s, v15.8h\\n" + + // Lower part of a0 * {b0,b1,b2,b3} + "umull2 v8.8h, v0.16b, v4.16b\\n" + "umull2 v9.8h, v0.16b, v5.16b\\n" + "umull2 v10.8h, v0.16b, v6.16b\\n" + "umull2 v11.8h, v0.16b, v7.16b\\n" + + // Lower part of a1 * {b0,b1,b2,b3} + "umull2 v12.8h, v1.16b, v4.16b\\n" + "umull2 v13.8h, v1.16b, v5.16b\\n" + "umull2 v14.8h, v1.16b, v6.16b\\n" + "umull2 v15.8h, v1.16b, v7.16b\\n" + + // Accumulate again + "uadalp v16.4s, v8.8h\\n" + "uadalp v17.4s, v9.8h\\n" + "uadalp v18.4s, v10.8h\\n" + "uadalp v19.4s, v11.8h\\n" + "uadalp v20.4s, v12.8h\\n" + "uadalp v21.4s, v13.8h\\n" + "uadalp v22.4s, v14.8h\\n" + "uadalp v23.4s, v15.8h\\n" + + // Second half + + // Lower part of a2 * {b0,b1,b2,b3} + "umull v8.8h, v2.8b, v4.8b\\n" + "umull v9.8h, v2.8b, v5.8b\\n" + "umull v10.8h, v2.8b, v6.8b\\n" + "umull v11.8h, v2.8b, v7.8b\\n" + + // Lower part of a3 * {b0,b1,b2,b3} + "umull v12.8h, v3.8b, v4.8b\\n" + "umull v13.8h, v3.8b, v5.8b\\n" + "umull v14.8h, v3.8b, v6.8b\\n" + "umull v15.8h, v3.8b, v7.8b\\n" + + // Accumulate + "uadalp v24.4s, v8.8h\\n" + "uadalp v25.4s, v9.8h\\n" + "uadalp v26.4s, v10.8h\\n" + "uadalp v27.4s, v11.8h\\n" + "uadalp v28.4s, v12.8h\\n" + "uadalp v29.4s, v13.8h\\n" + "uadalp v30.4s, v14.8h\\n" + "uadalp v31.4s, v15.8h\\n" + + // Higher part of a2 * {b0,b1,b2,b3} + "umull2 v8.8h, v2.16b, v4.16b\\n" + "umull2 v9.8h, v2.16b, v5.16b\\n" + "umull2 v10.8h, v2.16b, v6.16b\\n" + "umull2 v11.8h, v2.16b, v7.16b\\n" + + // Higher part of a3 * {b0,b1,b2,b3} + "umull2 v12.8h, v3.16b, v4.16b\\n" + "umull2 v13.8h, v3.16b, v5.16b\\n" + "umull2 v14.8h, v3.16b, v6.16b\\n" + "umull2 v15.8h, v3.16b, v7.16b\\n" + + // Accumulate again + "uadalp v24.4s, v8.8h\\n" + "uadalp v25.4s, v9.8h\\n" + "uadalp v26.4s, v10.8h\\n" + "uadalp v27.4s, v11.8h\\n" + "uadalp v28.4s, v12.8h\\n" + "uadalp v29.4s, v13.8h\\n" + "uadalp v30.4s, v14.8h\\n" + "uadalp v31.4s, v15.8h\\n" + """ + blockA = min(64, M * 16) + blockB = min(64, N * 16) + + cc_code += """ + // Increment pointers and decrement k + "add %[a_ptr], %[a_ptr], #{0}\\n" + "add %[b_ptr], %[b_ptr], #{1}\\n" + "subs %w[k], %w[k], #1\\n" + """.format(blockA, blockB) + + stepC = min(4, N) + + cc_code += """ + "cbnz %w[k], 1b\\n" + + // Final additions + + // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d) + // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h) + // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l) + // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p) + "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h) + "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p) + "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d) + // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h) + // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l) + // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p) + "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h) + "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p) + "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d) + // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h) + // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l) + // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p) + "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b, c+d, e+f, g+h) + "addp v25.4s, v26.4s, v27.4s\\n" // v25 = (i+j, k+l, m+n, o+p) + "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d) + // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h) + // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l) + // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p) + "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h) + "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p) + "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) + + "str q16, [%[c_ptr]]\\n" + """ + + if M > 1: + cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4) + + if M > 2: + cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8) + + if M > 3: + cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12) + + cc_code += """ + : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k) + : + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31" + ); + return 0; + } + """ + + if data_type == 'int8': + cc_code = cc_code.replace('unsigned char', 'char') + cc_code = cc_code.replace('umull', 'smull') + cc_code = cc_code.replace('uadalp', 'sadalp') + + temp = util.tempdir() + ll_path = temp.relpath("temp.ll") + # Create LLVM ir from c source code + ll_code = clang.create_llvm(cc_code, + options=["--target=aarch64-linux-gnu -mattr=+neon"], + output=ll_path) + return ll_code + + +def gemv_quantized(M, N, K, in_type, out_type): + """ + Use integer ARM v8 instructions in order to produce a block c of 4x4 elements + given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final + result is c = a*b (where '*' indicates the matrix product) + + Every row of the matrix c is obtained (for uint8) by a sequence of + + umull -> uadalp -> umull2 -> uadalp + + The block size is constrained by the number of registers available in arvm8. This + function returns a TensorIntrin that can be used to tensorize + a schedule. + + Parameters + ---------- + M: int + rows of the matrix A + N: int + columns of the matrix B + K: int + columns of matrix A + in_type: str, {'uint8', 'int8'} + out_type: str, {'uint32', 'int32'} + + Returns + ------- + intrin : TensorIntrin + The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule + """ + A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name='A') + B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name='B') + + idxm = tvm.tir.indexmod + + k = te.reduce_axis((0, K), "k") + + C = te.compute((te.var("m"), te.var("n")), + lambda x, y: te.sum(A[k // 16, x, idxm(k, 16)].astype(out_type) * + B[k // 16, y, idxm(k, 16)].astype(out_type), + axis=k), name="C") + + a_buffer = tvm.tir.decl_buffer(A.shape, dtype=in_type, name="a_buffer", + offset_factor=1, strides=[te.var('sa_1'), te.var('sa_2'), 1]) + + b_buffer = tvm.tir.decl_buffer(B.shape, dtype=in_type, name="b_buffer", + offset_factor=1, strides=[te.var('sb_1'), te.var('sb_2'), 1]) + + c_buffer = tvm.tir.decl_buffer(C.shape, dtype=out_type, name="c_buffer", + offset_factor=1, strides=[te.var('sc'), 1]) + + def _intrin_func(ins, outs): + + def _instr(): + ib = tvm.tir.ir_builder.create() + aa, bb = ins + cc = outs[0] + stepA = min(4, M) + stepB = min(4, N) + + if in_type == 'int8': + ib.emit(tvm.tir.call_extern("int32", + "gemv_int8_int8_int32_{0}_{1}".format(stepA, stepB), + outs[0].access_ptr("w"), + a_buffer.access_ptr("r"), + b_buffer.access_ptr("r"), + K)) + else: + ib.emit(tvm.tir.call_extern("int32", + "gemv_uint8_uint8_int32_{0}_{1}".format(stepA, stepB), + c_buffer.access_ptr("w"), + a_buffer.access_ptr("r"), + b_buffer.access_ptr("r"), + K, + C.shape[0], # m, very useful for debug + C.shape[1])) # n, very useful for debug + return ib.get() + + # body, reset, update + return _instr() + + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin(C.op, _intrin_func, + binds={A:a_buffer, B:b_buffer, C:c_buffer}, + default_buffer_params=buffer_params) + def dot_int8_int8_int32(int32_lanes, dtype='uint'): """ diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 767087b0d4f0..7645588f2d35 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -187,6 +187,25 @@ def schedule_conv2d_winograd_weight_transform(outs): return s +def schedule_conv2d_gemm_weight_transform(outs): + """Schedule for weight transformation of gemm + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of this operator + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + # Typically this is computed in PreCompute pass + s = te.create_schedule([x.op for x in outs]) + return s + + def schedule_conv3d_winograd_weight_transform(outs): """Schedule for weight transformation of 3D winograd diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 4c7941b49692..59288892ebaa 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -590,6 +590,55 @@ def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layou name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") +def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): + """Weight transformation for winograd + + Parameters + ---------- + kernel: Tensor + The raw kernel tensor with layout "NHWC". + tile_rows: int + Tile rows of the weight transformation for ConvGemm. + tile_cols: int + Tile columns of the weight transformation for ConvGemm. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [CI*KH*KW,CO] + """ + KH, KW, IC, OC = get_const_tuple(kernel.shape) + K = KH * KW * IC + N = OC + + kernel_flat = te.compute((K, N), lambda x, y: + kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], + 'weight_flatten') + + pad_K = 0 + pad_N = 0 + + if N % tile_rows != 0: + pad_N = tile_rows - (N % tile_rows) + + if K % tile_cols != 0: + pad_k = tile_cols - (K % tile_cols) + + N_padded = N + pad_N + K_padded = K + pad_K + + if pad_K != 0 or pad_N != 0: + kernel_flat = pad(kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), + name='weight_padding') + + return te.compute((N_padded // tile_rows, + K_padded // tile_cols, + tile_rows, + tile_cols), lambda x, y, z, w: + kernel_flat[w + tile_cols * y, z + tile_rows * x], + name='weight_block_reshape') + + def conv2d_winograd_weight_transform(kernel, tile_size): """Weight transformation for winograd