diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 165e39a09c1b3..5471494a0043c 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -23,6 +23,9 @@ from tvm.ir import IRModule from tvm.relay import transform, build_module from tvm.runtime.ndarray import cpu +# TODO(weberlo) remove when we port dtype collectors to C++ +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.type_functor import TypeVisitor from . import _ffi_api from .feature import Feature @@ -236,6 +239,50 @@ def all_type_vars(expr, mod=None): return _ffi_api.all_type_vars(expr, use_mod) +class TyDtypeCollector(TypeVisitor): + """Pass that collects data types used in the visited type.""" + + def __init__(self): + TypeVisitor.__init__(self) + self.dtypes = set() + + def visit_tensor_type(self, tt): + self.dtypes.add(tt.dtype) + + +class ExprDtypeCollector(ExprVisitor): + """Pass that collects data types used in all types in the visited expression.""" + + def __init__(self): + ExprVisitor.__init__(self) + self.ty_visitor = TyDtypeCollector() + + def visit(self, expr): + if hasattr(expr, 'checked_type'): + self.ty_visitor.visit(expr.checked_type) + elif hasattr(expr, 'type_annotation'): + self.ty_visitor.visit(expr.type_annotation) + ExprVisitor.visit(self, expr) + + +def all_dtypes(expr): + """Collect set of all data types used in `expr`. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + ret : Set[String] + Set of data types used in the expression + """ + dtype_collector = ExprDtypeCollector() + dtype_collector.visit(expr) + return dtype_collector.ty_visitor.dtypes + + def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py new file mode 100644 index 0000000000000..a71e8a1fd28a6 --- /dev/null +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Utilities for partitioning input quantization and output dequantization expressions.""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.type_functor import TypeMutator + +# operators that are allowed in prefix/suffix partitions, because they are used +# to quantize/dequantize +ALLOWED_CONVERSION_OPS = ['add', 'multiply', 'right_shift', 'clip', 'round', 'cast'] + +def partition_conversions(mod, quantized_dtypes): + """Partition mod into input quantization, core quantized inference, and output dequantization. + + The resulting module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + mod : tvm.IRModule + Quantized module to partition + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization (`quantize_inputs`), core + quantized inference (`quantized_main`), output dequantization + (`dequantize_outputs`), and full quantized inference functions + """ + # Partitioning is implemented as in the diagram below: + # + # +----------------------------+ + # |Quantized Inference Function| + # +--------------+-------------+ + # | + # partition_prefix + # | + # +-----+-------------------------+ + # | | + # +--------v---------+ +-----------------v------------------+ + # |Input Quantization| |Rest of Quantized Inference Function| + # +------------------+ +-----------------+------------------+ + # | + # partition_suffix + # | + # +------+---------------------+ + # | | + # +------------------+ +----------v------------+ +-----------v---------+ + # |Input Quantization| |Core Quantized Function| |Output Dequantization| + # +------------------+ +-----------------------+ +---------------------+ + # + # The final module contains all three partitions, as well as a + # `main` function that composes these three functions (depicted below). + # + # +--------------------+-------------------------+-----------------------+ + # | Input Quantization | Core Quantized Function | Output Dequantization | + # +--------------------+-------------------------+-----------------------+ + assert len(mod.functions) == 1 + pre_mod, mid_mod = partition_prefix(mod, quantized_dtypes) + mid_mod, post_mod = partition_suffix(mid_mod, quantized_dtypes) + assert has_only_conversion_ops(pre_mod['main']) + assert relay.analysis.all_dtypes(mid_mod['main']).issubset(quantized_dtypes) + assert has_only_conversion_ops(post_mod['main']) + return fuse_partitions(pre_mod, mid_mod, post_mod) + + +def fuse_partitions(pre_mod, mid_mod, post_mod): + """Combine prefix, middle, and suffix modules into a single module. + + The combined module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + pre_mod : tvm.IRModule + Module containing an input quantization function + + mid_mod : tvm.IRModule + Module containing core of a quantized inference function + + post_mod : tvm.IRModule + Module containing an output dequantization function + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization, core quantized inference, + output dequantization, and full quantized inference functions + """ + pre_func = pre_mod['main'] + mid_func = mid_mod['main'] + post_func = post_mod['main'] + # create a module containing the prefix, middle, and suffix partitions + fused_mod = tvm.IRModule(functions={ + relay.GlobalVar('quantize_inputs'): pre_func, + relay.GlobalVar('quantized_main'): mid_func, + relay.GlobalVar('dequantize_outputs'): post_func, + }) + # construct a `main` that strings together the partitions, such that its + # behaviour is equivalent to `main` in an *unpartitioned* module + scope_builder = relay.ScopeBuilder() + fused_mod_main_params = [relay.Var(param.name_hint) for param in pre_func.params] + quantized_inputs = scope_builder.let('quantized_inputs', relay.Call( + fused_mod.get_global_var('quantize_inputs'), + fused_mod_main_params + )) + quantized_outputs = scope_builder.let('quantized_outputs', relay.Call( + fused_mod.get_global_var('quantized_main'), + [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))] + )) + dequantized_outputs = scope_builder.let('dequantized_outputs', relay.Call( + fused_mod.get_global_var('dequantize_outputs'), + [quantized_outputs] + )) + scope_builder.ret(dequantized_outputs) + fused_mod['main'] = relay.Function(fused_mod_main_params, scope_builder.get()) + return fused_mod + + +def with_dtype(typ, target_dtype): + """Generates a type from the given type where all dtypes are replaced with the target dtype. + + Parameters + ---------- + typ : relay.Type + Type whose dtypes are being replaced + + target_dtype : str + Target data type (e.g., 'int8') + + Returns + ------- + typ : relay.Type + Type with only `target_dtype` for dtypes + """ + class DtypeReplacer(TypeMutator): + def __init__(self, target_dtype): + TypeMutator.__init__(self) + self.target_dtype = target_dtype + + def visit_tensor_type(self, tt): + return relay.TensorType(tt.shape, self.target_dtype) + + return DtypeReplacer(target_dtype).visit(typ) + + +class PrefixCutter(ExprMutator): + """A mutator for extracting input quantization expressions from a function + + The result of `visit` is the core function, and the input quantization + expressions are stored in the `prefix_sb` scope builder. + """ + + def __init__(self, params, quantized_dtypes): + ExprMutator.__init__(self) + self.params = set(params) + self.quantized_dtypes = quantized_dtypes + self.subtree_params = set() + self.new_func_params = [] + self.prefix_sb = relay.ScopeBuilder() + self.prefix_binding_map = {} + + def visit_var(self, var): + if var in self.params: + self.subtree_params.add(var) + return var + + def visit_call(self, call): + # TODO(weberlo) use graph pattern matching? + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + new_args = [] + for arg in call.args: + new_arg = self.visit(arg) + if len(self.subtree_params) == 0: + new_args.append(new_arg) + else: + assert len(self.subtree_params) == 1 + param = next(iter(self.subtree_params)) + pre_param = self.prefix_sb.let(param.name_hint, new_arg) + self.subtree_params.clear() + mid_param = relay.Var( + param.name_hint, + with_dtype(param.type_annotation, arg.checked_type.dtype)) + self.prefix_binding_map[mid_param] = pre_param + # return new parameter, then we can use + # relay.analysis.free_vars at the end of the pass to generate + # new `mid_func` type signature + new_args.append(mid_param) + return relay.Call(call.op, new_args, call.attrs) + + return super().visit_call(call) + + +def partition_prefix(mod, quantized_dtypes): + """Extract input quantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + prefix_cutter = PrefixCutter(func.params, quantized_dtypes) + mid_body = prefix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + mid_func = relay.Function( + relay.analysis.free_vars(mid_body), + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + scope_builder = prefix_cutter.prefix_sb + # make sure we pass through all inputs in the prefix function's return expr + # (even those that don't require quantization) + ret_expr = [] + for param in mid_func.params: + if param in prefix_cutter.prefix_binding_map: + # this param required a conversion, so we collected it in the + # prefix cutter pass, and we can use the pass's mapping from mid + # func params to pre func params + ret_expr.append(prefix_cutter.prefix_binding_map[param]) + else: + # there was no detected conversion for this argument, so we thread + # it through the prefix function untouched + ret_expr.append(relay.Var(param.name_hint, param.checked_type)) + ret_expr = relay.Tuple(ret_expr) + scope_builder.ret(ret_expr) + pre_func_body = scope_builder.get() + pre_func = relay.Function(relay.analysis.free_vars(pre_func_body), pre_func_body) + pre_mod = tvm.IRModule.from_expr(pre_func) + + return pre_mod, mid_mod + + +class SuffixCutter(ExprMutator): + """A mutator for extracting output dequantization expressions from a function + + The result of `visit` is a function containing the output dequantization + expressions, and the middle of the function is stored in `mid_body`. + """ + + def __init__(self, quantized_dtypes): + ExprMutator.__init__(self) + self.mid_body = None + self.quantized_dtypes = quantized_dtypes + + def visit(self, expr): + if hasattr(expr, 'checked_type') and expr.checked_type.dtype in self.quantized_dtypes: + self.mid_body = expr + return relay.Var('input', expr.checked_type) + + return super().visit(expr) + + +def partition_suffix(mod, quantized_dtypes): + """Extract output dequantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + suffix_cutter = SuffixCutter(quantized_dtypes) + post_body = suffix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + post_func = relay.Function( + relay.analysis.free_vars(post_body), + post_body, + func.ret_type) + post_mod = tvm.IRModule.from_expr(post_func) + + mid_body = suffix_cutter.mid_body + if mid_body is None: + # The suffix contains the entire function, meaning there was no + # quantization boundary in the given mod. In this case, we use the + # suffix mod as the middle mod and make the suffix an identity function. + mid_mod = post_mod + post_body = relay.Var('input', mid_mod['main'].ret_type) + post_func = relay.Function( + [post_body], + post_body) + post_mod = tvm.IRModule.from_expr(post_func) + else: + mid_func = relay.Function( + func.params, + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + return mid_mod, post_mod + + +class ConversionOpChecker(ExprVisitor): + """A pass for checking that the visited function contains only conversion ops""" + def __init__(self): + ExprVisitor.__init__(self) + self.valid = True + + def visit_call(self, call): + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + self.valid = False + super().visit_call(call) + + +def has_only_conversion_ops(func): + """Return true iff the given function contains only quantization/dequantization ops. + + Parameters + ---------- + func : relay.Function + Function being checked + + Returns + ------- + valid : bool + Whether the function contains only conversion ops + """ + checker = ConversionOpChecker() + checker.visit(func) + return checker.valid diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 28ebf7f3032bd..d1e6b42a8a630 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,6 +22,7 @@ from . import _quantize from ._calibrate import calibrate +from ._partition_conversions import partition_conversions from .. import expr as _expr from .. import transform as _transform @@ -85,6 +86,7 @@ class QConfig(Object): "debug_enabled_ops": None, "rounding": "UPWARD", "calibrate_chunk_by": -1, + "partition_conversions": False, } # pylint: disable=no-member @@ -179,6 +181,15 @@ def qconfig(**kwargs): rounding: "UPWARD" or "TONEAREST" Rounding direction for fixed point multiplications. + partition_conversions: bool + Whether to partition a quantized result into a module containing + a prefix function (consisting of input conversion into the quantized data space), + a middle function (consisting of the core quantized network), + a suffix function (consisting of output dequantization), + and a main function (that calls the prefix, middle, and suffix functions in succession). + If there are unquantized operators in the core network, an exception is raised. + The default value is `False`. + Returns ------- config: QConfig @@ -359,4 +370,9 @@ def quantize(mod, params=None, dataset=None): with quantize_context(): mod = quantize_seq(mod) + q_cfg = current_qconfig() + if q_cfg.partition_conversions: + quantized_dtypes = {q_cfg.dtype_input, q_cfg.dtype_weight, q_cfg.dtype_activation} + return partition_conversions(mod, quantized_dtypes) + return mod diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index cd8dc8dcd3096..86ff805b1d34d 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -115,7 +115,7 @@ def let(self, var, value): The variable or name of variable. value: tvm.relay.Expr - The value to be binded + The value to be bound """ if isinstance(var, (tuple, list)): if len(var) > 2: diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 1bf858b43db02..28fc68ebf0a8a 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -125,7 +125,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; - p->stream << "rounding==" << op->rounding; + p->stream << "rounding==" << op->rounding << ", "; + p->stream << "partition_conversions==" << op->partition_conversions; p->stream << ")"; }); diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 86f8926c98ac7..9b9281e74ea5c 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -74,6 +74,7 @@ class QConfigNode : public Object { Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; int calibrate_chunk_by = -1; + bool partition_conversions = false; void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); @@ -92,6 +93,7 @@ class QConfigNode : public Object { v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("rounding", &rounding); v->Visit("calibrate_chunk_by", &calibrate_chunk_by); + v->Visit("partition_conversions", &partition_conversions); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index f7427974904ce..697b34f92ff7e 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.relay import testing from tvm.relay.expr import Call +from tvm.topi.util import get_const_tuple def quantize_and_build(out): @@ -103,9 +104,216 @@ def test_calibrate_memory_bound(): relay.quantize.quantize(mod, params, dataset) +#################################### +# Quant/Dequant Partitioning Tests # +#################################### + +BASE_CFG = { + 'skip_conv_layers': [], + 'skip_dense_layers': False, + 'dtype_input': "int8", + 'dtype_weight': "int8", + 'dtype_activation': "int32", +} + +def gen_rand_tvm(tt, low, high): + if 'int' in tt.dtype: + data_np = np.random.randint(low, high, size=get_const_tuple(tt.shape), dtype=tt.dtype) + elif 'float' in tt.dtype: + data_np = np.random.uniform(low, high, size=get_const_tuple(tt.shape)).astype(tt.dtype) + else: + assert False, 'unknown dtype' + return tvm.nd.array(data_np, ctx=tvm.cpu(0)) + + +def verify_partition_fails(mod, params): + try: + with relay.quantize.qconfig(**BASE_CFG, partition_conversions=True): + partitioned_mod = relay.quantize.quantize(mod, params) + raise RuntimeError('partitioning should have failed') + except AssertionError: + pass + + +def verify_partition(mod, params): + with relay.quantize.qconfig(**BASE_CFG): + unpartitioned_mod = relay.quantize.quantize(mod, params) + with relay.quantize.qconfig(**BASE_CFG, partition_conversions=True): + partitioned_mod = relay.quantize.quantize(mod, params) + + # ensure the quantized core indeed only consists of quantized dtypes + q_dtypes = set([ + BASE_CFG['dtype_input'], + BASE_CFG['dtype_weight'], + BASE_CFG['dtype_activation'] + ]) + q_main_dtypes = relay.analysis.all_dtypes(partitioned_mod['quantized_main']) + assert q_main_dtypes.issubset(q_dtypes) + + # ensure results of `partition_result=False` and `partition_result=True` agree + params = [ + gen_rand_tvm(param.type_annotation, 0, 1) + for param in partitioned_mod['main'].params + ] + def _eval_mod(mod): + vm = relay.create_executor('vm', ctx=tvm.cpu(0), target='llvm', mod=mod) + return vm.evaluate()(*params) + partitioned_mod_result = _eval_mod(partitioned_mod) + unpartitioned_mod_result = _eval_mod(unpartitioned_mod) + tvm.testing.assert_allclose( + unpartitioned_mod_result.asnumpy(), partitioned_mod_result.asnumpy()) + + +def test_add_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(10, 10), float32], + %y: Tensor[(10, 10), float32]) { + add(%x, %y) + } + """) + mod = tvm.IRModule.from_expr(func) + params = {} + verify_partition_fails(mod, params) + + +def test_conv2d_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + weight_ty = mod['main'].params[1].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_multiple_arg_conversions_partition(): + mod = tvm.IRModule.from_expr(relay.fromtext(""" + v0.0.4 + fn (%x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %x2: Tensor[(1, 4, 16, 16), float32], + %w2: Tensor[(4, 4, 3, 3), float32] + ) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + %1 = nn.conv2d(%x2, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + add(%0, %1) + } + """)) + + w1_ty = mod['main'].params[1].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_unquantizable_prefix_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %b: Tensor[(4), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + // NOTE bias_add isn't currently quantizable + %0 = nn.bias_add(%x, %b); + nn.conv2d(%0, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + bias_ty = mod['main'].params[1].checked_type + weight_ty = mod['main'].params[2].checked_type + params = { + 'b': gen_rand_tvm(bias_ty, 0, 1), + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_core_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32], + %w2: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + %1 = nn.bias_add(%0, %b); + nn.conv2d(%1, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + mod = tvm.IRModule.from_expr(func) + w1_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_suffix_partition(): + func = relay.fromtext(""" + v0.0.4 + fn (%x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + nn.bias_add(%0, %b) + } + """) + mod = tvm.IRModule.from_expr(func) + weight_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() test_calibrate_target(False) test_calibrate_target(True) test_calibrate_memory_bound() + + test_add_partition() + test_conv2d_partition() + test_multiple_arg_conversions_partition() + test_unquantizable_prefix_partition() + test_unquantizable_core_partition() + test_unquantizable_suffix_partition()