Skip to content

Commit

Permalink
[Relay][Quantization] KL-divergence-based per-layer calibration (#3538)
Browse files Browse the repository at this point in the history
* [Relay][Quantization] Support floating-point scale

* [Relay][Quantization] KL-divergence calibration on dataset

* Fix unhandled LeftShift case in QuantizeRealize

* Fix lint

* drop QBias

* fix lint

* address comments

* address comments

* Update comments

* address comments

* lint

* kQIdentity = 0
  • Loading branch information
vinx13 authored and ZihengJiang committed Aug 2, 2019
1 parent 5357f49 commit 33ab3c6
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 50 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from .quantize import *
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
10 changes: 7 additions & 3 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):

data, scale, clip_min, clip_max = inputs

if attrs.kind == QAnnotateKind.IDENTITY:
return [topi.identity(data)]

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
Expand All @@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE)
_reg.OpPattern.ELEMWISE)


@register_relay_node
Expand Down Expand Up @@ -251,7 +254,7 @@ def add_rewrite(ref_call, new_args, ctx):

if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
Expand All @@ -275,7 +278,8 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
(lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()
Expand Down
124 changes: 124 additions & 0 deletions python/tvm/relay/quantize/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find optimal scale for quantization by minimizing KL-divergence"""

try:
from scipy import stats
except ImportError:
stats = None

import numpy as np


def _smooth_distribution(p, eps=0.0001):
"""Given a discrete distribution (may have not been normalized to 1),
smooth it by replacing zeros with eps multiplied by a scaling factor and taking the
corresponding amount off the non-zero values.
Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf
"""
is_zeros = (p == 0).astype(np.float32)
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
hist += eps * is_zeros + (-eps1) * is_nonzeros
assert (hist <= 0).sum() == 0
return hist


# pylint: disable=invalid-name
def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Ref:
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)

min_val = np.min(arr)
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))

if min_val >= 0 and quantized_dtype in ['uint8']:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1

hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2

thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
# i means the number of bins on half axis excluding the zero bin.
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1):
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]

# generate reference distribution p
p = sliced_nd_hist.copy()
assert p.size % 2 == 1
assert p.size >= num_quantized_bins
# put left outlier count in p[0]
left_outlier_count = np.sum(hist[0:p_bin_idx_start])
p[0] += left_outlier_count
# put right outlier count in p[-1]
right_outlier_count = np.sum(hist[p_bin_idx_stop:])
p[-1] += right_outlier_count
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (p != 0).astype(np.int32)

# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum()
# expand quantized_bins into p.size bins
q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
for j in range(num_quantized_bins):
start = j * num_merged_bins
if j == num_quantized_bins - 1:
stop = len(is_nonzeros)
else:
stop = start + num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[p == 0] = 0
p = _smooth_distribution(p)
# There is a chance that q is an invalid probability distribution.
try:
q = _smooth_distribution(q)
except ValueError:
divergence[i - num_half_quantized_bins] = float("inf")
divergence[i - num_half_quantized_bins] = stats.entropy(p, q)

min_divergence_idx = np.argmin(divergence)
opt_th = thresholds[min_divergence_idx]
return opt_th
57 changes: 53 additions & 4 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class QAnnotateKind(object):
"""Denote the kind of annotation field, corresponding
to different nbit configure."""
IDENTITY = 0
INPUT = 1
WEIGHT = 2
ACTIVATION = 3
Expand All @@ -43,6 +44,7 @@ def kind2str(kind):
QAnnotateKind.INPUT: "input",
QAnnotateKind.WEIGHT: "weight",
QAnnotateKind.ACTIVATION: "activation",
QAnnotateKind.IDENTITY: "identity"
}
assert kind in str_map
return str_map[kind]
Expand Down Expand Up @@ -195,7 +197,26 @@ def annotate_context():
return AnnotateContext.Current


def calibrate(graph, mod=None, ctx=None):
def collect_stats(graph):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
graph: Function
The simulation graph after annotation.
Returns
-------
ret: Function
The profile graph which outputs a tuple of profile data.
"""
return _quantize.CollectStats(graph)


def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Expand All @@ -211,6 +232,16 @@ def calibrate(graph, mod=None, ctx=None):
ctx: tvm.relay.PassContext
The pass context used for calibration.
weight_scales: 'power2' or 'max'.
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor.
scales: List[float]
Pre-calculated scales for input and activations. Length and the order of elements of the
scales list should match the output tuple of the profile graph created by collect_stats.
Returns
-------
ret: Function
Expand All @@ -221,24 +252,42 @@ def power2_scale(arr):
val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0

def max_scale(arr):
"""calculate weight scale with maximum absolute value"""
val = np.amax(np.abs(arr.asnumpy()))
return val

scale_idx = 0

cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")

def visit_func(expr):
"""Internal visit function"""
nonlocal scale_idx
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)

valid_bit = nbit - attrs.sign

if kind == QAnnotateKind.WEIGHT:
if kind in [QAnnotateKind.WEIGHT]:
if all([isinstance(arg, _expr.Constant)
for arg in [ndom_scale, nclip_min, nclip_max]]):
return
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data)
if weight_scales == 'max':
scale = max_scale(var.data)
elif weight_scales == 'power2':
scale = power2_scale(var.data)
else:
raise ValueError('{} not supported'.format(weight_scales))
elif scales is not None:
scale = scales[scale_idx]
scale_idx += 1
else:
scale = cfg.global_scale

Expand Down
99 changes: 99 additions & 0 deletions src/relay/pass/quantize/calibrate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
*
* \file calibrate.cc
*
* \brief Create profile graph and calibrate on dataset
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "./quantize.h"


namespace tvm {
namespace relay {
namespace quantize {

class StatsCollector : private ExprMutator {
public:
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}

private:
Array<Expr> profile_data_;

Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>();
const Expr& quantize_input = new_call->args[0]; // expression being quantized
auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
new_attrs->kind = QAnnotateKind::kQIdentity;
new_attrs->sign = attrs->sign;
new_attrs->rounding = attrs->rounding;
Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});

// add non-const expressions to profile data
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!quantize_input.as<ConstantNode>());
profile_data_.push_back(identity_quantize);
}
return identity_quantize;
} else {
return new_e;
}
}
};

/*
* \brief Given an annotated graph, create a profile graph to collect profile data from the
* calibration dataset.
*
* This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
* identity mode. The tuple is the output of the profile graph. Both input and output of this pass
* are relay::Function.
*
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
Expr CollectStats(const Expr& expr) {
return StatsCollector().Collect(expr);
}

TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);

} // namespace quantize
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 33ab3c6

Please sign in to comment.