diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 0884b249df488..7666691aa19f4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -18,13 +18,13 @@ import tvm from tvm import relay +from tvm import ir from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator from tvm.relay.backend.contrib.ethosu import util from tvm.relay.expr_functor import ExprMutator -from tvm.ir.transform import Pass # pylint: disable=unused-import from tvm.relay.backend.contrib.ethosu.op import op_attrs @@ -109,13 +109,11 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") -class LUTsOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer") +class LUTsOptimizer: """Register LUTsOptimizer as a relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """Visit relay nodes in the given module. Parameters @@ -131,7 +129,13 @@ def transform_function( New module with optimized LUTs. """ assert len(mod.functions.items()) == 1, "Module can only contain one function." - return OptimizeLUTs().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = OptimizeLUTs().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass class LayoutOptimization(ExprMutator): @@ -247,19 +251,23 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return super().visit_call(call) -@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer") -class LayoutOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer") +class LayoutOptimizer: """Register LayoutOptimizer as a Relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """A pass to optimize the layout of NPU operations. If both the producer and consumer of a tensor are NPU operators, then the layout is converted from NHWC to NHCWB16 as this is the layout NPU uses internally.""" assert len(mod.functions.items()) == 1, "Module can only contain one function." - return LayoutOptimization().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = LayoutOptimization().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index aafae1497ea40..62a1fabe0b98f 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -33,14 +33,17 @@ from tvm import relay from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func from . import infra -def _run_pass(expr, relay_pass): - """Create IRModule and run Relay pass.""" +def _optimize(expr, optimize=True): + """Create IRModule and run layout optimizer pass.""" mod = tvm.IRModule.from_expr(expr) - mod = relay_pass(mod) + mod = relay.transform.InferType()(mod) + if optimize: + mod = LayoutOptimizer()(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -111,8 +114,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -144,8 +147,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -176,8 +179,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -222,8 +225,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -268,8 +271,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -322,8 +325,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(pool_3), pool_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -368,8 +371,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv), conv) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -413,8 +416,8 @@ def get_graph(get_expected=False): concat = relay.concatenate(poolings, axis=0) return relay.Function(relay.analysis.free_vars(concat), concat) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -467,8 +470,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(add_3), add_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -500,8 +503,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -530,8 +533,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -619,5 +622,32 @@ def representative_dataset(): _compile_and_compare_model(create_model(), ifm_shape, dtype) +def test_layout_optimizer_runs_in_compilation_pipeline(): + """Checks that the layout optimization pass runs as part of the NPU compilation + pipeline.""" + + def get_graph(): + x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8") + for _ in range(2): + x = relay.nn.max_pool2d(x, layout="NHWC") + + func = relay.Function(relay.analysis.free_vars(x), x) + return tvm.IRModule.from_expr(func) + + mod = get_graph() + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the layout optimization pass has ran + ops = prim_func.body.body.seq + max_pool1, max_pool2 = ops + + assert str(max_pool1.value.args[31]) == '"NHCWB16"' + assert str(max_pool2.value.args[14]) == '"NHCWB16"' + + if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index 16835ce94ed77..d9a543c1a7716 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -21,9 +21,16 @@ pytest.importorskip("ethosu.vela") +import tensorflow as tf +import numpy as np + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.op.contrib.ethosu import partition_for_ethosu + +from .test_codegen import _get_tflite_graph from . import infra @@ -59,6 +66,7 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) @@ -91,5 +99,35 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) + + +def test_lut_optimizer_runs_in_compilation_pipeline(): + """Test that the LUT optimization pass runs as part of the NPU compilation pipeline.""" + ifm_shape = (1, 4, 4, 4) + + @tf.function + def get_graph(x): + weight1 = tf.constant(np.random.uniform(size=(1, 1, 4, 4)), dtype=tf.float32) + op = tf.nn.conv2d(x, weight1, (1, 1), "VALID") + op = tf.nn.tanh(op) + weight2 = tf.constant(np.random.uniform(size=(1, 1, 4, 1)), dtype=tf.float32) + op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID") + return tf.nn.tanh(op) + + mod, _ = _get_tflite_graph(get_graph, [ifm_shape]) + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the LUT optimization pass has ran. + # If the module was optimized, there should be no identity operations. + def check_identity(stmt): + if isinstance(stmt, tvm.tir.expr.Call): + assert stmt.args[0] != "ethosu_identity" + + tvm.tir.stmt_functor.post_order_visit(prim_func.body, check_identity)