-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][Quantization] KL-divergence-based per-layer calibration (#3538)
* [Relay][Quantization] Support floating-point scale * [Relay][Quantization] KL-divergence calibration on dataset * Fix unhandled LeftShift case in QuantizeRealize * Fix lint * drop QBias * fix lint * address comments * address comments * Update comments * address comments * lint * kQIdentity = 0
- Loading branch information
1 parent
5357f49
commit 33ab3c6
Showing
7 changed files
with
329 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Find optimal scale for quantization by minimizing KL-divergence""" | ||
|
||
try: | ||
from scipy import stats | ||
except ImportError: | ||
stats = None | ||
|
||
import numpy as np | ||
|
||
|
||
def _smooth_distribution(p, eps=0.0001): | ||
"""Given a discrete distribution (may have not been normalized to 1), | ||
smooth it by replacing zeros with eps multiplied by a scaling factor and taking the | ||
corresponding amount off the non-zero values. | ||
Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf | ||
""" | ||
is_zeros = (p == 0).astype(np.float32) | ||
is_nonzeros = (p != 0).astype(np.float32) | ||
n_zeros = is_zeros.sum() | ||
n_nonzeros = p.size - n_zeros | ||
if not n_nonzeros: | ||
raise ValueError('The discrete probability distribution is malformed. All entries are 0.') | ||
eps1 = eps * float(n_zeros) / float(n_nonzeros) | ||
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) | ||
hist = p.astype(np.float32) | ||
hist += eps * is_zeros + (-eps1) * is_nonzeros | ||
assert (hist <= 0).sum() == 0 | ||
return hist | ||
|
||
|
||
# pylint: disable=invalid-name | ||
def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): | ||
"""Given a tensor, find the optimal threshold for quantizing it. | ||
The reference distribution is `q`, and the candidate distribution is `p`. | ||
`q` is a truncated version of the original distribution. | ||
Ref: | ||
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf | ||
""" | ||
assert isinstance(arr, np.ndarray) | ||
|
||
min_val = np.min(arr) | ||
max_val = np.max(arr) | ||
th = max(abs(min_val), abs(max_val)) | ||
|
||
if min_val >= 0 and quantized_dtype in ['uint8']: | ||
# We need to move negative bins to positive bins to fit uint8 range. | ||
num_quantized_bins = num_quantized_bins * 2 + 1 | ||
|
||
hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) | ||
zero_bin_idx = num_bins // 2 | ||
num_half_quantized_bins = num_quantized_bins // 2 | ||
|
||
thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) | ||
divergence = np.zeros_like(thresholds) | ||
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) | ||
# i means the number of bins on half axis excluding the zero bin. | ||
for i in range(num_quantized_bins // 2, | ||
num_bins // 2 + 1): | ||
p_bin_idx_start = zero_bin_idx - i | ||
p_bin_idx_stop = zero_bin_idx + i + 1 | ||
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] | ||
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] | ||
|
||
# generate reference distribution p | ||
p = sliced_nd_hist.copy() | ||
assert p.size % 2 == 1 | ||
assert p.size >= num_quantized_bins | ||
# put left outlier count in p[0] | ||
left_outlier_count = np.sum(hist[0:p_bin_idx_start]) | ||
p[0] += left_outlier_count | ||
# put right outlier count in p[-1] | ||
right_outlier_count = np.sum(hist[p_bin_idx_stop:]) | ||
p[-1] += right_outlier_count | ||
# is_nonzeros[k] indicates whether hist[k] is nonzero | ||
is_nonzeros = (p != 0).astype(np.int32) | ||
|
||
# calculate how many bins should be merged to generate quantized distribution q | ||
num_merged_bins = sliced_nd_hist.size // num_quantized_bins | ||
# merge hist into num_quantized_bins bins | ||
for j in range(num_quantized_bins): | ||
start = j * num_merged_bins | ||
stop = start + num_merged_bins | ||
quantized_bins[j] = sliced_nd_hist[start:stop].sum() | ||
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() | ||
# expand quantized_bins into p.size bins | ||
q = np.zeros(sliced_nd_hist.size, dtype=np.float32) | ||
for j in range(num_quantized_bins): | ||
start = j * num_merged_bins | ||
if j == num_quantized_bins - 1: | ||
stop = len(is_nonzeros) | ||
else: | ||
stop = start + num_merged_bins | ||
norm = is_nonzeros[start:stop].sum() | ||
if norm != 0: | ||
q[start:stop] = float(quantized_bins[j]) / float(norm) | ||
q[p == 0] = 0 | ||
p = _smooth_distribution(p) | ||
# There is a chance that q is an invalid probability distribution. | ||
try: | ||
q = _smooth_distribution(q) | ||
except ValueError: | ||
divergence[i - num_half_quantized_bins] = float("inf") | ||
divergence[i - num_half_quantized_bins] = stats.entropy(p, q) | ||
|
||
min_divergence_idx = np.argmin(divergence) | ||
opt_th = thresholds[min_divergence_idx] | ||
return opt_th |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* | ||
* \file calibrate.cc | ||
* | ||
* \brief Create profile graph and calibrate on dataset | ||
*/ | ||
#include <tvm/relay/analysis.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include "./quantize.h" | ||
|
||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace quantize { | ||
|
||
class StatsCollector : private ExprMutator { | ||
public: | ||
Expr Collect(const Expr& expr) { | ||
auto new_e = this->Mutate(expr); | ||
const FunctionNode* func = new_e.as<FunctionNode>(); | ||
CHECK(func) << "Input shoule be Function"; | ||
Expr new_body = TupleNode::make(std::move(profile_data_)); | ||
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params, | ||
func->attrs); | ||
} | ||
|
||
private: | ||
Array<Expr> profile_data_; | ||
|
||
Expr VisitExpr_(const CallNode* call) { | ||
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); | ||
Expr new_e = ExprMutator::VisitExpr_(call); | ||
const CallNode* new_call = new_e.as<CallNode>(); | ||
CHECK(new_call); | ||
if (new_call->op.same_as(simulated_quantize)) { | ||
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>(); | ||
// rewrite the annotation | ||
auto new_attrs = make_node<SimulatedQuantizeAttrs>(); | ||
const Expr& quantize_input = new_call->args[0]; // expression being quantized | ||
auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument | ||
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder}; | ||
new_attrs->kind = QAnnotateKind::kQIdentity; | ||
new_attrs->sign = attrs->sign; | ||
new_attrs->rounding = attrs->rounding; | ||
Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {}); | ||
|
||
// add non-const expressions to profile data | ||
if (attrs->kind != QAnnotateKind::kQWeight) { | ||
CHECK(!quantize_input.as<ConstantNode>()); | ||
profile_data_.push_back(identity_quantize); | ||
} | ||
return identity_quantize; | ||
} else { | ||
return new_e; | ||
} | ||
} | ||
}; | ||
|
||
/* | ||
* \brief Given an annotated graph, create a profile graph to collect profile data from the | ||
* calibration dataset. | ||
* | ||
* This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to | ||
* identity mode. The tuple is the output of the profile graph. Both input and output of this pass | ||
* are relay::Function. | ||
* | ||
* \param expr The simulation graph after annotation. | ||
* \return The profile graph. | ||
*/ | ||
Expr CollectStats(const Expr& expr) { | ||
return StatsCollector().Collect(expr); | ||
} | ||
|
||
TVM_REGISTER_API("relay._quantize.CollectStats") | ||
.set_body_typed(CollectStats); | ||
|
||
} // namespace quantize | ||
} // namespace relay | ||
} // namespace tvm |
Oops, something went wrong.