Skip to content

Commit

Permalink
[microNPU] Move optimization passes to be a module pass and ensure th…
Browse files Browse the repository at this point in the history
…ey (#9831)

are running

Moves LayoutOptimizer and LUTOptimizer passes to be a module pass,
rather than a function pass. This is because it was found that these
passes were not running in the NPU compilation flow. In addition, a
test for both LayoutOptimizer and LUTOptimizer has been added to check
that the passes are running in the compilation pipeline of the NPU.

Change-Id: I5145c6f02eeb0daea3cdba56198e0804ec32f351
  • Loading branch information
lhutton1 authored Jan 20, 2022
1 parent 589fc01 commit e390d9e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 38 deletions.
34 changes: 21 additions & 13 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

import tvm
from tvm import relay
from tvm import ir
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr_functor import ExprMutator
from tvm.ir.transform import Pass

# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
Expand Down Expand Up @@ -109,13 +109,11 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
return new_call


@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
class LUTsOptimizer(Pass):
@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer")
class LUTsOptimizer:
"""Register LUTsOptimizer as a relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
"""Visit relay nodes in the given module.
Parameters
Expand All @@ -131,7 +129,13 @@ def transform_function(
New module with optimized LUTs.
"""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return OptimizeLUTs().visit(func)
global_var, func = mod.functions.items()[0]
optimized_func = OptimizeLUTs().visit(func)
mod.update_func(global_var, optimized_func)
return mod

def __call__(self, *args, **kwargs):
pass


class LayoutOptimization(ExprMutator):
Expand Down Expand Up @@ -247,19 +251,23 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
return super().visit_call(call)


@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer")
class LayoutOptimizer(Pass):
@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer")
class LayoutOptimizer:
"""Register LayoutOptimizer as a Relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16 as this is the layout NPU
uses internally."""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return LayoutOptimization().visit(func)
global_var, func = mod.functions.items()[0]
optimized_func = LayoutOptimization().visit(func)
mod.update_func(global_var, optimized_func)
return mod

def __call__(self, *args, **kwargs):
pass


@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
Expand Down
80 changes: 55 additions & 25 deletions tests/python/contrib/test_ethosu/test_layout_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@
from tvm import relay
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer
from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func

from . import infra


def _run_pass(expr, relay_pass):
"""Create IRModule and run Relay pass."""
def _optimize(expr, optimize=True):
"""Create IRModule and run layout optimizer pass."""
mod = tvm.IRModule.from_expr(expr)
mod = relay_pass(mod)
mod = relay.transform.InferType()(mod)
if optimize:
mod = LayoutOptimizer()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body

Expand Down Expand Up @@ -111,8 +114,8 @@ def get_graph():
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -144,8 +147,8 @@ def get_graph(get_expected=False):
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -176,8 +179,8 @@ def get_graph(get_expected=False):
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -222,8 +225,8 @@ def get_graph():
)
return relay.Function(relay.analysis.free_vars(conv_2), conv_2)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -268,8 +271,8 @@ def get_graph():
)
return relay.Function(relay.analysis.free_vars(conv_2), conv_2)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -322,8 +325,8 @@ def get_graph():
)
return relay.Function(relay.analysis.free_vars(pool_3), pool_3)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -368,8 +371,8 @@ def get_graph():
)
return relay.Function(relay.analysis.free_vars(conv), conv)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -413,8 +416,8 @@ def get_graph(get_expected=False):
concat = relay.concatenate(poolings, axis=0)
return relay.Function(relay.analysis.free_vars(concat), concat)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -467,8 +470,8 @@ def get_graph(get_expected=False):
)
return relay.Function(relay.analysis.free_vars(add_3), add_3)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -500,8 +503,8 @@ def get_graph(get_expected=False):
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -530,8 +533,8 @@ def get_graph(get_expected=False):
)
return relay.Function(relay.analysis.free_vars(x), x)

a = _run_pass(get_graph(), LayoutOptimizer())
b = _run_pass(get_graph(get_expected=True), relay.transform.InferType())
a = _optimize(get_graph())
b = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(a, b)


Expand Down Expand Up @@ -619,5 +622,32 @@ def representative_dataset():
_compile_and_compare_model(create_model(), ifm_shape, dtype)


def test_layout_optimizer_runs_in_compilation_pipeline():
"""Checks that the layout optimization pass runs as part of the NPU compilation
pipeline."""

def get_graph():
x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8")
for _ in range(2):
x = relay.nn.max_pool2d(x, layout="NHWC")

func = relay.Function(relay.analysis.free_vars(x), x)
return tvm.IRModule.from_expr(func)

mod = get_graph()
mod = partition_for_ethosu(mod)

external_gv_name = mod["main"].body.op.name_hint
external_func = mod[external_gv_name]
prim_func = relay_to_tir_func(external_func)

# Check for hints in the TIR prim func that the layout optimization pass has ran
ops = prim_func.body.body.seq
max_pool1, max_pool2 = ops

assert str(max_pool1.value.args[31]) == '"NHCWB16"'
assert str(max_pool2.value.args[14]) == '"NHCWB16"'


if __name__ == "__main__":
pytest.main([__file__] + sys.argv[1:])
38 changes: 38 additions & 0 deletions tests/python/contrib/test_ethosu/test_lut_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@

pytest.importorskip("ethosu.vela")

import tensorflow as tf
import numpy as np

import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer
from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func
from tvm.relay.op.contrib.ethosu import partition_for_ethosu

from .test_codegen import _get_tflite_graph
from . import infra


Expand Down Expand Up @@ -59,6 +66,7 @@ def after():
return mod

mod = LUTsOptimizer()(before())
mod = relay.transform.InferType()(mod)

assert tvm.ir.structural_equal(mod, after())

Expand Down Expand Up @@ -91,5 +99,35 @@ def after():
return mod

mod = LUTsOptimizer()(before())
mod = relay.transform.InferType()(mod)

assert tvm.ir.structural_equal(mod, after())


def test_lut_optimizer_runs_in_compilation_pipeline():
"""Test that the LUT optimization pass runs as part of the NPU compilation pipeline."""
ifm_shape = (1, 4, 4, 4)

@tf.function
def get_graph(x):
weight1 = tf.constant(np.random.uniform(size=(1, 1, 4, 4)), dtype=tf.float32)
op = tf.nn.conv2d(x, weight1, (1, 1), "VALID")
op = tf.nn.tanh(op)
weight2 = tf.constant(np.random.uniform(size=(1, 1, 4, 1)), dtype=tf.float32)
op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID")
return tf.nn.tanh(op)

mod, _ = _get_tflite_graph(get_graph, [ifm_shape])
mod = partition_for_ethosu(mod)

external_gv_name = mod["main"].body.op.name_hint
external_func = mod[external_gv_name]
prim_func = relay_to_tir_func(external_func)

# Check for hints in the TIR prim func that the LUT optimization pass has ran.
# If the module was optimized, there should be no identity operations.
def check_identity(stmt):
if isinstance(stmt, tvm.tir.expr.Call):
assert stmt.args[0] != "ethosu_identity"

tvm.tir.stmt_functor.post_order_visit(prim_func.body, check_identity)

0 comments on commit e390d9e

Please sign in to comment.