Skip to content

Commit

Permalink
[QNN] Quantize - Fixing the sequence of lowering. (apache#4316)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and Xingyu Zhou committed Nov 15, 2019
1 parent 9f2a7d0 commit 9ab329e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ bool QuantizeRel(const Array<Type>& types,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8))
<< "Output type should be one of [int8, unit8 ] but was " << out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true;
Expand All @@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data,
Expr QuantizeLower(const Expr& input_tensor,
const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype;
const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point);
const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point);
const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
auto add_zero_point = Add(scale_data, output_zero_point);
auto scale_data = Divide(input_tensor, scale);
auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype;
Expand Down

0 comments on commit 9ab329e

Please sign in to comment.