Skip to content

Commit

Permalink
[microNPU] Removing constant args from PrimFunc
Browse files Browse the repository at this point in the history
Before this commit, microNPU creates PrimFunc as if
it accepts constants from the callee. This commit
changes the PrimFunc to remove the constants as an
argument to PrimFunc as they are not provided from
the main function.

Change-Id: If1fe2b8bcd9daf73ecabbb7930451de81e6f7e3b
  • Loading branch information
manupak committed Jan 24, 2022
1 parent 4e60749 commit 220f665
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 30 deletions.
14 changes: 6 additions & 8 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
# scratch memory size.
tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())

for idx in const_dict.keys():
const_dict[idx] = tvm.nd.array(const_dict[idx])
for param in const_dict.keys():
const_dict[param] = tvm.nd.array(const_dict[param])

primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
Expand Down Expand Up @@ -341,11 +341,9 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
tir_mod = tvm.IRModule()
tir_mod[symbol] = primfunc

const_dict_with_int_keys = dict()
for idx in const_dict.keys():
const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy()
const_dict_np = dict()
for buffer_var in const_dict.keys():
const_dict_np[buffer_var] = const_dict[buffer_var].numpy()

cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(
tir_mod, const_dict_with_int_keys
)
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np)
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = tvm.tir.transform.StorageRewrite()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = ethosu_passes.AnnotateAllocates()(mod)
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
return mod, const_dict


Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,40 @@ def _ftransform(f, mod, ctx):
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.remove_concatenates"
)


def CreatePrimFuncWithoutConstants(const_dict):
"""
This pass will remove arguments that are constants
from PrimFunc Args. These should be replaced properly
with tir.allocate_const when it becomes available.
It also modifies the constant dictionary to
rewrite the keys as the actual tir.Vars that are params
rather than the index because this pass removes PrimFunc
arguments that represent constants.
"""

new_const_dict = dict()

def _ftransform(f, mod, ctx):
new_params = list()
new_buffer_map = dict()
for param_idx in const_dict.keys():
# We are using buffer_var to key the constants as
# PrimFunc params of constants will be removed.
new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]
for i in range(len(f.params)):
if i not in const_dict.keys():
new_params.append(f.params[i])
new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]]
return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span)

def _create_primfunc_without_constants(mod):
transform_func = tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants"
)
mod = transform_func(mod)
return mod, new_const_dict

return _create_primfunc_without_constants
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def extract_buffer_info(
----------
mod : tvm.IRModule
The NPU TIR IRModule.
param_dict : Dict[int, np.ndarray]
param_dict : Dict[tvm.tir.Var, np.ndarray]
A dictionary containing param idx --> const numpy.NDArray
Returns
Expand All @@ -222,8 +222,7 @@ def extract_buffer_info(
assert len(mod.functions.items()) == 1
primfunc = mod.functions.items()[0][1]

for idx, const_data in param_dict.items():
param = primfunc.params[idx]
for param, const_data in param_dict.items():
buffer_info[param] = BufferInfo(
const_data, const_data.shape, const_data.dtype, BufferType.constant
)
Expand Down Expand Up @@ -257,7 +256,6 @@ def populate_allocate_buffer_info(stmt):
)

tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info)

return buffer_info


Expand Down
55 changes: 37 additions & 18 deletions tests/python/contrib/test_ethosu/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,46 @@
import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from . import infra


def test_lower_to_tir():
data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8")
weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8")
p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32")
conv = relay.nn.conv2d(
data,
weight,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32",
)
tile = relay.tile(p2, reps=(1, 1, 1, 1001))
subtract = relay.subtract(conv, tile)
func = subtract
expr = relay.Function(relay.analysis.free_vars(func), func)
mod = tvm.IRModule.from_expr(expr)
def _create_single_conv2d():
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
func = relay.Function(relay.analysis.free_vars(conv1), conv1)
return func


def _create_double_conv2d():
ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
conv2 = infra.make_ethosu_conv2d(conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1))
func = relay.Function(relay.analysis.free_vars(conv2), conv2)
return func


def _create_non_linear_conv2d():
shape = (1, 8, 8, 4)
ifm1 = relay.var("x", shape=shape, dtype="int8")
ifm2 = relay.var("y", shape=shape, dtype="int8")
conv1 = infra.make_ethosu_conv2d(ifm1, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
conv2 = infra.make_ethosu_conv2d(ifm2, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
add = infra.make_ethosu_binary_elementwise(conv1, conv2, shape[3], shape[3], "ADD", "int8")
func = relay.Function(relay.analysis.free_vars(add), add)
return func


@pytest.mark.parametrize(
"relay_function, arg_count",
[(_create_single_conv2d, 2), (_create_double_conv2d, 2), (_create_non_linear_conv2d, 3)],
)
def test_lower_to_tir_arg_count(relay_function, arg_count):
mod = tvm.IRModule()
mod["main"] = relay_function()
mod = relay.transform.InferType()(mod)
lower_to_tir(mod["main"])
tir_mod = lower_to_tir(mod["main"])[0]
primfunc = tir_mod["main"]
assert len(primfunc.params) == arg_count


if __name__ == "__main__":
Expand Down

0 comments on commit 220f665

Please sign in to comment.