-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPT] Low-bit Quantization #2116
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#pylint: disable=wildcard-import, redefined-builtin | ||
"""Automatic quantization utilities.""" | ||
from __future__ import absolute_import as _abs | ||
|
||
from .quantize import * | ||
from ._annotate import register_annotate_function |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
#pylint: disable=unused-argument | ||
"""Internal module for registering attribute for annotation.""" | ||
from __future__ import absolute_import | ||
|
||
from . import _quantize | ||
from .quantize import QAnnotateKind, current_qconfig | ||
from .quantize import _conv_counter, _set_conv_counter | ||
from .. import expr as _expr | ||
from .. import op as _op | ||
from ..base import register_relay_node | ||
from ..._ffi.function import register_func | ||
|
||
|
||
@register_relay_node | ||
class QAnnotateExpr(_expr.TempExpr): | ||
"""A special kind of Expr for Annotating. | ||
|
||
Parameters | ||
--------- | ||
expr: Expr | ||
the original relay ir expr. | ||
|
||
kind: QAnnotateKind | ||
the kind of annotation field. | ||
""" | ||
def __init__(self, expr, kind): | ||
self.__init_handle_by_constructor__( | ||
_quantize.make_annotate_expr, expr, kind) | ||
|
||
|
||
def _forward_op(ref_call, args): | ||
"""forward the operator of ref_call with provided arguments""" | ||
return _expr.Call( | ||
ref_call.op, args, ref_call.attrs, ref_call.type_args) | ||
|
||
|
||
def _get_expr_kind(anno): | ||
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr""" | ||
if isinstance(anno, QAnnotateExpr): | ||
return anno.expr, anno.kind | ||
return anno, None | ||
|
||
|
||
def register_annotate_function(op_name, frewrite=None, level=10): | ||
"""register a rewrite function for operator, used by annotation. | ||
|
||
Parameters | ||
--------- | ||
op_name: str | ||
The name of operation | ||
|
||
frewrite : function, optional | ||
The function to be registered. | ||
|
||
level : int, optional | ||
The priority level | ||
""" | ||
def default_rewrite(ref_call, new_args, ctx): | ||
# recover from QAnnotateExpr | ||
args = [_get_expr_kind(x)[0] for x in new_args] | ||
return _forward_op(ref_call, args) | ||
|
||
def _register(func): | ||
"""internal register function""" | ||
def frewrite_with_guard(ref_call, new_args, ctx): | ||
if not current_qconfig().guard(ref_call): | ||
return default_rewrite(ref_call, new_args, ctx) | ||
return func(ref_call, new_args, ctx) | ||
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) | ||
return frewrite_with_guard | ||
|
||
return _register(frewrite) if frewrite is not None else _register | ||
|
||
|
||
@register_func("relay.quantize.attach_simulated_quantize") | ||
def attach_simulated_quantize(data, kind, sign=True, rounding="round"): | ||
"""Attach a simulated quantize operation after input data expr. | ||
|
||
Parameters | ||
--------- | ||
data: Expr | ||
the original data expr. | ||
|
||
kind: QAnnotateKind | ||
the kind of annotation field. | ||
""" | ||
dom_scale = _expr.var("dom_scale") | ||
clip_min = _expr.var("clip_min") | ||
clip_max = _expr.var("clip_max") | ||
return _quantize.simulated_quantize( | ||
data, dom_scale, clip_min, clip_max, kind, sign, rounding) | ||
|
||
|
||
@register_annotate_function("nn.conv2d") | ||
def conv2d_rewrite(ref_call, new_args, ctx): | ||
"""Rewrite function for conv2d. Lhs of conv will be quantized to | ||
input field, and rhs of conv will be quantized to weight field. | ||
Output would be in activation field""" | ||
cnt = _conv_counter() | ||
if cnt < current_qconfig().skip_k_conv: | ||
_set_conv_counter(cnt + 1) | ||
return None | ||
_set_conv_counter(cnt + 1) | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) | ||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) | ||
|
||
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT: | ||
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can / should we avoid duplicated quantization in parallel branches? e.g
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should and we can, possibly via memoization. In theory forward rewrite already memoize, if there is any problem, please provide a minimum test case and let us double check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a test case, to reproduce, you need to set opt level
Result:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ZihengJiang @merrymercy @vinx13 can you look into this? let us open this testcase as an issue to be fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because quantization of data happens during rewrite of conv2d, so this won't be memorized. We need some message passing to quantize data during forward rewrite of data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not have things to do with ANF. The problem is that if two conv refers to the same input and they want to run the same transformation f on that input, there will be two such f. One solution is to build a generic common subexpression combination(elimination) path to create a concise dag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we have seen in @vinx13 's test case, there're three There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as i understand, this PR already do that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought so, and i think it should be configured by disabling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ZihengJiang What's the status of this issue of parallel branches? Will it be future work? |
||
|
||
assert rhs_kind is None | ||
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) | ||
|
||
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) | ||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@register_annotate_function("multiply") | ||
def multiply_rewrite(ref_call, new_args, ctx): | ||
"""Rewrite function for multiply.""" | ||
if _conv_counter() <= current_qconfig().skip_k_conv: | ||
return None | ||
|
||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) | ||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) | ||
|
||
if lhs_kind is None and rhs_kind is None: | ||
return None | ||
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None: | ||
# quantize lhs to INPUT field | ||
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) | ||
# quantize rhs to WEIGHT field | ||
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) | ||
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) | ||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) | ||
raise ValueError | ||
|
||
|
||
@register_annotate_function("add") | ||
def add_rewrite(ref_call, new_args, ctx): | ||
"""Rewrite function for add.""" | ||
if _conv_counter() <= current_qconfig().skip_k_conv: | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return None | ||
|
||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) | ||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) | ||
|
||
if lhs_kind is None and rhs_kind is None: | ||
return None | ||
if lhs_kind is None and rhs_kind is not None: | ||
# quantize lhs to INPUT field if it is normal expression | ||
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) | ||
if lhs_kind is not None and rhs_kind is None: | ||
if isinstance(rhs_expr, _expr.Constant): | ||
# quantize rhs to WEIGHT field if it is Constant | ||
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) | ||
else: | ||
# quantize rhs to INPUT field if it is not Constant | ||
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) | ||
|
||
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) | ||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) | ||
|
||
|
||
def identity_rewrite(ref_call, new_args, ctx): | ||
"""Simply forward the original operation""" | ||
if _conv_counter() <= current_qconfig().skip_k_conv: | ||
return None | ||
|
||
x_expr, x_kind = _get_expr_kind(new_args[0]) | ||
if x_kind is None: | ||
return None | ||
|
||
ret_expr = _forward_op(ref_call, [x_expr]) | ||
return QAnnotateExpr(ret_expr, x_kind) | ||
|
||
|
||
register_annotate_function("nn.relu", identity_rewrite) | ||
register_annotate_function("strided_slice", identity_rewrite) | ||
register_annotate_function("nn.avg_pool2d", identity_rewrite) | ||
|
||
|
||
def pool2d_rewrite(ref_call, new_args, ctx): | ||
"""Rewrite function for max pool2d""" | ||
if _conv_counter() <= current_qconfig().skip_k_conv: | ||
return None | ||
expr, x_kind = _get_expr_kind(new_args[0]) | ||
|
||
if x_kind is None: | ||
return None | ||
if x_kind == QAnnotateKind.ACTIVATION: | ||
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) | ||
expr = _forward_op(ref_call, [expr]) | ||
return QAnnotateExpr(expr, QAnnotateKind.INPUT) | ||
|
||
|
||
register_annotate_function("nn.max_pool2d", pool2d_rewrite) | ||
|
||
|
||
@register_annotate_function("concatenate") | ||
def concatenate_rewrite(ref_call, new_args, ctx): | ||
"""Rewrite function for concatenate""" | ||
if _conv_counter() <= current_qconfig().skip_k_conv: | ||
return None | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
input_tuple = new_args[0] | ||
expr_list = [_get_expr_kind(x)[0] for x in input_tuple] | ||
kind_list = [_get_expr_kind(x)[1] for x in input_tuple] | ||
|
||
# make sure the inputs of concatenate are all normal | ||
# expression or annotate expression | ||
if kind_list[0] is None: | ||
for k in kind_list: | ||
assert k is None | ||
return None | ||
for k in kind_list: | ||
assert k is not None | ||
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)]) | ||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#pylint: disable=unused-argument | ||
"""Internal module for quantization.""" | ||
from __future__ import absolute_import | ||
import topi | ||
from tvm._ffi.function import _init_api | ||
from ..op import op as _reg | ||
|
||
|
||
@_reg.register_compute("simulated_quantize") | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def simulated_quantize_compute(attrs, inputs, out_type, target): | ||
"""Compiler for simulated_quantize.""" | ||
assert len(inputs) == 4 | ||
assert attrs.sign | ||
assert attrs.rounding == "round" | ||
|
||
data, scale, clip_min, clip_max = inputs | ||
|
||
# simulate rounding error | ||
scaled_data = topi.divide(data, scale) | ||
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) | ||
round_data = topi.round(clipped_data) | ||
|
||
# recover data | ||
rdata = topi.multiply(round_data, scale) | ||
return [rdata] | ||
|
||
|
||
_reg.register_schedule("simulated_quantize", _reg.schedule_injective) | ||
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE) | ||
|
||
_init_api("relay._quantize", __name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this API changes recently? It breaks some codes @tqchen