diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 165e39a09c1b3..99f4252ac4f7d 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -236,6 +236,22 @@ def all_type_vars(expr, mod=None): return _ffi_api.all_type_vars(expr, use_mod) +def all_dtypes(expr): + """Collect set of all data types used in `expr`. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + ret : Set[String] + Set of data types used in the expression (e.g., `{'int8', 'int32'}`) + """ + return set(_ffi_api.all_dtypes(expr)) + + def collect_device_info(expr): """Collect the device allocation map for the given expression. The device ids are propagated from the `device_copy` operators. diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py new file mode 100644 index 0000000000000..d1c3b59118841 --- /dev/null +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Utilities for partitioning input quantization and output dequantization expressions.""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator, ExprVisitor + +# operators that are allowed in prefix/suffix partitions, because they are used +# to quantize/dequantize +ALLOWED_CONVERSION_OPS = ['add', 'multiply', 'right_shift', 'clip', 'round', 'cast'] + +def partition_conversions(mod, quantized_dtypes, ensure_fully_integral): + """Partition mod into input quantization, core quantized inference, and output dequantization. + + The resulting module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + mod : tvm.IRModule + Quantized module to partition + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + ensure_fully_integral : bool + Whether to raise an exception if there are unquantized operators in the result + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization (`quantize_inputs`), core + quantized inference (`quantized_main`), output dequantization + (`dequantize_outputs`), and full quantized inference functions + """ + # Partitioning is implemented as in the diagram below: + # + # +----------------------------+ + # |Quantized Inference Function| + # +--------------+-------------+ + # | + # partition_prefix + # | + # +-----+-------------------------+ + # | | + # +--------v---------+ +-----------------v------------------+ + # |Input Quantization| |Rest of Quantized Inference Function| + # +------------------+ +-----------------+------------------+ + # | + # partition_suffix + # | + # +------+---------------------+ + # | | + # +------------------+ +----------v------------+ +-----------v---------+ + # |Input Quantization| |Core Quantized Function| |Output Dequantization| + # +------------------+ +-----------------------+ +---------------------+ + # + # The final module contains all three partitions, as well as a + # `main` function that composes these three functions (depicted below). + # + # +--------------------+-------------------------+-----------------------+ + # | Input Quantization | Core Quantized Function | Output Dequantization | + # +--------------------+-------------------------+-----------------------+ + assert len(mod.functions) == 1 + pre_mod, mid_mod = partition_prefix(mod, quantized_dtypes) + mid_mod, post_mod = partition_suffix(mid_mod, quantized_dtypes) + if ensure_fully_integral: + assert has_only_conversion_ops(pre_mod['main']) + assert relay.analysis.all_dtypes(mid_mod['main']).issubset(quantized_dtypes) + assert has_only_conversion_ops(post_mod['main']) + return fuse_partitions(pre_mod, mid_mod, post_mod) + + +def fuse_partitions(pre_mod, mid_mod, post_mod): + """Combine prefix, middle, and suffix modules into a single module. + + The combined module includes an additional `main` that fuses all three + partitions together. + + Parameters + ---------- + pre_mod : tvm.IRModule + Module containing an input quantization function + + mid_mod : tvm.IRModule + Module containing core of a quantized inference function + + post_mod : tvm.IRModule + Module containing an output dequantization function + + Returns + ------- + fused_mod : tvm.IRModule + Module containing the input quantization, core quantized inference, + output dequantization, and full quantized inference functions + """ + pre_func = pre_mod['main'] + mid_func = mid_mod['main'] + post_func = post_mod['main'] + # create a module containing the prefix, middle, and suffix partitions + fused_mod = tvm.IRModule(functions={ + relay.GlobalVar('quantize_inputs'): pre_func, + relay.GlobalVar('quantized_main'): mid_func, + relay.GlobalVar('dequantize_outputs'): post_func, + }) + # construct a `main` that strings together the partitions, such that its + # behaviour is equivalent to `main` in an *unpartitioned* module + scope_builder = relay.ScopeBuilder() + fused_mod_main_params = [relay.Var(param.name_hint) for param in pre_func.params] + quantized_inputs = scope_builder.let('quantized_inputs', relay.Call( + fused_mod.get_global_var('quantize_inputs'), + fused_mod_main_params + )) + quantized_outputs = scope_builder.let('quantized_outputs', relay.Call( + fused_mod.get_global_var('quantized_main'), + [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))] + )) + dequantized_outputs = scope_builder.let('dequantized_outputs', relay.Call( + fused_mod.get_global_var('dequantize_outputs'), + [quantized_outputs] + )) + scope_builder.ret(dequantized_outputs) + fused_mod['main'] = relay.Function(fused_mod_main_params, scope_builder.get()) + return fused_mod + + +class PrefixCutter(ExprMutator): + """A mutator for extracting input quantization expressions from a function + + The result of `visit` is the core function, and the input quantization + expressions are stored in the `prefix_sb` scope builder. + """ + + def __init__(self, params, quantized_dtypes): + ExprMutator.__init__(self) + self.params = set(params) + self.quantized_dtypes = quantized_dtypes + self.subtree_params = set() + self.new_func_params = [] + self.prefix_sb = relay.ScopeBuilder() + self.prefix_binding_map = {} + + def visit_var(self, var): + if var in self.params: + self.subtree_params.add(var) + return var + + def visit_call(self, call): + # TODO(weberlo) use graph pattern matching? + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + new_args = [] + for arg in call.args: + new_arg = self.visit(arg) + if len(self.subtree_params) == 0: + new_args.append(new_arg) + else: + assert len(self.subtree_params) == 1 + param = next(iter(self.subtree_params)) + pre_param = self.prefix_sb.let(param.name_hint, new_arg) + self.subtree_params.clear() + mid_param = relay.Var( + param.name_hint, + arg.checked_type) + self.prefix_binding_map[mid_param] = pre_param + # return new parameter, then we can use + # relay.analysis.free_vars at the end of the pass to generate + # new `mid_func` type signature + new_args.append(mid_param) + return relay.Call(call.op, new_args, call.attrs) + + return super().visit_call(call) + + +def partition_prefix(mod, quantized_dtypes): + """Extract input quantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + prefix_cutter = PrefixCutter(func.params, quantized_dtypes) + mid_body = prefix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + mid_func = relay.Function( + relay.analysis.free_vars(mid_body), + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + scope_builder = prefix_cutter.prefix_sb + # make sure we pass through all inputs in the prefix function's return expr + # (even those that don't require quantization) + ret_expr = [] + for param in mid_func.params: + if param in prefix_cutter.prefix_binding_map: + # this param required a conversion, so we collected it in the + # prefix cutter pass, and we can use the pass's mapping from mid + # func params to pre func params + ret_expr.append(prefix_cutter.prefix_binding_map[param]) + else: + # there was no detected conversion for this argument, so we thread + # it through the prefix function untouched + ret_expr.append(relay.Var(param.name_hint, param.checked_type)) + ret_expr = relay.Tuple(ret_expr) + scope_builder.ret(ret_expr) + pre_func_body = scope_builder.get() + pre_func = relay.Function(relay.analysis.free_vars(pre_func_body), pre_func_body) + pre_mod = tvm.IRModule.from_expr(pre_func) + + return pre_mod, mid_mod + + +class SuffixCutter(ExprMutator): + """A mutator for extracting output dequantization expressions from a function + + The result of `visit` is a function containing the output dequantization + expressions, and the middle of the function is stored in `mid_body`. + """ + + def __init__(self, quantized_dtypes): + ExprMutator.__init__(self) + self.mid_body = None + self.quantized_dtypes = quantized_dtypes + + def visit(self, expr): + if hasattr(expr, 'checked_type') and expr.checked_type.dtype in self.quantized_dtypes: + self.mid_body = expr + return relay.Var('input', expr.checked_type) + + return super().visit(expr) + + +def partition_suffix(mod, quantized_dtypes): + """Extract output dequantization expressions from `mod['main']`. + + Parameters + ---------- + mod : tvm.IRModule + Module containing a quantized inference function + + quantized_dtypes : Set[str] + Set of data types allowed in quantized operators + + Returns + ------- + pre_mod : tvm.IRModule + Module containing the input quantization function + + mid_mod : tvm.IRModule + Module containing a function with everything except for input quantization + """ + assert len(mod.functions) == 1 + func = mod['main'] + suffix_cutter = SuffixCutter(quantized_dtypes) + post_body = suffix_cutter.visit(func.body) + assert not func.type_params, 'unimplemented' + assert func.attrs is None, 'unimplemented' + post_func = relay.Function( + relay.analysis.free_vars(post_body), + post_body, + func.ret_type) + post_mod = tvm.IRModule.from_expr(post_func) + + mid_body = suffix_cutter.mid_body + if mid_body is None: + # The suffix contains the entire function, meaning there was no + # quantization boundary in the given mod. In this case, we use the + # suffix mod as the middle mod and make the suffix an identity function. + mid_mod = post_mod + post_body = relay.Var('input', mid_mod['main'].ret_type) + post_func = relay.Function( + [post_body], + post_body) + post_mod = tvm.IRModule.from_expr(post_func) + else: + mid_func = relay.Function( + func.params, + mid_body) + mid_mod = tvm.IRModule.from_expr(mid_func) + + return mid_mod, post_mod + + +class ConversionOpChecker(ExprVisitor): + """A pass for checking that the visited function contains only conversion ops""" + def __init__(self): + ExprVisitor.__init__(self) + self.valid = True + + def visit_call(self, call): + if not hasattr(call.op, 'name') or call.op.name not in ALLOWED_CONVERSION_OPS: + self.valid = False + super().visit_call(call) + + +def has_only_conversion_ops(func): + """Return true iff the given function contains only quantization/dequantization ops. + + Parameters + ---------- + func : relay.Function + Function being checked + + Returns + ------- + valid : bool + Whether the function contains only conversion ops + """ + checker = ConversionOpChecker() + checker.visit(func) + return checker.valid diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 28ebf7f3032bd..8a8c82c7e6d69 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,6 +22,7 @@ from . import _quantize from ._calibrate import calibrate +from ._partition_conversions import partition_conversions from .. import expr as _expr from .. import transform as _transform @@ -85,6 +86,7 @@ class QConfig(Object): "debug_enabled_ops": None, "rounding": "UPWARD", "calibrate_chunk_by": -1, + "partition_conversions": "disabled", } # pylint: disable=no-member @@ -179,6 +181,17 @@ def qconfig(**kwargs): rounding: "UPWARD" or "TONEAREST" Rounding direction for fixed point multiplications. + partition_conversions: 'disabled', 'enabled', or 'fully_integral' + If set to 'enabled' or 'fully_integral', partitions a quantized + result into a module containing + a prefix function (consisting of input conversion into the quantized data space), + a middle function (consisting of the core quantized network), + a suffix function (consisting of output dequantization), + and a main function (that calls the prefix, middle, and suffix functions in succession). + If set to 'fully_integral' and there are unquantized operators in the result, + an exception is raised. + The default value is 'disabled'. + Returns ------- config: QConfig @@ -359,4 +372,11 @@ def quantize(mod, params=None, dataset=None): with quantize_context(): mod = quantize_seq(mod) + q_cfg = current_qconfig() + assert q_cfg.partition_conversions in ['disabled', 'enabled', 'fully_integral'] + if q_cfg.partition_conversions != 'disabled': + quantized_dtypes = {q_cfg.dtype_input, q_cfg.dtype_weight, q_cfg.dtype_activation} + ensure_fully_integral = q_cfg.partition_conversions == 'fully_integral' + return partition_conversions(mod, quantized_dtypes, ensure_fully_integral) + return mod diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index cd8dc8dcd3096..86ff805b1d34d 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -115,7 +115,7 @@ def let(self, var, value): The variable or name of variable. value: tvm.relay.Expr - The value to be binded + The value to be bound """ if isinstance(var, (tuple, list)): if len(var) > 2: diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index b1c512478072f..b98106a091b31 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -307,6 +307,35 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TV } }); +class DtypeCollector : protected ExprVisitor, protected TypeVisitor { + public: + void VisitExpr(const Expr& expr) final { + if (expr->checked_type_.defined()) { + TypeVisitor::VisitType(expr->checked_type()); + } + ExprVisitor::VisitExpr(expr); + } + + void VisitType_(const TensorTypeNode* op) final { dtypes_.insert(DLDataType2String(op->dtype)); } + + Array All(const Expr& expr) { + VisitExpr(expr); + + Array res; + for (const auto& dtype : dtypes_) { + res.push_back(String(dtype)); + } + return res; + } + + private: + std::unordered_set dtypes_; +}; + +tvm::Array AllDtypes(const Expr& expr) { return DtypeCollector().All(expr); } + +TVM_REGISTER_GLOBAL("relay.analysis.all_dtypes").set_body_typed(AllDtypes); + /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 1bf858b43db02..28fc68ebf0a8a 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -125,7 +125,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; - p->stream << "rounding==" << op->rounding; + p->stream << "rounding==" << op->rounding << ", "; + p->stream << "partition_conversions==" << op->partition_conversions; p->stream << ")"; }); diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 86f8926c98ac7..d5396dea00d18 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -74,6 +74,7 @@ class QConfigNode : public Object { Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; int calibrate_chunk_by = -1; + std::string partition_conversions = "disabled"; void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); @@ -92,6 +93,7 @@ class QConfigNode : public Object { v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("rounding", &rounding); v->Visit("calibrate_chunk_by", &calibrate_chunk_by); + v->Visit("partition_conversions", &partition_conversions); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index da5291f5e927a..bf9452a9db0ab 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.relay import testing from tvm.relay.expr import Call +from tvm.topi.util import get_const_tuple def quantize_and_build(out): @@ -104,9 +105,214 @@ def test_calibrate_memory_bound(): relay.quantize.quantize(mod, params, dataset) +#################################### +# Quant/Dequant Partitioning Tests # +#################################### + +BASE_CFG = { + 'skip_conv_layers': [], + 'skip_dense_layers': False, + 'dtype_input': "int8", + 'dtype_weight': "int8", + 'dtype_activation': "int32", +} + +def gen_rand_tvm(tt, low, high): + if 'int' in tt.dtype: + data_np = np.random.randint(low, high, size=get_const_tuple(tt.shape), dtype=tt.dtype) + elif 'float' in tt.dtype: + data_np = np.random.uniform(low, high, size=get_const_tuple(tt.shape)).astype(tt.dtype) + else: + assert False, 'unknown dtype' + return tvm.nd.array(data_np, ctx=tvm.cpu(0)) + + +def verify_partition_fails(mod, params): + # standard partition should always succeed + with relay.quantize.qconfig(**BASE_CFG, partition_conversions='enabled'): + partitioned_mod = relay.quantize.quantize(mod, params) + + try: + with relay.quantize.qconfig(**BASE_CFG, partition_conversions='fully_integral'): + partitioned_mod = relay.quantize.quantize(mod, params) + raise RuntimeError('partitioning should have failed') + except AssertionError: + pass + + +def verify_partition(mod, params): + with relay.quantize.qconfig(**BASE_CFG, paritition_conversions='disabled'): + unpartitioned_mod = relay.quantize.quantize(mod, params) + assert len(unpartitioned_mod.get_global_vars()) == 1, \ + 'unpartitioned module should only have one function' + with relay.quantize.qconfig(**BASE_CFG, partition_conversions='fully_integral'): + partitioned_mod = relay.quantize.quantize(mod, params) + + # ensure partitioned and unpartitioned results agree + params = [ + gen_rand_tvm(param.type_annotation, 0, 1) + for param in partitioned_mod['main'].params + ] + def _eval_mod(mod): + vm = relay.create_executor('vm', ctx=tvm.cpu(0), target='llvm', mod=mod) + return vm.evaluate()(*params) + partitioned_mod_result = _eval_mod(partitioned_mod) + unpartitioned_mod_result = _eval_mod(unpartitioned_mod) + tvm.testing.assert_allclose( + unpartitioned_mod_result.asnumpy(), partitioned_mod_result.asnumpy()) + + +def test_add_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x: Tensor[(10, 10), float32], + %y: Tensor[(10, 10), float32]) { + add(%x, %y) + } + """) + params = {} + verify_partition_fails(mod, params) + + +def test_conv2d_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + weight_ty = mod['main'].params[1].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_multiple_arg_conversions_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %x2: Tensor[(1, 4, 16, 16), float32], + %w2: Tensor[(4, 4, 3, 3), float32] + ) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + %1 = nn.conv2d(%x2, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + add(%0, %1) + } + """) + + w1_ty = mod['main'].params[1].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1) + } + verify_partition(mod, params) + + +def test_unquantizable_prefix_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x: Tensor[(1, 4, 16, 16), float32], + %b: Tensor[(4), float32], + %w: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + // NOTE bias_add isn't currently quantizable + %0 = nn.bias_add(%x, %b); + nn.conv2d(%0, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + bias_ty = mod['main'].params[1].checked_type + weight_ty = mod['main'].params[2].checked_type + params = { + 'b': gen_rand_tvm(bias_ty, 0, 1), + 'w': gen_rand_tvm(weight_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_core_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x1: Tensor[(1, 4, 16, 16), float32], + %w1: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32], + %w2: Tensor[(4, 4, 3, 3), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x1, %w1, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + %1 = nn.bias_add(%0, %b); + nn.conv2d(%1, %w2, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]) + } + """) + w1_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + w2_ty = mod['main'].params[3].checked_type + params = { + 'w1': gen_rand_tvm(w1_ty, 0, 1), + 'w2': gen_rand_tvm(w2_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + +def test_unquantizable_suffix_partition(): + mod = tvm.parser.parse(""" + #[version = "0.0.5"] + def @main( + %x: Tensor[(1, 4, 16, 16), float32], + %w: Tensor[(4, 4, 3, 3), float32], + %b: Tensor[(4), float32]) -> Tensor[(1, 4, 16, 16), float32] { + %0 = nn.conv2d(%x, %w, + padding=[1, 1, 1, 1], + channels=4, + kernel_size=[3, 3]); + // NOTE bias_add isn't currently quantizable + nn.bias_add(%0, %b) + } + """) + weight_ty = mod['main'].params[1].checked_type + bias_ty = mod['main'].params[2].checked_type + params = { + 'w': gen_rand_tvm(weight_ty, 0, 1), + 'b': gen_rand_tvm(bias_ty, 0, 1) + } + verify_partition_fails(mod, params) + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() test_calibrate_target(False) test_calibrate_target(True) test_calibrate_memory_bound() + + test_add_partition() + test_conv2d_partition() + test_multiple_arg_conversions_partition() + test_unquantizable_prefix_partition() + test_unquantizable_core_partition() + test_unquantizable_suffix_partition()