forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[QUANTIZE] Quantization implementation. (apache#32)
- Loading branch information
1 parent
11a7245
commit fad10c4
Showing
7 changed files
with
242 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import absolute_import | ||
from .. import expr as _expr | ||
from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite | ||
from .quantize import attach_simulated_quantize, get_current_qconfig | ||
|
||
|
||
def _forward_op(ref_call, args): | ||
return _expr.Call(ref_call.op, args, | ||
ref_call.attrs, ref_call.type_args) | ||
|
||
|
||
@register_qfield_rewrite("nn.conv2d") | ||
def conv2d_rewrite(ref_call, new_args, ctx): | ||
cfg = get_current_qconfig() | ||
if cfg.counter < cfg.skip_k_conv: | ||
cfg.counter += 1 | ||
return None | ||
cfg.counter += 1 | ||
|
||
lhs, rhs = new_args | ||
if isinstance(lhs, QFieldExpr): | ||
lhs_expr = lhs.expr | ||
if lhs.kind != QFieldKind.INPUT: | ||
lhs_expr = attach_simulated_quantize(lhs_expr, QFieldKind.INPUT) | ||
else: | ||
lhs_expr = attach_simulated_quantize(lhs, QFieldKind.INPUT) | ||
|
||
assert not isinstance(rhs, QFieldExpr) | ||
rhs_expr = attach_simulated_quantize(rhs, QFieldKind.WEIGHT) | ||
|
||
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
|
||
|
||
@register_qfield_rewrite("multiply") | ||
def multiply_rewrite(ref_call, new_args, ctx): | ||
cfg = get_current_qconfig() | ||
if cfg.counter <= cfg.skip_k_conv: | ||
return None | ||
|
||
lhs, rhs = new_args | ||
if not isinstance(lhs, QFieldExpr) and not isinstance(rhs, QFieldExpr): | ||
return None | ||
elif lhs.kind == QFieldKind.ACTIVATION and not isinstance(rhs, QFieldExpr): | ||
lhs_expr = attach_simulated_quantize(lhs.expr, QFieldKind.INPUT) | ||
rhs_expr = attach_simulated_quantize(rhs.expr, QFieldKind.WEIGHT) | ||
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
else: | ||
raise ValueError | ||
|
||
|
||
@register_qfield_rewrite("add") | ||
def add_rewrite(ref_call, new_args, ctx): | ||
cfg = get_current_qconfig() | ||
if cfg.counter <= cfg.skip_k_conv: | ||
return None | ||
|
||
lhs, rhs = new_args | ||
if not isinstance(lhs, QFieldExpr) and not isinstance(rhs, QFieldExpr): | ||
# on float domain | ||
return None | ||
elif not isinstance(lhs, QFieldExpr) and rhs.kind == QFieldKind.ACTIVATION: | ||
# addition for residual, but lhs are calculated on real domain | ||
lhs_expr = attach_simulated_quantize(lhs, QFieldKind.INPUT) | ||
expr = _forward_op(ref_call, [lhs_expr, rhs.expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
elif lhs.kind == QFieldKind.ACTIVATION and not isinstance(rhs, QFieldExpr): | ||
# the most common situation, e.g. bias add in bn | ||
rhs_expr = attach_simulated_quantize(rhs, QFieldKind.WEIGHT) | ||
expr = _forward_op(ref_call, [lhs.expr, rhs_expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
elif lhs.kind == QFieldKind.INPUT and rhs.kind == QFieldKind.ACTIVATION: | ||
# addition for residual, but lhs are muti-refered | ||
expr = _forward_op(ref_call, [lhs.expr, rhs.expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
elif lhs.kind == QFieldKind.ACTIVATION and rhs.kind == QFieldKind.ACTIVATION: | ||
# addition for residual | ||
expr = _forward_op(ref_call, [lhs.expr, rhs.expr]) | ||
return QFieldExpr(expr, QFieldKind.ACTIVATION) | ||
else: | ||
raise ValueError | ||
|
||
|
||
@register_qfield_rewrite("nn.relu") | ||
def relu_rewrite(ref_call, new_args, ctx): | ||
cfg = get_current_qconfig() | ||
if cfg.counter <= cfg.skip_k_conv: | ||
return None | ||
|
||
x = new_args[0] | ||
if isinstance(x, QFieldExpr): | ||
expr = _forward_op(ref_call, [x.expr]) | ||
return QFieldExpr(expr, x.kind) | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay import quantize as qtz | ||
|
||
def test_simulated_quantize(): | ||
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32")) | ||
scale = relay.var("scale") | ||
bit = relay.var("bit") | ||
clip_min = relay.var("clip_min") | ||
clip_max = relay.var("clip_max") | ||
out = qtz.simulated_quantize(data, scale, bit, clip_min, clip_max, sign=True, rounding='round', kind=0) | ||
out = relay.ir_pass.infer_type(out) | ||
assert out.checked_type == out.args[0].checked_type | ||
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32") | ||
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "int32") | ||
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") | ||
assert out.args[4].checked_type == relay.ty.TensorType(tuple(), "float32") | ||
|
||
def test_annotate_pass(): | ||
n, c, h, w = 1, 3, 224, 224 | ||
def residual_block(data, cnt): | ||
# conv | ||
weight = relay.var("conv_weight" + str(cnt)) | ||
conv = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) | ||
scale = relay.var('bn_scale' + str(cnt), relay.TensorType((1, c, 1, 1))) | ||
bias = relay.var('bn_bias' + str(cnt), relay.TensorType((1, c, 1, 1))) | ||
bn = conv * scale + bias | ||
relu = relay.nn.relu(bn) | ||
return relu | ||
|
||
data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) | ||
out = data | ||
for i in range(1): | ||
out = residual_block(out, i) | ||
|
||
out = relay.ir_pass.infer_type(out) | ||
out = relay.ir_pass.simplify_inference(out) | ||
|
||
def make_dataset(args, size=100): | ||
def create_arr(var): | ||
ttype = var.type_annotation | ||
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype) | ||
return tvm.ndarray.array(np_arr) | ||
|
||
params = {} | ||
for arg in args: | ||
if arg.name_hint == 'data': | ||
dataset = [{'data': create_arr(arg)} for _ in range(size)] | ||
else: | ||
params[arg.name_hint] = create_arr(arg) | ||
return dataset, params | ||
|
||
args = relay.ir_pass.free_vars(out) | ||
graph = relay.Function(args, out) | ||
dataset, params = make_dataset(args, 10) | ||
|
||
with qtz.qconfig(skip_k_conv=0, global_scale=4.0): | ||
print('before:') | ||
print(graph.astext(show_meta_data=False)) | ||
|
||
qgraph = qtz.quantize(graph, params) | ||
print('after quantize:') | ||
print(qgraph.astext(show_meta_data=False)) | ||
print('\n') | ||
|
||
|
||
if __name__ == "__main__": | ||
test_simulated_quantize() | ||
test_annotate_pass() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters