Skip to content

Commit

Permalink
[microNPU] Add the infrastructure for lookup table and TANH (#9547)
Browse files Browse the repository at this point in the history
Some activation functions like TANH and SIGMOID are implemented
by calculating the values based on the QNN parameters and
recording the values into a lookup table (LUT).

This patch adds the LUT functionality alongside with the TANH
activation function and the tests.
  • Loading branch information
ekalda authored Dec 1, 2021
1 parent 3047709 commit 9ada371
Show file tree
Hide file tree
Showing 21 changed files with 832 additions and 42 deletions.
110 changes: 110 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,115 @@
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr_functor import ExprMutator
from tvm.ir.transform import Pass

# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
from tvm.relay.backend.contrib.ethosu import op


class OptimizeLUTs(ExprMutator):
"""A pass to merge an identity operator with a LUT based activation function with
a preceding operator provided that operator can do a table lookup for the activation
in the hardware"""

def __init__(self):
super().__init__()
self.lut_ops = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
}

def create_op_with_lut(self, call):
"""Extract the parameters and attributes from the NPU operator and create
a new operator with LUT.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
The new operator with LUT.
"""
identity = call
ethosu_op = call.args[0]
lut = identity.args[1]
activation = identity.attrs.activation

new_attrs = dict(ethosu_op.attrs)
new_attrs["activation"] = activation

# Assume that LUT is always the last argument
new_args = ethosu_op.args[:-1] + [lut]
assert ethosu_op.op.name in self.lut_ops.keys()

return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Recursively visit call nodes in the input graph and if an ethosu.identity
operator with LUT is found and the preceding operator has a LUT attribute, create
a new NPU operator.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
The input call node in the case the current call node does
not refer to an Op. Else, a new call node with a new operator.
"""
new_call = call
lut_activations = ["TANH", "LUT"]

if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
producer_op = call.args[0]
# Check if the producer can do a LUT operation
if (
producer_op.op.name in self.lut_ops.keys()
and call.op.name == "contrib.ethosu.identity"
and call.attrs.activation in lut_activations
):
# Check the producer doesn't already have a LUT
has_lut = producer_op.attrs.activation in lut_activations
if not has_lut:
new_call = self.create_op_with_lut(call)

new_call = super().visit_call(new_call)

return new_call


@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
class LUTsOptimizer(Pass):
"""Register LUTsOptimizer as a relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
"""Visit relay nodes in the given module.
Parameters
----------
func : tvm.relay.function.Function
The function to apply the optimization pass for multiple LUTs to.
mod : tvm.IRModule
The module to apply the optimization pass for multiple LUTs to.
Returns
-------
mod : tvm.IRModule
New module with optimized LUTs.
"""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return OptimizeLUTs().visit(func)


@tvm._ffi.register_func("relay.ext.ethos-u")
Expand Down Expand Up @@ -74,6 +183,7 @@ def _compile(ext_func):
mod = tvm.IRModule()
mod["main"] = ext_func
mod = LegalizeEthosU()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
Expand Down
72 changes: 72 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
from typing import List, Type
import math

import numpy as np # type: ignore

Expand All @@ -31,6 +32,7 @@
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
from tvm.relay.backend.contrib.ethosu import vela_api
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore


Expand Down Expand Up @@ -123,6 +125,75 @@ def __call__(self, *args, **kwargs):
pass


def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
"""Method to calculate the values of the tanh lookup table"""
lut_values = list()
# Only int8 is currently supported
dtype = np.int8
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = math.tanh(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

return lut_values


class TanhRewriter(DFPatternCallback):
"""This pass adds tanh as a LUT to the identity operator"""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
)(wildcard())

def callback(self, pre, post, node_map):
id_input = post.args[0]

quantize_args = post.op.body.args
output_scale = float(quantize_args[1].data.asnumpy())
output_zp = int(quantize_args[2].data.asnumpy())

dequantize_args = quantize_args[0].args[0].args
input_scale = float(dequantize_args[1].data.asnumpy())
input_zp = int(dequantize_args[2].data.asnumpy())

lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
lut = relay.const(lut_values, dtype="uint8")

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=id_input,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation="TANH",
)

return identity


@ir.transform.module_pass(opt_level=1)
class LegalizeTanh:
"""This is the pass that wraps TanhRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(TanhRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""

Expand Down Expand Up @@ -915,6 +986,7 @@ def transform_module(
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeTanh()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Arm(R) Ethos(TM)-U NPU Relay operators."""
from tvm.ir import Attrs
import tvm._ffi


@tvm._ffi.register_object("relay.attrs.EthosuConv2DAttrs")
class EthosuConv2DAttrs(Attrs):
"""Attributes for contrib.ethosu.conv2d."""


@tvm._ffi.register_object("relay.attrs.EthosuIdentityAttrs")
class EthosuIdentityAttrs(Attrs):
"""Attributes for contrib.ethosu.identity."""


@tvm._ffi.register_object("relay.attrs.EthosuDepthwiseConv2DAttrs")
class EthosuDepthwiseConv2DAttrs(Attrs):
"""Attributes for contrib.ethosu.depthwise_conv2d."""


@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs")
class EthosuPooling2DAttrs(Attrs):
"""Attributes for contrib.ethosu.pooling."""
9 changes: 8 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def conv2d_compute(
"dilation_w": dilation_w,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
conv2d_attrs["lut"] = lut

conv = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.sum(
Expand All @@ -148,7 +155,7 @@ def conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, rc].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
axis=[rh, rw, rc],
),
name="ethosu_conv2d",
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def depthwise_conv2d_compute(
"dilation_w": dilation_w,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
depthwise_conv2d_attrs["lut"] = lut

depthwise = te.compute(
(1, ofm_height, ofm_width, channels),
lambda nn, hh, ww, cc: te.sum(
Expand All @@ -144,7 +151,7 @@ def depthwise_conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, 0].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
axis=[rh, rw],
),
name="ethosu_depthwise_conv2d",
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/te/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,21 @@ def identity_compute(
The Output Feature Map tensor.
"""

dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
id_attrs = {"op": "ethosu_identity", "activation": activation}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
id_attrs["lut"] = lut

identity = te.compute(
ifm.shape,
lambda *i: dmaed_ifm(*i).astype(ifm.dtype),
lambda *i: (dmaed_ifm(*i) + lut_expr).astype(ifm.dtype),
name="ethosu_identity",
attrs={"op": "ethosu_identity", "activation": activation},
attrs=id_attrs,
)

dmaed_ofm = write_compute(identity, ofm_zero_point, ofm_scale)
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ def pooling_compute(
"upscale": upscale,
}

# This is a trick to insert the LUT tensor into the TE graph if LUT is present
lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0

# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
if activation in ("TANH", "LUT"):
pooling_attrs["lut"] = lut

pooling = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.max(
dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype),
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
ifm.dtype
),
axis=[rh, rw],
),
name="ethosu_pooling",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_conv2d_params(stmt, producers, consumers):
rh = inner
rw = rh.body
rc = rw.body
# loads = [output, input, weights, scale_bias, scale_bias]
# loads = [output, input, weights, scale_bias, scale_bias, LUT, LUT]
loads = get_loads(rc.body)
# stores = [output]
stores = get_stores(rc.body)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, Tuple
import tvm
from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap
from .utils import get_op_attrs, get_base_address, get_strides
from .utils import get_op_attrs, get_base_address, get_strides, get_loads


def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]:
Expand Down Expand Up @@ -123,7 +123,10 @@ def get_identity_params(
while hasattr(stmt, "body"):
stmt = stmt.body

input_pointer = stmt.value.buffer_var
# loads = [input, LUT, LUT]
loads = get_loads(stmt)

input_pointer = loads[0].buffer_var
output_pointer = stmt.buffer_var

read = producers[input_pointer]
Expand Down
Loading

0 comments on commit 9ada371

Please sign in to comment.