From cff4568b8a3bf500178cad58c0aea574ea97cc41 Mon Sep 17 00:00:00 2001 From: Alexey Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Thu, 9 Feb 2023 21:34:58 +0400 Subject: [PATCH] [microNPU] Merge LUT activation with binary elementwise operation (#13935) Add binary elementwise operator to OptimizeLUTs pass to merge LUT activation with elementwise operation. --- .../relay/backend/contrib/ethosu/codegen.py | 1 + .../contrib/ethosu/te/binary_elementwise.py | 13 +++++- .../contrib/ethosu/tir/binary_elementwise.py | 27 +++--------- tests/python/contrib/test_ethosu/infra.py | 3 +- .../contrib/test_ethosu/test_codegen.py | 19 +++++++++ .../contrib/test_ethosu/test_lut_optimizer.py | 42 +++++++++++++++++++ 6 files changed, 80 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5119c04edba4..b07b260f1965 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -51,6 +51,7 @@ def __init__(self): "contrib.ethosu.conv2d": op.ethosu_conv2d, "contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d, "contrib.ethosu.pooling": op.ethosu_pooling, + "contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise, } def create_op_with_lut(self, call): diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py index caa6656fa07c..86fdb958fd53 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -178,6 +178,14 @@ def binary_elementwise_compute( } broadcast = [value == 1 for value in dmaed_ifm2.shape] + has_lut = activation in ("TANH", "LUT", "SIGMOID") + # This is a trick to insert the LUT tensor into the TE graph if LUT is present + lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0 + + # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT + if has_lut: + binary_elementwise_attrs["lut"] = lut + if reversed_operands: binary_elementwise = te.compute( (1, ofm_height, ofm_width, ifm_channels), @@ -188,7 +196,7 @@ def binary_elementwise_compute( 0 if broadcast[2] else ww, 0 if broadcast[3] else cc, ).astype(ifm.dtype), - dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype), + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype) + lut_expr, ).astype(ofm_dtype), name="ethosu_binary_elementwise", attrs=binary_elementwise_attrs, @@ -203,7 +211,8 @@ def binary_elementwise_compute( 0 if broadcast[1] else hh, 0 if broadcast[2] else ww, 0 if broadcast[3] else cc, - ).astype(ifm.dtype), + ).astype(ifm.dtype) + + lut_expr, ).astype(ofm_dtype), name="ethosu_binary_elementwise", attrs=binary_elementwise_attrs, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index ad780ab2b90b..91f5512453fb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -18,30 +18,12 @@ """Extract information from the binary_elementwise operators in TIR.""" from typing import Tuple import tvm -from .utils import get_outer_loops, get_op_attrs +from .utils import get_outer_loops, get_op_attrs, get_loads from .dma import get_ifm_params, get_ofm_params from .spec import SerialActivation, SerialBinaryElementwise, SerialRescaleConfig from .producers_consumers import ProducersConsumers -def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var: - """When the datatype of the ifm, ifm2 and ofm do not match, - casts are inserted in TE to handle the difference in these types. - Since TIR is not directly run on the NPU we can simply ignore - these, and allow the NPU to handle the difference in datatypes - itself. - - Parameters - ---------- - tir_load : tvm.tir.expr.Load - - Returns - ------- - tvm.tir.Var - """ - return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load - - def get_binary_elementwise_params( stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers ) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]: @@ -72,9 +54,10 @@ def get_binary_elementwise_params( reversed_operands = attrs["reversed_operands"] _, _, _, _, _, inner = get_outer_loops(body, "NHWC") - op = ignore_cast(inner.value) - input_pointer = ignore_cast(op.a).buffer.data - input_pointer1 = ignore_cast(op.b).buffer.data + # loads = [input, input, LUT, LUT] + loads = get_loads(inner) + input_pointer = loads[0].buffer.data + input_pointer1 = loads[1].buffer.data if reversed_operands: input_pointer, input_pointer1 = input_pointer1, input_pointer diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index abddaf47c169..844d08c66e03 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -694,11 +694,12 @@ def make_ethosu_binary_elementwise( use_rescale: bool = False, rescale_scale: int = 0, rescale_shift: int = 0, + lut=relay.const([], dtype="int8"), ): ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( ifm=ifm, ifm2=ifm2, - lut=relay.const([], dtype="int8"), + lut=lut, operator_type=operator_type, ifm_scale=1, ifm_zero_point=0, diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index bf00f0897476..f07fd0463b5a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1313,5 +1313,24 @@ def fully_connected(x): ) +@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"]) +def test_tflite_subtract_sigmoid(accel_type): + np.random.seed(0) + ifm_shape = [1, 6, 8, 4] + + @tf.function + def subtract_sigmoid_function(lhs, rhs): + op = tf.math.subtract(lhs, rhs) + op = tf.nn.sigmoid(op) + return op + + infra.compare_tvm_with_tflite( + subtract_sigmoid_function, + [ifm_shape, ifm_shape], + accel_type, + enable_cascader=is_u55_accel_type(accel_type), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index 12b6ed70d8ed..dc3dd59a5a93 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -72,6 +72,48 @@ def after(): assert tvm.ir.structural_equal(mod, after()) +def test_merge_lut_into_binary_elementwise(): + """If an binary elementwise operator is followed by an identity operator + with LUT, we can merge the two operataors.""" + + shape = (1, 8, 8, 4) + dtype = "int8" + ifm = relay.var("x", shape=shape, dtype=dtype) + ifm2 = relay.var("x", shape=shape, dtype=dtype) + lut1 = relay.const([i for i in range(256)], dtype=dtype) + lut2 = relay.const([i for i in reversed(range(256))], dtype=dtype) + + def before(): + sub = infra.make_ethosu_binary_elementwise(ifm, ifm2, shape[-1], shape[-1], "SUB", dtype) + id1 = infra.make_ethosu_identity(sub, lut=lut1, activation="TANH") + add = infra.make_ethosu_binary_elementwise(id1, ifm2, shape[-1], shape[-1], "ADD", dtype) + id2 = infra.make_ethosu_identity(add, lut=lut2, activation="SIGMOID") + + func = relay.Function(relay.analysis.free_vars(id2), id2) + func = func.with_attr("Compiler", "ethos-u") + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + sub = infra.make_ethosu_binary_elementwise( + ifm, ifm2, shape[-1], shape[-1], "SUB", dtype, lut=lut1, activation="TANH" + ) + add = infra.make_ethosu_binary_elementwise( + sub, ifm2, shape[-1], shape[-1], "ADD", dtype, lut=lut2, activation="SIGMOID" + ) + + func = relay.Function(relay.analysis.free_vars(add), add) + func = func.with_attr("Compiler", "ethos-u") + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod + + mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) + + assert tvm.ir.structural_equal(mod, after()) + + def test_multiple_luts(): """Test that when an operation already has a LUT, we don't overwrite that LUT"""