diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5fe51b4cbda0..e51f1702773b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -22,6 +22,115 @@ from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.expr_functor import ExprMutator +from tvm.ir.transform import Pass + +# pylint: disable=unused-import +from tvm.relay.backend.contrib.ethosu.op import op_attrs +from tvm.relay.backend.contrib.ethosu import op + + +class OptimizeLUTs(ExprMutator): + """A pass to merge an identity operator with a LUT based activation function with + a preceding operator provided that operator can do a table lookup for the activation + in the hardware""" + + def __init__(self): + super().__init__() + self.lut_ops = { + "contrib.ethosu.conv2d": op.ethosu_conv2d, + "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, + "contrib.ethosu.pooling": op.ethosu_pooling, + } + + def create_op_with_lut(self, call): + """Extract the parameters and attributes from the NPU operator and create + a new operator with LUT. + + Parameters + ---------- + call : tvm.relay.expr.Call + The current call node being visited. + + Returns + ------- + tvm.relay.expr.Call + The new operator with LUT. + """ + identity = call + ethosu_op = call.args[0] + lut = identity.args[1] + activation = identity.attrs.activation + + new_attrs = dict(ethosu_op.attrs) + new_attrs["activation"] = activation + + # Assume that LUT is always the last argument + new_args = ethosu_op.args[:-1] + [lut] + assert ethosu_op.op.name in self.lut_ops.keys() + + return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs) + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + """Recursively visit call nodes in the input graph and if an ethosu.identity + operator with LUT is found and the preceding operator has a LUT attribute, create + a new NPU operator. + + Parameters + ---------- + call : tvm.relay.expr.Call + The current call node being visited. + + Returns + ------- + tvm.relay.expr.Call + The input call node in the case the current call node does + not refer to an Op. Else, a new call node with a new operator. + """ + new_call = call + lut_activations = ["TANH", "LUT"] + + if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call): + producer_op = call.args[0] + # Check if the producer can do a LUT operation + if ( + producer_op.op.name in self.lut_ops.keys() + and call.op.name == "contrib.ethosu.identity" + and call.attrs.activation in lut_activations + ): + # Check the producer doesn't already have a LUT + has_lut = producer_op.attrs.activation in lut_activations + if not has_lut: + new_call = self.create_op_with_lut(call) + + new_call = super().visit_call(new_call) + + return new_call + + +@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") +class LUTsOptimizer(Pass): + """Register LUTsOptimizer as a relay pass.""" + + def transform_function( + self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ + ) -> tvm.IRModule: + """Visit relay nodes in the given module. + + Parameters + ---------- + func : tvm.relay.function.Function + The function to apply the optimization pass for multiple LUTs to. + mod : tvm.IRModule + The module to apply the optimization pass for multiple LUTs to. + + Returns + ------- + mod : tvm.IRModule + New module with optimized LUTs. + """ + assert len(mod.functions.items()) == 1, "Module can only contain one function." + return OptimizeLUTs().visit(func) @tvm._ffi.register_func("relay.ext.ethos-u") @@ -74,6 +183,7 @@ def _compile(ext_func): mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) + mod = LUTsOptimizer()(mod) mod = relay.transform.InferType()(mod) # We are currently using copy_constants scheduler In the long run, # this should be a single intelligent and a composite scheduler diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8f2dddbf88a6..5613d613f984 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" from typing import List, Type +import math import numpy as np # type: ignore @@ -31,6 +32,7 @@ from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore @@ -123,6 +125,75 @@ def __call__(self, *args, **kwargs): pass +def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): + """Method to calculate the values of the tanh lookup table""" + lut_values = list() + # Only int8 is currently supported + dtype = np.int8 + qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max + for x in range(qmin, qmax + 1): + x_real = ifm_scale * (x - ifm_zp) + out_real = math.tanh(x_real) + lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale)) + lut_result = min(qmax, max(qmin, lut_result)) + lut_values.append(lut_result) + + return lut_values + + +class TanhRewriter(DFPatternCallback): + """This pass adds tanh as a LUT to the identity operator""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name}) + )(wildcard()) + + def callback(self, pre, post, node_map): + id_input = post.args[0] + + quantize_args = post.op.body.args + output_scale = float(quantize_args[1].data.asnumpy()) + output_zp = int(quantize_args[2].data.asnumpy()) + + dequantize_args = quantize_args[0].args[0].args + input_scale = float(dequantize_args[1].data.asnumpy()) + input_zp = int(dequantize_args[2].data.asnumpy()) + + lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp) + lut = relay.const(lut_values, dtype="uint8") + + # We baked the requantization into the LUT, so we don't requantize the identity operator + identity = ethosu_ops.ethosu_identity( + ifm=id_input, + lut=lut, + ifm_scale=input_scale, + ifm_zero_point=input_zp, + ofm_scale=input_scale, + ofm_zero_point=input_zp, + activation="TANH", + ) + + return identity + + +@ir.transform.module_pass(opt_level=1) +class LegalizeTanh: + """This is the pass that wraps TanhRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(TanhRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" @@ -915,6 +986,7 @@ def transform_module( mod = LegalizeMax()(mod) mod = LegalizeShl()(mod) mod = LegalizeAbs()(mod) + mod = LegalizeTanh()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py new file mode 100644 index 000000000000..a52736fe3964 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Arm(R) Ethos(TM)-U NPU Relay operators.""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relay.attrs.EthosuConv2DAttrs") +class EthosuConv2DAttrs(Attrs): + """Attributes for contrib.ethosu.conv2d.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuIdentityAttrs") +class EthosuIdentityAttrs(Attrs): + """Attributes for contrib.ethosu.identity.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuDepthwiseConv2DAttrs") +class EthosuDepthwiseConv2DAttrs(Attrs): + """Attributes for contrib.ethosu.depthwise_conv2d.""" + + +@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs") +class EthosuPooling2DAttrs(Attrs): + """Attributes for contrib.ethosu.pooling.""" diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 26785649457c..242c6feaa195 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -140,6 +140,13 @@ def conv2d_compute( "dilation_w": dilation_w, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + conv2d_attrs["lut"] = lut + conv = te.compute( (1, ofm_height, ofm_width, ofm_channels), lambda nn, hh, ww, cc: te.sum( @@ -148,7 +155,7 @@ def conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, rc].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype), axis=[rh, rw, rc], ), name="ethosu_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index 664a3f489fb5..05b2993f5857 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -136,6 +136,13 @@ def depthwise_conv2d_compute( "dilation_w": dilation_w, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + depthwise_conv2d_attrs["lut"] = lut + depthwise = te.compute( (1, ofm_height, ofm_width, channels), lambda nn, hh, ww, cc: te.sum( @@ -144,7 +151,7 @@ def depthwise_conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, 0].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), + + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype), axis=[rh, rw], ), name="ethosu_depthwise_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py index f26179422b4b..574fc661599f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py @@ -58,14 +58,21 @@ def identity_compute( The Output Feature Map tensor. """ - dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale) + id_attrs = {"op": "ethosu_identity", "activation": activation} + + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + id_attrs["lut"] = lut identity = te.compute( ifm.shape, - lambda *i: dmaed_ifm(*i).astype(ifm.dtype), + lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype), name="ethosu_identity", - attrs={"op": "ethosu_identity", "activation": activation}, + attrs=id_attrs, ) dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index bf35479d7556..2ab0844b1622 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -123,10 +123,19 @@ def pooling_compute( "upscale": upscale, } + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if activation in ("TANH", "LUT"): + pooling_attrs["lut"] = lut + pooling = te.compute( (1, ofm_height, ofm_width, ofm_channels), lambda nn, hh, ww, cc: te.max( - dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype), + (dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype( + ifm.dtype + ), axis=[rh, rw], ), name="ethosu_pooling", diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 5e8ea002783f..254f92a30c32 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -53,7 +53,7 @@ def get_conv2d_params(stmt, producers, consumers): rh = inner rw = rh.body rc = rw.body - # loads = [output, input, weights, scale_bias, scale_bias] + # loads = [output, input, weights, scale_bias, scale_bias, LUT, LUT] loads = get_loads(rc.body) # stores = [output] stores = get_stores(rc.body) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 7a81a702f019..23fc31efbfac 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -19,7 +19,7 @@ from typing import Dict, Tuple import tvm from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap -from .utils import get_op_attrs, get_base_address, get_strides +from .utils import get_op_attrs, get_base_address, get_strides, get_loads def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]: @@ -123,7 +123,10 @@ def get_identity_params( while hasattr(stmt, "body"): stmt = stmt.body - input_pointer = stmt.value.buffer_var + # loads = [input, LUT, LUT] + loads = get_loads(stmt) + + input_pointer = loads[0].buffer_var output_pointer = stmt.buffer_var read = producers[input_pointer] diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 33dcb36fbbb6..b19ec034e7d4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -18,7 +18,7 @@ """Extract information from the pooling operators in TIR.""" from typing import Dict, Tuple import tvm -from .utils import get_outer_loops, get_op_attrs +from .utils import get_outer_loops, get_op_attrs, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialActivation, SerialPooling @@ -55,9 +55,12 @@ def get_pooling_params( _, _, _, _, _, inner = get_outer_loops(body, "NHWC") rh = inner rw = rh.body - compute = rw.body.value.b - input_pointer = compute.buffer_var - output_pointer = rw.body.buffer_var + # loads = [output, input, LUT, LUT] + loads = get_loads(rw.body) + # stores = [output] + stores = get_stores(rw.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 7f892d0c602a..e4dcfcd670aa 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -42,6 +42,8 @@ def schedule(cached_func, const_dict, cascader=None): if cascader: cascader(cached_func, const_dict, s) inline_no_ops(cached_func, s) + copy_luts()(cached_func, const_dict, s) + inline_no_ops(cached_func, s) schedule_pragmas(s) schedule_cache_reads(s) return s @@ -129,20 +131,54 @@ def copy_constants(): def _planner(cached_func, const_dict, sch): planned = set() # type: ignore - def _visit(tensor, reader): + def _visit(tensor, reader, lut): if tensor is not planned: planned.add(tensor) - if isinstance(tensor.op, tvm.te.PlaceholderOp): + if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut: index = list(cached_func.inputs).index(tensor) if index in const_dict: sch.cache_read(tensor, "global", [reader]) elif isinstance(tensor.op, tvm.te.ComputeOp): + if "lut" in tensor.op.attrs.keys(): + lut = tensor.op.attrs["lut"] for input_tensor in tensor.op.input_tensors: - _visit(input_tensor, tensor) + _visit(input_tensor, tensor, lut) for output_tensor in cached_func.outputs: - _visit(output_tensor, None) + _visit(output_tensor, None, None) + + return _planner + + +def copy_luts(): + """A scheduler that copies LUTs to SHRAM. + + Returns + ------- + planner : callable + The planning function. + """ + + def _planner(te_graph, const_dict, sch): + planned = set() # type: ignore + + def _visit(tensor, reader, lut): + if tensor is not planned: + planned.add(tensor) + if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor == lut: + index = list(te_graph.inputs).index(tensor) + if index in const_dict: + sch.cache_read(tensor, "local", [reader]) + + elif isinstance(tensor.op, tvm.te.ComputeOp): + if "lut" in tensor.op.attrs.keys(): + lut = tensor.op.attrs["lut"] + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, tensor, lut) + + for output_tensor in te_graph.outputs: + _visit(output_tensor, None, None) return _planner @@ -165,7 +201,7 @@ def _add_pragmas(stage, ax): if "op" in [attr for attr, val in stage.op.attrs.items()]: stage.pragma(ax, "op", stage.op.attrs["op"]) for attr, val in stage.op.attrs.items(): - if attr != "op": + if attr not in ("op", "lut"): stage.pragma(ax, str(attr), val) for stage in sch.stages: diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 4e84febe5e48..e1af7f1534e2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -39,6 +39,7 @@ class BufferType(Enum): scratch = auto() input = auto() output = auto() + shram = auto() _REGION_MAP = { @@ -46,6 +47,7 @@ class BufferType(Enum): BufferType.scratch: 1, BufferType.input: 3, BufferType.output: 4, + BufferType.shram: int((1 << 8) | (3 << 0)), } @@ -59,6 +61,25 @@ class BufferInfo(NamedTuple): btype: BufferType +class AcceleratorArchConfig: + def __init__(self, total_shram_banks): + self.shram_bank_size = 1024 + self.total_shram_banks = total_shram_banks + self.shram_size_bytes = self.shram_bank_size * self.total_shram_banks + self.lut_size_bytes = 2048 + self.lut_start_address = self.shram_size_bytes - self.lut_size_bytes + + +def get_accelerator_arch_config(accel_type): + accel_config_str_map = { + "ethos-u55-256": AcceleratorArchConfig(48), + "ethos-u55-128": AcceleratorArchConfig(24), + "ethos-u55-64": AcceleratorArchConfig(16), + "ethos-u55-32": AcceleratorArchConfig(16), + } + return accel_config_str_map[accel_type] + + def translate(tir_module, params): """This will take an tir module for the NPU and compile to command stream @@ -168,11 +189,20 @@ def extract_buffer_info( def populate_allocate_buffer_info(stmt): if isinstance(stmt, tvm.tir.stmt.Allocate): allocate = stmt + if "placeholder" in allocate.buffer_var.name: + storage_scope = allocate.buffer_var.name.split(".")[-1] + else: + storage_scope = "global" + + if storage_scope == "local": + buffer_type = BufferType.shram + else: + buffer_type = BufferType.scratch buffer_info[allocate.buffer_var] = BufferInfo( None, allocate.extents, allocate.dtype, - BufferType.scratch, + buffer_type, ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) @@ -279,6 +309,11 @@ def classify_io(buffer): assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) + elif info.btype == BufferType.shram: + accl_config = util.get_accelerator_config() + arch_config = get_accelerator_arch_config(accl_config) + address = arch_config.lut_start_address + buffer_addresses[_buffer] = (address, info.btype) else: assert info.btype == BufferType.scratch address = scratch_size @@ -597,14 +632,18 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu return None op_map = { "CLIP": vapi.NpuActivationOp.NONE_OR_RELU, - "TANH": vapi.NpuActivationOp.TANH, - "SIGMOID": vapi.NpuActivationOp.SIGMOID, + "TANH": vapi.NpuActivationOp.TABLE_LOOKUP, + "SIGMOID": vapi.NpuActivationOp.TABLE_LOOKUP, + "LUT": vapi.NpuActivationOp.TABLE_LOOKUP, } op = str(serial_activation.op.value) assert op in op_map.keys() act_op = vapi.NpuActivation(op_map[op]) - act_op.min = int(serial_activation.clip_min) - act_op.max = int(serial_activation.clip_max) + if serial_activation.op == "CLIP": + act_op.min = int(serial_activation.clip_min.value) + act_op.max = int(serial_activation.clip_max.value) + if op_map[op] == vapi.NpuActivationOp.TABLE_LOOKUP: + act_op.lookup_table_index = 0 return act_op diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 73de3329c45f..73d94e8ca3bd 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -915,6 +915,35 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern +class TanhParams: + """ + This class will parse a call to a ethos-u.tanh composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.tanh" + + def __init__(self, func_body: Call): + self.ofm = TensorParams(func_body) + self.ifm = TensorParams(func_body.args[0].args[0].args[0]) + + def is_valid(self): + """ + This function checks whether reshape has compatible attributes with the NPU + """ + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): + return False + return True + + +def tanh_pattern(): + """Create pattern for tanh""" + dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + tanh = is_op("tanh")(dequant) + quant = is_op("qnn.quantize")(tanh, is_constant(), is_constant()) + return quant + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -983,6 +1012,7 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal abs_pattern(), lambda pat: AbsParams(pat).is_valid(), ), + (TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()), ] diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 5f339267e0b8..7842de5d9ac9 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -411,6 +411,7 @@ def make_ethosu_conv2d( padding, strides, dilation, + lut=relay.const([], dtype="int8"), activation="NONE", ifm_layout="NHWC", ofm_layout="NHWC", @@ -430,7 +431,7 @@ def make_ethosu_conv2d( ifm, weight, scale_bias, - lut=relay.const([], dtype="int8"), + lut=lut, ifm_scale=0.5, ifm_zero_point=10, weight_zero_point=12, diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b6cf873cb6f3..e20ab41cb576 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1003,5 +1003,72 @@ def clz_comp(n): infra.verify_source(compiled_model, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_tflite_tanh(accel_type): + dtype = "int8" + ifm_shape = [1, 115, 32, 7] + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tanh_function(self, x): + op = tf.nn.tanh(x) + return op + + model = Model() + concrete_func = model.tanh_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index dbe11cd2d7ad..64bdae5c1b8b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -541,7 +541,6 @@ def verify(ext_func): lambda pat: ethosu.AvgPool2DParams(pat).is_valid(), ), ] - tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) @@ -1007,5 +1006,58 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_tflite_tanh_legalize(): + dtype = "int8" + ifm_shape = (1, 241, 132, 7) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tanh_func(self, x): + op = tf.math.tanh(x) + return op + + model = Model() + concrete_func = model.tanh_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod = ethosu.partition_for_ethosu(mod, params) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.TanhRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod = relay.transform.InferType()(mod) + + func_body = mod["tvmgen_default_ethos_u_main_0"].body + assert func_body.op.name == "contrib.ethosu.identity" + assert func_body.attrs.activation == "TANH" + assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape) + assert tuple(func_body.args[1].checked_type.shape) == (256,) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py new file mode 100644 index 000000000000..d32b441fd2eb --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import numpy as np +import tflite.Model + +import tvm +import tensorflow as tf +from tvm import relay +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tvm.relay.build_module import bind_params_by_name # type: ignore + +from . import infra + + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_tflite_lut_activations(accel_type): + + dtype = "int8" + ifm_shape = (1, 55, 55, 3) + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_func(self, x): + weight_shape = (3, 3, ifm_shape[3], 4) + weight = tf.constant( + np.random.uniform(low=0, high=0.3, size=weight_shape), dtype=tf.float32 + ) + # The input strides to the TensorFlow API needs to be of shape 1x4 + op = tf.nn.conv2d(x, weight, strides=(1, 2, 2, 1), padding="SAME", dilations=(1, 1)) + op = tf.nn.tanh(op) + op = tf.nn.tanh(op) + + weight_shape2 = (2, 3, 4, 1) + weight2 = tf.constant( + np.random.uniform(low=0, high=0.3, size=weight_shape2), dtype=tf.float32 + ) + op = tf.nn.depthwise_conv2d( + op, weight2, strides=(1, 1, 1, 1), padding="VALID", dilations=(2, 2) + ) + op = tf.nn.tanh(op) + op = tf.nn.max_pool(op, (1, 1), strides=(1, 1, 1, 1), padding="SAME") + op = tf.nn.tanh(op) + return op + + model = Model() + concrete_func = model.tf_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = 0.7 * np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_random_lut(accel_type): + + dtype = "int8" + ifm_shape = (1, 55, 55, 3) + + lut_data = np.random.randint(-128, high=127, size=[256]) + lut_data_map = {idx: lut_data[idx + 128] for idx in range(-128, 128)} + + in_data = np.random.randint(-128, high=127, size=ifm_shape, dtype=dtype) + out_data = np.array([lut_data_map[i] for i in in_data.ravel()]).reshape(ifm_shape).astype(dtype) + + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm0 = relay.var("ifm0", shape=ifm_shape, dtype=dtype) + lut1 = relay.var("lut1", shape=(256,), dtype="uint8") + + identity = infra.make_ethosu_identity(ifm0, lut=lut1, activation="LUT") + glb_ethosu = relay.GlobalVar("tvmgen_default_ethos_u_main_0") + + func = ( + relay.Function([ifm0, lut1], identity) + .with_attr("Inline", 1) + .with_attr("Compiler", "ethos-u") + .with_attr("global_symbol", "tvmgen_default_ethos_u_main_0") + .with_attr("Primitive", 1) + ) + + params = {"lut1": tvm.nd.array(lut_data.astype("uint8"))} + func = bind_params_by_name(func, params) + + mod = tvm.IRModule() + mod[glb_ethosu] = func + mod = relay.transform.InferType()(mod) + + call = relay.Call(glb_ethosu, [ifm]) + mod["main"] = relay.Function([ifm], call) + mod = relay.transform.InferType()(mod) + + compiled_models = infra.build_source( + mod, + {"ifm": in_data}, + out_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py new file mode 100644 index 000000000000..8b406d15cfc7 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the pass that removes unnecssary identity operation if the identity +uses LUT and the preceding operator is LUT capable and doesn't already have a LUT. +""" +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer +from . import infra + + +def test_merge_lut_into_conv(): + """If an operator that has a LUT attribute is followed by an identity operator + with LUT, we can merge the two operataors.""" + + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + lut1 = relay.const([i for i in range(256)], dtype="int8") + lut2 = relay.const([i for i in reversed(range(256))], dtype="int8") + + def before(): + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH") + conv2 = infra.make_ethosu_conv2d(id1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1)) + id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + conv1 = infra.make_ethosu_conv2d( + ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH" + ) + conv2 = infra.make_ethosu_conv2d( + conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="TANH" + ) + + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod + + mod = LUTsOptimizer()(before()) + + assert tvm.ir.structural_equal(mod, after()) + + +def test_multiple_luts(): + """Test that when an operation already has a LUT, we don't overwrite that LUT""" + + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + lut1 = relay.const([i for i in range(256)], dtype="int8") + lut2 = relay.const([i for i in reversed(range(256))], dtype="int8") + + def before(): + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH") + id2 = infra.make_ethosu_identity(id1, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + conv1 = infra.make_ethosu_conv2d( + ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH" + ) + id2 = infra.make_ethosu_identity(conv1, lut=lut2, activation="TANH") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod + + mod = LUTsOptimizer()(before()) + + assert tvm.ir.structural_equal(mod, after()) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 7992f421a5bd..1d3afec30cbc 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize( "trial", [ - [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC", "TFL"], + [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TRUNCATE"], [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "TFL"], @@ -124,12 +124,10 @@ def _get_func( padding, strides, dilation, - activation, - ifm_layout, - ofm_layout, - "int8", - "uint8", - rounding_mode, + activation=activation, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + rounding_mode=rounding_mode, ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) @@ -409,9 +407,9 @@ def _get_func( padding, strides, dilation, - "NONE", - layout, - layout, + activation="NONE", + ifm_layout=layout, + ofm_layout=layout, ) conv2 = make_ethosu_conv2d( conv1, @@ -421,9 +419,9 @@ def _get_func( padding, strides, dilation, - "NONE", - layout, - layout, + activation="NONE", + ifm_layout=layout, + ofm_layout=layout, ) func = relay.Function(relay.analysis.free_vars(conv2), conv2) func = run_opt_pass(func, relay.transform.InferType()) @@ -577,7 +575,15 @@ def _get_func(ifm_shape, reshaped, ifm_layout): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") ifm_reshaped = relay.reshape(ifm, reshaped) conv = make_ethosu_conv2d( - ifm_reshaped, reshaped[3], 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", ifm_layout + ifm_reshaped, + reshaped[3], + 16, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + activation="NONE", + ifm_layout=ifm_layout, ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) @@ -598,7 +604,9 @@ def test_conv2d_big_pad(): def _get_func(): ifm_shape = (1, 2, 2, 8) ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") - conv = make_ethosu_conv2d(ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), "NHWC") + conv = make_ethosu_conv2d( + ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), ifm_layout="NHWC" + ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) return func diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py index cf2ac147759c..afd632cf355e 100644 --- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py @@ -30,7 +30,7 @@ "trial", [ [(1, 8, 8, 3), 3, (3, 2), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], - [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC", "NATURAL"], + [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"], [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "TRUNCATE"], [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"], [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "NATURAL"], diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index b04059011e8e..cd84449c4a1b 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -27,9 +27,10 @@ total_cascader, copy_constants, schedule_cache_reads, + copy_luts, ) from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants -from .infra import AttachType, make_ethosu_conv2d +from .infra import AttachType, make_ethosu_conv2d, make_ethosu_identity class TestTEGraph: @@ -126,6 +127,31 @@ def test_copy_constants(): assert ".global" in sch.stages[17].op.name +# This test makes sure that constants and LUTs have a correct storage scope +def test_copy_luts(): + ifm_shape = (1, 33, 33, 11) + ifm = relay.var("IFM", shape=ifm_shape, dtype="int8") + lut = relay.const([i for i in range(256)], dtype="int8") + conv = make_ethosu_conv2d( + ifm, ifm_shape[3], 8, (3, 3), (0, 0), (1, 1), (1, 1), lut=lut, activation="TANH" + ) + identity = make_ethosu_identity(conv, lut=lut, activation="TANH") + func = relay.Function(relay.analysis.free_vars(identity), identity) + func = run_opt_pass(func, relay.transform.InferType()) + + func, const_dict = extract_constants(func) + te_graph = lower_to_te(func) + + sch = te.create_schedule([te_graph.outputs[0].op]) + copy_constants()(te_graph, const_dict, sch) + copy_luts()(te_graph, const_dict, sch) + assert len(sch.stages) == 17 + assert ".global" in sch.stages[5].op.name + assert ".global" in sch.stages[7].op.name + assert ".local" in sch.stages[9].op.name + assert ".local" in sch.stages[10].op.name + + def test_schedule_cache_reads(): a = te.placeholder((12, 12), dtype="uint8", name="a") b = te.placeholder((12, 12), dtype="uint8", name="b")