diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 1f7dbc1b6bb6..6df9b433560a 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -48,8 +48,8 @@ bool QuantizeRel(const Array& types, const auto* quantize_attrs = attrs.as(); const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; - CHECK(out_dtype == Int(8) || out_dtype == UInt(8)) - << "Output type should be one of [int8, unit8 ] but was " << out_dtype; + CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32)) + << "Output type should be one of [int8, unit8, int32] but was " << out_dtype; // assign output type reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); return true; @@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data, Expr QuantizeLower(const Expr& input_tensor, const QuantizeAttrs* attrs) { const auto out_dtype = attrs->out_dtype; - const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point); + const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point); const auto scale = MakeConstantScalar(Float(32), attrs->output_scale); const int32_t min_val = GetQmin(out_dtype); const int32_t max_val = GetQmax(out_dtype); - auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32)); - auto add_zero_point = Add(scale_data, output_zero_point); + auto scale_data = Divide(input_tensor, scale); + auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32)); auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamp_out_dtype = Cast(clamped_output, out_dtype); return clamp_out_dtype;