Skip to content

Commit

Permalink
[QUANTIZE] Quantization implementation. (apache#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tmoreau89 committed Mar 20, 2019
1 parent 11a7245 commit fad10c4
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 4 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from . import adt
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from . import prelude
from . import parser
from . import debug
from . import param_dict
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
"AlterOpLayout": 4,
}


Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/quantize/_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,30 @@
"""Internal module for quantization."""
from __future__ import absolute_import
from tvm._ffi.function import _init_api
import topi
from ..op import op as _reg


@_reg.register_compute("simulated_quantize")
def simulated_quantize_compute(attrs, inputs, output_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 5
assert attrs.sign
assert attrs.rounding == "round"

data, scale, bit, clip_min, clip_max = inputs

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("simulated_quantize", _reg.schedule_injective)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE)

_init_api("relay._quantize", __name__)
96 changes: 96 additions & 0 deletions python/tvm/relay/quantize/annotate_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import absolute_import
from .. import expr as _expr
from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite
from .quantize import attach_simulated_quantize, get_current_qconfig


def _forward_op(ref_call, args):
return _expr.Call(ref_call.op, args,
ref_call.attrs, ref_call.type_args)


@register_qfield_rewrite("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
cfg = get_current_qconfig()
if cfg.counter < cfg.skip_k_conv:
cfg.counter += 1
return None
cfg.counter += 1

lhs, rhs = new_args
if isinstance(lhs, QFieldExpr):
lhs_expr = lhs.expr
if lhs.kind != QFieldKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QFieldKind.INPUT)
else:
lhs_expr = attach_simulated_quantize(lhs, QFieldKind.INPUT)

assert not isinstance(rhs, QFieldExpr)
rhs_expr = attach_simulated_quantize(rhs, QFieldKind.WEIGHT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)


@register_qfield_rewrite("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
cfg = get_current_qconfig()
if cfg.counter <= cfg.skip_k_conv:
return None

lhs, rhs = new_args
if not isinstance(lhs, QFieldExpr) and not isinstance(rhs, QFieldExpr):
return None
elif lhs.kind == QFieldKind.ACTIVATION and not isinstance(rhs, QFieldExpr):
lhs_expr = attach_simulated_quantize(lhs.expr, QFieldKind.INPUT)
rhs_expr = attach_simulated_quantize(rhs.expr, QFieldKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
else:
raise ValueError


@register_qfield_rewrite("add")
def add_rewrite(ref_call, new_args, ctx):
cfg = get_current_qconfig()
if cfg.counter <= cfg.skip_k_conv:
return None

lhs, rhs = new_args
if not isinstance(lhs, QFieldExpr) and not isinstance(rhs, QFieldExpr):
# on float domain
return None
elif not isinstance(lhs, QFieldExpr) and rhs.kind == QFieldKind.ACTIVATION:
# addition for residual, but lhs are calculated on real domain
lhs_expr = attach_simulated_quantize(lhs, QFieldKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs.expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
elif lhs.kind == QFieldKind.ACTIVATION and not isinstance(rhs, QFieldExpr):
# the most common situation, e.g. bias add in bn
rhs_expr = attach_simulated_quantize(rhs, QFieldKind.WEIGHT)
expr = _forward_op(ref_call, [lhs.expr, rhs_expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
elif lhs.kind == QFieldKind.INPUT and rhs.kind == QFieldKind.ACTIVATION:
# addition for residual, but lhs are muti-refered
expr = _forward_op(ref_call, [lhs.expr, rhs.expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
elif lhs.kind == QFieldKind.ACTIVATION and rhs.kind == QFieldKind.ACTIVATION:
# addition for residual
expr = _forward_op(ref_call, [lhs.expr, rhs.expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
else:
raise ValueError


@register_qfield_rewrite("nn.relu")
def relu_rewrite(ref_call, new_args, ctx):
cfg = get_current_qconfig()
if cfg.counter <= cfg.skip_k_conv:
return None

x = new_args[0]
if isinstance(x, QFieldExpr):
expr = _forward_op(ref_call, [x.expr])
return QFieldExpr(expr, x.kind)
else:
return None
36 changes: 36 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ namespace relay {
LOG(FATAL) << "unknown data type " << type; \
}

/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
inline GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}

/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
Expand Down Expand Up @@ -327,6 +345,24 @@ inline Expr LeftShift(Expr x, Expr nbit) {
}


inline Expr Power(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("power");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}


inline Expr RightShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("right_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}


inline Expr LeftShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("left_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}


inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
Expand Down
70 changes: 70 additions & 0 deletions tests/python/quantize/test_pass_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay import quantize as qtz

def test_simulated_quantize():
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
scale = relay.var("scale")
bit = relay.var("bit")
clip_min = relay.var("clip_min")
clip_max = relay.var("clip_max")
out = qtz.simulated_quantize(data, scale, bit, clip_min, clip_max, sign=True, rounding='round', kind=0)
out = relay.ir_pass.infer_type(out)
assert out.checked_type == out.args[0].checked_type
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "int32")
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[4].checked_type == relay.ty.TensorType(tuple(), "float32")

def test_annotate_pass():
n, c, h, w = 1, 3, 224, 224
def residual_block(data, cnt):
# conv
weight = relay.var("conv_weight" + str(cnt))
conv = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
scale = relay.var('bn_scale' + str(cnt), relay.TensorType((1, c, 1, 1)))
bias = relay.var('bn_bias' + str(cnt), relay.TensorType((1, c, 1, 1)))
bn = conv * scale + bias
relu = relay.nn.relu(bn)
return relu

data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
out = data
for i in range(1):
out = residual_block(out, i)

out = relay.ir_pass.infer_type(out)
out = relay.ir_pass.simplify_inference(out)

def make_dataset(args, size=100):
def create_arr(var):
ttype = var.type_annotation
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
return tvm.ndarray.array(np_arr)

params = {}
for arg in args:
if arg.name_hint == 'data':
dataset = [{'data': create_arr(arg)} for _ in range(size)]
else:
params[arg.name_hint] = create_arr(arg)
return dataset, params

args = relay.ir_pass.free_vars(out)
graph = relay.Function(args, out)
dataset, params = make_dataset(args, 10)

with qtz.qconfig(skip_k_conv=0, global_scale=4.0):
print('before:')
print(graph.astext(show_meta_data=False))

qgraph = qtz.quantize(graph, params)
print('after quantize:')
print(qgraph.astext(show_meta_data=False))
print('\n')


if __name__ == "__main__":
test_simulated_quantize()
test_annotate_pass()
14 changes: 14 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from numbers import Integral

import ctypes
import tvm
from . import tag

Expand Down Expand Up @@ -285,3 +286,16 @@ def get_max_power2_factor(n, max_value=None):
x *= 2
n /= 2
return x


@tvm.register_func("print_tensor")
def print_tensor_impl(x, y, msg):
print(ctypes.string_at(msg))
print(x.asnumpy())
x.copyto(y)


def print_tensor(x, msg=""):
return tvm.extern(x.shape, [x], lambda ins, outs:
tvm.call_packed("print_tensor", ins[0], outs[0], msg),
name='print_tensor')

0 comments on commit fad10c4

Please sign in to comment.