diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index afbe1f343bb8..5726e9eec1f0 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -71,17 +71,20 @@ class TensorAffineTypeNode : public AffineTypeNode { RelayExpr zero_point; /*! \brief The data type of this type */ DataType dtype; + /*! \brief The axis for per-channel quantization */ + int axis; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("scale", &scale); v->Visit("zero_point", &zero_point); v->Visit("dtype", &dtype); + v->Visit("axis", &axis); } bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); + equal(dtype, other->dtype) && equal(axis, other->axis); } void SHashReduce(SHashReducer hash_reduce) const { @@ -89,6 +92,7 @@ class TensorAffineTypeNode : public AffineTypeNode { hash_reduce(scale); hash_reduce(zero_point); hash_reduce(dtype); + hash_reduce(axis); } static constexpr const char* _type_key = "TensorAffineType"; @@ -101,7 +105,7 @@ class TensorAffineTypeNode : public AffineTypeNode { */ class TensorAffineType : public AffineType { public: - TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis); TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); }; diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index a1ce08017b1b..bd77c187af40 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -48,10 +48,15 @@ class TensorAffineType(AffineType): dtype : str The content data type. + + axis : int + The axis for per-channel quantization. """ - def __init__(self, scale, zero_point, dtype): - self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + def __init__(self, scale, zero_point, dtype, axis=-1): + self.__init_handle_by_constructor__( + _ffi_api.TensorAffineType, scale, zero_point, dtype, axis + ) @tvm._ffi.register_object("TupleAffineType") diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4202374f8da8..fd4bca843687 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -490,7 +490,7 @@ def _impl_v1(cls, inputs, attr, params): attr["dilations"] = [1] + list(attr["dilations"]) if "pads" in attr: attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]] - + attr["channels"] = kernel_shapes[0][0] out = AttrCvt( op_name=dimension_picker("conv"), transforms={ diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3226240fbe39..fd3a1686f5a8 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,6 +20,7 @@ import tvm from tvm import relay +from tvm._ffi.base import TVMError from .. import op as reg ################################################# @@ -139,11 +140,35 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs shift_data = relay.subtract( - relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") - ) - shift_kernel = relay.subtract( - relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") + relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16") ) + # If kernel zero point is a scalar we can directly subtract it. + if len(types[3].shape) == 0: + shift_kernel = relay.subtract( + relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16") + ) + # Otherwise it needs to be broadcast. + else: + # Determine output axis of kernel for spatial operations. + if hasattr(attrs, "kernel_layout"): + output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + # For dense operations, broadcast to [N, K] layout. + elif isinstance(attrs, relay.op.op_attrs.DenseAttrs): + output_axis = 0 + # For matrix multiplication instead expand to [K, N] layout. + elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs): + output_axis = 1 + else: + raise TVMError( + "Legalization of %s is not yet supported with per channel parameters" + % str(type(attrs)) + ) + + shift_kernel = relay.nn.bias_add( + relay.cast(kernel, dtype="int16"), + relay.cast(kernel_zero_point, dtype="int16"), + output_axis, + ) new_attrs = {k: attrs[k] for k in attrs.keys()} return relay_op(shift_data, shift_kernel, **new_attrs) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index e74256ec74c3..83b5cf0a831c 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -276,8 +276,10 @@ def conv2d( ): r"""Quantized 2D convolution. - This operator convolves quantized data with quantized kernel. The scale of - the output quantized tensor is the product of the kernel_scale and + This operator convolves quantized data with quantized kernel. + If doing Per-channel quantization, qnn expects the kernel_zero_scale + and optionally the kernel_zero_point will be 1-D vectors instead of scalars. + The scale of the output quantized tensor is the product of the kernel_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is 0. By default, the dtype of output is int32. Please also refer to Requantize operator to understand how to scale back the int32 @@ -544,6 +546,9 @@ def dense( `Y = X * W` + If doing Per-channel quantization, qnn expects the kernel_zero_scale + and optionally the kernel_zero_point will be 1-D vectors instead of scalars. + Parameters ---------- data : tvm.relay.Expr diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index cf55c67c8083..6032dbf92dbc 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -18,6 +18,7 @@ import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType +from tvm.tir import bijective_layout from ..op import register_fake_quantization_to_integer @@ -25,6 +26,14 @@ def fold_constant(expr): return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) +def get_zeros(scale): + return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32")) + + +def infer_shape(expr): + return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape + + @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): """Remove dequantize op""" @@ -52,8 +61,13 @@ def quantize(expr, type_map): expr.args[1], expr.args[2], out_dtype=expr.attrs.out_dtype, + axis=t.axis, ) - return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] + + return [ + out, + TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), + ] def register_unary_identity(op_name): @@ -94,7 +108,11 @@ def bias_add(expr, type_map): b_t = type_map[b] in_scale = fold_constant(x_t.scale) in_zero_point = fold_constant(x_t.zero_point) - if not tvm.ir.structural_equal(x_t, b_t): + if not ( + tvm.ir.structural_equal(x_t.scale, b_t.scale) + and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point) + and tvm.ir.structural_equal(x_t.dtype, b_t.dtype) + ): b = relay.qnn.op.requantize( b, b_t.scale, @@ -102,6 +120,7 @@ def bias_add(expr, type_map): in_scale, in_zero_point, out_dtype=x_t.dtype, + axis=0, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -116,11 +135,13 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - conv_zp = relay.const(0) + conv_zp = get_zeros(conv_scale) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] + out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] @register_fake_quantization_to_integer("nn.dense") @@ -132,11 +153,11 @@ def dense(expr, type_map): x_t = type_map[x] w_t = type_map[weight] dense_scale = fold_constant(x_t.scale * w_t.scale) - dense_zp = relay.const(0) + dense_zp = get_zeros(dense_scale) out = relay.qnn.op.dense( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, 1)] @register_fake_quantization_to_integer("nn.batch_matmul") @@ -148,7 +169,7 @@ def batch_matmul(expr, type_map): matmul_scale = fold_constant(x_t.scale * y_t.scale) matmul_zp = relay.const(0) out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale) - return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype, x_t.axis)] @register_fake_quantization_to_integer("concatenate") @@ -198,19 +219,52 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) - if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + if ( + isinstance(scale, relay.expr.Constant) + and scale.data.numpy().size == 1 + and isinstance(z_p, relay.expr.Constant) + and z_p.data.numpy().size == 1 + ): scale = scale.data.numpy().item() z_p = z_p.data.numpy().item() new_min = int(amin / scale + z_p) new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.round(relay.op.const(amin) / scale + z_p) - amax = relay.op.round(relay.op.const(amax) / scale + z_p) - out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + if not isinstance(amin, relay.expr.Constant): + amin = relay.op.const(amin) + if not isinstance(amax, relay.expr.Constant): + amax = relay.op.const(amax) + + scale_shape = infer_shape(scale) + if len(scale_shape) > 0 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape) + amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape) + amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype) + amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype) + out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), fold_constant(amax)) + return [out, t] +@register_fake_quantization_to_integer("nn.relu") +def relu(expr, type_map): + """Rewrite a relu op""" + arg = expr.args[0] + t = type_map[arg] + scale_shape = infer_shape(t.scale) + z_p = t.zero_point + assert len(scale_shape) <= 1 + if len(scale_shape) == 1 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape) + zero = relay.op.cast(z_p, t.dtype) + return [relay.op.maximum(arg, fold_constant(zero)), t] + + @register_fake_quantization_to_integer("nn.pad") def pad(expr, type_map): """Rewite an nn.pad op""" @@ -231,6 +285,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, + axis=pad_t.axis, ) else: ## If the pad-value is a constant, we need to quantize it @@ -319,6 +374,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=left_t.axis, ) if right_t != out_t: @@ -329,6 +385,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=right_t.axis, ) out = op(left, right) return [out, out_t] diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc index 3454b6011c9b..87235fe20ade 100644 --- a/src/ir/affine_type.cc +++ b/src/ir/affine_type.cc @@ -30,26 +30,28 @@ namespace tvm { using tvm::ReprPrinter; using namespace tvm::runtime; -TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, + int axis) { ObjectPtr n = make_object(); n->scale = std::move(scale); n->zero_point = std::move(zero_point); n->dtype = std::move(dtype); + n->axis = std::move(axis); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); TVM_REGISTER_GLOBAL("ir.TensorAffineType") - .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { - return TensorAffineType(scale, zero_point, dtype); + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { + return TensorAffineType(scale, zero_point, dtype, axis); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " - << node->dtype << ")"; + << node->dtype << ", " << node->axis << ")"; }); TupleAffineType::TupleAffineType(Array types) { diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index cf5266485f2e..5782f1f6b4d1 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -495,7 +495,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, * \param input_zero_point The input zero point expr. * \param param The qnn conv2d attributes. * \param out_channels The number of output channels. - * \return The sequence of Relay operatos for term3. + * \return The sequence of Relay operators for term3. * \note The term3 looks like this * * Sigma(c,r,s) zp_a * QW(k, c, r, s) @@ -625,7 +625,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * \node Lowering of the qnn.conv2d operator * A quantized tensor is represented in following manner * A = scale_a x (QA - zp_A) - * where QA is quantized tensor, scale_a and zp_A are quantizations + * where QA is quantized tensor, scale_a and zp_A are quantization * params. * * Quantized convolution will convolve two quantized tensors and returns a @@ -662,8 +662,8 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * a workaround, we fall back to simpler lowering using int32 conv if * the conv is dilated. We fallback also in case of grouped conv. * - * For depthwise, we can similarly unroll the computation. The intial compute is as follows - * wehere cm = channel_multiplier + * For depthwise, we can similarly unroll the computation. The initial compute is as follows + * where cm = channel_multiplier * * Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w) * * (Qa(n, oc/cm, oh + r, ow + s) - zp_a) @@ -693,12 +693,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto* param = attrs.as(); ICHECK(param != nullptr); - // Assertion checks for exisiing support. + // Assertion checks for existing support. ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") << "qnn.conv2d supports only NCHW/NHWC input data layout."; ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" || param->kernel_layout == "HWOI") << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout."; + ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified."; int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 592fa77aed77..7b733d4777ec 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -61,8 +61,8 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + // weight_zero_point can be a scalar or a vector of the same shape as the weight_scale AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; @@ -89,10 +89,17 @@ Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } -Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point) { +Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point, + const int out_dim_size) { Array axes = {1}; - return Multiply(kernel_zero_point, - Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false)); + Expr reduced_t2 = Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false); + Expr multiplied_t2; + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(kernel_zero_point, MakeRepeat(reduced_t2, out_dim_size, 1)); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } + return multiplied_t2; } Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& input_zero_point) { @@ -159,25 +166,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto in_shape = get_shape(arg_types[0]); + const auto w_shape = get_shape(arg_types[1]); const int reduction_dim_size = get_const_int(in_shape[1]); + const int out_dim_size = get_const_int(w_shape[0]); const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); - auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); + auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point, out_dim_size); auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); // Extract the integer zero points. - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); - if (!IsConstScalar(input_zero_point)) { - if (kernel_zero_point_int == 0) { - return Subtract(term1, term3); - } + if (!IsConstScalar(input_zero_point) || !IsConstScalar(kernel_zero_point)) { auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); return DenseCombineTerms(term1, term2, term3, term4); } + auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); auto input_zero_point_int = GetScalarFromConstant(input_zero_point); // Get all the terms as described in the comments. diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7af5c2ac1c33..c843eb3f544e 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -55,8 +55,15 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, int axis = dequantize_attrs->axis; auto rank = static_cast(data->shape.size()); axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f1d7d8da16c..b116eb9da034 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -53,8 +53,15 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, int axis = quantize_attrs->axis; auto rank = static_cast(data->shape.size()); axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 46de3522061b..a7d214761b9b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -136,10 +136,17 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { auto tensor = Cast(input_tensor, DataType::Int(32)); - // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { - tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32))); + // Broadcast input zero point if needed. + int rank = static_cast(input_shape.size()); + int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis; + Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, + { + -1, + }), + rank, {axis}); + tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index b5f434e74c43..77d18d7556f2 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace tvm { @@ -109,18 +110,23 @@ class SubgraphExtractor : public ExprVisitor { protected: void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Only look at arg0 for quantize VisitExpr(call_node->args[0]); // Collect type of quantize ops - affine_types_.Set(GetRef(call_node), - TensorAffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); } else if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Collect type of dequantize ops affine_types_.Set( GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + call_node->args[0]->checked_type().as()->dtype, + attrs->axis)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 2bc2e4e635f0..7ede17d07d99 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -34,7 +34,6 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): .evaluate()(*args) .numpy() ) - result_int = ( relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") .evaluate()(*args) @@ -42,7 +41,7 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): ) if allow_rounding_error: - assert np.all(np.abs(result - result_int) <= 1) + assert np.all(np.abs(result.astype("int32") - result_int.astype("int32")) <= 1) else: assert np.array_equal(result, result_int) @@ -57,6 +56,7 @@ def test_fake_quantize_conv(): op = relay.op.nn.conv2d( relay.qnn.op.dequantize(x, relay.const(2.0), zero), relay.qnn.op.dequantize(w, relay.const(0.5), zero), + kernel_size=[5, 5], ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) @@ -66,6 +66,29 @@ def test_fake_quantize_conv(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_conv_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const([1.0] * 16) + zero = relay.const([0] * 16) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(0)), + relay.qnn.op.dequantize( + w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0 + ), + kernel_size=[5, 5], + channels=16, + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) + + def test_fake_quantize_dense(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[128, 64], dtype="int8") @@ -85,6 +108,31 @@ def test_fake_quantize_dense(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_dense_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize( + w, + relay.const(np.random.random([256]).astype("float32")), + relay.const([0] * 256), + axis=0, + ), + units=256, + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) + + def test_fake_quantize_batch_matmul(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 128, 64], dtype="int8") @@ -112,7 +160,9 @@ def test_fake_transpose_quantize_conv(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.qnn.op.quantize(op, one, zero) x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") @@ -130,7 +180,9 @@ def test_fake_transpose_quantize_conv_bias_add(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -141,6 +193,32 @@ def test_fake_transpose_quantize_conv_bias_add(): compare_fq_to_int(op, [x_np, w_np, bias_np]) +def test_fake_transpose_quantize_conv_bias_add_per_channel(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + w_scale = (np.random.random([16]).astype("float32") - 0.5) / 10 + 0.5 + w_zp = relay.const([0] * 16) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(w_scale), w_zp, axis=0), kernel_size=[5, 5] + ) + op = relay.op.nn.bias_add( + op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale), w_zp, axis=0) + ) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np], allow_rounding_error=True) + + def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") @@ -151,7 +229,9 @@ def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -318,6 +398,50 @@ def test_fake_quantize_clip(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_clip_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_relu(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_relu_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py index 03944c85ade5..acfadc9d51ad 100644 --- a/tests/python/unittest/test_target_texture_codegen_opencl.py +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -514,7 +514,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -694,7 +694,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -879,7 +879,7 @@ def copy_to_texture(stage): def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( Input, Filter, stride, padding, dilation, out_dtype=None ): - """Depthwise convolution operator in NCHWc layout. """ + """Depthwise convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2