Skip to content

Commit

Permalink
[microNPU] Fix output mismatch in Leaky ReLU (apache#11397)
Browse files Browse the repository at this point in the history
* [microNPU] Fix output mismatch in Leaky ReLU

All codegen tests have been running with a representative dataset
between 0,1 which masked an output mismatch in Leaky ReLU when compared
to TFLite kernels. This issue can be replicated by replacing the
representative dataset range with something like -1,1.

To fix this mismatch, we use the same implementation for calculating
LUT values as Vela which uses arithmetic constrained to quantized
values, rather than the previously used floating point calculations.

Change-Id: I0ed52215acd27722873be609271971b6fc4aaef1

* fix lint

Change-Id: Ica7de0c000ee015e79fe10985b2ec7a9b341861f

* fix lint again

Change-Id: I005d90ad248bfff7090f99d161eefbdc962cba48
  • Loading branch information
lhutton1 authored Jun 6, 2022
1 parent b555bf5 commit 609d6af
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
88 changes: 57 additions & 31 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
from typing import List, Type, Callable, Any, Dict
from typing import List, Type, Callable
import math

import numpy as np # type: ignore
from ethosu.vela import scaling, fp_math

import tvm # type: ignore
from tvm import relay
Expand Down Expand Up @@ -132,7 +133,6 @@ def get_lut_from_func(
ofm_scale: float,
ofm_zp: int,
func: Callable[[float], float],
func_params: Dict[str, Any],
) -> List[int]:
"""Calculates the values of the lookup table based on the calculation function"""

Expand All @@ -142,7 +142,7 @@ def get_lut_from_func(
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = func(x_real, **func_params)
out_real = func(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)
Expand All @@ -165,29 +165,10 @@ def __init__(
self.activation_type = activation_type
self.calc_func = calc_func

def get_calc_func_params(self, expr: tvm.relay.Expr) -> Dict[str, Any]:
"""
Overridable method that can be used to extract additional arguments
for passing to calc_func.
Parameters
----------
expr : tvm.relay.Expr
The matched composite activation function.
Returns
-------
Dict[str, Any]
Maps argument name to argument value.
"""
return {}

def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]

calc_func_params = self.get_calc_func_params(post.op)

input_scale = float(params.ifm.q_params.scale_f32)
input_zp = int(params.ifm.q_params.zero_point)
output_scale = float(params.ofm.q_params.scale_f32)
Expand All @@ -199,7 +180,6 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c
output_scale,
output_zp,
self.calc_func,
calc_func_params,
)
lut = relay.const(lut_values, dtype=params.ifm.dtype)

Expand Down Expand Up @@ -257,19 +237,65 @@ def leaky_relu_calc_func(x: float, alpha: float) -> float:
return x if x >= 0 else x * alpha


class LeakyReLURewriter(LutActivationRewriter):
class LeakyReLURewriter(DFPatternCallback):
"""This pass adds leaky relu as a LUT for identity op."""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.LeakyReLUParams,
activation_type="LUT",
calc_func=leaky_relu_calc_func,
super().__init__(require_type=True, rewrite_once=True)
self.params_class = ethosu_patterns.LeakyReLUParams
self.pattern = wildcard().has_attr({"Composite": self.params_class.composite_name})(
wildcard()
)

def get_calc_func_params(self, expr: tvm.relay.Expr) -> Dict[str, Any]:
params = ethosu_patterns.LeakyReLUParams(expr.body)
return {"alpha": params.alpha}
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
params = self.params_class(post.op.body)
params.ifm.tensor = post.args[0]

input_scale = np.double(float(params.ifm.q_params.scale_f32))
input_zp = int(params.ifm.q_params.zero_point)
output_scale = np.double(float(params.ofm.q_params.scale_f32))
output_zp = int(params.ofm.q_params.zero_point)

alpha = params.alpha

# The calculation of the LUT values is similar to that in Vela
# convert_lrelu_to_lut(op, arch)
# (https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#864) # pylint: disable=line-too-long
alpha_scalar = 1
alpha_scale, alpha_shift = scaling.elementwise_mul_scale(input_scale, alpha, output_scale)
identity_scale, identity_shift = scaling.elementwise_mul_scale(input_scale, 1, output_scale)

dtype = params.ifm.dtype
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max

def calculate_lut_value(i):
zp_shift = (
fp_math.multiply_by_quantized_multiplier(
alpha_scalar * (i - input_zp), alpha_scale, alpha_shift
)
if i < input_zp
else fp_math.multiply_by_quantized_multiplier(
i - input_zp, identity_scale, identity_shift
)
)

return min(qmax, max(qmin, output_zp + zp_shift))

values = list(map(calculate_lut_value, range(qmin, qmax + 1)))
lut = relay.const(values, dtype=dtype)

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation="LUT",
)

return identity


class Conv2DRewriter(DFPatternCallback):
Expand Down
6 changes: 5 additions & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,11 @@ def leaky_relu_func(x):
return tf.nn.leaky_relu(x, alpha=alpha)

infra.compare_tvm_with_tflite(
leaky_relu_func, [ifm_shape], accel_type, enable_cascader=is_u55_accel_type(accel_type)
leaky_relu_func,
[ifm_shape],
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
ranges=[(-1, 1)],
)


Expand Down

0 comments on commit 609d6af

Please sign in to comment.