Skip to content

Commit

Permalink
[microNPU] Merge LUT activation with binary elementwise operation (#1…
Browse files Browse the repository at this point in the history
…3935)

Add binary elementwise operator to OptimizeLUTs pass to merge LUT activation with elementwise operation.
  • Loading branch information
Aleksei-grovety authored Feb 9, 2023
1 parent 5cf3405 commit cff4568
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 25 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
}

def create_op_with_lut(self, call):
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def binary_elementwise_compute(
}
broadcast = [value == 1 for value in dmaed_ifm2.shape]

has_lut = activation in ("TANH", "LUT", "SIGMOID")
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if has_lut:
binary_elementwise_attrs["lut"] = lut

if reversed_operands:
binary_elementwise = te.compute(
(1, ofm_height, ofm_width, ifm_channels),
Expand All @@ -188,7 +196,7 @@ def binary_elementwise_compute(
0 if broadcast[2] else ww,
0 if broadcast[3] else cc,
).astype(ifm.dtype),
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype) + lut_expr,
).astype(ofm_dtype),
name="ethosu_binary_elementwise",
attrs=binary_elementwise_attrs,
Expand All @@ -203,7 +211,8 @@ def binary_elementwise_compute(
0 if broadcast[1] else hh,
0 if broadcast[2] else ww,
0 if broadcast[3] else cc,
).astype(ifm.dtype),
).astype(ifm.dtype)
+ lut_expr,
).astype(ofm_dtype),
name="ethosu_binary_elementwise",
attrs=binary_elementwise_attrs,
Expand Down
27 changes: 5 additions & 22 deletions python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,12 @@
"""Extract information from the binary_elementwise operators in TIR."""
from typing import Tuple
import tvm
from .utils import get_outer_loops, get_op_attrs
from .utils import get_outer_loops, get_op_attrs, get_loads
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialActivation, SerialBinaryElementwise, SerialRescaleConfig
from .producers_consumers import ProducersConsumers


def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
"""When the datatype of the ifm, ifm2 and ofm do not match,
casts are inserted in TE to handle the difference in these types.
Since TIR is not directly run on the NPU we can simply ignore
these, and allow the NPU to handle the difference in datatypes
itself.
Parameters
----------
tir_load : tvm.tir.expr.Load
Returns
-------
tvm.tir.Var
"""
return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load


def get_binary_elementwise_params(
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
Expand Down Expand Up @@ -72,9 +54,10 @@ def get_binary_elementwise_params(
reversed_operands = attrs["reversed_operands"]

_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
op = ignore_cast(inner.value)
input_pointer = ignore_cast(op.a).buffer.data
input_pointer1 = ignore_cast(op.b).buffer.data
# loads = [input, input, LUT, LUT]
loads = get_loads(inner)
input_pointer = loads[0].buffer.data
input_pointer1 = loads[1].buffer.data

if reversed_operands:
input_pointer, input_pointer1 = input_pointer1, input_pointer
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,12 @@ def make_ethosu_binary_elementwise(
use_rescale: bool = False,
rescale_scale: int = 0,
rescale_shift: int = 0,
lut=relay.const([], dtype="int8"),
):
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
ifm2=ifm2,
lut=relay.const([], dtype="int8"),
lut=lut,
operator_type=operator_type,
ifm_scale=1,
ifm_zero_point=0,
Expand Down
19 changes: 19 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,5 +1313,24 @@ def fully_connected(x):
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
def test_tflite_subtract_sigmoid(accel_type):
np.random.seed(0)
ifm_shape = [1, 6, 8, 4]

@tf.function
def subtract_sigmoid_function(lhs, rhs):
op = tf.math.subtract(lhs, rhs)
op = tf.nn.sigmoid(op)
return op

infra.compare_tvm_with_tflite(
subtract_sigmoid_function,
[ifm_shape, ifm_shape],
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
)


if __name__ == "__main__":
tvm.testing.main()
42 changes: 42 additions & 0 deletions tests/python/contrib/test_ethosu/test_lut_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,48 @@ def after():
assert tvm.ir.structural_equal(mod, after())


def test_merge_lut_into_binary_elementwise():
"""If an binary elementwise operator is followed by an identity operator
with LUT, we can merge the two operataors."""

shape = (1, 8, 8, 4)
dtype = "int8"
ifm = relay.var("x", shape=shape, dtype=dtype)
ifm2 = relay.var("x", shape=shape, dtype=dtype)
lut1 = relay.const([i for i in range(256)], dtype=dtype)
lut2 = relay.const([i for i in reversed(range(256))], dtype=dtype)

def before():
sub = infra.make_ethosu_binary_elementwise(ifm, ifm2, shape[-1], shape[-1], "SUB", dtype)
id1 = infra.make_ethosu_identity(sub, lut=lut1, activation="TANH")
add = infra.make_ethosu_binary_elementwise(id1, ifm2, shape[-1], shape[-1], "ADD", dtype)
id2 = infra.make_ethosu_identity(add, lut=lut2, activation="SIGMOID")

func = relay.Function(relay.analysis.free_vars(id2), id2)
func = func.with_attr("Compiler", "ethos-u")
mod = tvm.IRModule.from_expr(func)
return mod

def after():
sub = infra.make_ethosu_binary_elementwise(
ifm, ifm2, shape[-1], shape[-1], "SUB", dtype, lut=lut1, activation="TANH"
)
add = infra.make_ethosu_binary_elementwise(
sub, ifm2, shape[-1], shape[-1], "ADD", dtype, lut=lut2, activation="SIGMOID"
)

func = relay.Function(relay.analysis.free_vars(add), add)
func = func.with_attr("Compiler", "ethos-u")
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
return mod

mod = LUTsOptimizer()(before())
mod = relay.transform.InferType()(mod)

assert tvm.ir.structural_equal(mod, after())


def test_multiple_luts():
"""Test that when an operation already has a LUT, we don't overwrite that LUT"""

Expand Down

0 comments on commit cff4568

Please sign in to comment.