From 53748eed862f38b852fcdb5a0381eed15fc2e9d8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Aug 2019 20:55:27 -0700 Subject: [PATCH] [Relay][Quantization] KL-divergence-based per-layer calibration (#3538) * [Relay][Quantization] Support floating-point scale * [Relay][Quantization] KL-divergence calibration on dataset * Fix unhandled LeftShift case in QuantizeRealize * Fix lint * drop QBias * fix lint * address comments * address comments * Update comments * address comments * lint * kQIdentity = 0 --- python/tvm/relay/quantize/__init__.py | 1 + python/tvm/relay/quantize/_annotate.py | 10 +- python/tvm/relay/quantize/kl_divergence.py | 124 +++++++++++++++++++++ python/tvm/relay/quantize/quantize.py | 57 +++++++++- src/relay/pass/quantize/calibrate.cc | 99 ++++++++++++++++ src/relay/pass/{ => quantize}/quantize.cc | 61 ++++------ src/relay/pass/{ => quantize}/quantize.h | 27 ++++- 7 files changed, 329 insertions(+), 50 deletions(-) create mode 100644 python/tvm/relay/quantize/kl_divergence.py create mode 100644 src/relay/pass/quantize/calibrate.cc rename src/relay/pass/{ => quantize}/quantize.cc (93%) rename src/relay/pass/{ => quantize}/quantize.h (89%) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 45bb62e66853..a9e7b40b039e 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -20,3 +20,4 @@ from .quantize import * from ._annotate import register_annotate_function +from .kl_divergence import kl_divergence_scale diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 7b7f9c42f2f1..e03eaab507ad 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): data, scale, clip_min, clip_max = inputs + if attrs.kind == QAnnotateKind.IDENTITY: + return [topi.identity(data)] + # simulate rounding error scaled_data = topi.divide(data, scale) clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) @@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): _reg.register_schedule("relay.op.annotation.simulated_quantize", _reg.schedule_injective) _reg.register_pattern("relay.op.annotation.simulated_quantize", - _reg.OpPattern.OPAQUE) + _reg.OpPattern.ELEMWISE) @register_relay_node @@ -251,7 +254,7 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind is None and rhs_kind is not None: # quantize lhs to INPUT field if it is normal expression - assert rhs_kind == QAnnotateKind.INPUT + assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION] lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) @@ -275,7 +278,8 @@ def add_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) - if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT: + if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \ + (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError() diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py new file mode 100644 index 000000000000..bce45dca6f1c --- /dev/null +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Find optimal scale for quantization by minimizing KL-divergence""" + +try: + from scipy import stats +except ImportError: + stats = None + +import numpy as np + + +def _smooth_distribution(p, eps=0.0001): + """Given a discrete distribution (may have not been normalized to 1), + smooth it by replacing zeros with eps multiplied by a scaling factor and taking the + corresponding amount off the non-zero values. + Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf + """ + is_zeros = (p == 0).astype(np.float32) + is_nonzeros = (p != 0).astype(np.float32) + n_zeros = is_zeros.sum() + n_nonzeros = p.size - n_zeros + if not n_nonzeros: + raise ValueError('The discrete probability distribution is malformed. All entries are 0.') + eps1 = eps * float(n_zeros) / float(n_nonzeros) + assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) + hist = p.astype(np.float32) + hist += eps * is_zeros + (-eps1) * is_nonzeros + assert (hist <= 0).sum() == 0 + return hist + + +# pylint: disable=invalid-name +def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): + """Given a tensor, find the optimal threshold for quantizing it. + The reference distribution is `q`, and the candidate distribution is `p`. + `q` is a truncated version of the original distribution. + + Ref: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + """ + assert isinstance(arr, np.ndarray) + + min_val = np.min(arr) + max_val = np.max(arr) + th = max(abs(min_val), abs(max_val)) + + if min_val >= 0 and quantized_dtype in ['uint8']: + # We need to move negative bins to positive bins to fit uint8 range. + num_quantized_bins = num_quantized_bins * 2 + 1 + + hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) + zero_bin_idx = num_bins // 2 + num_half_quantized_bins = num_quantized_bins // 2 + + thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) + divergence = np.zeros_like(thresholds) + quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) + # i means the number of bins on half axis excluding the zero bin. + for i in range(num_quantized_bins // 2, + num_bins // 2 + 1): + p_bin_idx_start = zero_bin_idx - i + p_bin_idx_stop = zero_bin_idx + i + 1 + thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] + sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] + + # generate reference distribution p + p = sliced_nd_hist.copy() + assert p.size % 2 == 1 + assert p.size >= num_quantized_bins + # put left outlier count in p[0] + left_outlier_count = np.sum(hist[0:p_bin_idx_start]) + p[0] += left_outlier_count + # put right outlier count in p[-1] + right_outlier_count = np.sum(hist[p_bin_idx_stop:]) + p[-1] += right_outlier_count + # is_nonzeros[k] indicates whether hist[k] is nonzero + is_nonzeros = (p != 0).astype(np.int32) + + # calculate how many bins should be merged to generate quantized distribution q + num_merged_bins = sliced_nd_hist.size // num_quantized_bins + # merge hist into num_quantized_bins bins + for j in range(num_quantized_bins): + start = j * num_merged_bins + stop = start + num_merged_bins + quantized_bins[j] = sliced_nd_hist[start:stop].sum() + quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() + # expand quantized_bins into p.size bins + q = np.zeros(sliced_nd_hist.size, dtype=np.float32) + for j in range(num_quantized_bins): + start = j * num_merged_bins + if j == num_quantized_bins - 1: + stop = len(is_nonzeros) + else: + stop = start + num_merged_bins + norm = is_nonzeros[start:stop].sum() + if norm != 0: + q[start:stop] = float(quantized_bins[j]) / float(norm) + q[p == 0] = 0 + p = _smooth_distribution(p) + # There is a chance that q is an invalid probability distribution. + try: + q = _smooth_distribution(q) + except ValueError: + divergence[i - num_half_quantized_bins] = float("inf") + divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + + min_divergence_idx = np.argmin(divergence) + opt_th = thresholds[min_divergence_idx] + return opt_th diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index beebceaf8590..07d4d9d25e01 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -32,6 +32,7 @@ class QAnnotateKind(object): """Denote the kind of annotation field, corresponding to different nbit configure.""" + IDENTITY = 0 INPUT = 1 WEIGHT = 2 ACTIVATION = 3 @@ -43,6 +44,7 @@ def kind2str(kind): QAnnotateKind.INPUT: "input", QAnnotateKind.WEIGHT: "weight", QAnnotateKind.ACTIVATION: "activation", + QAnnotateKind.IDENTITY: "identity" } assert kind in str_map return str_map[kind] @@ -195,7 +197,26 @@ def annotate_context(): return AnnotateContext.Current -def calibrate(graph, mod=None, ctx=None): +def collect_stats(graph): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass collects simulated_quantize op input into a tuple. + Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile + graph. + + Parameters + ---------- + graph: Function + The simulation graph after annotation. + + Returns + ------- + ret: Function + The profile graph which outputs a tuple of profile data. + """ + return _quantize.CollectStats(graph) + + +def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -211,6 +232,16 @@ def calibrate(graph, mod=None, ctx=None): ctx: tvm.relay.PassContext The pass context used for calibration. + weight_scales: 'power2' or 'max'. + The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). + power2: Find the maximum of the absolute value of the tensor, and then round up to power + of two. + max: Find the maximum of the absolute value of the tensor. + + scales: List[float] + Pre-calculated scales for input and activations. Length and the order of elements of the + scales list should match the output tuple of the profile graph created by collect_stats. + Returns ------- ret: Function @@ -221,12 +252,20 @@ def power2_scale(arr): val = np.amax(np.abs(arr.asnumpy())) return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + def max_scale(arr): + """calculate weight scale with maximum absolute value""" + val = np.amax(np.abs(arr.asnumpy())) + return val + + scale_idx = 0 + cfg = current_qconfig() const_params = {} quantize_op = _op.get("relay.op.annotation.simulated_quantize") def visit_func(expr): """Internal visit function""" + nonlocal scale_idx if isinstance(expr, _expr.Call) and expr.op == quantize_op: _, ndom_scale, nclip_min, nclip_max = expr.args attrs = expr.attrs @@ -234,11 +273,21 @@ def visit_func(expr): nbit = cfg.get_nbit_by_kind(kind) valid_bit = nbit - attrs.sign - - if kind == QAnnotateKind.WEIGHT: + if kind in [QAnnotateKind.WEIGHT]: + if all([isinstance(arg, _expr.Constant) + for arg in [ndom_scale, nclip_min, nclip_max]]): + return var = expr.args[0] assert isinstance(var, _expr.Constant) - scale = power2_scale(var.data) + if weight_scales == 'max': + scale = max_scale(var.data) + elif weight_scales == 'power2': + scale = power2_scale(var.data) + else: + raise ValueError('{} not supported'.format(weight_scales)) + elif scales is not None: + scale = scales[scale_idx] + scale_idx += 1 else: scale = cfg.global_scale diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc new file mode 100644 index 000000000000..30b47ba69a6e --- /dev/null +++ b/src/relay/pass/quantize/calibrate.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file calibrate.cc + * + * \brief Create profile graph and calibrate on dataset + */ +#include +#include +#include "./quantize.h" + + +namespace tvm { +namespace relay { +namespace quantize { + +class StatsCollector : private ExprMutator { + public: + Expr Collect(const Expr& expr) { + auto new_e = this->Mutate(expr); + const FunctionNode* func = new_e.as(); + CHECK(func) << "Input shoule be Function"; + Expr new_body = TupleNode::make(std::move(profile_data_)); + return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + func->attrs); + } + + private: + Array profile_data_; + + Expr VisitExpr_(const CallNode* call) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + Expr new_e = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_e.as(); + CHECK(new_call); + if (new_call->op.same_as(simulated_quantize)) { + auto attrs = new_call->attrs.as(); + // rewrite the annotation + auto new_attrs = make_node(); + const Expr& quantize_input = new_call->args[0]; // expression being quantized + auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument + Array new_args{quantize_input, placeholder, placeholder, placeholder}; + new_attrs->kind = QAnnotateKind::kQIdentity; + new_attrs->sign = attrs->sign; + new_attrs->rounding = attrs->rounding; + Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {}); + + // add non-const expressions to profile data + if (attrs->kind != QAnnotateKind::kQWeight) { + CHECK(!quantize_input.as()); + profile_data_.push_back(identity_quantize); + } + return identity_quantize; + } else { + return new_e; + } + } +}; + +/* + * \brief Given an annotated graph, create a profile graph to collect profile data from the + * calibration dataset. + * + * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to + * identity mode. The tuple is the output of the profile graph. Both input and output of this pass + * are relay::Function. + * + * \param expr The simulation graph after annotation. + * \return The profile graph. + */ +Expr CollectStats(const Expr& expr) { + return StatsCollector().Collect(expr); +} + +TVM_REGISTER_API("relay._quantize.CollectStats") +.set_body_typed(CollectStats); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize/quantize.cc similarity index 93% rename from src/relay/pass/quantize.cc rename to src/relay/pass/quantize/quantize.cc index 83d9220ccf79..6cffc2053e5c 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -36,8 +36,8 @@ #include #include #include -#include "pattern_util.h" -#include "quantize.h" +#include "../pattern_util.h" +#include "./quantize.h" namespace tvm { @@ -46,22 +46,6 @@ namespace quantize { using namespace relay::transform; -/*! \brief Attribute for simulated quantize operator */ -struct SimulatedQuantizeAttrs : public tvm::AttrsNode { - int kind; - bool sign; - std::string rounding; - - TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); - } -}; - TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array& types, @@ -166,23 +150,22 @@ inline Expr ForwardOp(const Call& ref_call, const Array& args) { /* calculate `data * s1 / s2`, use shift if possible */ -inline Expr MulAndDiv(Expr data, float s1, float s2) { +inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { // here we assume the dtype of data is dtype activation - const QConfig& cfg = QConfig::Current(); if (s1 == s2) return data; float factor = s1 / s2; float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + return LeftShift(data, MakeConstantScalar(dtype, static_cast(shift_factor))); } else if (static_cast(factor) == factor) { - return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor)); + return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - LOG(FATAL) << "fall back to float computation"; data = Cast(data, Float(32)); - return Multiply(data, MakeConstantScalar(Float(32), factor)); + data = Multiply(data, MakeConstantScalar(Float(32), factor)); + return Cast(Round(data), dtype); } } @@ -216,15 +199,21 @@ Expr QuantizeRealize(const Call& ref_call, } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - CHECK_GT(shift_nbit, 0); + CHECK_NE(shift_nbit, 0); if (static_cast(shift_nbit) == shift_nbit) { - // use right shift - if (cfg->round_for_shift) { - float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + if (shift_nbit > 0) { + // use right shift + if (cfg->round_for_shift) { + float round_bias = std::pow(2.0, shift_nbit - 1); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(round_bias))); + } + data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); + } else { + data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else { @@ -338,15 +327,11 @@ Expr MulRealize(const Call& ref_call, Expr rdata = rhs->data; DataType dtype = cfg->dtype_activation; - if (lhs->dtype == Float(32)) { + if (lhs->dtype != dtype) { ldata = Cast(ldata, dtype); - } else { - CHECK_EQ(lhs->dtype, dtype); } - if (rhs->dtype == Float(32)) { + if (rhs->dtype != dtype) { rdata = Cast(rdata, dtype); - } else { - CHECK_EQ(rhs->dtype, dtype); } Expr ret = ForwardOp(ref_call, {ldata, rdata}); @@ -418,7 +403,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args Expr dom_scale = MakeConstantScalar(Float(32), s); for (size_t i = 0; i < ret.size(); ++i) { float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s)); + ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); } *dtype_ptr = dtype; diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize/quantize.h similarity index 89% rename from src/relay/pass/quantize.h rename to src/relay/pass/quantize/quantize.h index 262d420acf97..4965a706b4b4 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -23,13 +23,13 @@ * \file tvm/relay/pass/quantize.h * \brief Header of definitions for quantization */ -#ifndef TVM_RELAY_PASS_QUANTIZE_H_ -#define TVM_RELAY_PASS_QUANTIZE_H_ +#ifndef TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ +#define TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ #include #include #include -#include "pattern_util.h" +#include "../pattern_util.h" namespace tvm { namespace relay { @@ -37,9 +37,26 @@ namespace quantize { /*! \brief Kind of annotate field */ enum QAnnotateKind : int { + kQIdentity = 0, kQInput = 1, kQWeight = 2, - kQActivation = 3, + kQActivation = 3 +}; + +/*! \brief Attribute for simulated quantize operator */ +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + int kind; + bool sign; + std::string rounding; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(kind) + .describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true) + .describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round") + .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + } }; /*! @@ -242,4 +259,4 @@ TVM_DLL QConfig qconfig(); } // namespace quantize } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_QUANTIZE_H_ +#endif // TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_