Skip to content

Commit

Permalink
[Relay][Transform] quantize opt passes to pass manager (#3289)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Jun 13, 2019
1 parent 579e96d commit 6e2c7ed
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 130 deletions.
173 changes: 68 additions & 105 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

from . import _quantize
from .. import expr as _expr
from .. import module as _module
from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
Expand Down Expand Up @@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER = n


def annotate(graph):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)


def calibrate(graph, dataset=None):
def calibrate(graph, mod=None, ctx=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Expand All @@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
The calibration dataset.
mod: tvm.relay.Module
The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns
-------
Expand Down Expand Up @@ -253,93 +239,52 @@ def _make_const(val):
return _expr.bind(graph, const_params)


def realize(graph):
"""The realize pass will transform the simulated quantized
graph, which computes with float32 actually, to a real low-bit
integer graph. It will replace the simulated_quantize with
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: Function
The graph after realization
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.realize(graph)
return _quantize.QuantizeAnnotate()


def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
def realize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
ret: tvm.relay.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()

opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]

if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)

if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)

if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)

if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)

if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)

if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)

return func
def _bind_params(func, params):
"""Bind the params to the expression.
"""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)


def quantize(graph, params=None, dataset=None):
Expand All @@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
# TODO(zhiics) Move this to the pass manager.
graph = optimize(graph, params)

graph = annotate(graph)
graph = calibrate(graph, dataset)
graph = realize(graph)
graph = _ir_pass.fold_constant(graph)
return graph
if params:
graph = _bind_params(graph, params)

mod = _module.Module.from_expr(graph)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
# "CanonicalizeOps" optimization before quantization.
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])

calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
1 change: 1 addition & 0 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

Module updated_mod = mod;
// Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
Expand Down
63 changes: 38 additions & 25 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace tvm {
namespace relay {
namespace quantize {

using namespace relay::transform;

/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
Expand Down Expand Up @@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int()));
});


TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});


// =============
// realize pass

Expand Down Expand Up @@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);


TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});


// =============
// qconfig

Expand Down Expand Up @@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);

Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};

runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}

TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);

Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}

TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);

} // namespace quantize
} // namespace relay
} // namespace tvm

0 comments on commit 6e2c7ed

Please sign in to comment.