diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 3b412cb646ca..002cb4b6be9b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -134,6 +134,134 @@ def transform_function( return OptimizeLUTs().visit(func) +class LayoutOptimization(ExprMutator): + """A pass to optimize the layout of NPU operations. If both the + producer and consumer of a tensor are NPU operators, then the + layout is converted from NHWC to NHCWB16. + + Attributes + ---------- + children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]] + A map from current call to a list of calls that rely on the current + call. This allows the graph to be traversed backwards, which is useful + for checking whether the output layouts can be rewritten. + optimize_op : Dict[str, Callable] + A map from NPU op name to function that creates NPU op. + """ + + def __init__(self): + self.children = {} + self.optimize_op = { + "contrib.ethosu.conv2d": op.ethosu_conv2d, + "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, + "contrib.ethosu.pooling": op.ethosu_pooling, + "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise, + "contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise, + } + + super().__init__() + + def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + """Alter the input and output layouts of an NPU operation if needed. + Input layout is only altered if the producing operation is an NPU + operation. Likewise, the output layout is only altered if the consuming + operation is an NPU operation. + + Parameters + ---------- + call : tvm.relay.expr.Call + The call pointing to an NPU operation that will be checked if + the layout needs altering. + + Returns + ------- + new_call : tvm.relay.expr.Call + New call with altered layouts. + """ + assert isinstance(call.attrs, tvm.ir.Attrs), ( + f"The attributes for operator '{call.op.name}' could not be " + "found. Did you register the relay.attrs.EthosuAttrs " + "object in python api?" + ) + + new_attrs = dict(call.attrs) + parents = [] + + # Check if we can rewrite the input layouts + input_count = 0 + for arg in call.args: + input_count += 1 + if not isinstance(arg, tvm.relay.expr.Call): + continue + if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op: + layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout" + new_attrs[layout_string] = "NHCWB16" + parents.append(arg) + + # Check if we can rewrite the output layouts + if call in self.children: + children = self.children[call] + if all( + isinstance(child, tvm.relay.expr.Call) + and isinstance(child.op, tvm.ir.op.Op) + and child.op.name in self.optimize_op + and child.attrs["ifm_layout"] == "NHCWB16" + for child in children + ): + new_attrs["ofm_layout"] = "NHCWB16" + + name = call.op.name + assert name in self.optimize_op, ( + f"Could not create operator '{name}' as the creation function " + "is unknown. Please provide a mapping." + ) + new_call = self.optimize_op[name](*call.args, **new_attrs) + + # Update map of children + for input_arg in parents: + if input_arg in self.children: + self.children[input_arg].append(new_call) + else: + self.children[input_arg] = [new_call] + + return super().visit_call(new_call) + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + """Recursively visit call nodes in the input graph and alter the + layout of an op if needed. + + Parameters + ---------- + call : tvm.relay.expr.Call + The current call node being visited. + + Returns + ------- + tvm.relay.expr.Call + The input call node in the case the current call node does + not refer to an Op. Else, a new call node with altered Op + attributes. + """ + if isinstance(call.op, tvm.ir.op.Op) and call.op.name in self.optimize_op: + return self.alter_ethosu_op_layout(call) + return super().visit_call(call) + + +@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer") +class LayoutOptimizer(Pass): + """Register LayoutOptimizer as a Relay pass.""" + + def transform_function( + self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ + ) -> tvm.IRModule: + """A pass to optimize the layout of NPU operations. If both the + producer and consumer of a tensor are NPU operators, then the + layout is converted from NHWC to NHCWB16 as this is the layout NPU + uses internally.""" + assert len(mod.functions.items()) == 1, "Module can only contain one function." + return LayoutOptimization().visit(func) + + @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ diff --git a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py index a52736fe3964..c421788bcacf 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py @@ -37,3 +37,13 @@ class EthosuDepthwiseConv2DAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs") class EthosuPooling2DAttrs(Attrs): """Attributes for contrib.ethosu.pooling.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuBinaryElementwiseAttrs") +class EthosuBinaryElementwiseAttrs(Attrs): + """Attributes for contrib.ethosu.binary_elementwise""" + + +@tvm._ffi.register_object("relay.attrs.EthosuUnaryElementwiseAttrs") +class EthosuUnaryElementwiseAttrs(Attrs): + """Attributes for contrib.ethosu.unary_elementwise""" diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py new file mode 100644 index 000000000000..aafae1497ea4 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -0,0 +1,623 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test the layout optimization pass. This pass is used to +convert subgraphs to the preferred layout of NHCWB16. +""" + +import pytest + +pytest.importorskip("ethosu.vela") + +import sys + +import numpy as np +import tensorflow as tf +import tflite.Model + +import tvm +from tvm import relay +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer + +from . import infra + + +def _run_pass(expr, relay_pass): + """Create IRModule and run Relay pass.""" + mod = tvm.IRModule.from_expr(expr) + mod = relay_pass(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def _assert_structural_equal(a, b): + """Check structural equality of two Relay expressions.""" + reason = ( + "Actual and expected relay functions are not equal. " + "LayoutOptimizer is not correctly converting layouts." + ) + assert tvm.ir.structural_equal(a, b), reason + + +def _compile_and_compare_model(tflite_graph, ifm_shape, dtype): + """Compare running result of compilation against TFLite.""" + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={ + "ifm": ifm_shape, + }, + dtype_dict={ + "ifm": dtype, + }, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + "ethos-u55-256", + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] + + # Verify generated C source + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) + infra.print_payload(cmms) + infra.verify_source(compiled_models, "ethos-u55-256") + + +def test_single_convolution(): + """Test a single convolution to make sure the layouts remain + unaltered. + """ + + def get_graph(): + x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + x = infra.make_ethosu_conv2d( + ifm=x, + ifm_channels=8, + ofm_channels=8, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(x), x) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_convolution(): + """Test layout optimization pass on linear chain of convolutions. I.e, + + conv_1 + | + conv_2 + | + conv_3 + """ + + def get_graph(get_expected=False): + x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + for i in range(3): + ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC" + ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC" + x = infra.make_ethosu_conv2d( + ifm=x, + ifm_channels=8, + ofm_channels=8, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return relay.Function(relay.analysis.free_vars(x), x) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_depthwise_convolution(): + """Test layout optimization pass on multiple depthwise convolutions. + + depthwise_conv_1 + | + depthwise_conv_2 + | + depthwise_conv_3 + """ + + def get_graph(get_expected=False): + x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + for i in range(3): + ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC" + ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC" + x = infra.make_ethosu_depthwise_conv2d( + ifm=x, + channels=4, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return relay.Function(relay.analysis.free_vars(x), x) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_ignore_transform_operations(): + """Test layout optimization pass ignores transform operations + such as reshape and strided slice. + + conv_1 + | + reshape + | + strided_slice + | + conv_2 + """ + + def get_graph(): + in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype="int8") + conv_1 = infra.make_ethosu_conv2d( + ifm=in_1, + ifm_channels=8, + ofm_channels=8, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + reshape = relay.reshape(conv_1, (1, 16, 16, 8)) + strided_slice = relay.strided_slice(reshape, (0, 0, 0, 0), (1, 16, 16, 8)) + conv_2 = infra.make_ethosu_conv2d( + ifm=strided_slice, + ifm_channels=8, + ofm_channels=8, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(conv_2), conv_2) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_ignore_concatenate(): + """Test layout optimization pass ignores the concatenate operation, + when layout transformation cannot occur. + + in_1 in_2 + \ / + \ conv_1 + \ / + concat + | + conv_2 + """ + + def get_graph(): + in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype="int8") + in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype="int8") + conv_1 = infra.make_ethosu_conv2d( + ifm=in_2, + ifm_channels=8, + ofm_channels=8, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + concat = relay.concatenate([in_1, conv_1], axis=1) + conv_2 = infra.make_ethosu_conv2d( + ifm=concat, + ifm_channels=8, + ofm_channels=4, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(conv_2), conv_2) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_ignore_concatnate_with_layout_transform(): + """Test the layout optimization pass ignores the concatenate + operation and performs a layout transformation. + + in_1 in_2 + \ / + pool_1 pool_2 + \ / + concat + | + pool_3 + """ + + def get_graph(): + in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype="int8") + in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype="int8") + pool_1 = infra.make_ethosu_pooling( + in_1, + "MAX", + (1, 1), + ofm_channels=8, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + pool_2 = infra.make_ethosu_pooling( + in_2, + "MAX", + (1, 1), + ofm_channels=8, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + concat = relay.concatenate([pool_1, pool_2], axis=1) + pool_3 = infra.make_ethosu_pooling( + concat, + "MAX", + (1, 1), + ofm_channels=8, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(pool_3), pool_3) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_inputs(): + """Test the layout optimization pass works as expected when there + are multiple inputs in the graph. + + pool_1 pool_2 pool_3 + \ | / + \ | / + concat + | + conv + """ + + def get_graph(): + poolings = [] + for _ in range(3): + inp = relay.var("x", shape=(1, 3, 3, 4), dtype="int8") + pool = infra.make_ethosu_pooling( + inp, + "MAX", + (1, 1), + ofm_channels=4, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + poolings.append(pool) + concat = relay.concatenate(poolings, axis=0) + conv = infra.make_ethosu_conv2d( + ifm=concat, + ifm_channels=8, + ofm_channels=4, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(conv), conv) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_outputs(): + """Test the layout optimization pass works as expected when there + are multiple outputs in the graph. + + pool_1 + / | \ + pool_2 pool_3 pool_4 + \ | / + concat + """ + + def get_graph(get_expected=False): + in_1 = relay.var("x", shape=(1, 4, 4, 8), dtype="int8") + pool_1 = infra.make_ethosu_pooling( + in_1, + "MAX", + (1, 1), + ofm_channels=4, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHWC", + ofm_layout="NHCWB16" if get_expected else "NHWC", + ) + poolings = [] + for _ in range(3): + poolings.append( + infra.make_ethosu_pooling( + pool_1, + "MAX", + (1, 1), + ofm_channels=4, + strides=(1, 1), + padding=(0, 0), + ifm_layout="NHCWB16" if get_expected else "NHWC", + ofm_layout="NHWC", + ) + ) + concat = relay.concatenate(poolings, axis=0) + return relay.Function(relay.analysis.free_vars(concat), concat) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_binary_elementwise(): + """Test the layout optimization pass works as expected for + binary elementwise operations. + + add_1 add_2 + \ / + \ / + add_3 + """ + + def get_graph(get_expected=False): + in_1 = relay.var("x", shape=(1, 2, 2, 2), dtype="int8") + in_2 = relay.var("y", shape=(1, 2, 2, 2), dtype="int8") + in_3 = relay.var("z", shape=(1, 2, 2, 2), dtype="int8") + add_1 = infra.make_ethosu_binary_elementwise( + in_1, + in_2, + ifm_channels=2, + ifm2_channels=2, + operator_type="ADD", + ofm_dtype="int8", + ifm_layout="NHWC", + ifm2_layout="NHWC", + ofm_layout="NHCWB16" if get_expected else "NHWC", + ) + add_2 = infra.make_ethosu_binary_elementwise( + in_2, + in_3, + ifm_channels=2, + ifm2_channels=2, + operator_type="ADD", + ofm_dtype="int8", + ifm_layout="NHWC", + ifm2_layout="NHWC", + ofm_layout="NHCWB16" if get_expected else "NHWC", + ) + add_3 = infra.make_ethosu_binary_elementwise( + add_1, + add_2, + ifm_channels=2, + ifm2_channels=2, + operator_type="ADD", + ofm_dtype="int8", + ifm_layout="NHCWB16" if get_expected else "NHWC", + ifm2_layout="NHCWB16" if get_expected else "NHWC", + ofm_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(add_3), add_3) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_pooling(): + """Test the layout optimization pass works as expected for + multiple pooling operations. + + pool_1 + | + pool_2 + | + pool_3 + """ + + def get_graph(get_expected=False): + x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + for i in range(3): + ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC" + ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC" + x = infra.make_ethosu_pooling( + x, + "MAX", + (1, 1), + ofm_channels=4, + strides=(1, 1), + padding=(0, 0), + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return relay.Function(relay.analysis.free_vars(x), x) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_multiple_unary_elementwise(): + """Test the layout optimization pass works as expected for multiple + unary elementwise operations. + + abs_1 + | + abs_2 + | + abs_3 + """ + + def get_graph(get_expected=False): + x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + for i in range(3): + ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC" + ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC" + x = infra.make_ethosu_unary_elementwise( + x, + ofm_channels=4, + operator_type="ABS", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return relay.Function(relay.analysis.free_vars(x), x) + + a = _run_pass(get_graph(), LayoutOptimizer()) + b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + _assert_structural_equal(a, b) + + +def test_same_output_multiple_convolutions(): + """Test running the layout optimization pass with multiple convolutions + gives same output as TFLite.""" + + np.random.seed(0) + dtype = "int8" + ifm_shape = (1, 8, 8, 32) + kernel_shape = (1, 1, 32, 32) + + def create_model(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + for _ in range(3): + x = tf.nn.conv2d( + x, + filters=tf.constant(np.random.uniform(size=kernel_shape), dtype=tf.float32), + strides=(1, 1), + padding="SAME", + data_format="NHWC", + dilations=1, + ) + return x + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + return converter.convert() + + _compile_and_compare_model(create_model(), ifm_shape, dtype) + + +def test_same_output_multiple_pooling(): + """Test running the layout optimization pass with multiple pooling + operations gives same output as TFLite.""" + + np.random.seed(0) + dtype = "int8" + ifm_shape = (1, 4, 2, 7) + + def create_model(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + for _ in range(2): + x = tf.nn.max_pool2d(x, (1, 1), (1, 1), "SAME", "NHWC") + return x + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + return converter.convert() + + _compile_and_compare_model(create_model(), ifm_shape, dtype) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:])