Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Make calibration faster and more memory usage friendly #4589

Merged
merged 11 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions python/tvm/relay/quantize/_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,7 @@
from .kl_divergence import _find_scale_by_kl


def collect_stats(mod, dataset):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.

Parameters
----------
mod: Module
The simulation graph after annotation.

Returns
-------
ret: list of ndarray
List of output data of each layer
"""

logging.info("collecting statistics for calibration...")
def _get_profile_runtime(mod):
func = mod['main']
func = _quantize.CreateStatsCollector(func)

Expand All @@ -63,30 +46,61 @@ def collect_stats(mod, dataset):

with _transform.build_config(opt_level=3):
graph, lib, params = _build_module.build(func, target=target)
outputs = []
runtime = graph_runtime.create(graph, lib, ctx)
runtime.set_input(**params)

return runtime


def collect_stats(mod, dataset, chunk_by=-1):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.

Parameters
----------
mod: Module
The simulation graph after annotation.

dataset: Iterable[NDArray]
The calibration dataset.

chunk_by: optional, int
The size of chunk to be returned in one iteration. It is meant to be
used for reducing memory usage. If not specified, return samples for
all layers in one chunk.

Returns
-------
ret: Iterable[list of ndarray]
List of output data of each layer, chunked by the chunk_by parameter
"""
logging.info("collecting statistics for calibration...")
runtime = _get_profile_runtime(mod)
num_outputs = runtime.get_num_outputs()
outputs = [[] for i in range(num_outputs)]
chunk_by = num_outputs if chunk_by == -1 else chunk_by

for batch in dataset:
runtime.set_input(**batch)
runtime.run()
for i in range(num_outputs):
output = runtime.get_output(i).asnumpy()
outputs[i].append(output)
for i in range(num_outputs):
outputs[i] = np.concatenate(outputs[i]).reshape(-1)
return outputs
for i in range(0, num_outputs, chunk_by):
outputs = [[] for i in range(min(chunk_by, num_outputs - i))]
for batch in dataset:
runtime.set_input(**batch)
runtime.run()
for j in range(i, min(i+chunk_by, num_outputs)):
outputs[j-i].append(runtime.get_output(j).asnumpy())
yield [np.concatenate(output).reshape(-1) for output in outputs]


def _kl_scale(stats):
with mp.Pool() as pool:
def _kl_scale(mod, dataset):
cfg = quantize.current_qconfig()
chunk_by = cfg.calibrate_chunk_by
scales = []
for samples in collect_stats(mod, dataset, chunk_by):
logging.info("finding threshold with kl for calibration...")
scales = list(pool.map(_find_scale_by_kl, stats))
with mp.Pool() as pool:
scales += list(pool.map(_find_scale_by_kl, samples))

def func(sq_call): # pylint: disable=unused-argument
def func(_):
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale
Expand Down Expand Up @@ -168,13 +182,12 @@ def calibrate(dataset=None):
ret: Function
The module pass function.
"""
def wrapped_func(mod, ctx): # pylint: disable=unused-argument
def wrapped_func(mod, _):
"""make transform.module pass happy"""
cfg = quantize.current_qconfig()

if cfg.calibrate_mode == 'kl_divergence':
stats = collect_stats(mod, dataset)
input_scale_func = _kl_scale(stats)
input_scale_func = _kl_scale(mod, dataset)
elif cfg.calibrate_mode == 'global_scale':
input_scale_func = _global_scale
else:
Expand Down
100 changes: 13 additions & 87 deletions python/tvm/relay/quantize/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,14 @@
# under the License.
"""Find optimal scale for quantization by minimizing KL-divergence"""

try:
from scipy import stats
except ImportError:
stats = None

import ctypes
import numpy as np


def _smooth_distribution(p, eps=0.0001):
"""Given a discrete distribution (may have not been normalized to 1),
smooth it by replacing zeros with eps multiplied by a scaling factor and taking the
corresponding amount off the non-zero values.
Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf
"""
is_zeros = (p == 0).astype(np.float32)
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
hist += eps * is_zeros + (-eps1) * is_nonzeros
assert (hist <= 0).sum() == 0
return hist
from . import _quantize


# pylint: disable=invalid-name
def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
def _find_scale_by_kl(arr, quantized_dtype='int8',
num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Expand All @@ -54,73 +32,21 @@ def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)
assert stats is not None, "scipy needs to be installed for \
utilizing kl calibration during quantization"

min_val = np.min(arr)
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))
thres = max(abs(min_val), abs(max_val))

if min_val >= 0 and quantized_dtype in ['uint8']:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1

hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2

thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
# i means the number of bins on half axis excluding the zero bin.
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1):
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]

# generate reference distribution p
p = sliced_nd_hist.copy()
assert p.size % 2 == 1
assert p.size >= num_quantized_bins
# put left outlier count in p[0]
left_outlier_count = np.sum(hist[0:p_bin_idx_start])
p[0] += left_outlier_count
# put right outlier count in p[-1]
right_outlier_count = np.sum(hist[p_bin_idx_stop:])
p[-1] += right_outlier_count
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (p != 0).astype(np.int32)
def get_pointer(arr, ctypes_type):
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type))
return ctypes.cast(ptr, ctypes.c_void_p)

# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum()
# expand quantized_bins into p.size bins
q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
for j in range(num_quantized_bins):
start = j * num_merged_bins
if j == num_quantized_bins - 1:
stop = len(is_nonzeros)
else:
stop = start + num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[p == 0] = 0
p = _smooth_distribution(p)
# There is a chance that q is an invalid probability distribution.
try:
q = _smooth_distribution(q)
except ValueError:
divergence[i - num_half_quantized_bins] = float("inf")
divergence[i - num_half_quantized_bins] = stats.entropy(p, q)
hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-thres, thres))
hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int)
hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float)

min_divergence_idx = np.argmin(divergence)
opt_th = thresholds[min_divergence_idx]
return opt_th
return _quantize.FindScaleByKLMinimization(hist_ptr, hist_edges_ptr,
num_bins, num_quantized_bins)
3 changes: 2 additions & 1 deletion python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class QConfig(NodeBase):
"do_simulation": False,
"round_for_shift": True,
"debug_enabled_ops": None,
"rounding": "UPWARD"
"rounding": "UPWARD",
"calibrate_chunk_by": -1,
}

# pylint: disable=no-member
Expand Down
122 changes: 122 additions & 0 deletions src/relay/pass/quantize/calibrate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,122 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <numeric>
#include "./quantize.h"

namespace tvm {
namespace relay {
namespace quantize {

// KL divergence minimization code is adapted from MXNet.
// The original one is in incubator-mxnet/src/operator/quantization/calibrate.cc
static std::vector<float> SmoothDistribution(const std::vector<float>& p,
const float eps = 0.0001) {
std::vector<size_t> is_zeros(p.size());
std::vector<size_t> is_nonzeros(p.size());
{
auto it = p.begin();
std::generate(is_zeros.begin(), is_zeros.end(),
[&it]() { return static_cast<size_t>(*(it++) == 0.f); });
}
{
auto it = p.begin();
std::generate(is_nonzeros.begin(), is_nonzeros.end(),
[&it]() { return static_cast<size_t>(*(it++) != 0.f); });
}
size_t n_zeros = std::accumulate(is_zeros.begin(), is_zeros.end(), 0);
size_t n_nonzeros = p.size() - n_zeros;
if (!n_nonzeros) {
// The discrete probability distribution is malformed. All entries are 0.
return std::vector<float>();
}
float eps1 = eps * static_cast<float>(n_zeros) / static_cast<float>(n_nonzeros);
if (eps1 >= 1.0) return std::vector<float>();
auto ret = p;
for (size_t i = 0; i < p.size(); i++) {
ret[i] += eps * is_zeros[i] - eps1 * is_nonzeros[i];
}
return ret;
}

static float ComputeEntropy(float* p, float* q, size_t size) {
float p_sum = std::accumulate(p, p+size, 0.f);
float q_sum = std::accumulate(q, q+size, 0.f);
float ret = 0;
for (size_t i = 0; i < size; i++) {
CHECK(p[i] > 0 && q[i] > 0);
p[i] /= p_sum;
q[i] /= q_sum;
if (p[i] && q[i]) ret += p[i] * std::log(p[i] / q[i]);
}
return ret;
}

float MinimizeKL(const std::vector<int>& hist,
const std::vector<float>& hist_edges,
int num_bins, int num_quantized_bins) {
const int zero_bin_idx = num_bins / 2;
const int num_half_quantized_bins = num_quantized_bins / 2;
std::vector<float> thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f);
std::vector<float> divergence(thresholds.size(), 0.f);
std::vector<float> quantized_bins(num_quantized_bins, 0);
for (int i = num_quantized_bins / 2; i < zero_bin_idx + 1; ++i) {
const int p_bin_idx_start = zero_bin_idx - i;
const int p_bin_idx_stop = zero_bin_idx + i + 1;
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop];

std::vector<int> sliced_nd_hist(p_bin_idx_stop - p_bin_idx_start);
std::vector<float> p(sliced_nd_hist.size());
p[0] = 0;
p.back() = 0;
for (int j = 0; j < num_bins; j++) {
if (j <= p_bin_idx_start) {
p[0] += hist[j];
} else if (j >= p_bin_idx_stop) {
p.back() += hist[j];
} else {
sliced_nd_hist[j - p_bin_idx_start] = hist[j];
p[j - p_bin_idx_start] = hist[j];
}
}
// calculate how many bins should be merged to generate quantized distribution q
const auto num_merged_bins = sliced_nd_hist.size() / num_quantized_bins;
for (int j = 0; j < num_quantized_bins; j++) {
const int start = j * num_merged_bins;
const int stop = (j + 1) * num_merged_bins;
quantized_bins[j] =
std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0);
}
quantized_bins.back() += std::accumulate(
sliced_nd_hist.begin() + static_cast<int>(num_quantized_bins * num_merged_bins),
sliced_nd_hist.end(), 0);
// expand quantized_bins into p.size bins
std::vector<float> q(sliced_nd_hist.size(), 0);
for (int j = 0; j < num_quantized_bins; j++) {
const int start = j * num_merged_bins;
const int stop = (j == num_quantized_bins - 1) ? q.size() : ((j + 1) * num_merged_bins);
int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop,
[](size_t i) { return i != 0; });
if (norm) {
for (int k = start; k < stop; k++) {
if (p[k]) q[k] = quantized_bins[j] / norm;
}
}
}
p = SmoothDistribution(p);
q = SmoothDistribution(q);

if (!q.size()) {
divergence[i - num_half_quantized_bins] = std::numeric_limits<float>::infinity();
} else {
divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size());
}
}
auto min_divergence_idx = std::distance(divergence.begin(),
std::min_element(divergence.begin(), divergence.end()));
return thresholds[min_divergence_idx];;
}

class StatsCollector : private ExprMutator {
public:
StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {}
Expand Down Expand Up @@ -95,6 +205,18 @@ Expr CreateStatsCollector(const Expr& expr) {
TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
.set_body_typed(CreateStatsCollector);


TVM_REGISTER_API("relay._quantize.FindScaleByKLMinimization")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int* hist_ptr = static_cast<int*>(static_cast<void*>(args[0]));
float* hist_edges_ptr = static_cast<float*>(static_cast<void*>(args[1]));
int num_bins = args[2];
int num_quantized_bins = args[3];
std::vector<int> hist(hist_ptr, hist_ptr + num_bins);
std::vector<float> hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1);
ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins);
});

} // namespace quantize
} // namespace relay
} // namespace tvm
Loading