Skip to content

Commit

Permalink
[Requantize] Cleanup and Optimize Lowering (apache#5286)
Browse files Browse the repository at this point in the history
* Adding Cast back to Int32 in FixedPointMultiply.

* Removing extra clip.

* Fix space.

* Retrigger.

* Retrigger.
  • Loading branch information
anijain2305 authored and masahi committed Apr 12, 2020
1 parent 11f2826 commit 65c0db9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 27 deletions.
36 changes: 13 additions & 23 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
DataType hp_dtype = DataType::Int(64);

auto tensor = Cast(input_tensor, hp_dtype);
auto tensor = Cast(input_tensor, DataType::Int(32));
// 1) Subtract the input_zero_point
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
}

// Check if multiplier is greater than 1.
bool is_multiplier_gt_one = false;

// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor;
auto scaled_int32_t = tensor;
float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
if (double_multiplier > 1) {
is_multiplier_gt_one = true;
}
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
scaled_int32_t =
FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
}
} else {
// This is per-channel (per=axis) quantization.
Expand All @@ -171,30 +163,28 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
double multiplier =
static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
double_multipliers.push_back(multiplier);
if (multiplier > 1) {
is_multiplier_gt_one = true;
}
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
axis, param->rounding);
}

// 3) Add the output zero point.
auto shifted_int64_t = scaled_int64_t;
auto shifted_int32_t = scaled_int32_t;
if (!IsEqualScalar(output_zero_point, zero_scalar)) {
shifted_int64_t = Add(Cast(output_zero_point, hp_dtype), scaled_int64_t);
shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
}

// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range if the requantize scale is less than 1.
if (out_dtype == DataType::Int(32) && !is_multiplier_gt_one) {
return Cast(shifted_int64_t, out_dtype);
// multiplication keeps the value in int32 range.
if (out_dtype == DataType::Int(32)) {
return shifted_int32_t;
}

auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
return Cast(clipped_t, out_dtype);
}

Expand Down
8 changes: 6 additions & 2 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
Expand Down Expand Up @@ -130,7 +131,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
tensor =
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));

return tensor;
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
}

Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
Expand All @@ -145,6 +147,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
// channel.
Expand Down Expand Up @@ -218,7 +221,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
tensor = RightShift(tensor, exp_total_rshift_expr);

return tensor;
// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
return Cast(tensor, DataType::Int(32));
}

} // namespace qnn
Expand Down
3 changes: 1 addition & 2 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = qnn::FixedPointMultiply(
Cast(data, DataType::Int(64)), factor, data_shape, cfg->rounding);
data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
return Cast(data, dtype);
}
}
Expand Down

0 comments on commit 65c0db9

Please sign in to comment.