Skip to content

Commit

Permalink
Responding to the reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
ekalda committed Nov 29, 2021
1 parent 7678745 commit 46b4602
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
20 changes: 13 additions & 7 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def __init__(self):
def create_op_with_lut(self, call):
"""Extract the parameters and attributes from the NPU operator and create
a new operator with LUT.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
Expand All @@ -63,8 +66,7 @@ def create_op_with_lut(self, call):
new_attrs["activation"] = activation

# Assume that LUT is always the last argument
new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)]
new_args.append(lut)
new_args = ethosu_op.args[:-1] + [lut]
assert ethosu_op.op.name in self.lut_ops.keys()

return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)
Expand All @@ -73,10 +75,12 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Recursively visit call nodes in the input graph and if an ethosu.identity
operator with LUT is found and the preceding operator has a LUT attribute, create
a new NPU operator.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
Expand Down Expand Up @@ -104,24 +108,26 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
return new_call


@relay.transform.function_pass(opt_level=1, name="LutOptimizer")
@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
class LUTsOptimizer(Pass):
"""Register LutOptimizer as a relay pass."""
"""Register LUTsOptimizer as a relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
"""Visit relay nodes in the given module.
Parameters
----------
func : tvm.relay.function.Function
The function to apply the layout optimization pass to.
The function to apply the optimization pass for multiple LUTs to.
mod : tvm.IRModule
The module to apply the layout optimization pass to.
The module to apply the optimization pass for multiple LUTs to.
Returns
-------
mod : tvm.IRModule
New module with augmented layouts.
New module with optimized LUTs.
"""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return OptimizeLUTs().visit(func)
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
from tvm.relay.backend.contrib.ethosu import vela_api
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore


Expand Down Expand Up @@ -124,11 +125,6 @@ def __call__(self, *args, **kwargs):
pass


def round_away_zero(f):
r = -0.5 if (f < 0) else 0.5
return np.trunc(f + r)


def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
"""Method to calculate the values of the tanh lookup table"""
lut_values = list()
Expand All @@ -138,7 +134,7 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = math.tanh(x_real)
lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

Expand Down
2 changes: 0 additions & 2 deletions tests/python/contrib/test_ethosu/test_lookup_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def test_tflite_lut_activations(accel_type):
ifm_shape = (1, 55, 55, 3)

def create_tflite_graph():
tf.config.run_functions_eagerly(True)

class Model(tf.Module):
@tf.function
def tf_func(self, x):
Expand Down

0 comments on commit 46b4602

Please sign in to comment.