From 46b460225e01bb58e3ed75346f0e91b217d307bd Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 26 Nov 2021 11:35:08 +0000 Subject: [PATCH] Responding to the reviews --- .../relay/backend/contrib/ethosu/codegen.py | 20 ++++++++++++------- .../relay/backend/contrib/ethosu/legalize.py | 8 ++------ .../contrib/test_ethosu/test_lookup_table.py | 2 -- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 8f193d48ad6b..1f331822e1ac 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -46,9 +46,12 @@ def __init__(self): def create_op_with_lut(self, call): """Extract the parameters and attributes from the NPU operator and create a new operator with LUT. + + Parameters ---------- call : tvm.relay.expr.Call The current call node being visited. + Returns ------- tvm.relay.expr.Call @@ -63,8 +66,7 @@ def create_op_with_lut(self, call): new_attrs["activation"] = activation # Assume that LUT is always the last argument - new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)] - new_args.append(lut) + new_args = ethosu_op.args[:-1] + [lut] assert ethosu_op.op.name in self.lut_ops.keys() return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs) @@ -73,10 +75,12 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: """Recursively visit call nodes in the input graph and if an ethosu.identity operator with LUT is found and the preceding operator has a LUT attribute, create a new NPU operator. + Parameters ---------- call : tvm.relay.expr.Call The current call node being visited. + Returns ------- tvm.relay.expr.Call @@ -104,24 +108,26 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@relay.transform.function_pass(opt_level=1, name="LutOptimizer") +@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") class LUTsOptimizer(Pass): - """Register LutOptimizer as a relay pass.""" + """Register LUTsOptimizer as a relay pass.""" def transform_function( self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ ) -> tvm.IRModule: """Visit relay nodes in the given module. + Parameters ---------- func : tvm.relay.function.Function - The function to apply the layout optimization pass to. + The function to apply the optimization pass for multiple LUTs to. mod : tvm.IRModule - The module to apply the layout optimization pass to. + The module to apply the optimization pass for multiple LUTs to. + Returns ------- mod : tvm.IRModule - New module with augmented layouts. + New module with optimized LUTs. """ assert len(mod.functions.items()) == 1, "Module can only contain one function." return OptimizeLUTs().visit(func) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 3d4f8b71cfbb..5613d613f984 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -32,6 +32,7 @@ from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore @@ -124,11 +125,6 @@ def __call__(self, *args, **kwargs): pass -def round_away_zero(f): - r = -0.5 if (f < 0) else 0.5 - return np.trunc(f + r) - - def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): """Method to calculate the values of the tanh lookup table""" lut_values = list() @@ -138,7 +134,7 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): for x in range(qmin, qmax + 1): x_real = ifm_scale * (x - ifm_zp) out_real = math.tanh(x_real) - lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale)) + lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale)) lut_result = min(qmax, max(qmin, lut_result)) lut_values.append(lut_result) diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py index 67870bd6472e..d32b441fd2eb 100644 --- a/tests/python/contrib/test_ethosu/test_lookup_table.py +++ b/tests/python/contrib/test_ethosu/test_lookup_table.py @@ -40,8 +40,6 @@ def test_tflite_lut_activations(accel_type): ifm_shape = (1, 55, 55, 3) def create_tflite_graph(): - tf.config.run_functions_eagerly(True) - class Model(tf.Module): @tf.function def tf_func(self, x):