diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 3672a22847db..8c14f024fd5c 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -51,6 +51,18 @@ namespace relay { */ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); +/*! + * \brief Check whether an expression is constant. + * + * If the inputs of an expression are all constant, it means the expression + * itself is constant also. + * + * \param e the expression. + * + * \return whether the expression is constant. + */ +TVM_DLL bool ConstantCheck(const Expr& e); + /*! * \brief Compare two expressions for structural equivalence. * diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 29750c576b36..fd21db5a9c14 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -44,6 +44,19 @@ struct OnDeviceAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Annotate an expression to be cast into specific data type. + */ +struct CastHintAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") { + TVM_ATTR_FIELD(dtype) + .describe( + "The data type denoted to be cast."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py index 91b53bb5f196..7372fcdadd17 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis.py @@ -91,6 +91,22 @@ def check_kind(t, mod=None): return _analysis.check_kind(t) +def check_constant(expr): + """Check whether an expression is constant + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + result : bool + Whether the expression is constant. + """ + return _analysis.check_constant(expr) + + def free_vars(expr): """Get free Vars from expression expr in Post DFS order. diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index a9e7b40b039e..29b68950fa42 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -19,5 +19,6 @@ from __future__ import absolute_import as _abs from .quantize import * +from ._partition import register_partition_function from ._annotate import register_annotate_function from .kl_divergence import kl_divergence_scale diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index e03eaab507ad..55f3597881e7 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -20,14 +20,15 @@ import warnings import topi -from . import _quantize -from .quantize import QAnnotateKind, current_qconfig -from .quantize import annotate_context +from ..._ffi.function import register_func from .. import expr as _expr +from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg from ..base import register_relay_node -from ..._ffi.function import register_func +from . import _quantize +from .quantize import QAnnotateKind, current_qconfig, quantize_context +from .quantize import _forward_op @_reg.register_compute("relay.op.annotation.simulated_quantize") @@ -75,12 +76,6 @@ def __init__(self, expr, kind): _quantize.make_annotate_expr, expr, kind) -def _forward_op(ref_call, args): - """forward the operator of ref_call with provided arguments""" - return _expr.Call( - ref_call.op, args, ref_call.attrs, ref_call.type_args) - - def _get_expr_kind(anno): """Get the expression and QAnnotateKind from QAnnotateExpr or Expr""" if isinstance(anno, QAnnotateExpr): @@ -113,7 +108,7 @@ def frewrite_with_guard(ref_call, new_args, ctx): if not current_qconfig().guard(ref_call): return default_rewrite(ref_call, new_args, ctx) return func(ref_call, new_args, ctx) - _op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) + _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) return frewrite_with_guard return _register(frewrite) if frewrite is not None else _register @@ -135,17 +130,17 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data - actx = annotate_context() + qctx = quantize_context() key = tuple([data, kind, sign, rounding]) - if key in actx.qnode_map: - return actx.qnode_map[key] + if key in qctx.qnode_map: + return qctx.qnode_map[key] dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") qnode = _quantize.simulated_quantize( data, dom_scale, clip_min, clip_max, kind, sign, rounding) - actx.qnode_map[key] = qnode + qctx.qnode_map[key] = qnode return qnode register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @@ -163,13 +158,8 @@ def conv2d_rewrite(ref_call, new_args, ctx): """Rewrite function for conv2d. Lhs of conv will be quantized to input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" - actx = annotate_context() - if current_qconfig().skip_conv_layers is not None: - skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if actx.conv2d_counter() in skipped_indices: - actx.count_conv2d() - return None - actx.count_conv2d() + if quantize_context().check_to_skip(ref_call): + return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -185,21 +175,12 @@ def conv2d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) -def check_to_skip(): - """Check the index of conv2d layer to decide whether to skip the current operator.""" - if current_qconfig().skip_conv_layers is not None: - skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if annotate_context().conv2d_counter() - 1 in skipped_indices: - return True - return False - - # TODO(tmoreau89,ziheng) need to include an option to turn off dense quant # @register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -219,7 +200,7 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -243,13 +224,14 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) if lhs_kind is None and rhs_kind is None: + # trivial case return None if lhs_kind is None and rhs_kind is not None: @@ -260,11 +242,10 @@ def add_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.INPUT) if lhs_kind is not None and rhs_kind is None: - if isinstance(rhs_expr, _expr.Constant): - # quantize rhs to WEIGHT field if it is Constant + if _analysis.check_constant(rhs_expr): + # - introduced by batch_norm: add(out, const) rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) else: - # quantize rhs to INPUT field if it is not Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -274,7 +255,6 @@ def add_rewrite(ref_call, new_args, ctx): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION: - # quantize rhs to INPUT field if both lhs and rhs are ACTIVATION rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -285,24 +265,9 @@ def add_rewrite(ref_call, new_args, ctx): raise ValueError() -@register_annotate_function("stop_fusion") -def stop_fusion_rewrite(ref_call, new_args, ctx): - """Rewrite function for add.""" - if check_to_skip(): - return None - - x_expr, x_kind = _get_expr_kind(new_args[0]) - if x_kind is None: - return None - - ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT) - ret_expr = _forward_op(ref_call, [ret_expr]) - return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT) - - def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None x_expr, x_kind = _get_expr_kind(new_args[0]) @@ -322,7 +287,7 @@ def identity_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None expr, x_kind = _get_expr_kind(new_args[0]) @@ -339,14 +304,14 @@ def pool2d_rewrite(ref_call, new_args, ctx): register_annotate_function("nn.max_pool2d", pool2d_rewrite) -@register_annotate_function("annotation.force_cast") -def force_cast_rewrite(ref_call, new_args, ctx): +@register_annotate_function("annotation.cast_hint") +def cast_hint_rewrite(ref_call, new_args, ctx): """Rewrite function to force cast""" - if check_to_skip(): - return None - expr, x_kind = _get_expr_kind(new_args[0]) + if quantize_context().check_to_skip(ref_call): + return expr + if x_kind is None: return new_args[0] if x_kind == QAnnotateKind.ACTIVATION: @@ -359,7 +324,7 @@ def force_cast_rewrite(ref_call, new_args, ctx): @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - if check_to_skip(): + if quantize_context().check_to_skip(ref_call): return None input_tuple = new_args[0] @@ -377,69 +342,18 @@ def concatenate_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) -# Graph rewrite function registration for VTA target -def register_vta_rewrite(op_name, frewrite=None, level=10): - def _register(func): - return _op.op._Register(op_name, "FQVTARewrite", func, level) - return _register(frewrite) if frewrite is not None else _register +@register_annotate_function("nn.global_avg_pool2d") +def global_avg_pool2d_rewrite(ref_call, new_args, ctx): + """Rewrite function for global_avg_pool2d for stopping quantize""" + if quantize_context().check_to_skip(ref_call): + return None + expr, x_kind = _get_expr_kind(new_args[0]) -@register_relay_node -class QVTAExpr(_expr.TempExpr): - def __init__(self, expr): - self.__init_handle_by_constructor__( - _quantize.make_vta_expr, expr) - - def realize(self): - return _quantize.temp_expr_realize(self) - - -def vta_expr_check(expr): - if isinstance(expr, QVTAExpr): - return True, expr.expr - return False, expr - - -@register_vta_rewrite("nn.conv2d") -def conv2d_vta_rewrite(ref_call, new_args, ctx): - """Rewrite function for conv2d for VTA target""" - actx = annotate_context() - if current_qconfig().skip_conv_layers is not None: - skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] - if actx.conv2d_counter() in skipped_indices: - actx.count_conv2d() - return None - actx.count_conv2d() - - data_cond, data = vta_expr_check(new_args[0]) - kernel_cond, kernel = vta_expr_check(new_args[1]) - - assert not kernel_cond - if data_cond: - data = new_args[0].realize() - ret = _forward_op(ref_call, [data, kernel]) - return QVTAExpr(ret) - - -def identity_vta_rewrite(ref_call, new_args, ctx): - cond, expr = vta_expr_check(new_args[0]) - if cond: - return QVTAExpr(_forward_op(ref_call, [expr])) - return None - -register_vta_rewrite("nn.relu", identity_vta_rewrite) -register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite) - - -@register_vta_rewrite("add") -def add_vta_rewrite(ref_call, new_args, ctx): - """Rewrite function for ewise add for VTA target""" - lhs_cond, lhs = vta_expr_check(new_args[0]) - rhs_cond, rhs = vta_expr_check(new_args[1]) - if lhs_cond and rhs_cond: - lhs = new_args[0].realize() - rhs = new_args[1].realize() - return _forward_op(ref_call, [lhs, rhs]) - elif lhs_cond and not rhs_cond: - return QVTAExpr(_forward_op(ref_call, [lhs, rhs])) - return None + if x_kind is None: + return None + expr = _forward_op(ref_call, [new_args[0].realize()]) + + # stop quantize after global_avg_pool2d + quantize_context().stop_quantize() + return expr diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py new file mode 100644 index 000000000000..597c55c44481 --- /dev/null +++ b/python/tvm/relay/quantize/_partition.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument,inconsistent-return-statements +"""Internal module for registering attribute for annotation.""" +from __future__ import absolute_import + +from ... import target as _target +from .. import expr as _expr +from .. import analysis as _analysis +from ..base import register_relay_node +from ..op import op as _reg +from . import _quantize +from .quantize import _forward_op + +def register_partition_function(op_name, frewrite=None, level=10): + def _register(func): + return _reg._Register(op_name, "FQPartitionRewrite", func, level) + return _register(frewrite) if frewrite is not None else _register + + +@register_relay_node +class QPartitionExpr(_expr.TempExpr): + def __init__(self, expr): + self.__init_handle_by_constructor__( + _quantize.make_partition_expr, expr) + + +def partition_expr_check(expr): + if isinstance(expr, QPartitionExpr): + return True, expr.expr + return False, expr + + +@register_partition_function("nn.conv2d") +def conv2d_partition_function(ref_call, new_args, ctx): + """Rewrite function for conv2d for partition""" + data_cond, data = partition_expr_check(new_args[0]) + kernel_cond, kernel = partition_expr_check(new_args[1]) + + assert not kernel_cond + if data_cond: + data = new_args[0].realize() + ret = _forward_op(ref_call, [data, kernel]) + return QPartitionExpr(ret) + + +def identity_partition_function(ref_call, new_args, ctx): + cond, expr = partition_expr_check(new_args[0]) + if cond: + return QPartitionExpr(_forward_op(ref_call, [expr])) + return None + +register_partition_function("clip", identity_partition_function) +register_partition_function("nn.relu", identity_partition_function) +register_partition_function("nn.max_pool2d", identity_partition_function) + + +def add_partition_generic(ref_call, new_args, ctx): + """Rewrite function for ewise add for partition for generic devices""" + lhs_cond, lhs = partition_expr_check(new_args[0]) + rhs_cond, rhs = partition_expr_check(new_args[1]) + if lhs_cond and rhs_cond: + # - introduced by ResNet, when for the first residual connection + # ... + # %0 = nn.conv2d(%data, %meta[relay.Constant]) + # %1 = add(%0, %meta[relay.Constant]) + # %2 = nn.relu(%1) + # %3 = nn.max_pool2d(%2) + # ... + # %9 = nn.conv2d(%8, %meta[relay.Constant]) + # %10 = add(%9, %meta[relay.Constant]) + # %11 = add(%3, %10) <- need to insert annotations for %3, %10 + # ... + lhs = new_args[0].realize() + rhs = new_args[1].realize() + return _forward_op(ref_call, [lhs, rhs]) + elif not lhs_cond and rhs_cond: + # - introduced by residual connection in ResNet + # ... + # %13 = nn.conv2d(%12, %meta[relay.Constant]) + # %14 = add(%13, %meta[relay.Constant]) + # %15 = annotation.cast_hint(%15, 'int8') + # %16 = annotation.stop_fusion(%16) + # %17 = add(%5, %16) + # %18 = nn.relu(%17) + # ... + # %24 = nn.conv2d(%23, %meta[relay.Constant]) + # %25 = add(%24, %meta[relay.Constant]) + # %26 = add(%18, %25) <- need to insert annotations for %25 + # ... + rhs = new_args[1].realize() + return _forward_op(ref_call, [lhs, rhs]) + elif lhs_cond and not rhs_cond: + if _analysis.check_constant(rhs): + # - introduced by batch_norm: add(out, bias) + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + # - introduced by residual connection in MobileNetV2 + # ... + # %81 = add(%80, meta[relay.Constant]) + # %82 = annotation.cast_hint(%81, 'int8') + # %83 = annotation.stop_fusion(%82) + # %84 = add(%79, %83) + # ... + # %96 = nn.conv2d(%94, %meta[relay.Constant]) + # %96 = add(%95, %meta[relay.Constant]) + # %97 = add(%96, %84) <- need to insert annotations for %96 + # ... + lhs = new_args[0].realize() + return _forward_op(ref_call, [lhs, rhs]) + elif not lhs_cond and not rhs_cond: + # trivial case + return None + else: + raise ValueError + + +# TODO(ziheng) enhance `register_partition_function` to dispatch +# for target automatically +@register_partition_function("add") +def add_partition_function(ref_call, new_args, ctx): + """Rewrite function for ewise add for partition""" + if 'cuda' in _target.current_target().keys: + #TODO(wuwei/ziheng) cuda specific rules + return add_partition_generic(ref_call, new_args, ctx) + return add_partition_generic(ref_call, new_args, ctx) + + +@register_partition_function("multiply") +def multiply_partition_function(ref_call, new_args, ctx): + """Rewrite function for ewise add for partition""" + lhs_cond, lhs = partition_expr_check(new_args[0]) + rhs_cond, rhs = partition_expr_check(new_args[1]) + if lhs_cond: + # introduced by bn: multiply(out, scale) + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + assert (not lhs_cond) and (not rhs_cond) + return None diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 07d4d9d25e01..adde2058267a 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -50,6 +50,12 @@ def kind2str(kind): return str_map[kind] +def _forward_op(ref_call, args): + """forward the operator of ref_call with provided arguments""" + return _expr.Call( + ref_call.op, args, ref_call.attrs, ref_call.type_args) + + @register_relay_node("relay.quantize.QConfig") class QConfig(NodeBase): """Configure the quantization behavior by setting config variables. @@ -74,8 +80,8 @@ class QConfig(NodeBase): "dtype_activation": "int32", "global_scale": 8.0, "skip_conv_layers": [0], + "do_simulation": False, "round_for_shift": True, - "store_lowbit_output": True, "debug_enabled_ops": None, } @@ -92,6 +98,7 @@ def __init__(self, handle): self.handle = handle def guard(self, ref_call): + """Return true if op is enabled, otherwise return false""" op_name = ref_call.op.name if self.debug_enabled_ops is not None: name_list = [x.value for x in self.debug_enabled_ops] @@ -126,9 +133,7 @@ def current_qconfig(): """Get the current quantization configuration.""" return _quantize._GetCurrentQConfig() -# TODO(tmoreau89, ZihengJiang) the skip parameters are -# hacky - we should explore a more future-proof way to -# skip operators based on pattern matching + def qconfig(**kwargs): """Configure the quantization behavior by setting config variables. @@ -142,15 +147,14 @@ def qconfig(**kwargs): skip_conv_layers: list Specifying which layers to be skipped. Provide a list of indices - that indicate which conv2d layers to leave untouched. + that indicate which conv2d layers to leave untouched. Start from 0. + + do_simulation: boolean + Whether to do simulation with float operation only. round_for_shift: boolean Whether to add bias for rounding during shift. - store_lowbit_output: boolean - Whether to store low-bit integer back as output before dequantizing. - Some accelerators need this, e.g. VTA. - debug_enabled_ops: None or list of str Partially quantize specified operators for debugging. The default value is None, which means will try to call all operartors' annotate rewrite @@ -166,35 +170,79 @@ def qconfig(**kwargs): return _make.node("relay.quantize.QConfig", **node_args) -class AnnotateContext(object): - """A global singleton annotate scope""" +class QuantizeContext(object): + """An internal used global context object for annotation, + for putting some state variables like `conv2d_counter`.""" Current = None def __init__(self): self.qnode_map = dict() self._conv2d_counter = 0 + self._stop_quantize = False + + def check_to_skip(self, ref_call): + """Check the index of conv2d layer to decide whether to + skip the current operator.""" + if self._stop_quantize: + return True + + if current_qconfig().skip_conv_layers is not None: + # check skip conv layers + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if self._conv2d_counter in skipped_indices: + if ref_call.op.name == 'nn.conv2d': + self._conv2d_counter += 1 + return True + if ref_call.op.name == 'nn.conv2d': + self._conv2d_counter += 1 + + return False + + def stop_quantize(self): + self._stop_quantize = True + + def reset(self): + self._conv2d_counter = 0 + self._stop_quantize = False def __enter__(self): - self._conv2d_counter = 0 + self.reset() return self - def conv2d_counter(self): - """Get the counter for conv2d.""" - return self._conv2d_counter - - def count_conv2d(self): - """Increase the value of the conv2d counter by one.""" - self._conv2d_counter += 1 - def __exit__(self, ptype, value, traceback): pass -def annotate_context(): +def quantize_context(): """Get the global singleton scope""" - if AnnotateContext.Current is None: - AnnotateContext.Current = AnnotateContext() - return AnnotateContext.Current + if QuantizeContext.Current is None: + QuantizeContext.Current = QuantizeContext() + return QuantizeContext.Current + + +def partition(): + """Partition graph into small low-precision sections by `cast_hint` and + `stop_fusion`. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass for VTA rewrite. + """ + return _quantize.QuantizePartition() + + +def annotate(): + """Given a float32 graph, this pass will rewrite the graph and return + a graph which simulates the error brought by the current quantization + scheme. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass for quantization annotation. + """ + return _quantize.QuantizeAnnotate() def collect_stats(graph): @@ -300,20 +348,8 @@ def _make_const(val): const_params[nclip_max] = _make_const((valid_range - 1)) _analysis.post_order_visit(graph, visit_func) - return _expr.bind(graph, const_params) - - -def annotate(): - """Given a float32 graph, this pass will rewrite the graph and return - a graph which simulates the error brought by the current quantization - scheme. - - Returns - ------- - ret: tvm.relay.Pass - The registered pass for quantization annotation. - """ - return _quantize.QuantizeAnnotate() + ret = _expr.bind(graph, const_params) + return ret def realize(): @@ -330,17 +366,6 @@ def realize(): return _quantize.QuantizeRealize() -def rewrite_for_vta(): - """Performs rewriting for VTA target. - - Returns - ------- - ret: tvm.relay.Pass - The registered pass for VTA rewrite. - """ - return _quantize.QuantizeRewriteForVTA() - - def _bind_params(func, params): """Bind the params to the expression. """ @@ -362,6 +387,25 @@ def _bind_params(func, params): return _expr.bind(func, bind_dict) +def prerequisite_optimize(graph, params=None): + """ Prerequisite optimization passes for quantization. Perform + "SimplifyInference", "FoldScaleAxis", "FoldConstant", and + "CanonicalizeOps" optimization before quantization. """ + optimize = _transform.Sequential([_transform.SimplifyInference(), + _transform.FoldConstant(), + _transform.FoldScaleAxis(), + _transform.CanonicalizeOps(), + _transform.FoldConstant()]) + + if params: + graph = _bind_params(graph, params) + + mod = _module.Module.from_expr(graph) + with _transform.PassContext(opt_level=3): + mod = optimize(mod) + return mod["main"] + + def quantize(graph, params=None, dataset=None): """ The quantization procedure. Before running the three main procedure of quantization, "annotate", "calibrate" and "realize" @@ -385,33 +429,23 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - if params: - graph = _bind_params(graph, params) + graph = prerequisite_optimize(graph, params) mod = _module.Module.from_expr(graph) - # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and - # "CanonicalizeOps" optimization before quantization. - optimize = _transform.Sequential([_transform.SimplifyInference(), - _transform.FoldConstant(), - _transform.FoldScaleAxis(), - _transform.CanonicalizeOps(), - _transform.FoldConstant()]) - calibrate_pass = _transform.function_pass(calibrate, opt_level=1, name="QuantizeCalibrate") - # Quantize pass list - quant_passes = [annotate(), - calibrate_pass, - realize(), - _transform.FoldConstant()] - if current_qconfig().store_lowbit_output: - quant_passes = [rewrite_for_vta()] + quant_passes + quant_passes = [partition(), + annotate(), + calibrate_pass] + if not current_qconfig().do_simulation: + quant_passes.append(realize()) + quant_passes.append(_transform.FoldConstant()) quantize_seq = _transform.Sequential(quant_passes) with _transform.PassContext(opt_level=3, required_pass=["QuantizeAnnotate", "QuantizeCalibrate", "QuantizeRealize"]): - mod = optimize(mod) - mod = quantize_seq(mod) + with quantize_context(): + mod = quantize_seq(mod) return mod["main"] diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index a5ade5bde304..eeacc6cbf999 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -83,13 +83,18 @@ TVM_ADD_FILELINE) return {topi::identity(inputs[0])}; }); -Expr ForceCast(Expr data) { - static const Op& op = Op::Get("annotation.force_cast"); - return CallNode::make(op, {data}, Attrs{}, {}); +// relay.annotation.cast_hint +TVM_REGISTER_NODE_TYPE(CastHintAttrs); + +Expr CastHint(Expr data, DataType dtype) { + auto attrs = make_node(); + attrs->dtype = dtype; + static const Op& op = Op::Get("annotation.cast_hint"); + return CallNode::make(op, {data}, Attrs{attrs}, {}); } -RELAY_REGISTER_OP("annotation.force_cast") -.describe(R"code(Annotate an expression to force a cast.)code" +RELAY_REGISTER_OP("annotation.cast_hint") +.describe(R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input data.") diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 7b896a8d0f7f..eba77c7241a7 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -66,6 +66,13 @@ class ConstantChecker : private ExprVisitor { } }; +bool ConstantCheck(const Expr& e) { + return ConstantChecker().Check(e); +} + +TVM_REGISTER_API("relay._analysis.check_constant") +.set_body_typed(ConstantCheck); + // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 3ccfff0c3463..18e5df3e04df 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -420,7 +421,7 @@ Expr MakeStridedSlice(Expr data, Array begin, Array end, Array Expr StopFusion(Expr data); -Expr ForceCast(Expr data); +Expr CastHint(Expr data, DataType dtype); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc new file mode 100644 index 000000000000..d8a7a0f24818 --- /dev/null +++ b/src/relay/pass/quantize/annotate.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * + * \file annotate.cc + * + * \brief Annotating the graph with simulated quantize operators. + */ + +#include +#include +#include "./quantize.h" + +namespace tvm { +namespace relay { +namespace quantize { + +using namespace relay::transform; + +class QAnnotateExpr; +class QAnnotateExprNode : public TempExprNode { + public: + Expr expr; + QAnnotateKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + v->Visit("kind", &kind); + } + + TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind); + + Expr Realize() const final; + + static constexpr const char* _type_key = "relay.QAnnotateExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); + + +Expr QAnnotateExprNode::Realize() const { + return expr; +} + +QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { + auto rnode = make_node(); + rnode->expr = expr; + rnode->kind = kind; + return QAnnotateExpr(rnode); +} + +TVM_REGISTER_API("relay._quantize.make_annotate_expr") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = QAnnotateExprNode::make(args[0], + static_cast(args[1].operator int())); + }); + + +Pass QuantizeAnnotate() { + // TODO(tvm-teams): since partition has added cast_hint in different + // branches, try to remove this in the future. + std::function fmulti_ref = [](const Expr& e) { + if (e->derived_from()) { + const auto* n = e.as(); + CHECK(n); + const PackedFunc* f = + runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + Expr ret = (*f)(n->expr, static_cast(kQInput)); + return static_cast(QAnnotateExprNode::make(ret, kQInput)); + } + return e; + }; + + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); + auto new_params = func->params; + for (const auto& x : FreeVars(func)) { + new_params.push_back(x); + } + return FunctionNode::make(new_params, + func->body, + func->ret_type, + func->type_params, + func->attrs); + }; + return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); +} + +TVM_REGISTER_API("relay._quantize.QuantizeAnnotate") +.set_body_typed(QuantizeAnnotate); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc new file mode 100644 index 000000000000..3f46cf2f227e --- /dev/null +++ b/src/relay/pass/quantize/partition.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * + * \file partition.cc + * + * \brief Partition a graph into sections for quantization. + */ + +#include +#include "../pattern_util.h" +#include "./quantize.h" + +namespace tvm { +namespace relay { +namespace quantize { + +using namespace relay::transform; + +class QPartitionExpr; +class QPartitionExprNode : public TempExprNode { + public: + /*! \brief The original expression */ + Expr expr; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + } + + TVM_DLL static QPartitionExpr make(Expr expr); + + Expr Realize() const final; + + static constexpr const char* _type_key = "relay.QPartitionExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr); + + +Expr QPartitionExprNode::Realize() const { + // insert cast hint and stop fusion + const QConfig& cfg = QConfig::Current(); + Expr ret = CastHint(this->expr, cfg->dtype_input); + return StopFusion(ret); +} + +QPartitionExpr QPartitionExprNode::make(Expr expr) { + auto rnode = make_node(); + rnode->expr = expr; + return QPartitionExpr(rnode); +} + +TVM_REGISTER_API("relay._quantize.make_partition_expr") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = QPartitionExprNode::make(args[0]); + }); + +Pass QuantizePartition() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + auto ret = Downcast( + ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); + return ret; + }; + return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); +} + +TVM_REGISTER_API("relay._quantize.QuantizePartition") +.set_body_typed(QuantizePartition); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index 6cffc2053e5c..c6d71ba0ed32 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -26,17 +26,9 @@ * for compression and acceleration. */ #include -#include -#include -#include #include #include -#include -#include -#include #include -#include -#include "../pattern_util.h" #include "./quantize.h" @@ -44,8 +36,6 @@ namespace tvm { namespace relay { namespace quantize { -using namespace relay::transform; - TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array& types, @@ -91,490 +81,6 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize") }); -// ============= -// annotate pass - -Expr QAnnotateExprNode::Realize() const { - const auto& cfg = QConfig::Current(); - if (cfg->store_lowbit_output) { - // store low bit output back for VTA - const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - return (*f)(this->expr, static_cast(kQInput)); - } else { - return expr; - } -} - -QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { - auto rnode = make_node(); - rnode->expr = expr; - rnode->kind = kind; - return QAnnotateExpr(rnode); -} - -TVM_REGISTER_API("relay._quantize.make_annotate_expr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = QAnnotateExprNode::make(args[0], - static_cast(args[1].operator int())); - }); - - -// ============= -// realize pass - -Expr QRealizeIntExprNode::Realize() const { - const auto& cfg = QConfig::Current(); - Expr data = this->data; - if (cfg->store_lowbit_output) { - data = Cast(data, cfg->dtype_input); - } - // dequantize - data = Cast(data, Float(32)); - data = Multiply(data, this->dom_scale); - return data; -} - -QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { - NodePtr n = make_node(); - n->data = std::move(data); - n->dom_scale = std::move(dom_scale); - n->dtype = std::move(dtype); - return QRealizeIntExpr(n); -} - - -inline Expr ForwardOp(const Call& ref_call, const Array& args) { - return CallNode::make(ref_call->op, - args, ref_call->attrs, ref_call->type_args); -} - - -/* calculate `data * s1 / s2`, use shift if possible */ -inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { - // here we assume the dtype of data is dtype activation - if (s1 == s2) return data; - - float factor = s1 / s2; - float shift_factor = std::log2(factor); - CHECK_GT(shift_factor, 0); - if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(dtype, - static_cast(shift_factor))); - } else if (static_cast(factor) == factor) { - return Multiply(data, MakeConstantScalar(dtype, factor)); - } else { - data = Cast(data, Float(32)); - data = Multiply(data, MakeConstantScalar(Float(32), factor)); - return Cast(Round(data), dtype); - } -} - -Expr QuantizeRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - // do not handle data type cast - const auto param = ref_call->attrs.as(); - CHECK_EQ(param->rounding, "round"); - - Expr dom_scale = new_args[1]; - Expr clip_min = new_args[2]; - Expr clip_max = new_args[3]; - - float dom_scale_imm = GetScalarFromConstant(dom_scale); - float clip_min_imm = GetScalarFromConstant(clip_min); - float clip_max_imm = GetScalarFromConstant(clip_max); - - // x * idom_scale = y * odom_scale - // => y = x * idom_scale / odom_scale - if (const auto* n = new_args[0].as()) { - // int32->int8 - Expr data = n->data; - float idom_scale_imm = GetScalarFromConstant(n->dom_scale); - float odom_scale_imm = GetScalarFromConstant(dom_scale); - if (idom_scale_imm == odom_scale_imm) { - // same domain scale, only clip - data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); - } - - float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - CHECK_NE(shift_nbit, 0); - if (static_cast(shift_nbit) == shift_nbit) { - if (shift_nbit > 0) { - // use right shift - if (cfg->round_for_shift) { - float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(round_bias))); - } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); - } else { - data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); - } - data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); - } else { - // float computation - data = Cast(data, Float(32)); - Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); - Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); - } - } - - // quantize from real - CHECK(!new_args[0]->derived_from()); - Expr data = new_args[0]; - Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); - Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); -} - -Expr FoldConstantOpt(const Expr& expr) { - auto mod = ModuleNode::FromExpr(expr); - mod = transform::FoldConstant()(mod); - auto entry_func = mod->Lookup("main"); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - -RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.set_attr("FQRealizeRewrite", QuantizeRealize); - - -Expr Conv2dRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 2); - if (!new_args[0]->derived_from() && !new_args[1]->derived_from()) { - return Expr(nullptr); - } - const auto* lhs = new_args[0].as(); - CHECK(lhs); - const auto* rhs = new_args[1].as(); - CHECK(rhs); - - Expr ldata = lhs->data; - if (lhs->dtype != cfg->dtype_input) { - ldata = Cast(ldata, cfg->dtype_input); - } - Expr rdata = Cast(rhs->data, cfg->dtype_weight); - - const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); - *attrs = *ref_attrs; - DataType out_dtype = cfg->dtype_activation; - attrs->out_dtype = out_dtype; - - Expr ret = CallNode::make(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); - Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); -} - -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FQRealizeRewrite", Conv2dRealize); - - -Expr DenseRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 2); - if (!new_args[0]->derived_from() || !new_args[1]->derived_from()) { - return Expr(nullptr); - } - const auto* lhs = new_args[0].as(); - const auto* rhs = new_args[1].as(); - - Expr ldata = lhs->data; - if (lhs->dtype != cfg->dtype_input) { - ldata = Cast(ldata, cfg->dtype_input); - } - Expr rdata = Cast(rhs->data, cfg->dtype_weight); - - const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); - *attrs = *ref_attrs; - DataType out_dtype = cfg->dtype_activation; - attrs->out_dtype = out_dtype; - - Expr ret = CallNode::make(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); - Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); -} - -RELAY_REGISTER_OP("nn.dense") -.set_attr("FQRealizeRewrite", DenseRealize); - - -Expr MulRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 2); - if (new_args[0].as() && new_args[1].as()) { - // execute the operation with activation data type. - const auto* lhs = new_args[0].as(); - const auto* rhs = new_args[1].as(); - Expr ldata = lhs->data; - Expr rdata = rhs->data; - - DataType dtype = cfg->dtype_activation; - if (lhs->dtype != dtype) { - ldata = Cast(ldata, dtype); - } - if (rhs->dtype != dtype) { - rdata = Cast(rdata, dtype); - } - - Expr ret = ForwardOp(ref_call, {ldata, rdata}); - Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); - } - CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("multiply") -.set_attr("FQRealizeRewrite", MulRealize); - - -float ChooseDomScale(const std::vector& nptrs) { - if (nptrs.size() == 2) { - // x = a * s1, y = b * s2 - // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 - // = (a + b * s2 / s1) * s1, if s2 > s1 - float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); - float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); - return s1 > s2 ? s2 : s1; - } else { - const QConfig& cfg = QConfig::Current(); - float scale = cfg->global_scale; - return scale / std::pow(2.0, cfg->nbit_activation - 1); - } -} - - -/* \brief Unify the dom scale of arguments */ -Array UnifyDTypeScale(const Array& ref_args, const Array& args, - DataType* dtype_ptr, Expr* scale_ptr) { - static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); - const QConfig& cfg = QConfig::Current(); - - std::vector nptrs; - Array ret; - for (auto arg : args) { - const auto* nptr = arg.as(); - CHECK(nptr); - nptrs.push_back(nptr); - ret.push_back(nptr->data); - } - - // unify the data type - CHECK_EQ(ref_args.size(), args.size()); - DataType dtype; - if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { - dtype = cfg->dtype_input; - } else { - dtype = cfg->dtype_activation; - } - for (size_t i = 0; i < ret.size(); ++i) { - auto ref_arg = ref_args[i].as(); - if (nptrs[i]->dtype != dtype) { - ret.Set(i, Cast(ret[i], dtype)); - } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && - ref_arg->attrs.as()->kind == kQInput) { - auto new_arg = Cast(ret[i], cfg->dtype_input); - new_arg = StopFusion(new_arg); - ret.Set(i, Cast(new_arg, dtype)); - } - } - - // unify the dom_scale - float s = ChooseDomScale(nptrs); - Expr dom_scale = MakeConstantScalar(Float(32), s); - for (size_t i = 0; i < ret.size(); ++i) { - float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); - } - - *dtype_ptr = dtype; - *scale_ptr = dom_scale; - return ret; -} - -Expr AddRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 2); - if (new_args[0].as() && new_args[1].as()) { - DataType dtype; - Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); - Expr ret = ForwardOp(ref_call, ret_args); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); - } - - CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("add") -.set_attr("FQRealizeRewrite", AddRealize); - -Expr ClipRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); - double dom_scale = GetScalarFromConstant(n->dom_scale); - attrs->a_min = ref_attrs->a_min / dom_scale; - attrs->a_max = ref_attrs->a_max / dom_scale; - - Expr ret = CallNode::make(ref_call->op, - {n->data}, Attrs(attrs), ref_call->type_args); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("clip") -.set_attr("FQRealizeRewrite", ClipRealize); - - -Expr ConcatenateRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 1); - CHECK_EQ(ref_call->args.size(), 1); - - const auto* tuple = new_args[0].as(); - const auto* ref_tuple = ref_call->args[0].as(); - CHECK(tuple); - CHECK(ref_tuple); - const Array& arr = tuple->fields; - const Array& ref_arr = ref_tuple->fields; - - if (arr[0].as()) { - DataType dtype; - Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); - Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); - } else { - for (auto arg : new_args) { - CHECK(!arg->derived_from()); - } - return Expr(nullptr); - } -} - -RELAY_REGISTER_OP("concatenate") -.set_attr("FQRealizeRewrite", ConcatenateRealize); - - -/* \brief forward the original operator */ -Expr IdentityRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr ret = ForwardOp(ref_call, {n->data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("nn.relu") -.set_attr("FQRealizeRewrite", IdentityRealize); - -RELAY_REGISTER_OP("strided_slice") -.set_attr("FQRealizeRewrite", IdentityRealize); - -RELAY_REGISTER_OP("annotation.stop_fusion") -.set_attr("FQRealizeRewrite", IdentityRealize); - -/* \brief for unary operators which requantize its input to dtype_nbit */ -Expr CastDtypeInputRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr data = Cast(n->data, cfg->dtype_input); - Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("nn.max_pool2d") -.set_attr("FQRealizeRewrite", CastDtypeInputRealize); - - -Expr AvgPoolRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr data = n->data; - if (n->dtype != cfg->dtype_activation) { - data = Cast(n->data, cfg->dtype_activation); - } - Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("nn.avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); - -Expr ForceCastRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr ret = Cast(n->data, cfg->dtype_input); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} - -RELAY_REGISTER_OP("annotation.force_cast") -.set_attr("FQRealizeRewrite", ForceCastRealize); - -TVM_REGISTER_API("relay._quantize.realize") -.set_body_typed([](const Expr& e) { - Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr); - return ret; -}); - - -// ============= -// qconfig - -QConfig qconfig() { - return QConfig(make_node()); -} - /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ @@ -584,7 +90,7 @@ struct TVMQConfigThreadLocalEntry { std::stack context_stack; TVMQConfigThreadLocalEntry() : - default_config(qconfig()) { + default_config(make_node()) { } }; @@ -620,8 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; - p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; p->stream << ")"; }); @@ -635,95 +141,6 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope") TVM_REGISTER_API("relay._quantize._ExitQConfigScope") .set_body_typed(QConfig::ExitQConfigScope); -Pass QuantizeAnnotate() { - std::function fmulti_ref = [](const Expr& e) { - if (e->derived_from()) { - const auto* n = e.as(); - CHECK(n); - const PackedFunc* f = - runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); - } - return e; - }; - - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); - auto new_params = func->params; - for (const auto& x : FreeVars(func)) { - new_params.push_back(x); - } - return FunctionNode::make(new_params, - func->body, - func->ret_type, - func->type_params, - func->attrs); - }; - return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); -} - -TVM_REGISTER_API("relay._quantize.QuantizeAnnotate") -.set_body_typed(QuantizeAnnotate); - -Pass QuantizeRealizePass() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); - }; - return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); -} - -TVM_REGISTER_API("relay._quantize.QuantizeRealize") -.set_body_typed(QuantizeRealizePass); - -Pass QuantizeRewriteForVTAPass() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr)); - }; - return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {}); -} - -TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA") -.set_body_typed(QuantizeRewriteForVTAPass); - -// ============= -// Insert stop_fusion for vta. - - -Expr QVTAExprNode::Realize() const { - Expr ret = ForceCast(this->expr); - return StopFusion(ret); -} - -QVTAExpr QVTAExprNode::make(Expr expr) { - auto rnode = make_node(); - rnode->expr = expr; - return QVTAExpr(rnode); -} - -TVM_REGISTER_API("relay._quantize.make_vta_expr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = QVTAExprNode::make(args[0]); - }); - -TVM_REGISTER_API("relay._quantize.make_stop_fusion") -.set_body_typed([] (const Expr& expr) { - return StopFusion(expr); -}); - -TVM_REGISTER_API("relay._quantize.temp_expr_realize") -.set_body_typed([] (const Expr& expr) { - const QVTAExprNode* n = expr.as(); - CHECK(n); - return n->Realize(); -}); - - } // namespace quantize } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 4965a706b4b4..4c153d522d69 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -59,104 +59,8 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { } }; -/*! - * \brief TempExpr used during annotate forward rewrite. - */ -class QAnnotateExpr; -/*! - * \brief TempExprNode used during annotate forward rewrite. - */ -class QAnnotateExprNode : public TempExprNode { - public: - /*! \brief The original expression */ - Expr expr; - /*! \brief The kind of annotate field */ - QAnnotateKind kind; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("expr", &expr); - v->Visit("kind", &kind); - } - - TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind); - - Expr Realize() const final; - - static constexpr const char* _type_key = "relay.QAnnotateExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); -}; - -RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); - - -/*! - * \brief TempExpr used to insert `force_cast` for VTA. - */ -class QVTAExpr; -/*! - * \brief TempExprNode used to insert `force_cast` for VTA. - */ -class QVTAExprNode : public TempExprNode { - public: - /*! \brief The original expression */ - Expr expr; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("expr", &expr); - } - - TVM_DLL static QVTAExpr make(Expr expr); - - Expr Realize() const final; - - static constexpr const char* _type_key = "relay.QVTAExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode); -}; - -RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr); - - -/*! \brief TempExpr used during realize forward rewrite. */ -class QRealizeExpr; -/*! \brief TempExpr representing integer. */ -class QRealizeIntExpr; - -class QRealizeExprNode : public TempExprNode { - public: - /*! \brief The original expression */ - Expr data; - static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; - TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); -}; - -RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); - - -class QRealizeIntExprNode : public QRealizeExprNode { - public: - Expr dom_scale; - /*! \brief current data type */ - DataType dtype; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("data", &data); - v->Visit("dom_scale", &dom_scale); - v->Visit("dtype", &dtype); - } - - Expr Realize() const final; - - TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); - - static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); -}; - -RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); - class QConfig; - /*! * \brief Container for build configuration options */ @@ -170,8 +74,8 @@ class QConfigNode : public Node { DataType dtype_activation = Int(32); double global_scale = 8.0; Array skip_conv_layers = Array(NodePtr(nullptr)); + bool do_simulation = false; bool round_for_shift = true; - bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); void VisitAttrs(AttrVisitor* v) final { @@ -183,8 +87,8 @@ class QConfigNode : public Node { v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); v->Visit("skip_conv_layers", &skip_conv_layers); + v->Visit("do_simulation", &do_simulation); v->Visit("round_for_shift", &round_for_shift); - v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); } @@ -250,12 +154,6 @@ struct QConfigContext { } }; -/*! -* \brief Construct a BuildConfig containing a new BuildConfigNode -* \return The new BuildConfig -*/ -TVM_DLL QConfig qconfig(); - } // namespace quantize } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc new file mode 100644 index 000000000000..e4bc63adc6a0 --- /dev/null +++ b/src/relay/pass/quantize/realize.cc @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * + * \file realize.cc + * + * \brief Realizing the simulated graph into real low-precision + * graph. + */ + +#include +#include +#include +#include "./quantize.h" +#include "../pattern_util.h" + +namespace tvm { +namespace relay { +namespace quantize { + +using namespace relay::transform; + +class QRealizeExpr; +class QRealizeIntExpr; + +class QRealizeExprNode : public TempExprNode { + public: + Expr data; + static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; + TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); + + +class QRealizeIntExprNode : public QRealizeExprNode { + public: + Expr dom_scale; + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("dom_scale", &dom_scale); + v->Visit("dtype", &dtype); + } + + Expr Realize() const final; + + TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); + + static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); +}; + +RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); + + +Expr QRealizeIntExprNode::Realize() const { + Expr data = this->data; + // dequantize + data = Cast(data, Float(32)); + data = Multiply(data, this->dom_scale); + return data; +} + +QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { + NodePtr n = make_node(); + n->data = std::move(data); + n->dom_scale = std::move(dom_scale); + n->dtype = std::move(dtype); + return QRealizeIntExpr(n); +} + + +inline Expr ForwardOp(const Call& ref_call, const Array& args) { + return CallNode::make(ref_call->op, + args, ref_call->attrs, ref_call->type_args); +} + + +/* calculate `data * s1 / s2`, use shift if possible */ +inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { + // here we assume the dtype of data is dtype activation + if (s1 == s2) return data; + + float factor = s1 / s2; + float shift_factor = std::log2(factor); + CHECK_GT(shift_factor, 0); + if (static_cast(shift_factor) == shift_factor) { + return LeftShift(data, MakeConstantScalar(dtype, + static_cast(shift_factor))); + } else if (static_cast(factor) == factor) { + return Multiply(data, MakeConstantScalar(dtype, factor)); + } else { + LOG(FATAL) << "fall back to float computation"; + data = Cast(data, Float(32)); + data = Multiply(data, MakeConstantScalar(Float(32), factor)); + return Cast(Round(data), dtype); + } +} + +Expr QuantizeRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + // do not handle data type cast + const auto param = ref_call->attrs.as(); + CHECK_EQ(param->rounding, "round"); + + Expr dom_scale = new_args[1]; + Expr clip_min = new_args[2]; + Expr clip_max = new_args[3]; + + float dom_scale_imm = GetScalarFromConstant(dom_scale); + float clip_min_imm = GetScalarFromConstant(clip_min); + float clip_max_imm = GetScalarFromConstant(clip_max); + + // x * idom_scale = y * odom_scale + // => y = x * idom_scale / odom_scale + if (const auto* n = new_args[0].as()) { + // int32->int8 + Expr data = n->data; + float idom_scale_imm = GetScalarFromConstant(n->dom_scale); + float odom_scale_imm = GetScalarFromConstant(dom_scale); + if (idom_scale_imm == odom_scale_imm) { + // same domain scale, only clip + data = Clip(data, clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + } + + float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); + CHECK_GT(shift_nbit, 0); + if (static_cast(shift_nbit) == shift_nbit) { + // use right shift + if (cfg->round_for_shift) { + float round_bias = std::pow(2.0, shift_nbit - 1); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + } + data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); + data = Clip(data, clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + } else { + // float computation + data = Cast(data, Float(32)); + Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); + Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + } + } + + // quantize from real + CHECK(!new_args[0]->derived_from()); + Expr data = new_args[0]; + Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); + Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); +} + +Expr FoldConstantOpt(const Expr& expr) { + auto mod = ModuleNode::FromExpr(expr); + mod = transform::FoldConstant()(mod); + auto entry_func = mod->Lookup("main"); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + +RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") +.set_attr("FQRealizeRewrite", QuantizeRealize); + + +Expr Conv2dRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (!new_args[0]->derived_from() && !new_args[1]->derived_from()) { + return Expr(nullptr); + } + const auto* lhs = new_args[0].as(); + CHECK(lhs); + const auto* rhs = new_args[1].as(); + CHECK(rhs); + + Expr ldata = lhs->data; + if (lhs->dtype != cfg->dtype_input) { + ldata = Cast(ldata, cfg->dtype_input); + } + Expr rdata = Cast(rhs->data, cfg->dtype_weight); + + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + *attrs = *ref_attrs; + DataType out_dtype = cfg->dtype_activation; + attrs->out_dtype = out_dtype; + + Expr ret = CallNode::make(ref_call->op, + {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); + return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); +} + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FQRealizeRewrite", Conv2dRealize); + + +Expr DenseRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (!new_args[0]->derived_from() || !new_args[1]->derived_from()) { + return Expr(nullptr); + } + const auto* lhs = new_args[0].as(); + const auto* rhs = new_args[1].as(); + + Expr ldata = lhs->data; + if (lhs->dtype != cfg->dtype_input) { + ldata = Cast(ldata, cfg->dtype_input); + } + Expr rdata = Cast(rhs->data, cfg->dtype_weight); + + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + *attrs = *ref_attrs; + DataType out_dtype = cfg->dtype_activation; + attrs->out_dtype = out_dtype; + + Expr ret = CallNode::make(ref_call->op, + {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); + return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); +} + +RELAY_REGISTER_OP("nn.dense") +.set_attr("FQRealizeRewrite", DenseRealize); + + +Expr MulRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (new_args[0].as() && new_args[1].as()) { + // execute the operation with activation data type. + const auto* lhs = new_args[0].as(); + const auto* rhs = new_args[1].as(); + Expr ldata = lhs->data; + Expr rdata = rhs->data; + + DataType dtype = cfg->dtype_activation; + if (lhs->dtype != dtype) { + ldata = Cast(ldata, dtype); + } else { + CHECK_EQ(lhs->dtype, dtype); + } + if (rhs->dtype != dtype) { + rdata = Cast(rdata, dtype); + } else { + CHECK_EQ(rhs->dtype, dtype); + } + + Expr ret = ForwardOp(ref_call, {ldata, rdata}); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("multiply") +.set_attr("FQRealizeRewrite", MulRealize); + + +float ChooseDomScale(const std::vector& nptrs) { + if (nptrs.size() == 2) { + // x = a * s1, y = b * s2 + // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 + // = (a + b * s2 / s1) * s1, if s2 > s1 + float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); + float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); + return s1 > s2 ? s2 : s1; + } else { + const QConfig& cfg = QConfig::Current(); + float scale = cfg->global_scale; + return scale / std::pow(2.0, cfg->nbit_activation - 1); + } +} + + +/* \brief Unify the dom scale of arguments */ +Array UnifyDTypeScale(const Array& ref_args, const Array& args, + DataType* dtype_ptr, Expr* scale_ptr) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + const QConfig& cfg = QConfig::Current(); + + std::vector nptrs; + Array ret; + for (auto arg : args) { + const auto* nptr = arg.as(); + CHECK(nptr); + nptrs.push_back(nptr); + ret.push_back(nptr->data); + } + + // unify the data type + CHECK_EQ(ref_args.size(), args.size()); + DataType dtype; + + if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { + dtype = cfg->dtype_input; + } else { + dtype = cfg->dtype_activation; + } + for (size_t i = 0; i < ret.size(); ++i) { + auto ref_arg = ref_args[i].as(); + if (nptrs[i]->dtype != dtype) { + ret.Set(i, Cast(ret[i], dtype)); + } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && + ref_arg->attrs.as()->kind == kQInput) { + auto new_arg = Cast(ret[i], cfg->dtype_input); + new_arg = StopFusion(new_arg); + ret.Set(i, Cast(new_arg, dtype)); + } + } + + // unify the dom_scale + float s = ChooseDomScale(nptrs); + Expr dom_scale = MakeConstantScalar(Float(32), s); + for (size_t i = 0; i < ret.size(); ++i) { + float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); + ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); + } + + *dtype_ptr = dtype; + *scale_ptr = dom_scale; + return ret; +} + +Expr AddRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 2); + if (new_args[0].as() && new_args[1].as()) { + DataType dtype; + Expr dom_scale; + Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); + Expr ret = ForwardOp(ref_call, ret_args); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } + + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("add") +.set_attr("FQRealizeRewrite", AddRealize); + +Expr ClipRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + double dom_scale = GetScalarFromConstant(n->dom_scale); + attrs->a_min = ref_attrs->a_min / dom_scale; + attrs->a_max = ref_attrs->a_max / dom_scale; + + Expr ret = CallNode::make(ref_call->op, + {n->data}, Attrs(attrs), ref_call->type_args); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("clip") +.set_attr("FQRealizeRewrite", ClipRealize); + + +Expr ConcatenateRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + CHECK_EQ(ref_call->args.size(), 1); + + const auto* tuple = new_args[0].as(); + const auto* ref_tuple = ref_call->args[0].as(); + CHECK(tuple); + CHECK(ref_tuple); + const Array& arr = tuple->fields; + const Array& ref_arr = ref_tuple->fields; + + if (arr[0].as()) { + DataType dtype; + Expr dom_scale; + Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); + Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } else { + for (auto arg : new_args) { + CHECK(!arg->derived_from()); + } + return Expr(nullptr); + } +} + +RELAY_REGISTER_OP("concatenate") +.set_attr("FQRealizeRewrite", ConcatenateRealize); + + +/* \brief forward the original operator */ +Expr IdentityRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr ret = ForwardOp(ref_call, {n->data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FQRealizeRewrite", IdentityRealize); + +RELAY_REGISTER_OP("strided_slice") +.set_attr("FQRealizeRewrite", IdentityRealize); + +RELAY_REGISTER_OP("annotation.stop_fusion") +.set_attr("FQRealizeRewrite", IdentityRealize); + +/* \brief for unary operators which requantize its input to dtype_nbit */ +Expr CastDtypeInputRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr data = Cast(n->data, cfg->dtype_input); + Expr ret = ForwardOp(ref_call, {data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.max_pool2d") +.set_attr("FQRealizeRewrite", CastDtypeInputRealize); + + +Expr AvgPoolRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr data = n->data; + if (n->dtype != cfg->dtype_activation) { + data = Cast(n->data, cfg->dtype_activation); + } + Expr ret = ForwardOp(ref_call, {data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.avg_pool2d") +.set_attr("FQRealizeRewrite", AvgPoolRealize); + +Expr CastHintRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const auto param = ref_call->attrs.as(); + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr ret = Cast(n->data, param->dtype); + return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("annotation.cast_hint") +.set_attr("FQRealizeRewrite", CastHintRealize); + +Pass QuantizeRealizePass() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); + }; + return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); +} + +TVM_REGISTER_API("relay._quantize.QuantizeRealize") +.set_body_typed(QuantizeRealizePass); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py new file mode 100644 index 000000000000..f047952f3e6b --- /dev/null +++ b/tests/python/nightly/quantization/test_quantization_accuracy.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from collections import namedtuple +import tvm +from tvm import relay +from tvm.relay import quantize as qtz +import mxnet as mx +from mxnet import gluon +import logging +import os + +logging.basicConfig(level=logging.INFO) + +Config = namedtuple('Config', ['model', 'nbit_input', 'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc']) + + +def get_val_data(model_name, + rec_val, + batch_size, + num_workers=4): + rec_val = os.path.expanduser(rec_val) + mean_rgb = [123.68, 116.779, 103.939] + std_rgb = [58.393, 57.12, 57.375] + def batch_fn(batch, ctx): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) + return data, label + + img_size = 299 if model_name == 'inceptionv3' else 224 + val_data = mx.io.ImageRecordIter( + path_imgrec = rec_val, + preprocess_threads = num_workers, + shuffle = False, + batch_size = batch_size, + resize = 256, + data_shape = (3, img_size, img_size), + mean_r = mean_rgb[0], + mean_g = mean_rgb[1], + mean_b = mean_rgb[2], + std_r = std_rgb[0], + std_g = std_rgb[1], + std_b = std_rgb[2], + ) + return val_data, batch_fn + + +def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False): + gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) + img_size = 299 if model_name == 'inceptionv3' else 224 + data_shape = (batch_size, 3, img_size, img_size) + mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + net = mod['main'] + + with relay.build_config(opt_level=3): + qfunc = relay.quantize.prerequisite_optimize(net, params=params) + logging.debug('original') + logging.debug(qfunc.astext(show_meta_data=False)) + if original: + return qfunc + + with qconfig: + logging.debug('current quantize config') + logging.debug(qtz.current_qconfig()) + qfunc = qtz.quantize(qfunc) + logging.debug('after quantize') + logging.debug(qfunc.astext(show_meta_data=False)) + return qfunc + + +def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100): + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(model, target) + # create runtime module + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + + # setup evaluaiton metric + dataset.reset() + batch_size = dataset.batch_size + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + acc_top1.reset() + acc_top5.reset() + # Execute + for i, batch in enumerate(dataset): + data, label = batch_fn(batch, [mx.cpu(0)]) + m.run(data=data[0].asnumpy()) + out_arr = m.get_output(0) + acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) + acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) + + if not (i + 1) % log_interval: + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + nsamples = (i + 1) * batch_size + logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) + logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) + return top1 + +def test_quantize_acc(cfg, rec_val): + qconfig = qtz.qconfig(skip_conv_layers=[0], + nbit_input=cfg.nbit_input, + nbit_weight=cfg.nbit_input, + global_scale=cfg.global_scale, + dtype_input=cfg.dtype_input, + dtype_weight=cfg.dtype_input, + dtype_activation=cfg.dtype_output, + debug_enabled_ops=None) + + model = get_model(cfg.model, 32, qconfig, tvm.target.cuda()) + val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32) + + acc = eval_acc(model, val_data, batch_fn) + assert acc > cfg.expected_acc + return acc + + +if __name__ == "__main__": + #TODO(for user): replace the line with the path to imagenet validation dataset + rec_val = "/scratch/tqchen/imagenet/val.rec" + + results = [] + configs = [ + Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666), + + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692), + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692), + Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733), + Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747), + Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756), + # TODO: need to fix accuracy + # Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0), + ] + + for config in configs: + acc = test_quantize_acc(config, rec_val) + results.append((config, acc)) + for res in results: + print(res) diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py deleted file mode 100644 index f6f67d6b6ac9..000000000000 --- a/tests/python/relay/test_pass_quantize.py +++ /dev/null @@ -1,109 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import math -import numpy as np -import tvm -from tvm import relay -from tvm.relay import quantize as qtz -from tvm.relay import transform - - -def run_infer_type(expr): - mod = relay.Module.from_expr(expr) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body - - -def make_dataset(graph, size=100): - args = run_infer_type(graph).params - def create_arr(var): - ttype = var.type_annotation - np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype) - return tvm.ndarray.array(np_arr) - - params = {} - for arg in args: - if arg.name_hint == 'data': - dataset = [{'data': create_arr(arg)} for _ in range(size)] - else: - params[arg.name_hint] = create_arr(arg) - return dataset, params - - -def test_simulated_quantize(): - data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32")) - out = qtz._annotate.attach_simulated_quantize(data, 1) - out = run_infer_type(out) - assert out.checked_type == out.args[0].checked_type - assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32") - assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32") - assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") - - -def test_quantize_pass(): - def quantize_weight(arr): - maximum = np.amax(np.abs(arr.asnumpy())) - scale = 2**math.ceil(math.log(maximum, 2)) - out = np.around(arr.asnumpy() / scale * 128).astype('int8') - out = np.clip(out, -127, 127) - return relay.const(out, 'int8') - - n, c, h, w = 1, 3, 224, 224 - def make_graph(data): - weight = relay.var("conv_weight") - out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) - out = relay.Function(relay.analysis.free_vars(out), out) - return out - - def make_qgraph(data, weight): - out = data * relay.const(32.0) - out = relay.round(out) - out = relay.clip(out, a_min=-127, a_max=127) - out = out.astype('int8') - - out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), - padding=(1, 1), channels=c, out_dtype='int32') - out = out.astype('float32') - out = relay.multiply(out, relay.const(0.00024414062)) - out = relay.Function(relay.analysis.free_vars(out), out) - return out - - np.random.seed(42) - - data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) - graph = make_graph(data) - dataset, params = make_dataset(graph, 10) - - with qtz.qconfig(skip_conv_layers=None, global_scale=4.0, - round_for_shift=False, store_lowbit_output=False): - qgraph0 = qtz.quantize(graph, params) - qgraph0 = run_infer_type(qgraph0) - - conv_weight = quantize_weight(params['conv_weight']) - qgraph1 = make_qgraph(data, conv_weight) - qgraph1 = run_infer_type(qgraph1) - - graph = relay.create_executor('graph') - res0 = graph.evaluate(qgraph0)(dataset[0]['data']) - res1 = graph.evaluate(qgraph1)(dataset[0]['data']) - tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3) - - -if __name__ == "__main__": - test_simulated_quantize() - test_quantize_pass() diff --git a/tests/scripts/task_python_nightly.sh b/tests/scripts/task_python_nightly.sh new file mode 100755 index 000000000000..09f7e8a4c6e3 --- /dev/null +++ b/tests/scripts/task_python_nightly.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u + +export PYTHONPATH=python:topi/python + +# Rebuild cython +make cython3 + +rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc +rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc + +python3 -m nose -v topi/tests/python/nightly