diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 4ceb3597fc91..a2a46497e197 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { - DataType hp_dtype = DataType::Int(64); - - auto tensor = Cast(input_tensor, hp_dtype); + auto tensor = Cast(input_tensor, DataType::Int(32)); // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { - tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype)); + tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32))); } - // Check if multiplier is greater than 1. - bool is_multiplier_gt_one = false; - // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input // tensor. Depending on the quantization type, the fixed point multiplication routing is called. - auto scaled_int64_t = tensor; + auto scaled_int32_t = tensor; float output_scale_float = GetScalarFromConstant(output_scale); if (IsConstScalar(input_scale)) { // This is per-tensor quantization. Single scale. float input_scale_float = GetScalarFromConstant(input_scale); double double_multiplier = static_cast(input_scale_float) / static_cast(output_scale_float); - if (double_multiplier > 1) { - is_multiplier_gt_one = true; - } // Skip if input and output scales are same. if (!IsEqualScalar(input_scale, output_scale)) { - scaled_int64_t = - FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); + scaled_int32_t = + FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding); } } else { // This is per-channel (per=axis) quantization. @@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, double multiplier = static_cast(input_axis_scale) / static_cast(output_scale_float); double_multipliers.push_back(multiplier); - if (multiplier > 1) { - is_multiplier_gt_one = true; - } } int axis = param->axis; axis = (axis == -1) ? input_shape.size() - 1 : axis; - scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape, + scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape, axis, param->rounding); } // 3) Add the output zero point. - auto shifted_int64_t = scaled_int64_t; + auto shifted_int32_t = scaled_int32_t; if (!IsEqualScalar(output_zero_point, zero_scalar)) { - shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t); + shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t); } // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point - // multiplication keeps the value in int32 range if the requantize scale is less than 1. - if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) { - return Cast(shifted_int64_t, out_dtype); + // multiplication keeps the value in int32 range. + if (out_dtype == DataType::Int(32)) { + return shifted_int32_t; } + auto q_min = GetQmin(out_dtype); auto q_max = GetQmax(out_dtype); - auto clipped_t = Clip(shifted_int64_t, q_min, q_max); + auto clipped_t = Clip(shifted_int32_t, q_min, q_max); return Cast(clipped_t, out_dtype); } diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 648de5349ce1..91fe3ca2a948 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); + tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift int32_t fixed_point_multiplier, shift; @@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); - return tensor; + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + return Cast(tensor, DataType::Int(32)); } Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, @@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); + tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per // channel. @@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis}); tensor = RightShift(tensor, exp_total_rshift_expr); - return tensor; + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + return Cast(tensor, DataType::Int(32)); } } // namespace qnn diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 8e04a99e2813..6d56e19d229c 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - data = qnn::FixedPointMultiply( - Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding); + data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding); return Cast(data, dtype); } }