Skip to content

Commit

Permalink
[Relay][Quantization] Fix add_rewrite and UnifyDTypeScale (apache#3534)
Browse files Browse the repository at this point in the history
* [Relay][Quantization] Fix issue introduced in apache#3135

* Recover StopFusion

* Fix fmultiref

* Fix lint
  • Loading branch information
vinx13 authored and wweic committed Aug 9, 2019
1 parent e5939c0 commit aeea3a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 32 deletions.
10 changes: 6 additions & 4 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,11 @@ def add_rewrite(ref_call, new_args, ctx):
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert lhs_kind == QAnnotateKind.ACTIVATION
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
raise ValueError
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
Expand All @@ -277,6 +275,10 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()


@register_annotate_function("stop_fusion")
Expand Down
47 changes: 19 additions & 28 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
});


TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});


// =============
// realize pass

Expand Down Expand Up @@ -395,10 +379,9 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {


/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

std::vector<const QRealizeIntExprNode*> nptrs;
Expand All @@ -413,14 +396,21 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype;
if (nptrs[0]->dtype == cfg->dtype_activation) {
DataType dtype = cfg->dtype_activation;
ret.Set(1, Cast(ret[1], dtype));
} else if (nptrs[1]->dtype == cfg->dtype_input) {
DataType dtype = cfg->dtype_input;
ret.Set(0, Cast(ret[0], dtype));
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
LOG(FATAL) << "should not touch here.";
dtype = cfg->dtype_activation;
}
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}

// unify the dom_scale
Expand All @@ -447,6 +437,7 @@ Expr AddRealize(const Call& ref_call,
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}

CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
Expand Down Expand Up @@ -674,7 +665,7 @@ Pass QuantizeAnnotate() {

runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
Expand Down

0 comments on commit aeea3a4

Please sign in to comment.