From 304a225d9c4327206bce5e9abf30e4b1d92efa30 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 26 Dec 2022 17:18:16 +0300 Subject: [PATCH] [QNN] Change in Pass Context for lookup table calculation Motivation: It is possible to disable specific passes through the "disabled_pass" parameter in the Pass Context. These "disabled" passes can be optional for one target and mandatory for another one. Since lookup table for some QNN operations (tanh, round and etc.) is calculated on the host and some of disabled passes can be required for the host, no need to disable these passes. This constant calculation/ evaluation is orthogonal to the compilation process for specific target. What was changed: This commit creates its own compilation Pass Context for lookup table calculation and evaluation (for elemwise QNN ops: tanh, sqrt ...). --- python/tvm/relay/qnn/op/canonicalizations.py | 23 ++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/qnn/op/canonicalizations.py b/python/tvm/relay/qnn/op/canonicalizations.py index 1f2c57c6da34..6bfcd34aba90 100644 --- a/python/tvm/relay/qnn/op/canonicalizations.py +++ b/python/tvm/relay/qnn/op/canonicalizations.py @@ -23,10 +23,25 @@ def run_const_expr(expr: "relay.Expr") -> np.ndarray: - """Evaluate a const expression, receiving result as np array.""" - mod = tvm.IRModule.from_expr(expr) - vm_exe = relay.create_executor("vm", mod=mod) - return vm_exe.evaluate()().asnumpy() + """Evaluate a const expression, receiving result as np array. + + If a number of passes are disabled in the current Pass Context, then there is no need to disable + these passes for const expression evaluation as well. That's why we use empty list + "disabled_pass=[]", all other arguments are inherited from the current Pass Context. + """ + curr_pass_ctx = tvm.ir.transform.PassContext.current() + with tvm.ir.transform.PassContext( + opt_level=curr_pass_ctx.opt_level, + required_pass=curr_pass_ctx.required_pass, + disabled_pass=[], + instruments=curr_pass_ctx.instruments, + config=curr_pass_ctx.config, + ): + mod = tvm.IRModule.from_expr(expr) + vm_exe = relay.create_executor("vm", mod=mod) + output = vm_exe.evaluate()().asnumpy() + + return output def create_integer_lookup_table(