diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 90bb2d08a8ed..7b7f9c42f2f1 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -260,13 +260,11 @@ def add_rewrite(ref_call, new_args, ctx): if isinstance(rhs_expr, _expr.Constant): # quantize rhs to WEIGHT field if it is Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) - assert lhs_kind == QAnnotateKind.ACTIVATION - expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) else: # quantize rhs to INPUT field if it is not Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) - raise ValueError + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: @@ -277,6 +275,10 @@ def add_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT: + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + raise ValueError() @register_annotate_function("stop_fusion") diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 8220ca6b3bab..83d9220ccf79 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -135,22 +135,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") }); -TVM_REGISTER_API("relay._quantize.annotate") -.set_body_typed([] (const Expr& expr) { - std::function fmulti_ref = [](const Expr& e) { - if (e->derived_from()) { - const auto* n = e.as(); - CHECK(n); - const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); - } - return e; - }; - return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr); -}); - - // ============= // realize pass @@ -395,10 +379,9 @@ float ChooseDomScale(const std::vector& nptrs) { /* \brief Unify the dom scale of arguments */ -Array UnifyDTypeScale(const Array& ref_args, - const Array& args, - DataType* dtype_ptr, - Expr* scale_ptr) { +Array UnifyDTypeScale(const Array& ref_args, const Array& args, + DataType* dtype_ptr, Expr* scale_ptr) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); std::vector nptrs; @@ -413,14 +396,21 @@ Array UnifyDTypeScale(const Array& ref_args, // unify the data type CHECK_EQ(ref_args.size(), args.size()); DataType dtype; - if (nptrs[0]->dtype == cfg->dtype_activation) { - DataType dtype = cfg->dtype_activation; - ret.Set(1, Cast(ret[1], dtype)); - } else if (nptrs[1]->dtype == cfg->dtype_input) { - DataType dtype = cfg->dtype_input; - ret.Set(0, Cast(ret[0], dtype)); + if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { + dtype = cfg->dtype_input; } else { - LOG(FATAL) << "should not touch here."; + dtype = cfg->dtype_activation; + } + for (size_t i = 0; i < ret.size(); ++i) { + auto ref_arg = ref_args[i].as(); + if (nptrs[i]->dtype != dtype) { + ret.Set(i, Cast(ret[i], dtype)); + } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && + ref_arg->attrs.as()->kind == kQInput) { + auto new_arg = Cast(ret[i], cfg->dtype_input); + new_arg = StopFusion(new_arg); + ret.Set(i, Cast(new_arg, dtype)); + } } // unify the dom_scale @@ -447,6 +437,7 @@ Expr AddRealize(const Call& ref_call, Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); } @@ -674,7 +665,7 @@ Pass QuantizeAnnotate() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref)); + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); auto new_params = func->params; for (const auto& x : FreeVars(func)) { new_params.push_back(x);