From 4337d58e1ba0b9a6ca58a03506c7058603333337 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 20 May 2019 14:52:24 -0700 Subject: [PATCH 1/7] [QUANTIZE] Support for clip operator --- python/tvm/relay/backend/_backend.py | 4 ++-- python/tvm/relay/quantize/_annotate.py | 1 + src/relay/pass/quantize.cc | 26 ++++++++++++++++++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 50e9694b40df..2cd7320e4046 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -50,8 +50,8 @@ def lower(sch, inputs, func_name, source_func): # pylint: disable=broad-except try: f = _build.lower(sch, inputs, name=func_name) - logging.debug("lower function %s", func_name) - logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) + # logging.debug("lower function %s", func_name) + # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) except Exception: msg = traceback.format_exc() msg += "Error during compile function\n" diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 2fe1cb81675b..62a9b81bd488 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -235,6 +235,7 @@ def identity_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(ret_expr, x_kind) +register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) register_annotate_function("nn.avg_pool2d", identity_rewrite) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index cb0f9d9c5acb..808aa0a91778 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -413,6 +413,28 @@ Expr AddRealize(const Call& ref_call, RELAY_REGISTER_OP("add") .set_attr("FQRealizeRewrite", AddRealize); +Expr ClipRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + double dom_scale = GetScalarFromConstant(n->dom_scale); + attrs->a_min = ref_attrs->a_min / dom_scale; + attrs->a_max = ref_attrs->a_max / dom_scale; + + Expr ret = CallNode::make(ref_call->op, + {n->data}, Attrs(attrs), ref_call->type_args); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("clip") +.set_attr("FQRealizeRewrite", ClipRealize); + Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, From 2e725e5ef4b0fd24e6c6abc8b71cb19575f61e88 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 22 May 2019 22:06:22 -0700 Subject: [PATCH 2/7] [QUANTIZE] Memorizing the quantize node mapping. --- python/tvm/relay/quantize/_annotate.py | 95 +++++++++++++------------- python/tvm/relay/quantize/quantize.py | 39 ++++++++--- 2 files changed, 74 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index a33c3645b85f..61e895ac7efb 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -22,7 +22,7 @@ import topi from . import _quantize from .quantize import QAnnotateKind, current_qconfig -from .quantize import _conv_counter, _set_conv_counter +from .quantize import annotate_context from .. import expr as _expr from .. import op as _op from ..op import op as _reg @@ -116,7 +116,6 @@ def frewrite_with_guard(ref_call, new_args, ctx): return _register(frewrite) if frewrite is not None else _register -@register_func("relay.quantize.attach_simulated_quantize") def attach_simulated_quantize(data, kind, sign=True, rounding="round"): """Attach a simulated quantize operation after input data expr. @@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data + actx = annotate_context() + key = tuple([data, kind, sign, rounding]) + if key in actx.qnode_map: + return actx.qnode_map[key] + dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") - return _quantize.simulated_quantize( + qnode = _quantize.simulated_quantize( data, dom_scale, clip_min, clip_max, kind, sign, rounding) + actx.qnode_map[key] = qnode + return qnode + +register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): """Rewrite function for conv2d. Lhs of conv will be quantized to input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: - _set_conv_counter(cnt + 1) - return None - + actx = annotate_context() if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt in leave_alone_indices: - _set_conv_counter(cnt + 1) + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if actx.conv2d_counter() in skipped_indices: + actx.count_conv2d() return None - - _set_conv_counter(cnt + 1) + actx.count_conv2d() lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +def check_to_skip(): + """Check the index of conv2d layer to decide whether to skip the current operator.""" + if current_qconfig().skip_conv_layers is not None: + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if annotate_context().conv2d_counter() - 1 in skipped_indices: + return True + return False + + @register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +@register_annotate_function("stop_fusion") +def stop_fusion_rewrite(ref_call, new_args, ctx): + """Rewrite function for add.""" + if check_to_skip(): + return None + + x_expr, x_kind = _get_expr_kind(new_args[0]) + if x_kind is None: + return None + + ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT) + ret_expr = _forward_op(ref_call, [ret_expr]) + return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT) + + def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None x_expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -291,13 +298,8 @@ def identity_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -315,13 +317,8 @@ def pool2d_rewrite(ref_call, new_args, ctx): @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - cnt = _conv_counter() - if cnt <= current_qconfig().skip_k_conv: + if check_to_skip(): return None - if current_qconfig().skip_conv_layers is not None: - leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if cnt - 1 in leave_alone_indices: - return None input_tuple = new_args[0] expr_list = [_get_expr_kind(x)[0] for x in input_tuple] diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7fd0099e64a2..7e35b56cad08 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -75,7 +75,7 @@ class QConfig(NodeBase): "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, - "use_stop_fusion": True + "use_stop_fusion": False } # pylint: disable=no-member @@ -165,18 +165,35 @@ def qconfig(**kwargs): return _make.node("relay.quantize.QConfig", **node_args) -CONV_COUNTER = 0 +class AnnotateContext(object): + # a global singleton annotate scope + Current = None + def __init__(self): + self.qnode_map = dict() + self._conv2d_counter = 0 -def _conv_counter(): - """Get the global counter for conv2d.""" - return CONV_COUNTER + def __enter__(self): + self._conv2d_counter = 0 + return self + + def conv2d_counter(self): + """Get the counter for conv2d.""" + return self._conv2d_counter + + def count_conv2d(self): + """Increase the value of the conv2d counter by one.""" + self._conv2d_counter += 1 + + def __exit__(self, ptype, value, traceback): + pass -def _set_conv_counter(n): - """Set the value of the global conv2d counter.""" - global CONV_COUNTER - CONV_COUNTER = n +def annotate_context(): + """Get the global singleton scope""" + if AnnotateContext.Current is None: + AnnotateContext.Current = AnnotateContext() + return AnnotateContext.Current def annotate(graph): @@ -194,8 +211,8 @@ def annotate(graph): ret: Function The graph after annotation """ - _set_conv_counter(0) # reset counter - return _quantize.annotate(graph) + with annotate_context(): + return _quantize.annotate(graph) def calibrate(graph, dataset=None): From 733560cc0e93a8133e36fb180952d99801954c09 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 22 May 2019 22:18:23 -0700 Subject: [PATCH 3/7] [QUANTIZE] Remove use_stop_fusion and skip_k_conv in qconfig --- python/tvm/relay/quantize/quantize.py | 16 ++++++---------- src/relay/pass/quantize.cc | 6 ++---- src/relay/pass/quantize.h | 8 ++------ 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7e35b56cad08..3d3f8479a955 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -70,12 +70,10 @@ class QConfig(NodeBase): "dtype_weight": "int8", "dtype_activation": "int32", "global_scale": 8.0, - "skip_k_conv": 1, - "skip_conv_layers": None, + "skip_conv_layers": [0], "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, - "use_stop_fusion": False } # pylint: disable=no-member @@ -137,11 +135,8 @@ def qconfig(**kwargs): global_scale: float The global scale for calibration. - skip_k_conv: int - The number of skipped conv2d. - skip_conv_layers: list - Different way of specifying which layers to avoid. Provide a list of indices + Specifying which layers to be skipped. Provide a list of indices that indicate which conv2d layers to leave untouched. round_for_shift: boolean @@ -151,9 +146,10 @@ def qconfig(**kwargs): Whether to store low-bit integer back as output before dequantizing. Some accelerators need this, e.g. VTA. - use_stop_fusion: boolean - Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit - results to be stored in memory. + debug_enabled_ops: None or list of str + Partially quantize specified operators for debugging. The default value + is None, which means will try to call all operartors' annotate rewrite + function. Returns ------- diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 77f0fe934961..d9513fd35405 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -408,7 +408,7 @@ Array UnifyDTypeScale(const Array& ref_args, } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && ref_arg->attrs.as()->kind == kQInput) { auto new_arg = Cast(ret[i], cfg->dtype_input); - if (cfg->use_stop_fusion) { + if (cfg->store_lowbit_output) { new_arg = StopFusion(new_arg); } ret.Set(i, Cast(new_arg, dtype)); @@ -617,12 +617,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; - p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; - p->stream << "use_stop_fusion==" << op->use_stop_fusion; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; p->stream << ")"; }); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 2c70da177199..da95a6c2134a 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -125,12 +125,10 @@ class QConfigNode : public Node { DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); double global_scale = 8.0; - int skip_k_conv = 1; Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); - bool use_stop_fusion = true; void VisitAttrs(AttrVisitor* v) final { v->Visit("nbit_input", &nbit_input); @@ -140,12 +138,10 @@ class QConfigNode : public Node { v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); - v->Visit("skip_k_conv", &skip_k_conv); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); - v->Visit("use_stop_fusion", &use_stop_fusion); } static constexpr const char* _type_key = "relay.quantize.QConfig"; From d15b89e48cda1b775cfd83a7d10adb11551fd0b0 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 22 Jun 2019 09:13:41 -0700 Subject: [PATCH 4/7] update --- python/tvm/relay/quantize/quantize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 5c80ae682d5c..a045d6dab81a 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -163,7 +163,7 @@ def qconfig(**kwargs): class AnnotateContext(object): - # a global singleton annotate scope + """A global singleton annotate scope""" Current = None def __init__(self): @@ -192,7 +192,7 @@ def annotate_context(): AnnotateContext.Current = AnnotateContext() return AnnotateContext.Current - + def calibrate(graph, mod=None, ctx=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` From d9f1dd1b60281a40716cd95d611bd37e38b3dc16 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 22 Jun 2019 09:18:01 -0700 Subject: [PATCH 5/7] update --- python/tvm/relay/backend/_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 2cd7320e4046..860788a4e5d0 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,7 +17,6 @@ """The interface of expr function exposed from C++.""" from __future__ import absolute_import -import logging from ... import build_module as _build from ... import container as _container from ..._ffi.function import _init_api, register_func From 668e09be52c38cd2483e4192a255abb7890a5906 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 22 Jun 2019 09:34:50 -0700 Subject: [PATCH 6/7] update --- python/tvm/relay/quantize/quantize.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index a045d6dab81a..a7749d4892fb 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -337,16 +337,15 @@ def quantize(graph, params=None, dataset=None): calibrate_pass = _transform.function_pass(calibrate, opt_level=1, name="QuantizeCalibrate") - _set_conv_counter(0) # reset counter quantize_seq = _transform.Sequential([annotate(), calibrate_pass, realize(), _transform.FoldConstant()]) - with _transform.PassContext(opt_level=3, - required_pass=["QuantizeAnnotate", - "QuantizeCalibrate", - "QuantizeRealize"]): - mod = optimize(mod) - with annotate_context(): + with annotate_context(): + with _transform.PassContext(opt_level=3, + required_pass=["QuantizeAnnotate", + "QuantizeCalibrate", + "QuantizeRealize"]): + mod = optimize(mod) mod = quantize_seq(mod) return mod[mod.entry_func.name_hint] From 2907e97e1bfe518055aaf6105313329ff63133ee Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sat, 22 Jun 2019 09:44:55 -0700 Subject: [PATCH 7/7] update --- tests/python/relay/test_pass_quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py index e02601e926f2..fe62c3b5cea4 100644 --- a/tests/python/relay/test_pass_quantize.py +++ b/tests/python/relay/test_pass_quantize.py @@ -81,7 +81,7 @@ def make_qgraph(data, weight): graph = make_graph(data) dataset, params = make_dataset(graph, 10) - with qtz.qconfig(skip_k_conv=0, global_scale=4.0, + with qtz.qconfig(skip_conv_layers=None, global_scale=4.0, round_for_shift=False, store_lowbit_output=False): qgraph0 = qtz.quantize(graph, params) qgraph0 = relay.ir_pass.infer_type(qgraph0)