Skip to content

Commit

Permalink
Refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Aug 12, 2019
1 parent d482512 commit 46c9667
Show file tree
Hide file tree
Showing 18 changed files with 1,249 additions and 996 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ namespace relay {
*/
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);

TVM_DLL bool ConstantCheck(const Expr& e);

/*!
* \brief Compare two expressions for structural equivalence.
*
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
}
};

/*!
* \brief Annotate an expression to be cast into specific data type.
*/
struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The data type denoted to be cast.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
4 changes: 4 additions & 0 deletions python/tvm/relay/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def check_kind(t, mod=None):
return _analysis.check_kind(t)


def check_constant(expr):
return _analysis.check_constant(expr)


def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
from __future__ import absolute_import as _abs

from .quantize import *
from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
164 changes: 39 additions & 125 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import warnings

import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import annotate_context
from ..._ffi.function import register_func
from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..._ffi.function import register_func
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op


@_reg.register_compute("relay.op.annotation.simulated_quantize")
Expand Down Expand Up @@ -75,12 +76,6 @@ def __init__(self, expr, kind):
_quantize.make_annotate_expr, expr, kind)


def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)


def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
Expand Down Expand Up @@ -113,7 +108,7 @@ def frewrite_with_guard(ref_call, new_args, ctx):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
_reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard

return _register(frewrite) if frewrite is not None else _register
Expand All @@ -135,17 +130,17 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data

actx = annotate_context()
qctx = quantize_context()
key = tuple([data, kind, sign, rounding])
if key in actx.qnode_map:
return actx.qnode_map[key]
if key in qctx.qnode_map:
return qctx.qnode_map[key]

dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
actx.qnode_map[key] = qnode
qctx.qnode_map[key] = qnode
return qnode

register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
Expand All @@ -163,13 +158,8 @@ def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
return None
actx.count_conv2d()
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
Expand All @@ -185,21 +175,12 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


def check_to_skip():
"""Check the index of conv2d layer to decide whether to skip the current operator."""
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if annotate_context().conv2d_counter() - 1 in skipped_indices:
return True
return False


# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
Expand All @@ -219,7 +200,7 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
Expand All @@ -243,13 +224,14 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None and rhs_kind is None:
# trivial case
return None

if lhs_kind is None and rhs_kind is not None:
Expand All @@ -260,11 +242,10 @@ def add_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
if _analysis.check_constant(rhs_expr):
# - introduced by batch_norm: add(out, const)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
Expand All @@ -274,7 +255,6 @@ def add_rewrite(ref_call, new_args, ctx):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
Expand All @@ -285,24 +265,9 @@ def add_rewrite(ref_call, new_args, ctx):
raise ValueError()


@register_annotate_function("stop_fusion")
def stop_fusion_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None

ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
ret_expr = _forward_op(ref_call, [ret_expr])
return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)


def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
Expand All @@ -322,7 +287,7 @@ def identity_rewrite(ref_call, new_args, ctx):

def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

expr, x_kind = _get_expr_kind(new_args[0])
Expand All @@ -339,14 +304,14 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)


@register_annotate_function("annotation.force_cast")
def force_cast_rewrite(ref_call, new_args, ctx):
@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
if check_to_skip():
return None

expr, x_kind = _get_expr_kind(new_args[0])

if quantize_context().check_to_skip(ref_call):
return expr

if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
Expand All @@ -359,7 +324,7 @@ def force_cast_rewrite(ref_call, new_args, ctx):
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None

input_tuple = new_args[0]
Expand All @@ -377,69 +342,18 @@ def concatenate_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


# Graph rewrite function registration for VTA target
def register_vta_rewrite(op_name, frewrite=None, level=10):
def _register(func):
return _op.op._Register(op_name, "FQVTARewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_annotate_function("nn.global_avg_pool2d")
def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for global_avg_pool2d for stopping quantize"""
if quantize_context().check_to_skip(ref_call):
return None

expr, x_kind = _get_expr_kind(new_args[0])

@register_relay_node
class QVTAExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_vta_expr, expr)

def realize(self):
return _quantize.temp_expr_realize(self)


def vta_expr_check(expr):
if isinstance(expr, QVTAExpr):
return True, expr.expr
return False, expr


@register_vta_rewrite("nn.conv2d")
def conv2d_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d for VTA target"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
return None
actx.count_conv2d()

data_cond, data = vta_expr_check(new_args[0])
kernel_cond, kernel = vta_expr_check(new_args[1])

assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QVTAExpr(ret)


def identity_vta_rewrite(ref_call, new_args, ctx):
cond, expr = vta_expr_check(new_args[0])
if cond:
return QVTAExpr(_forward_op(ref_call, [expr]))
return None

register_vta_rewrite("nn.relu", identity_vta_rewrite)
register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)


@register_vta_rewrite("add")
def add_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for ewise add for VTA target"""
lhs_cond, lhs = vta_expr_check(new_args[0])
rhs_cond, rhs = vta_expr_check(new_args[1])
if lhs_cond and rhs_cond:
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
return None
if x_kind is None:
return None
expr = _forward_op(ref_call, [new_args[0].realize()])

# stop quantize after global_avg_pool2d
quantize_context().stop_quantize()
return expr
Loading

0 comments on commit 46c9667

Please sign in to comment.