From 9984f6516ba78d066892c7e9988d541f040992ca Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 12 Nov 2019 19:59:34 +0000 Subject: [PATCH] [QNN] Quantize - Fixing the sequence of lowering. --- src/relay/qnn/op/quantize.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 1f7dbc1b6bb6..6df9b433560a 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -48,8 +48,8 @@ bool QuantizeRel(const Array& types, const auto* quantize_attrs = attrs.as(); const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; - CHECK(out_dtype == Int(8) || out_dtype == UInt(8)) - << "Output type should be one of [int8, unit8 ] but was " << out_dtype; + CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32)) + << "Output type should be one of [int8, unit8, int32] but was " << out_dtype; // assign output type reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); return true; @@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data, Expr QuantizeLower(const Expr& input_tensor, const QuantizeAttrs* attrs) { const auto out_dtype = attrs->out_dtype; - const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point); + const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point); const auto scale = MakeConstantScalar(Float(32), attrs->output_scale); const int32_t min_val = GetQmin(out_dtype); const int32_t max_val = GetQmax(out_dtype); - auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32)); - auto add_zero_point = Add(scale_data, output_zero_point); + auto scale_data = Divide(input_tensor, scale); + auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32)); auto clamped_output = Clip(add_zero_point, min_val, max_val); auto clamp_out_dtype = Cast(clamped_output, out_dtype); return clamp_out_dtype;