diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 452744bec8b3b..ff6c8ea5c1875 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -198,14 +198,20 @@ Expr QuantizeRealize(const Call& ref_call, // x * idom_scale = y * odom_scale // => y = x * idom_scale / odom_scale if (const auto* n = new_args[0].as()) { + // int32->int8 Expr data = n->data; float idom_scale_imm = GetScalarFromConstant(n->dom_scale); float odom_scale_imm = GetScalarFromConstant(dom_scale); + if (idom_scale_imm == odom_scale_imm) { + // same domain scale, only clip + data = Clip(data, clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + } + float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - // int32->int8 CHECK_GT(shift_nbit, 0); if (static_cast(shift_nbit) == shift_nbit) { - // use shift + // use right shift if (cfg->round_for_shift) { float round_bias = std::pow(2.0, shift_nbit - 1); data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias)));