diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 448395a16b3a..cf5f316a1c97 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -129,48 +129,55 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, tensor = Subtract(tensor, input_zp); } - // 3) Multiply the integer multiplier - if (left_shift != 0) { - tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); - } - // Perform the multiplication in higher precision. - // The scalar is a fixed point value of int32 where the decimal point is - // between bits 31 and 30. After multiplying with input_tensor, the result is - // in int64 where the decimal point is sitting between bits 31 and 30 (from - // the right, rightmost bit is bit 0). The computation is performed in higher - // precision to avoid overflow in multiplying two int32 values. - Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); - auto multiplied_t = Multiply(tensor, scalar); + // If the input and output scales are same, we can skip the fixed point multiplication. + auto scaled_int64_t = tensor; + if (param->input_scale != param->output_scale) { + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // The scalar is a fixed point value of int32 where the decimal point is + // between bits 31 and 30. After multiplying with input_tensor, the result is + // in int64 where the decimal point is sitting between bits 31 and 30 (from + // the right, rightmost bit is bit 0). The computation is performed in higher + // precision to avoid overflow in multiplying two int32 values. + Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); - // 4) Find the rounding scalar. This depends on where the final decimal point - // sits. As we will be right shifting the multiplied_t, we need to first - // calculate the total_right_shift. - int total_right_shift = right_shift + 31; - int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the total_right_shift. + int total_right_shift = right_shift + 31; + int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); - tensor = multiplied_t; - Expr round_scalar; - if (param->rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (param->rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding == "UPWARD") { + round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); + } else if (param->rounding == "TONEAREST") { + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); - auto zero = MakeConstantScalar(hp_dtype, 0); - auto zero_t = Full(zero, input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); - } - // Add the rounding scalar. - tensor = Add(tensor, round_scalar); + auto zero = MakeConstantScalar(hp_dtype, 0); + auto zero_t = Full(zero, input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); - // 5) Simply right shift the result to get the final output. - auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + // 5) Simply right shift the result to get the final output. + scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + } // 6) Add the output zero point. - auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); - auto shifted_int64_t = Add(output_zp, scaled_int64_t); + auto shifted_int64_t = scaled_int64_t; + if (param->output_zero_point != 0) { + auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); + shifted_int64_t = Add(output_zp, scaled_int64_t); + } // 7) Clip to the out_dtype min/max. auto q_min = GetQmin(out_dtype); diff --git a/tests/python/relay/test_qnn_requantize.py b/tests/python/relay/test_qnn_requantize.py index 2afa7d9e7ea6..131500c094bf 100644 --- a/tests/python/relay/test_qnn_requantize.py +++ b/tests/python/relay/test_qnn_requantize.py @@ -64,6 +64,7 @@ def same_scale_test(): input_scale=0.5, output_scale=0.5, rounding=rounding) + assert 'right_shift' not in mod.astext() verify(mod, (golden_data, golden_output)) def downscale_test():