From 50a7c6747352380745e64b4326d4048425b47cb2 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 19 Nov 2019 17:54:57 -0500 Subject: [PATCH] [Relay][Quantize] Integrate data-aware calibration into quantization (#4295) * [Relay][Quantize] Integrate data-aware calibration into quantization * Update _calibrate.py * trigger ci * Address comments * address comments --- python/tvm/relay/quantize/__init__.py | 1 - python/tvm/relay/quantize/_annotate.py | 1 + python/tvm/relay/quantize/_calibrate.py | 184 +++++++++++++++++++++ python/tvm/relay/quantize/kl_divergence.py | 4 +- python/tvm/relay/quantize/quantize.py | 151 +++-------------- src/relay/pass/quantize/calibrate.cc | 6 +- src/relay/pass/quantize/quantize.cc | 2 + src/relay/pass/quantize/quantize.h | 4 + 8 files changed, 222 insertions(+), 131 deletions(-) create mode 100644 python/tvm/relay/quantize/_calibrate.py diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 29b68950fa42..09dfa8f515eb 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -21,4 +21,3 @@ from .quantize import * from ._partition import register_partition_function from ._annotate import register_annotate_function -from .kl_divergence import kl_divergence_scale diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 55f3597881e7..9d679d206508 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -57,6 +57,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): _reg.schedule_injective) _reg.register_pattern("relay.op.annotation.simulated_quantize", _reg.OpPattern.ELEMWISE) +_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective) @register_relay_node diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py new file mode 100644 index 000000000000..aae50519b132 --- /dev/null +++ b/python/tvm/relay/quantize/_calibrate.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Find scales for quantization on the dataset.""" +from __future__ import absolute_import +import logging +import multiprocessing as mp +import numpy as np +import tvm + +from . import _quantize +from . import quantize +from .. import op as _op +from .. import expr as _expr +from .. import module as _module +from .. import analysis as _analysis +from .. import transform as _transform +from .. import build_module as _build_module +from ...contrib import graph_runtime +from .kl_divergence import _find_scale_by_kl + + +def collect_stats(mod, dataset): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass collects simulated_quantize op input into a tuple. + Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile + graph. + + Parameters + ---------- + mod: Module + The simulation graph after annotation. + + Returns + ------- + ret: list of ndarray + List of output data of each layer + """ + + logging.info("collecting statistics for calibration...") + func = mod['main'] + func = _quantize.CreateStatsCollector(func) + target = tvm.target.current_target() or 'llvm' + with _transform.build_config(opt_level=3): + graph, lib, params = _build_module.build(func, target=target) + outputs = [] + runtime = graph_runtime.create(graph, lib, tvm.context(target)) + runtime.set_input(**params) + + num_outputs = runtime.get_num_outputs() + outputs = [[] for i in range(num_outputs)] + + for batch in dataset: + runtime.set_input(**batch) + runtime.run() + for i in range(num_outputs): + output = runtime.get_output(i).asnumpy() + outputs[i].append(output) + for i in range(num_outputs): + outputs[i] = np.concatenate(outputs[i]).reshape(-1) + return outputs + + +def _kl_scale(stats): + with mp.Pool() as pool: + logging.info("finding threshold with kl for calibration...") + scales = list(pool.map(_find_scale_by_kl, stats)) + + def func(sq_call): # pylint: disable=unused-argument + scale = scales[func.scale_idx] + func.scale_idx += 1 + return scale + func.scale_idx = 0 + + return func + + +def _set_params(mod, input_scale_func, weight_scale_func): + quantize_op = _op.get("relay.op.annotation.simulated_quantize") + cfg = quantize.current_qconfig() + const_params = {} + + def visit_func(expr): + '''visitor function for traverse''' + if isinstance(expr, _expr.Call) and expr.op == quantize_op: + _, ndom_scale, nclip_min, nclip_max = expr.args + attrs = expr.attrs + kind = attrs.kind + nbit = cfg.get_nbit_by_kind(kind) + valid_bit = nbit - attrs.sign + + # set scale + if kind == quantize.QAnnotateKind.WEIGHT: + assert isinstance(expr.args[0], _expr.Constant) + scale = weight_scale_func(expr) + else: + scale = input_scale_func(expr) + + def _make_const(val): + return _expr.const(val, 'float32') + + valid_range = 2**valid_bit + const_params[ndom_scale] = _make_const(scale / valid_range) + const_params[nclip_min] = _make_const(- (valid_range - 1)) + const_params[nclip_max] = _make_const((valid_range - 1)) + + func = mod['main'] + _analysis.post_order_visit(func, visit_func) + func = _expr.bind(func, const_params) + return _module.Module.from_expr(func) + + +# weight scale functions +def _power2_scale(sq_call): # pylint: disable=unused-argument + """calculate weight scale with nearest mode-2 scale""" + var = sq_call.args[0] + assert isinstance(var, _expr.Constant) + val = np.amax(np.abs(var.data.asnumpy())) + return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + + +def _max_scale(sq_call): + """calculate weight scale with maximum absolute value""" + var = sq_call.args[0] + assert isinstance(var, _expr.Constant) + val = np.amax(np.abs(var.data.asnumpy())) + return val + + +# input scale functions +def _global_scale(sq_call): # pylint: disable=unused-argument + cfg = quantize.current_qconfig() + return cfg.global_scale + + +def calibrate(dataset=None): + """The calibrate procedure will try to calculate the content of + dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` + operator. + + Parameters + --------- + dataset: Optional[Iterable[NDArray]] + The calibration dataset. + + Returns + ------- + ret: Function + The module pass function. + """ + def wrapped_func(mod, ctx): # pylint: disable=unused-argument + """make transform.module pass happy""" + cfg = quantize.current_qconfig() + + if cfg.calibrate_mode == 'kl_divergence': + stats = collect_stats(mod, dataset) + input_scale_func = _kl_scale(stats) + elif cfg.calibrate_mode == 'global_scale': + input_scale_func = _global_scale + else: + raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode)) + + if cfg.weight_scale == 'max': + weight_scale_func = _max_scale + elif cfg.weight_scale == 'power2': + weight_scale_func = _power2_scale + else: + raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale)) + + return _set_params(mod, input_scale_func, weight_scale_func) + return wrapped_func diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py index bce45dca6f1c..2feb514e8611 100644 --- a/python/tvm/relay/quantize/kl_divergence.py +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -45,7 +45,7 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=invalid-name -def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): +def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): """Given a tensor, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. @@ -54,6 +54,8 @@ def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantize http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ assert isinstance(arr, np.ndarray) + assert stats is not None, "scipy needs to be installed for \ + utilizing kl calibration during quantization" min_val = np.min(arr) max_val = np.max(arr) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 7fa8a66878bc..1d60145bb133 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,14 +17,10 @@ #pylint: disable=unused-argument """Automatic quantization toolkit.""" from __future__ import absolute_import -import numpy as np - from . import _quantize +from ._calibrate import calibrate from .. import expr as _expr -from .. import module as _module -from .. import analysis as _analysis from .. import transform as _transform -from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -78,7 +74,9 @@ class QConfig(NodeBase): "dtype_input": "int8", "dtype_weight": "int8", "dtype_activation": "int32", + "calibrate_mode": "global_scale", "global_scale": 8.0, + "weight_scale": "power2", "skip_conv_layers": [0], "do_simulation": False, "round_for_shift": True, @@ -143,9 +141,20 @@ def qconfig(**kwargs): nbit_dict: dict of QAnnotateKind -> int Number of bit for every kind of annotate field. + calibrate_mode: str + The calibration mode. 'global_scale' or 'kl_divergence'. + global_scale: use global scale + kl_divergence: find scales by kl divergence on the dataset. + global_scale: float The global scale for calibration. + weight_scale: str + The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). + power2: Find the maximum of the absolute value of the tensor, and then round up to power + of two. + max: Find the maximum of the absolute value of the tensor + skip_conv_layers: list Specifying which layers to be skipped. Provide a list of indices that indicate which conv2d layers to leave untouched. Start from 0. @@ -249,113 +258,6 @@ def annotate(): return _quantize.QuantizeAnnotate() -def collect_stats(graph): - """Given an annotated graph, create a profile graph to collect profile data from the - calibration dataset. This pass collects simulated_quantize op input into a tuple. - Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile - graph. - - Parameters - ---------- - graph: Function - The simulation graph after annotation. - - Returns - ------- - ret: Function - The profile graph which outputs a tuple of profile data. - """ - return _quantize.CollectStats(graph) - - -def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): - """The calibrate procedure will try to calculate the content of - dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` - operator. - - Parameters - --------- - graph: Function - The simulation graph after annotation. - - mod: tvm.relay.Module - The module where calibration happens on. - - ctx: tvm.relay.PassContext - The pass context used for calibration. - - weight_scales: 'power2' or 'max'. - The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). - power2: Find the maximum of the absolute value of the tensor, and then round up to power - of two. - max: Find the maximum of the absolute value of the tensor. - - scales: List[float] - Pre-calculated scales for input and activations. Length and the order of elements of the - scales list should match the output tuple of the profile graph created by collect_stats. - - Returns - ------- - ret: Function - The graph after calibration - """ - def power2_scale(arr): - """calculate weight scale with nearest mode-2 scale""" - val = np.amax(np.abs(arr.asnumpy())) - return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 - - def max_scale(arr): - """calculate weight scale with maximum absolute value""" - val = np.amax(np.abs(arr.asnumpy())) - return val - - scale_idx = 0 - - cfg = current_qconfig() - const_params = {} - quantize_op = _op.get("relay.op.annotation.simulated_quantize") - - def visit_func(expr): - """Internal visit function""" - nonlocal scale_idx - if isinstance(expr, _expr.Call) and expr.op == quantize_op: - _, ndom_scale, nclip_min, nclip_max = expr.args - attrs = expr.attrs - kind = attrs.kind - nbit = cfg.get_nbit_by_kind(kind) - - valid_bit = nbit - attrs.sign - if kind in [QAnnotateKind.WEIGHT]: - if all([isinstance(arg, _expr.Constant) - for arg in [ndom_scale, nclip_min, nclip_max]]): - return - var = expr.args[0] - assert isinstance(var, _expr.Constant) - if weight_scales == 'max': - scale = max_scale(var.data) - elif weight_scales == 'power2': - scale = power2_scale(var.data) - else: - raise ValueError('{} not supported'.format(weight_scales)) - elif scales is not None: - scale = scales[scale_idx] - scale_idx += 1 - else: - scale = cfg.global_scale - - def _make_const(val): - return _expr.const(val, 'float32') - - valid_range = 2**valid_bit - const_params[ndom_scale] = _make_const(scale / valid_range) - const_params[nclip_min] = _make_const(- (valid_range - 1)) - const_params[nclip_max] = _make_const((valid_range - 1)) - - _analysis.post_order_visit(graph, visit_func) - ret = _expr.bind(graph, const_params) - return ret - - def realize(): """The realize pass will transform the simulated quantized graph, which actually computes with float32, to a real low-bit integer graph. It will @@ -391,7 +293,7 @@ def _bind_params(func, params): return _expr.bind(func, bind_dict) -def prerequisite_optimize(graph, params=None): +def prerequisite_optimize(mod, params=None): """ Prerequisite optimization passes for quantization. Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and "CanonicalizeOps" optimization before quantization. """ @@ -402,15 +304,13 @@ def prerequisite_optimize(graph, params=None): _transform.FoldConstant()]) if params: - graph = _bind_params(graph, params) + mod['main'] = _bind_params(mod['main'], params) - mod = _module.Module.from_expr(graph) - with _transform.PassContext(opt_level=3): - mod = optimize(mod) - return mod["main"] + mod = optimize(mod) + return mod -def quantize(graph, params=None, dataset=None): +def quantize(mod, params=None, dataset=None): """ The quantization procedure. Before running the three main procedure of quantization, "annotate", "calibrate" and "realize" , we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant" @@ -418,8 +318,8 @@ def quantize(graph, params=None, dataset=None): Parameters --------- - graph: Function - The original graph. + mod: Module + The original module. params : dict of str to NDArray Input parameters to the graph that do not change @@ -433,11 +333,10 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - graph = prerequisite_optimize(graph, params) + mod = prerequisite_optimize(mod, params) - mod = _module.Module.from_expr(graph) - calibrate_pass = _transform.function_pass(calibrate, opt_level=1, - name="QuantizeCalibrate") + calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1, + name="QuantizeCalibrate") quant_passes = [partition(), annotate(), calibrate_pass] @@ -452,4 +351,4 @@ def quantize(graph, params=None, dataset=None): with quantize_context(): mod = quantize_seq(mod) - return mod["main"] + return mod diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index 30b47ba69a6e..9757e58922dd 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -87,12 +87,12 @@ class StatsCollector : private ExprMutator { * \param expr The simulation graph after annotation. * \return The profile graph. */ -Expr CollectStats(const Expr& expr) { +Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } -TVM_REGISTER_API("relay._quantize.CollectStats") -.set_body_typed(CollectStats); +TVM_REGISTER_API("relay._quantize.CreateStatsCollector") +.set_body_typed(CreateStatsCollector); } // namespace quantize } // namespace relay diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index 2793577cfee2..be24ad7404e0 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -123,7 +123,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; p->stream << "global_scale=" << op->global_scale << ", "; + p->stream << "weight_scale=" << op->weight_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 8a0282ab4929..3af13a97b578 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -70,7 +70,9 @@ class QConfigNode : public Node { DataType dtype_input = Int(8); DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); + std::string calibrate_mode = "global_scale"; double global_scale = 8.0; + std::string weight_scale = "power2"; Array skip_conv_layers = Array(NodePtr(nullptr)); bool do_simulation = false; bool round_for_shift = true; @@ -84,7 +86,9 @@ class QConfigNode : public Node { v->Visit("dtype_input", &dtype_input); v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); + v->Visit("calibrate_mode", &calibrate_mode); v->Visit("global_scale", &global_scale); + v->Visit("weight_scale", &weight_scale); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("do_simulation", &do_simulation); v->Visit("round_for_shift", &round_for_shift);