From 53deebc9770addeb1f39af00d103c7a509320d8e Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Mon, 22 Jun 2020 17:22:04 -0700 Subject: [PATCH] Implement quant/dequant partitioning on our way get clooooooser clean up (part 1) clean up (part 2) clean up (part 3) clean up (part 4) clean clean cleaanaannanaaananaananaananaan clkjsdflkjlfsjdflkj revert parser changes add docs roll lint roll lint --- python/tvm/relay/analysis/analysis.py | 47 +++ .../relay/quantize/_partition_conversions.py | 364 ++++++++++++++++++ python/tvm/relay/quantize/quantize.py | 16 + python/tvm/relay/scope_builder.py | 2 +- src/relay/quantize/quantize.cc | 3 +- src/relay/quantize/quantize.h | 2 + tests/python/relay/test_pass_auto_quantize.py | 208 ++++++++++ 7 files changed, 640 insertions(+), 2 deletions(-) create mode 100644 python/tvm/relay/quantize/_partition_conversions.py diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 165e39a09c1b3..5471494a0043c 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -23,6 +23,9 @@ from tvm.ir import IRModule from tvm.relay import transform, build_module from tvm.runtime.ndarray import cpu +# TODO(weberlo) remove when we port dtype collectors to C++ +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.type_functor import TypeVisitor from . import _ffi_api from .feature import Feature @@ -236,6 +239,50 @@ def all_type_vars(expr, mod=None): return _ffi_api.all_type_vars(expr, use_mod) +class TyDtypeCollector(TypeVisitor): + """Pass that collects data types used in the visited type.""" + + def __init__(self): + TypeVisitor.__init__(self) + self.dtypes = set() + + def visit_tensor_type(self, tt): + self.dtypes.add(tt.dtype) + + +class ExprDtypeCollector(ExprVisitor): + """Pass that collects data types used in all types in the visited expression.""" + + def __init__(self): + ExprVisitor.__init__(self) + self.ty_visitor = TyDtypeCollector() + + def visit(self, expr): + if hasattr(expr, 'checked_type'): + self.ty_visitor.visit(expr.checked_type) + elif hasattr(expr, 'type_annotation'): + self.ty_visitor.visit(expr.type_annotation) + ExprVisitor.visit(self, expr) + + +def all_dtypes(expr): + """Collect set of all data types used in `expr`. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + ret : Set[String] + Set of data types used in the expression + """ + dtype_collector = ExprDtypeCollector() + dtype_collector.visit(expr) + return dtype_collector.ty_visitor.dtypes + + def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py new file mode 100644 index 0000000000000..a71e8a1fd28a6 --- /dev/null +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Utilities for partitioning input quantization and output dequantization expressions.""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.type_functor import TypeMutator + +# operators that are allowed in prefix/suffix partitions, because they are used +# to quantize/dequantize +ALLOWED_CONVERSION_OPS = ['add', 'multiply', 'right_shift', 'clip', 'round', 'cast'] + +def partition_conversions(mod, quantized_dtypes): + """Partition mod into input quantization, core quantized inference, and output dequantization. + + The resulting module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + mod : tvm.IRModule + Quantized module to partition + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization (`quantize_inputs`), core + quantized inference (`quantized_main`), output dequantization + (`dequantize_outputs`), and full quantized inference functions + """ + # Partitioning is implemented as in the diagram below: + # + # +----------------------------+ + # |Quantized Inference Function| + # +--------------+-------------+ + # | + # partition_prefix + # | + # +-----+-------------------------+ + # | | + # +--------v---------+ +-----------------v------------------+ + # |Input Quantization| |Rest of Quantized Inference Function| + # +------------------+ +-----------------+------------------+ + # | + # partition_suffix + # | + # +------+---------------------+ + # | | + # +------------------+ +----------v------------+ +-----------v---------+ + # |Input Quantization| |Core Quantized Function| |Output Dequantization| + # +------------------+ +-----------------------+ +---------------------+ + # + # The final module contains all three partitions, as well as a + # `main` function that composes these three functions (depicted below). + # + # +--------------------+-------------------------+-----------------------+ + # | Input Quantization | Core Quantized Function | Output Dequantization | + # +--------------------+-------------------------+-----------------------+ + assert len(mod.functions) == 1 + pre_mod, mid_mod = partition_prefix(mod, quantized_dtypes) + mid_mod, post_mod = partition_suffix(mid_mod, quantized_dtypes) + assert has_only_conversion_ops(pre_mod['main']) + assert relay.analysis.all_dtypes(mid_mod['main']).issubset(quantized_dtypes) + assert has_only_conversion_ops(post_mod['main']) + return fuse_partitions(pre_mod, mid_mod, post_mod) + + +def fuse_partitions(pre_mod, mid_mod, post_mod): + """Combine prefix, middle, and suffix modules into a single module. + + The combined module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + pre_mod : tvm.IRModule + Module containing an input quantization function + + mid_mod : tvm.IRModule + Module containing core of a quantized inference function + + post_mod : tvm.IRModule + Module containing an output dequantization function + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization, core quantized inference, + output dequantization, and full quantized inference functions + """ + pre_func = pre_mod['main'] + mid_func = mid_mod['main'] + post_func = post_mod['main'] + # create a module containing the prefix, middle, and suffix partitions + fused_mod = tvm.IRModule(functions={ + relay.GlobalVar('quantize_inputs'): pre_func, + relay.GlobalVar('quantized_main'): mid_func, + relay.GlobalVar('dequantize_outputs'): post_func, + }) + # construct a `main` that strings together the partitions, such that its + # behaviour is equivalent to `main` in an *unpartitioned* module + scope_builder = relay.ScopeBuilder() + fused_mod_main_params = [relay.Var(param.name_hint) for param in pre_func.params] + quantized_inputs = scope_builder.let('quantized_inputs', relay.Call( + fused_mod.get_global_var('quantize_inputs'), + fused_mod_main_params + )) + quantized_outputs = scope_builder.let('quantized_outputs', relay.Call( + fused_mod.get_global_var('quantized_main'), + [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))] + )) + dequantized_outputs = scope_builder.let('dequantized_outputs', relay.Call( + fused_mod.get_global_var('dequantize_outputs'), + [quantized_outputs] + )) + scope_builder.ret(dequantized_outputs) + fused_mod['main'] = relay.Function(fused_mod_main_params, scope_builder.get()) + return fused_mod + + +def with_dtype(typ, target_dtype): + """Generates a type from the given type where all dtypes are replaced with the target dtype. + + Parameters + ---------- + typ : relay.Type + Type whose dtypes are being replaced + + target_dtype : str + Target data type (e.g., 'int8') + + Returns + ------- + typ : relay.Type + Type with only `target_dtype` for dtypes + """ + class DtypeReplacer(TypeMutator): + def __init__(self, target_dtype): + TypeMutator.__init__(self) + self.target_dtype = target_dtype + + def visit_tensor_type(self, tt): + return relay.TensorType(tt.shape, self.target_dtype) + + return DtypeReplacer(target_dtype).visit(typ) + + +class PrefixCutter(ExprMutator): + """A mutator for extracting input quantization expressions from a function + + The result of `visit` is the core function, and the input quantization + expressions are stored in the `prefix_sb` scope builder. + """ + + def __init__(self, params, quantized_dtypes): + ExprMutator.__init__(self) + self.params = set(params) + self.quantized_dtypes = quantized_dtypes + self.subtree_params = set() + self.new_func_params = [] + self.prefix_sb = relay.ScopeBuilder() + self.prefix_binding_map = {} + + def visit_var(self, var): + if var in self.params: + self.subtree_params.add(var) + return var + + def visit_call(self, call): + # TODO(weberlo) use graph pattern matching? + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + new_args = [] + for arg in call.args: + new_arg = self.visit(arg) + if len(self.subtree_params) == 0: + new_args.append(new_arg) + else: + assert len(self.subtree_params) == 1 + param = next(iter(self.subtree_params)) + pre_param = self.prefix_sb.let(param.name_hint, new_arg) + self.subtree_params.clear() + mid_param = relay.Var( + param.name_hint, + with_dtype(param.type_annotation, arg.checked_type.dtype)) + self.prefix_binding_map[mid_param] = pre_param + # return new parameter, then we can use + # relay.analysis.free_vars at the end of the pass to generate + # new `mid_func` type signature + new_args.append(mid_param) + return relay.Call(call.op, new_args, call.attrs) + + return super().visit_call(call) + + +def partition_prefix(mod, quantized_dtypes): + """Extract input quantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + prefix_cutter = PrefixCutter(func.params, quantized_dtypes) + mid_body = prefix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + mid_func = relay.Function( + relay.analysis.free_vars(mid_body), + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + scope_builder = prefix_cutter.prefix_sb + # make sure we pass through all inputs in the prefix function's return expr + # (even those that don't require quantization) + ret_expr = [] + for param in mid_func.params: + if param in prefix_cutter.prefix_binding_map: + # this param required a conversion, so we collected it in the + # prefix cutter pass, and we can use the pass's mapping from mid + # func params to pre func params + ret_expr.append(prefix_cutter.prefix_binding_map[param]) + else: + # there was no detected conversion for this argument, so we thread + # it through the prefix function untouched + ret_expr.append(relay.Var(param.name_hint, param.checked_type)) + ret_expr = relay.Tuple(ret_expr) + scope_builder.ret(ret_expr) + pre_func_body = scope_builder.get() + pre_func = relay.Function(relay.analysis.free_vars(pre_func_body), pre_func_body) + pre_mod = tvm.IRModule.from_expr(pre_func) + + return pre_mod, mid_mod + + +class SuffixCutter(ExprMutator): + """A mutator for extracting output dequantization expressions from a function + + The result of `visit` is a function containing the output dequantization + expressions, and the middle of the function is stored in `mid_body`. + """ + + def __init__(self, quantized_dtypes): + ExprMutator.__init__(self) + self.mid_body = None + self.quantized_dtypes = quantized_dtypes + + def visit(self, expr): + if hasattr(expr, 'checked_type') and expr.checked_type.dtype in self.quantized_dtypes: + self.mid_body = expr + return relay.Var('input', expr.checked_type) + + return super().visit(expr) + + +def partition_suffix(mod, quantized_dtypes): + """Extract output dequantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + suffix_cutter = SuffixCutter(quantized_dtypes) + post_body = suffix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + post_func = relay.Function( + relay.analysis.free_vars(post_body), + post_body, + func.ret_type) + post_mod = tvm.IRModule.from_expr(post_func) + + mid_body = suffix_cutter.mid_body + if mid_body is None: + # The suffix contains the entire function, meaning there was no + # quantization boundary in the given mod. In this case, we use the + # suffix mod as the middle mod and make the suffix an identity function. + mid_mod = post_mod + post_body = relay.Var('input', mid_mod['main'].ret_type) + post_func = relay.Function( + [post_body], + post_body) + post_mod = tvm.IRModule.from_expr(post_func) + else: + mid_func = relay.Function( + func.params, + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + return mid_mod, post_mod + + +class ConversionOpChecker(ExprVisitor): + """A pass for checking that the visited function contains only conversion ops""" + def __init__(self): + ExprVisitor.__init__(self) + self.valid = True + + def visit_call(self, call): + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + self.valid = False + super().visit_call(call) + + +def has_only_conversion_ops(func): + """Return true iff the given function contains only quantization/dequantization ops. + + Parameters + ---------- + func : relay.Function + Function being checked + + Returns + ------- + valid : bool + Whether the function contains only conversion ops + """ + checker = ConversionOpChecker() + checker.visit(func) + return checker.valid diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 28ebf7f3032bd..d1e6b42a8a630 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,6 +22,7 @@ from . import _quantize from ._calibrate import calibrate +from ._partition_conversions import partition_conversions from .. import expr as _expr from .. import transform as _transform @@ -85,6 +86,7 @@ class QConfig(Object): "debug_enabled_ops": None, "rounding": "UPWARD", "calibrate_chunk_by": -1, + "partition_conversions": False, } # pylint: disable=no-member @@ -179,6 +181,15 @@ def qconfig(**kwargs): rounding: "UPWARD" or "TONEAREST" Rounding direction for fixed point multiplications. + partition_conversions: bool + Whether to partition a quantized result into a module containing + a prefix function (consisting of input conversion into the quantized data space), + a middle function (consisting of the core quantized network), + a suffix function (consisting of output dequantization), + and a main function (that calls the prefix, middle, and suffix functions in succession). + If there are unquantized operators in the core network, an exception is raised. + The default value is `False`. + Returns ------- config: QConfig @@ -359,4 +370,9 @@ def quantize(mod, params=None, dataset=None): with quantize_context(): mod = quantize_seq(mod) + q_cfg = current_qconfig() + if q_cfg.partition_conversions: + quantized_dtypes = {q_cfg.dtype_input, q_cfg.dtype_weight, q_cfg.dtype_activation} + return partition_conversions(mod, quantized_dtypes) + return mod diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index cd8dc8dcd3096..86ff805b1d34d 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -115,7 +115,7 @@ def let(self, var, value): The variable or name of variable. value: tvm.relay.Expr - The value to be binded + The value to be bound """ if isinstance(var, (tuple, list)): if len(var) > 2: diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 1bf858b43db02..28fc68ebf0a8a 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -125,7 +125,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; - p->stream << "rounding==" << op->rounding; + p->stream << "rounding==" << op->rounding << ", "; + p->stream << "partition_conversions==" << op->partition_conversions; p->stream << ")"; }); diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 86f8926c98ac7..9b9281e74ea5c 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -74,6 +74,7 @@ class QConfigNode : public Object { Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; int calibrate_chunk_by = -1; + bool partition_conversions = false; void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); @@ -92,6 +93,7 @@ class QConfigNode : public Object { v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("rounding", &rounding); v->Visit("calibrate_chunk_by", &calibrate_chunk_by); + v->Visit("partition_conversions", &partition_conversions); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index f7427974904ce..697b34f92ff7e 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.relay import testing from tvm.relay.expr import Call +from tvm.topi.util import get_const_tuple def quantize_and_build(out): @@ -103,9 +104,216 @@ def test_calibrate_memory_bound(): relay.quantize.quantize(mod, params, dataset) +#################################### +# Quant/Dequant Partitioning Tests # +#################################### + +BASE_CFG = { + 'skip_conv_layers': [], + 'skip_dense_layers': False, + 'dtype_input': "int8", + 'dtype_weight': "int8", + 'dtype_activation': "int32", +} + +def gen_rand_tvm(tt, low, high): + if 'int' in tt.dtype: + data_np = np.random.randint(low, high, size=get_const_tuple(tt.shape), dtype=tt.dtype) + elif 'float' in tt.dtype: + data_np = np.random.uniform(low, high, size=get_const_tuple(tt.shape)).astype(tt.dtype) + else: + assert False, 'unknown dtype' + return tvm.nd.array(data_np, ctx=tvm.cpu(0)) + + +def verify_partition_fails(mod, params): + try: + with relay.quantize.qconfig(**BASE_CFG, partition_conversions=True): + partitioned_mod = relay.quantize.quantize(mod, params) + raise RuntimeError('partitioning should have failed') + except AssertionError: + pass + + +def verify_partition(mod, params): + with relay.quantize.qconfig(**BASE_CFG): + unpartitioned_mod = relay.quantize.quantize(mod, params) + with relay.quantize.qconfig(**BASE_CFG, partition_conversions=True): + partitioned_mod = relay.quantize.quantize(mod, params) + + # ensure the quantized core indeed only consists of quantized dtypes + q_dtypes = set([ + BASE_CFG['dtype_input'], + BASE_CFG['dtype_weight'], + BASE_CFG['dtype_activation'] + ]) + q_main_dtypes = relay.analysis.all_dtypes(partitioned_mod['quantized_main']) + assert q_main_dtypes.issubset(q_dtypes) + + # ensure results of `partition_result=False` and `partition_result=True` agree + params = [ + gen_rand_tvm(param.type_annotation, 0, 1) + for param in partitioned_mod['main'].params + ] + def _eval_mod(mod): + vm = relay.create_executor('vm', ctx=tvm.cpu(0), target='llvm', mod=mod) + return vm.evaluate()(*params) + partitioned_mod_result = _eval_mod(partitioned_mod) + unpartitioned_mod_result = _eval_mod(unpartitioned_mod) + tvm.testing.assert_allclose( + unpartitioned_mod_result.asnumpy(), partitioned_mod_result.asnumpy()) + + +def test_add_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(10, 10), float32], + %y: Tensor[(10, 10), float32]) { + add(%x, %y) + } + """) + mod = tvm.IRModule.from_expr(func) + params = {} + verify_partition_fails(mod, params) + + +def test_conv2d_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + weight_ty = mod['main'].params[1].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_multiple_arg_conversions_partition(): + mod = tvm.IRModule.from_expr(relay.fromtext(""" + v0.0.4 + fn (%x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %x2: Tensor[(1, 4, 16, 16), float32], + %w2: Tensor[(4, 4, 3, 3), float32] + ) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + %1 = nn.conv2d(%x2, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + add(%0, %1) + } + """)) + + w1_ty = mod['main'].params[1].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_unquantizable_prefix_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %b: Tensor[(4), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + // NOTE bias_add isn't currently quantizable + %0 = nn.bias_add(%x, %b); + nn.conv2d(%0, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + bias_ty = mod['main'].params[1].checked_type + weight_ty = mod['main'].params[2].checked_type + params = { + 'b': gen_rand_tvm(bias_ty, 0, 1), + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_core_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32], + %w2: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + %1 = nn.bias_add(%0, %b); + nn.conv2d(%1, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + w1_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_suffix_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + nn.bias_add(%0, %b) + } + """) + mod = tvm.IRModule.from_expr(func) + weight_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() test_calibrate_target(False) test_calibrate_target(True) test_calibrate_memory_bound() + + test_add_partition() + test_conv2d_partition() + test_multiple_arg_conversions_partition() + test_unquantizable_prefix_partition() + test_unquantizable_core_partition() + test_unquantizable_suffix_partition()