From 220f6658c9d2792bc0a9d0f5fa8115b6f6cea54e Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 14 Jan 2022 18:59:02 +0000 Subject: [PATCH] [microNPU] Removing constant args from PrimFunc Before this commit, microNPU creates PrimFunc as if it accepts constants from the callee. This commit changes the PrimFunc to remove the constants as an argument to PrimFunc as they are not provided from the main function. Change-Id: If1fe2b8bcd9daf73ecabbb7930451de81e6f7e3b --- .../relay/backend/contrib/ethosu/codegen.py | 14 ++--- .../backend/contrib/ethosu/tir/compiler.py | 1 + .../backend/contrib/ethosu/tir/passes.py | 37 +++++++++++++ .../contrib/ethosu/tir_to_cs_translator.py | 6 +- .../contrib/test_ethosu/test_compiler.py | 55 +++++++++++++------ 5 files changed, 83 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 98ee41f428b2f..54312f6c8d6fe 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -309,8 +309,8 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: # scratch memory size. tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants()) - for idx in const_dict.keys(): - const_dict[idx] = tvm.nd.array(const_dict[idx]) + for param in const_dict.keys(): + const_dict[param] = tvm.nd.array(const_dict[param]) primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) @@ -341,11 +341,9 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact tir_mod = tvm.IRModule() tir_mod[symbol] = primfunc - const_dict_with_int_keys = dict() - for idx in const_dict.keys(): - const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() + const_dict_np = dict() + for buffer_var in const_dict.keys(): + const_dict_np[buffer_var] = const_dict[buffer_var].numpy() - cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate( - tir_mod, const_dict_with_int_keys - ) + cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np) return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index bcd785ddbbd8d..ee35da4cab61b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -90,6 +90,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.AnnotateAllocates()(mod) + mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod) return mod, const_dict diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index fbc9bf3ff41c6..c2fff8abb9b04 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -687,3 +687,40 @@ def _ftransform(f, mod, ctx): return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.ethosu.remove_concatenates" ) + + +def CreatePrimFuncWithoutConstants(const_dict): + """ + This pass will remove arguments that are constants + from PrimFunc Args. These should be replaced properly + with tir.allocate_const when it becomes available. + + It also modifies the constant dictionary to + rewrite the keys as the actual tir.Vars that are params + rather than the index because this pass removes PrimFunc + arguments that represent constants. + """ + + new_const_dict = dict() + + def _ftransform(f, mod, ctx): + new_params = list() + new_buffer_map = dict() + for param_idx in const_dict.keys(): + # We are using buffer_var to key the constants as + # PrimFunc params of constants will be removed. + new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx] + for i in range(len(f.params)): + if i not in const_dict.keys(): + new_params.append(f.params[i]) + new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] + return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span) + + def _create_primfunc_without_constants(mod): + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants" + ) + mod = transform_func(mod) + return mod, new_const_dict + + return _create_primfunc_without_constants diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index d7254511ebfc5..ecea6eb28f098 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -208,7 +208,7 @@ def extract_buffer_info( ---------- mod : tvm.IRModule The NPU TIR IRModule. - param_dict : Dict[int, np.ndarray] + param_dict : Dict[tvm.tir.Var, np.ndarray] A dictionary containing param idx --> const numpy.NDArray Returns @@ -222,8 +222,7 @@ def extract_buffer_info( assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] - for idx, const_data in param_dict.items(): - param = primfunc.params[idx] + for param, const_data in param_dict.items(): buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) @@ -257,7 +256,6 @@ def populate_allocate_buffer_info(stmt): ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) - return buffer_info diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index e1688b8aa512e..0e31be86becba 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -20,27 +20,46 @@ import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from . import infra -def test_lower_to_tir(): - data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") - weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") - p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") - conv = relay.nn.conv2d( - data, - weight, - kernel_size=(1, 1), - data_layout="NHWC", - kernel_layout="HWIO", - out_dtype="int32", - ) - tile = relay.tile(p2, reps=(1, 1, 1, 1001)) - subtract = relay.subtract(conv, tile) - func = subtract - expr = relay.Function(relay.analysis.free_vars(func), func) - mod = tvm.IRModule.from_expr(expr) +def _create_single_conv2d(): + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv1), conv1) + return func + + +def _create_double_conv2d(): + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + conv2 = infra.make_ethosu_conv2d(conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + return func + + +def _create_non_linear_conv2d(): + shape = (1, 8, 8, 4) + ifm1 = relay.var("x", shape=shape, dtype="int8") + ifm2 = relay.var("y", shape=shape, dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm1, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + conv2 = infra.make_ethosu_conv2d(ifm2, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + add = infra.make_ethosu_binary_elementwise(conv1, conv2, shape[3], shape[3], "ADD", "int8") + func = relay.Function(relay.analysis.free_vars(add), add) + return func + + +@pytest.mark.parametrize( + "relay_function, arg_count", + [(_create_single_conv2d, 2), (_create_double_conv2d, 2), (_create_non_linear_conv2d, 3)], +) +def test_lower_to_tir_arg_count(relay_function, arg_count): + mod = tvm.IRModule() + mod["main"] = relay_function() mod = relay.transform.InferType()(mod) - lower_to_tir(mod["main"]) + tir_mod = lower_to_tir(mod["main"])[0] + primfunc = tir_mod["main"] + assert len(primfunc.params) == arg_count if __name__ == "__main__":