Skip to content

Commit

Permalink
[microNPU] Add the infrastructure for lookup table and TANH (apache#9547
Browse files Browse the repository at this point in the history
)

Some activation functions like TANH and SIGMOID are implemented
by calculating the values based on the QNN parameters and
recording the values into a lookup table (LUT).

This patch adds the LUT functionality alongside with the TANH
activation function and the tests.
  • Loading branch information
ekalda authored and yangulei committed Jan 11, 2022
1 parent b6fce0e commit 3a8d181
Show file tree
Hide file tree
Showing 21 changed files with 832 additions and 42 deletions.
110 changes: 110 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,115 @@
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr_functor import ExprMutator
from tvm.ir.transform import Pass

# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
from tvm.relay.backend.contrib.ethosu import op


class OptimizeLUTs(ExprMutator):
"""A pass to merge an identity operator with a LUT based activation function with
a preceding operator provided that operator can do a table lookup for the activation
in the hardware"""

def __init__(self):
super().__init__()
self.lut_ops = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
}

def create_op_with_lut(self, call):
"""Extract the parameters and attributes from the NPU operator and create
a new operator with LUT.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
The new operator with LUT.
"""
identity = call
ethosu_op = call.args[0]
lut = identity.args[1]
activation = identity.attrs.activation

new_attrs = dict(ethosu_op.attrs)
new_attrs["activation"] = activation

# Assume that LUT is always the last argument
new_args = ethosu_op.args[:-1] + [lut]
assert ethosu_op.op.name in self.lut_ops.keys()

return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Recursively visit call nodes in the input graph and if an ethosu.identity
operator with LUT is found and the preceding operator has a LUT attribute, create
a new NPU operator.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
The input call node in the case the current call node does
not refer to an Op. Else, a new call node with a new operator.
"""
new_call = call
lut_activations = ["TANH", "LUT"]

if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
producer_op = call.args[0]
# Check if the producer can do a LUT operation
if (
producer_op.op.name in self.lut_ops.keys()
and call.op.name == "contrib.ethosu.identity"
and call.attrs.activation in lut_activations
):
# Check the producer doesn't already have a LUT
has_lut = producer_op.attrs.activation in lut_activations
if not has_lut:
new_call = self.create_op_with_lut(call)

new_call = super().visit_call(new_call)

return new_call


@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
class LUTsOptimizer(Pass):
"""Register LUTsOptimizer as a relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
"""Visit relay nodes in the given module.
Parameters
----------
func : tvm.relay.function.Function
The function to apply the optimization pass for multiple LUTs to.
mod : tvm.IRModule
The module to apply the optimization pass for multiple LUTs to.
Returns
-------
mod : tvm.IRModule
New module with optimized LUTs.
"""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return OptimizeLUTs().visit(func)


@tvm._ffi.register_func("relay.ext.ethos-u")
Expand Down Expand Up @@ -74,6 +183,7 @@ def _compile(ext_func):
mod = tvm.IRModule()
mod["main"] = ext_func
mod = LegalizeEthosU()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
Expand Down
72 changes: 72 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
from typing import List, Type
import math

import numpy as np # type: ignore

Expand All @@ -31,6 +32,7 @@
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
from tvm.relay.backend.contrib.ethosu import vela_api
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore


Expand Down Expand Up @@ -123,6 +125,75 @@ def __call__(self, *args, **kwargs):
pass


def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
"""Method to calculate the values of the tanh lookup table"""
lut_values = list()
# Only int8 is currently supported
dtype = np.int8
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = math.tanh(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

return lut_values


class TanhRewriter(DFPatternCallback):
"""This pass adds tanh as a LUT to the identity operator"""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
)(wildcard())

def callback(self, pre, post, node_map):
id_input = post.args[0]

quantize_args = post.op.body.args
output_scale = float(quantize_args[1].data.asnumpy())
output_zp = int(quantize_args[2].data.asnumpy())

dequantize_args = quantize_args[0].args[0].args
input_scale = float(dequantize_args[1].data.asnumpy())
input_zp = int(dequantize_args[2].data.asnumpy())

lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
lut = relay.const(lut_values, dtype="uint8")

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=id_input,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation="TANH",
)

return identity


@ir.transform.module_pass(opt_level=1)
class LegalizeTanh:
"""This is the pass that wraps TanhRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(TanhRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""

Expand Down Expand Up @@ -915,6 +986,7 @@ def transform_module(
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeTanh()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Arm(R) Ethos(TM)-U NPU Relay operators."""
from tvm.ir import Attrs
import tvm._ffi


@tvm._ffi.register_object("relay.attrs.EthosuConv2DAttrs")
class EthosuConv2DAttrs(Attrs):
"""Attributes for contrib.ethosu.conv2d."""


@tvm._ffi.register_object("relay.attrs.EthosuIdentityAttrs")
class EthosuIdentityAttrs(Attrs):
"""Attributes for contrib.ethosu.identity."""


@tvm._ffi.register_object("relay.attrs.EthosuDepthwiseConv2DAttrs")
class EthosuDepthwiseConv2DAttrs(Attrs):
"""Attributes for contrib.ethosu.depthwise_conv2d."""


@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs")
class EthosuPooling2DAttrs(Attrs):
"""Attributes for contrib.ethosu.pooling."""
9 changes: 8 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def conv2d_compute(
"dilation_w": dilation_w,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
conv2d_attrs["lut"] = lut

conv = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.sum(
Expand All @@ -148,7 +155,7 @@ def conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, rc].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
axis=[rh, rw, rc],
),
name="ethosu_conv2d",
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def depthwise_conv2d_compute(
"dilation_w": dilation_w,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
depthwise_conv2d_attrs["lut"] = lut

depthwise = te.compute(
(1, ofm_height, ofm_width, channels),
lambda nn, hh, ww, cc: te.sum(
Expand All @@ -144,7 +151,7 @@ def depthwise_conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, 0].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
axis=[rh, rw],
),
name="ethosu_depthwise_conv2d",
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/te/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,21 @@ def identity_compute(
The Output Feature Map tensor.
"""

dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
id_attrs = {"op": "ethosu_identity", "activation": activation}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
id_attrs["lut"] = lut

identity = te.compute(
ifm.shape,
lambda *i: dmaed_ifm(*i).astype(ifm.dtype),
lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype),
name="ethosu_identity",
attrs={"op": "ethosu_identity", "activation": activation},
attrs=id_attrs,
)

dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale)
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ def pooling_compute(
"upscale": upscale,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
pooling_attrs["lut"] = lut

pooling = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.max(
dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype),
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
ifm.dtype
),
axis=[rh, rw],
),
name="ethosu_pooling",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_conv2d_params(stmt, producers, consumers):
rh = inner
rw = rh.body
rc = rw.body
# loads = [output, input, weights, scale_bias, scale_bias]
# loads = [output, input, weights, scale_bias, scale_bias, LUT, LUT]
loads = get_loads(rc.body)
# stores = [output]
stores = get_stores(rc.body)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, Tuple
import tvm
from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap
from .utils import get_op_attrs, get_base_address, get_strides
from .utils import get_op_attrs, get_base_address, get_strides, get_loads


def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]:
Expand Down Expand Up @@ -123,7 +123,10 @@ def get_identity_params(
while hasattr(stmt, "body"):
stmt = stmt.body

input_pointer = stmt.value.buffer_var
# loads = [input, LUT, LUT]
loads = get_loads(stmt)

input_pointer = loads[0].buffer_var
output_pointer = stmt.buffer_var

read = producers[input_pointer]
Expand Down
Loading

0 comments on commit 3a8d181

Please sign in to comment.