diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h new file mode 100644 index 000000000000..2a7376b72e64 --- /dev/null +++ b/include/tvm/relay/attrs/bitserial.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/bitserial.h + * \brief Auxiliary attributes for bitserial operators. + */ + +#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_ +#define TVM_RELAY_ATTRS_BITSERIAL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attributes used in bitpack operators */ +struct BitPackAttrs : public tvm::AttrsNode { + int bits; + int pack_axis; + int bit_axis; + DataType pack_type; + std::string name; + + TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") { + TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with."); + TVM_ATTR_FIELD(pack_axis).set_default(1).describe( + "Axis that should be compressed, typically channels."); + TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits."); + TVM_ATTR_FIELD(pack_type) + .set_default(NullValue()) + .describe("Type of int to pack bits into."); + TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation."); + } +}; + +/*! \brief Attribues used in bitserial convolution operators */ +struct BinaryConv2DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + IndexExpr channels; + Array kernel_size; + int activation_bits; + int weight_bits; + std::string data_layout; + std::string kernel_layout; + DataType pack_dtype; + DataType out_dtype; + bool unipolar; + + TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") { + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero the input is implicitly zero-padded" + "on both sides for padding number of points."); + TVM_ATTR_FIELD(kernel_size) + .set_default(Array({3, 3})) + .describe("Specifies the dimensions of the convolution window."); + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe("Number of output channels, needed for shape inference."); + TVM_ATTR_FIELD(activation_bits) + .set_default(1) + .describe("Number of bits activation should be packed with."); + TVM_ATTR_FIELD(weight_bits) + .set_default(1) + .describe("Number of bits kernel should be packed with."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe("Dimension ordering of input data, can be 'NCHW' or NHWC'."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'."); + TVM_ATTR_FIELD(pack_dtype) + .set_default(NullValue()) + .describe("Datatype to pack bits into."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output datatype."); + TVM_ATTR_FIELD(unipolar).set_default(true).describe( + "Whether to use unipolar or bipolar quantization."); + } +}; + +/*~ \brief Attributes for bitserial dense operator */ +struct BinaryDenseAttrs : public tvm::AttrsNode { + IndexExpr units; + int data_bits; + int weight_bits; + DataType pack_dtype; + DataType out_dtype; + bool unipolar; + + TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") { + TVM_ATTR_FIELD(units) + .describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(data_bits) + .set_default(1) + .describe("Number of bits to pack for incoming tensor."); + TVM_ATTR_FIELD(weight_bits) + .set_default(1) + .describe("Number of bits to pack for weight tensor."); + TVM_ATTR_FIELD(pack_dtype) + .set_default(NullValue()) + .describe("Datatype to pack bits into before computation."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type."); + TVM_ATTR_FIELD(unipolar) + .set_default(true) + .describe("Whether to use unipolar or bipolar quantization for inputs."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_BITSERIAL_H_ diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 03a04c951d59..d652977924ca 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -600,3 +600,120 @@ def schedule_deformable_conv2d(attrs, outs, target): reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +@reg.register_compute("nn.bitpack") +def compute_bitpack(attrs, inputs, out_dtype, target): + """Compute definition for bitpack""" + bits = attrs.bits + pack_axis = attrs.pack_axis + bit_axis = attrs.bit_axis + pack_type = attrs.pack_type + name = attrs.name + with target: + out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, + name) + return [out] + +@reg.register_schedule("nn.bitpack") +def schedule_bitpack(attrs, outs, target): + with target: + return topi.generic.schedule_bitpack(outs) + +reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE) + + +@reg.register_compute("nn.bitserial_conv2d") +def compute_bitserial_conv2d(attrs, inputs, out_dtype, target): + """Compute definition for bitserial conv2d.""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + activation_bits = attrs.activation_bits + weight_bits = attrs.weight_bits + layout = attrs.data_layout + pack_dtype = attrs.pack_dtype + out_dtype = attrs.out_dtype + unipolar = attrs.unipolar + if layout == 'NCHW': + with target: + out = topi.nn.bitserial_conv2d_nchw( + inputs[0], inputs[1], strides, padding, activation_bits, + weight_bits, pack_dtype, out_dtype, unipolar) + elif layout == 'NHWC': + with target: + out = topi.nn.bitserial_conv2d_nhwc( + inputs[0], inputs[1], strides, padding, activation_bits, + weight_bits, pack_dtype, out_dtype, unipolar) + else: + raise ValueError("Data layout not supported.") + + return [out] + + +@reg.register_schedule("nn.bitserial_conv2d") +def schedule_bitserial_conv2d(attrs, outs, target): + """Schedule definition for bitserial conv2d.""" + layout = attrs.data_layout + if layout == 'NCHW': + with target: + return topi.generic.schedule_bitserial_conv2d_nchw(outs) + elif layout == 'NHWC': + with target: + return topi.generic.schedule_bitserial_conv2d_nhwc(outs) + else: + raise ValueError("Data layout not supported.") + +@reg.register_legalize("nn.bitserial_conv2d") +def legalize_bitserial_conv2d(attrs, inputs, types): + """Legalize bitserial_conv2d op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types) + + +reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# bitserial_dense +@reg.register_compute("nn.bitserial_dense") +def compute_bitserial_dense(attrs, inputs, out_type, target): + """Compute definition of bitserial_dense""" + data_bits = attrs.data_bits + weight_bits = attrs.weight_bits + pack_dtype = attrs.pack_dtype + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + unipolar = attrs.unipolar + return [ + topi.nn.bitserial_dense( + inputs[0], + inputs[1], + data_bits, + weight_bits, + pack_dtype, + out_dtype, + unipolar) + ] + + +@reg.register_schedule("nn.bitserial_dense") +def schedule_bitserial_dense(attrs, outputs, target): + """Schedule definition of bitserial_dense""" + with target: + return topi.generic.schedule_bitserial_dense(outputs) + + +reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 946ea335e0db..19c50d6dc700 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1459,3 +1459,165 @@ def deformable_conv2d(data, return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) + + +def bitpack(data, + bits=1, + pack_axis=1, + bit_axis=2, + pack_type="uint32", + name="BitPack"): + r"""Tensor packing for bitserial operations. + The values along the input tensor's pack_axis are quantized + and packed together into the specified pack_type in a new + bit axis. + + For example, consider bitpacking with data to be a tensor with shape [1, 64, 128, 128], + pack_axis=1, bit_axis=4, pack_type=uint8, and bits=2. The output in this case will + be of shape [1, 8, 128, 128, 2]. The dimension of axis 1 has been reduced by a factor + of 8 since each value is packed into an 8-bit uint8. Axis 4 is now two bitplanes + representing the quantized value of the incoming data. The output tensor is now + ready to be used in a bitserial operation. + + Parameters + ---------- + data : tvm.relay.expr + The incoming tensor to be packed. + + bits : int + Number of bits that should be packed. + + pack_axis : int + Axis that should be decomposed and packed. + + bit_axis : int + New axis containing bitplane. + + pack_type : str + Datatype to pack bits into. + + name : str, optional + Name of the operation. + + Returns + ------- + result : tvm.relay.Expr + The packed tensor. + """ + return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name) + + +def bitserial_conv2d(data, + weight, + strides=(1, 1), + padding=(0, 0), + channels=None, + kernel_size=(3, 3), + activation_bits=1, + weight_bits=1, + data_layout='NCHW', + kernel_layout='OIHW', + pack_dtype='uint32', + out_dtype='int16', + unipolar=True): + r"""2D convolution using bitserial computation. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + activation_bits : int + Number of bits to pack for activations. + + weight_bits : int + Number of bits to pack for weights. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the kernel + + pack_dtype: str, optional + Datatype to pack bits into. + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _make.bitserial_conv2d(data, weight, strides, padding, channels, + kernel_size, activation_bits, weight_bits, + data_layout, kernel_layout, pack_dtype, + out_dtype, unipolar) + + +def bitserial_dense(data, + weight, + units=None, + data_bits=1, + weight_bits=1, + pack_dtype='uint32', + out_dtype='int16', + unipolar=True): + """Bitserial Dense operator. + Applies matrix multiplication of two quantized matrices + using a fast bitserial algorithm. + + .. math:: + + `Y = X * W` + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + units : int, optional + Number of hidden units of the dense transformation. + + data_bits : int + Number of bits incoming tensor should be packed with. + + weight_bits : int + Number of bits weight tensor should be packed with. + + pack_dtype : str, optional + Datatype to pack individual bits into before computation. + + out_dtype : str, optional + Specifies the output data type for mixed precision dense. + + unipolar : bool, optional + Whether to use unipolar or bipolar quantization for inputs. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.bitserial_dense(data, weight, units, data_bits, weight_bits, + pack_dtype, out_dtype, unipolar) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 48d3d2032f80..11f8ad1611cd 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -264,3 +264,18 @@ class MaxPool2DAttrs(Attrs): @register_relay_attr_node class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" + + +@register_relay_attr_node +class BitPackAttrs(Attrs): + """Attributes used in bitpack operator""" + + +@register_relay_attr_node +class BinaryConv2DAttrs(Attrs): + """Attributes used in bitserial conv2d operators""" + + +@register_relay_attr_node +class BinaryDenseAttrs(Attrs): + """Attributes used in bitserial dense operators""" diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc new file mode 100644 index 000000000000..6ee1ee675c06 --- /dev/null +++ b/src/relay/op/nn/bitserial.cc @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file bitserial.cc + * \brief Property def of bitserial operators. + */ + +#include +#include +#include + +#include "../../pass/alter_op_layout.h" + +namespace tvm { +namespace relay { + +// relay.nn.bitpack +TVM_REGISTER_NODE_TYPE(BitPackAttrs); + +template +Array> BinaryConv2DInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + const T* params = attrs.as(); + + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs + return Array>{{params->data_layout, params->kernel_layout}, {params->data_layout}}; +} + +bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const BitPackAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data); + int ndim = data->shape.size(); + int bits = param->bits; + int pack_axis = param->pack_axis; + int bit_axis = param->bit_axis; + DataType pack_type = param->pack_type; + + int pack_bits = pack_type.bits(); + + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i == bit_axis) { + out_shape.push_back(bits); + if (i == pack_axis) { + out_shape.push_back(data->shape[i] / pack_bits); + } else { + out_shape.push_back(data->shape[i]); + } + } else if (i == pack_axis) { + out_shape.push_back(data->shape[i] / pack_bits); + } else { + out_shape.push_back(data->shape[i]); + } + } + // Add extra check for last axis expansion. + if (bit_axis == ndim) { + out_shape.push_back(bits); + } + + reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type)); + return true; +} + +Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, + std::string name) { + auto attrs = make_node(); + attrs->bits = bits; + attrs->pack_axis = pack_axis; + attrs->bit_axis = bit_axis; + attrs->pack_type = pack_type; + attrs->name = name; + static const Op& op = Op::Get("nn.bitpack"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); + +RELAY_REGISTER_OP("nn.bitpack") + .describe(R"code(Bitpack layer that prepares data for bitserial operations. + +This layer backs the bits of an input into a single datatype, allowing +efficient implementation of bitserial operations. + +- **data**: Input tensor of any shape, dimension that is to be + packed must be divisible by number of bits. +- **out**: Packed tensor with shape appropriately compressed. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_attrs_type_key("relay.attrs.BitPackAttrs") + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("BitPack", BitPackRel); + +// relay.nn.bitserial_conv2d +TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); + +bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const BinaryConv2DAttrs* param = attrs.as(); + CHECK(param != nullptr); + + static const Layout kNCHW("NCHW"); + + const Layout in_layout(param->data_layout); + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + CHECK(param->channels.defined()); + CHECK(param->kernel_size.defined()); + Array oshape({dshape_nchw[0], param->channels, 0, 0}); + oshape.Set( + 2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set( + 3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); + DataType out_dtype = param->out_dtype; + oshape = trans_in_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +// Positional relay function to create binaryconv2d operator +// used by frontend FFI. +Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array padding, + IndexExpr channels, Array kernel_size, int activation_bits, + int weight_bits, std::string data_layout, std::string kernel_layout, + DataType pack_dtype, DataType out_dtype, bool unipolar) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->activation_bits = activation_bits; + attrs->weight_bits = weight_bits; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->pack_dtype = std::move(pack_dtype); + attrs->out_dtype = std::move(out_dtype); + attrs->unipolar = unipolar; + static const Op& op = Op::Get("nn.bitserial_conv2d"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D); + +RELAY_REGISTER_OP("nn.bitserial_conv2d") + .describe(R"code(2D convolution using packed binary computation. + +This layer creates a convolution kernel that is convolved with the +layer input using bitserial computation. This enables faster processing +on some platforms. + +- **data**: 4D input tensor that can be either `NCHW` or `NHWC` layout. + +- **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D). + When data is NCHW, weight is expected to be OIHW or OIHWi. + When data is NHWC weight is expected to be HWIO or HWIOi. + +- **out**: Output with same layout as input. +)code" TVM_ADD_FILELINE) + .set_attrs_type_key("relay.attrs.BinaryConv2DAttrs") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("BinaryConv2D", BinaryConv2DRel) + .set_attr("FInferCorrectLayout", + BinaryConv2DInferCorrectLayout); + +// relay.nn.bitserial_dense +TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); + +bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const BinaryDenseAttrs* param = attrs.as(); + CHECK(param != nullptr); + + CHECK(static_cast(data->shape.size()) != 0); + CHECK(param->units.defined()); + + Array oshape = data->shape; + oshape.Set((oshape.size() - 1), param->units); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // Assign output type. + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +// Positional relay function to create bitserial dense operator used by frontend FFI. +Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits, + DataType pack_dtype, DataType out_dtype, bool unipolar) { + auto attrs = make_node(); + attrs->units = units; + attrs->data_bits = data_bits; + attrs->weight_bits = weight_bits; + attrs->pack_dtype = pack_dtype; + attrs->out_dtype = out_dtype; + attrs->unipolar = unipolar; + static const Op& op = Op::Get("nn.bitserial_dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense); + +RELAY_REGISTER_OP("nn.bitserial_dense") + .describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units, input_dim)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type_key("relay.attrs.BinaryDenseAttrs") + .set_num_inputs(2) + .add_argument("data", "2D Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("BinaryDense", BinaryDenseRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 66e65c5fd409..c25393cf4026 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -337,6 +337,16 @@ def test_dense(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +def test_bitserial_dense(): + m, k = tvm.var("m"), tvm.var("k") + x = relay.var("x", relay.TensorType((m, k), "int16")) + w = relay.var("w", relay.TensorType((k, 32), "int16")) + y = relay.nn.bitserial_dense(x, w, units=32) + "units=8" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((m, 32), "int16") + + if __name__ == "__main__": test_concatenate() test_bias_add() @@ -349,3 +359,4 @@ def test_dense(): test_dropout() test_batch_norm() test_dense() + test_bitserial_dense() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 5e9abdf0faf4..a94a203f4d79 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -105,8 +105,8 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, except_targets=None, **attrs): if except_targets is None: - except_targets = [] - + except_targets = [] + x = relay.var("x", shape=dshape, dtype=dtype) w = relay.var("w", dtype=dtype) y = relay.nn.conv2d(x, w, @@ -599,12 +599,35 @@ def _compile(input_dtype, weight_dtype, output_dtype, target): assert "vpmulld" in asm and "vpadd" in asm +def test_bitserial_conv2d_infer_type(): + # Basic shape test with ambiguous batch. + n, c, h, w = tvm.var("n"), 32, 224, 224 + x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16")) + w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16")) + y = relay.nn.bitserial_conv2d( + x, w, kernel_size=(3, 3), padding=(0, 0), channels=32) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 32, 222, 222), "int16") + + +def test_bitpack_infer_type(): + # Test axis packing shape inference. + o, i, h, w = 32, 32, 128, 128 + x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16")) + y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (32, 2, 128, 128, 1), "uint16") + + if __name__ == "__main__": test_pool2d() test_avg_pool2d_no_count_pad() test_lrn() test_l2_normalize() test_conv2d_infer_type() + test_bitpack_infer_type() test_upsampling_infer_type() test_flatten_infer_type() test_pad_infer_type() @@ -612,6 +635,7 @@ def _compile(input_dtype, weight_dtype, output_dtype, target): test_conv2d_transpose_infer_type() test_conv2d_transpose_run() test_conv2d_run() + test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() test_conv2d_int8_intrinsics() diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index 4198267cac60..af9c5bebb998 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,unused-variable,invalid-name +# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument """Bitserial conv2d schedule on arm cpu""" from __future__ import absolute_import as _abs import tvm from tvm import autotvm +from tvm import relay from .. import tag from ..nn.pad import pad -from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc +from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize from ..nn.bitserial_util import bitpack, binary_op_multiplier from ..nn.util import get_pad_tuple from ..util import get_const_int, get_const_tuple @@ -350,3 +351,40 @@ def traverse(op): traverse(outs[0].op) return s + +@bitserial_conv2d_legalize.register("arm_cpu") +def _bitserial_conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Bitserial Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + + # Fix different kernel layouts where possible. + if attrs['data_layout'] == 'NHWC': + data, kernel = inputs + if len(kernel.data.shape) == 4: + # HWIO layout is expected for NHWC input. + if attrs['kernel_layout'] == 'HWOI': + # Handle HWOI layout. This is common in TF depthwise conv2d graph. + kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) + elif attrs['kernel_layout'] == 'OIHW': + kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) + ## Set new attrs for the tranposed conv. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['kernel_layout'] = 'HWIO' + + conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs) + return conv + return None diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 38b66320b428..8fbedec3fef1 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -470,6 +470,23 @@ def schedule_binarize_pack(outs): return _default_schedule(outs, False) +@tvm.target.override_native_generic_func("schedule_bitpack") +def schedule_bitpack(outs): + """Schedule for bitpack + Parameters + ---------- + outs: Array of Tensor + The computation graph description of bitpack + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.override_native_generic_func("schedule_binary_dense") def schedule_binary_dense(outs): """Schedule for binary_dense diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 99cac889deea..21abdf0de1ec 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, too-many-locals, too-many-arguments +# pylint: disable=unused-argument, redefined-builtin """Bitserial Conv2D operators""" from __future__ import absolute_import as _abs import tvm @@ -65,7 +66,10 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight """ assert isinstance(stride, int) or len(stride) == 2 Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) - Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + if len(filter.shape) == 4: + Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + else: + Filter_q = filter batch, in_channel, activation_bits, in_height, in_width = Input_q.shape num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape @@ -414,3 +418,24 @@ def _conv(n, h, w, co, vh, vw, vc): return tvm.compute(oshape, lambda n, h, w, co: conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], name='output_unpack', tag='spatial_bitserial_conv_nhwc') + +@tvm.target.generic_func +def bitserial_conv2d_legalize(attrs, inputs, types): + """Legalizes Bitserial Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + return None