Skip to content

Commit

Permalink
[microNPU] Add NHWC -> NHCWB16 layout transformation pass (apache#9561)
Browse files Browse the repository at this point in the history
Adds a layout optimization pass that modifies the ifm/ofm layout
of an operation to NHCWB16 where possible. This can occur when the
producer or consumer of a tensor is also an NPU operator.
  • Loading branch information
lhutton1 authored and baoxinqi committed Dec 27, 2021
1 parent ab959ff commit 4058da2
Show file tree
Hide file tree
Showing 3 changed files with 761 additions and 0 deletions.
128 changes: 128 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,134 @@ def transform_function(
return OptimizeLUTs().visit(func)


class LayoutOptimization(ExprMutator):
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16.
Attributes
----------
children : Dict[tvm.relay.expr.Call, List[tvm.relay.expr.Call]]
A map from current call to a list of calls that rely on the current
call. This allows the graph to be traversed backwards, which is useful
for checking whether the output layouts can be rewritten.
optimize_op : Dict[str, Callable]
A map from NPU op name to function that creates NPU op.
"""

def __init__(self):
self.children = {}
self.optimize_op = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
}

super().__init__()

def alter_ethosu_op_layout(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Alter the input and output layouts of an NPU operation if needed.
Input layout is only altered if the producing operation is an NPU
operation. Likewise, the output layout is only altered if the consuming
operation is an NPU operation.
Parameters
----------
call : tvm.relay.expr.Call
The call pointing to an NPU operation that will be checked if
the layout needs altering.
Returns
-------
new_call : tvm.relay.expr.Call
New call with altered layouts.
"""
assert isinstance(call.attrs, tvm.ir.Attrs), (
f"The attributes for operator '{call.op.name}' could not be "
"found. Did you register the relay.attrs.Ethosu<opname>Attrs "
"object in python api?"
)

new_attrs = dict(call.attrs)
parents = []

# Check if we can rewrite the input layouts
input_count = 0
for arg in call.args:
input_count += 1
if not isinstance(arg, tvm.relay.expr.Call):
continue
if isinstance(arg.op, tvm.ir.op.Op) and arg.op.name in self.optimize_op:
layout_string = "ifm_layout" if input_count <= 1 else f"ifm{input_count}_layout"
new_attrs[layout_string] = "NHCWB16"
parents.append(arg)

# Check if we can rewrite the output layouts
if call in self.children:
children = self.children[call]
if all(
isinstance(child, tvm.relay.expr.Call)
and isinstance(child.op, tvm.ir.op.Op)
and child.op.name in self.optimize_op
and child.attrs["ifm_layout"] == "NHCWB16"
for child in children
):
new_attrs["ofm_layout"] = "NHCWB16"

name = call.op.name
assert name in self.optimize_op, (
f"Could not create operator '{name}' as the creation function "
"is unknown. Please provide a mapping."
)
new_call = self.optimize_op[name](*call.args, **new_attrs)

# Update map of children
for input_arg in parents:
if input_arg in self.children:
self.children[input_arg].append(new_call)
else:
self.children[input_arg] = [new_call]

return super().visit_call(new_call)

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
"""Recursively visit call nodes in the input graph and alter the
layout of an op if needed.
Parameters
----------
call : tvm.relay.expr.Call
The current call node being visited.
Returns
-------
tvm.relay.expr.Call
The input call node in the case the current call node does
not refer to an Op. Else, a new call node with altered Op
attributes.
"""
if isinstance(call.op, tvm.ir.op.Op) and call.op.name in self.optimize_op:
return self.alter_ethosu_op_layout(call)
return super().visit_call(call)


@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer")
class LayoutOptimizer(Pass):
"""Register LayoutOptimizer as a Relay pass."""

def transform_function(
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
) -> tvm.IRModule:
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16 as this is the layout NPU
uses internally."""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
return LayoutOptimization().visit(func)


@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
def constant_updater(expr, symbol): # pylint: disable=unused-argument
"""
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@ class EthosuDepthwiseConv2DAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.EthosuPoolingAttrs")
class EthosuPooling2DAttrs(Attrs):
"""Attributes for contrib.ethosu.pooling."""


@tvm._ffi.register_object("relay.attrs.EthosuBinaryElementwiseAttrs")
class EthosuBinaryElementwiseAttrs(Attrs):
"""Attributes for contrib.ethosu.binary_elementwise"""


@tvm._ffi.register_object("relay.attrs.EthosuUnaryElementwiseAttrs")
class EthosuUnaryElementwiseAttrs(Attrs):
"""Attributes for contrib.ethosu.unary_elementwise"""
Loading

0 comments on commit 4058da2

Please sign in to comment.