Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUANTIZE] Memorizing the quantize node mapping #3233

Merged
merged 9 commits into from
Jun 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import

import logging
from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func
Expand Down Expand Up @@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func):
# pylint: disable=broad-except
try:
f = _build.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile function\n"
Expand Down
96 changes: 47 additions & 49 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter
from .quantize import annotate_context
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
Expand Down Expand Up @@ -116,7 +116,6 @@ def frewrite_with_guard(ref_call, new_args, ctx):
return _register(frewrite) if frewrite is not None else _register


@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr.

Expand All @@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data

actx = annotate_context()
key = tuple([data, kind, sign, rounding])
if key in actx.qnode_map:
return actx.qnode_map[key]

dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
return _quantize.simulated_quantize(
qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
actx.qnode_map[key] = qnode
return qnode

register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)


@register_annotate_function("nn.contrib_conv2d_NCHWc")
Expand All @@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None

actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt in leave_alone_indices:
_set_conv_counter(cnt + 1)
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
return None

_set_conv_counter(cnt + 1)
actx.count_conv2d()

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


def check_to_skip():
"""Check the index of conv2d layer to decide whether to skip the current operator."""
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if annotate_context().conv2d_counter() - 1 in skipped_indices:
return True
return False


@register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("stop_fusion")
def stop_fusion_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None

ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
ret_expr = _forward_op(ref_call, [ret_expr])
return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)


def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
Expand All @@ -283,20 +290,16 @@ def identity_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(ret_expr, x_kind)


register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)


def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

expr, x_kind = _get_expr_kind(new_args[0])

Expand All @@ -314,13 +317,8 @@ def pool2d_rewrite(ref_call, new_args, ctx):
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
if check_to_skip():
return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None

input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
Expand Down
63 changes: 38 additions & 25 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,10 @@ class QConfig(NodeBase):
"dtype_weight": "int8",
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_k_conv": 1,
"skip_conv_layers": None,
"skip_conv_layers": [0],
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
"use_stop_fusion": True
}

# pylint: disable=no-member
Expand Down Expand Up @@ -138,11 +136,8 @@ def qconfig(**kwargs):
global_scale: float
The global scale for calibration.

skip_k_conv: int
The number of skipped conv2d.

skip_conv_layers: list
Different way of specifying which layers to avoid. Provide a list of indices
Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched.

round_for_shift: boolean
Expand All @@ -152,9 +147,10 @@ def qconfig(**kwargs):
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.

use_stop_fusion: boolean
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
results to be stored in memory.
debug_enabled_ops: None or list of str
Partially quantize specified operators for debugging. The default value
is None, which means will try to call all operartors' annotate rewrite
function.

Returns
-------
Expand All @@ -166,18 +162,35 @@ def qconfig(**kwargs):
return _make.node("relay.quantize.QConfig", **node_args)


CONV_COUNTER = 0
class AnnotateContext(object):
"""A global singleton annotate scope"""
Current = None

def __init__(self):
self.qnode_map = dict()
self._conv2d_counter = 0

def __enter__(self):
self._conv2d_counter = 0
return self

def conv2d_counter(self):
"""Get the counter for conv2d."""
return self._conv2d_counter

def count_conv2d(self):
"""Increase the value of the conv2d counter by one."""
self._conv2d_counter += 1

def _conv_counter():
"""Get the global counter for conv2d."""
return CONV_COUNTER
def __exit__(self, ptype, value, traceback):
pass


def _set_conv_counter(n):
"""Set the value of the global conv2d counter."""
global CONV_COUNTER
CONV_COUNTER = n
def annotate_context():
"""Get the global singleton scope"""
if AnnotateContext.Current is None:
AnnotateContext.Current = AnnotateContext()
return AnnotateContext.Current


def calibrate(graph, mod=None, ctx=None):
Expand Down Expand Up @@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None):

calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
with annotate_context():
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
32 changes: 26 additions & 6 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -393,7 +393,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->use_stop_fusion) {
if (cfg->store_lowbit_output) {
new_arg = StopFusion(new_arg);
}
ret.Set(i, Cast(new_arg, dtype));
Expand Down Expand Up @@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call,
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);

Expr ClipRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
auto attrs = make_node<ClipAttrs>();
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;

Expr ret = CallNode::make(ref_call->op,
{n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}

RELAY_REGISTER_OP("clip")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);


Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
Expand Down Expand Up @@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << ")";
});

Expand Down
Loading