Skip to content

Commit

Permalink
[microNPU][2b] Create CascaderGraphs from TE graphs (apache#9471)
Browse files Browse the repository at this point in the history
The first step in the cascader is to create a
CascaderGraph from a TE graph. To do this, every
operator in the TE graph must get 'matched' by a
Part matcher. This converts TE operations into
cascader Parts by augmenting them with Propagators.

In this initial commit, we include basic Part
matchers for ethosu_conv2d and some transform
operators so that the graph creation can be tested.
  • Loading branch information
mbaret authored and AndrewZhaoLuo committed Jan 7, 2022
1 parent 711f083 commit fcaa30c
Show file tree
Hide file tree
Showing 16 changed files with 987 additions and 37 deletions.
12 changes: 10 additions & 2 deletions python/tvm/contrib/ethosu/cascader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,13 @@
"""
from .stripe_config import StripeConfig
from .propagator import Propagator
from .graph import PerformanceInfo, Tensor, Part, TESubgraph, CascaderGraph
from .parts import InlinePart
from .graph import (
PerformanceInfo,
Tensor,
Part,
TESubgraph,
CascaderGraph,
register_matcher,
create_cascader_graph,
)
from .parts import InlinePart, EthosuPart
87 changes: 85 additions & 2 deletions python/tvm/contrib/ethosu/cascader/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
# specific language governing permissions and limitations
# under the License.
"""Graph objects to define compute graphs for the NPU cascader."""
from typing import List
from typing import List, Dict
from collections import namedtuple
import tvm._ffi
import numpy as np

import tvm._ffi
from tvm import te
from tvm.runtime import Object

from .stripe_config import StripeConfig
from . import _ffi_api


# A global store to register matching functions
REGISTERED_MATCHERS = []


TESubgraph = namedtuple("TESubgraph", ["input_tensors", "output_tensor"])


Expand Down Expand Up @@ -168,3 +174,80 @@ def tensor_order(self):
@property
def part_order(self):
return list(self._part_order)


def register_matcher(matcher):
"""Register a match function to the frontend.
A match function takes a te.Tensor and checks whether it matches
a known operator/operator sequence. If it does, it returns a Part
which models the behaviour of that operator sequence. Otherwise,
it returns None.
"""
REGISTERED_MATCHERS.append(matcher)
return matcher


def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray]) -> CascaderGraph:
"""Create a CascaderGraph from a Tensor Expression graph and constant dictionary.
Parameters
----------
te_graph : TESubgraph
The Tensor Expression graph.
const_dict : Dict[int, np.ndarray]
The constant dictionary.
Returns
-------
CascaderGraph
The CascaderGraph.
"""
tensor_map = {}

def _visit_tensor(tensor):
if tensor not in tensor_map:
is_const = False
# Logic to determine if the tensor is constant
if tensor in list(te_graph.inputs):
i = list(te_graph.inputs).index(tensor)
if i in const_dict:
is_const = True

# TODO(@mbaret): Calculate the compression ratio
plan_tensor = Tensor(
tensor.shape,
tensor.dtype,
is_constant=is_const,
)
tensor_map[tensor] = plan_tensor
if isinstance(tensor.op, te.PlaceholderOp) or tensor in te_graph.inputs:
return

input_tensors = []
# Check whether any of the registered matchers match the current tensor
for matcher in REGISTERED_MATCHERS:
part = matcher(tensor)
if part:
input_tensors = part.subgraph.input_tensors
break

assert part is not None, f"The tensor {tensor} doesn't match any part."
part.set_output(plan_tensor)
plan_tensor.add_producer(part)
for i, input_tensor in enumerate(input_tensors):
_visit_tensor(input_tensor)
part.set_input(i, tensor_map[input_tensor])
tensor_map[input_tensor].add_consumer(part)

for output in te_graph.outputs:
_visit_tensor(output)

input_tensors = []
for t in te_graph.inputs:
# This is needed because sometimes there are orphaned constants
if t in tensor_map:
input_tensors.append(tensor_map[t])

output_tensors = [tensor_map[t] for t in te_graph.outputs]
return CascaderGraph(input_tensors, output_tensors)
24 changes: 24 additions & 0 deletions python/tvm/contrib/ethosu/cascader/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,27 @@ def __init__(
te_subgraph.output_tensor,
propagators,
)


@tvm._ffi.register_object("contrib.ethosu.cascader.EthosuPart")
class EthosuPart(Part):
"""A class to describe a Part to be executed on an Arm(R) Ethos(TM)-U NPU.
EthosuParts must be provided with an output quantum and the cycles taken to
compute an output quantum which depend on the operator the NPU is computing."""

def __init__(
self,
te_subgraph: TESubgraph,
propagators: List[Propagator],
output_quantum: List[int],
quantum_cycles: int,
):
self.__init_handle_by_constructor__(
_ffi_api.EthosuPart,
te_subgraph.input_tensors,
te_subgraph.output_tensor,
propagators,
output_quantum,
quantum_cycles,
)
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .binary_elementwise import *
from .identity import *
from .unary_elementwise import *
from .inline import *
133 changes: 129 additions & 4 deletions python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# pylint: disable=invalid-name,unused-argument
"""Tensor Expressions for convolutions for the NPU"""
from typing import Tuple, Union, List
import numpy as np # type: ignore

from tvm import te # type: ignore
from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher

from .dma import dma_ofm_compute, dma_ifm_compute


Expand Down Expand Up @@ -108,9 +111,10 @@ def conv2d_compute(
assert ifm_layout in {"NHWC", "NHCWB16"}
assert ofm_layout in {"NHWC", "NHCWB16"}

stride_h, stride_w = strides
dilation_h, dilation_w = dilation
ofm_channels, kernel_h, kernel_w, ifm_channels = weight.shape
padding = [int(v) for v in padding]
stride_h, stride_w = [int(v) for v in strides]
dilation_h, dilation_w = [int(v) for v in dilation]
ofm_channels, kernel_h, kernel_w, ifm_channels = [int(v) for v in weight.shape]

# Compute operation for the IFM DMA pipeline
dmaed_ifm = dma_ifm_compute(
Expand Down Expand Up @@ -164,5 +168,126 @@ def conv2d_compute(
attrs=conv2d_attrs,
)

nhwc_to_nhcwb16 = [
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1 / 16, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 16],
[0, 0, 0, 0, 1],
]
nhcwb16_to_nhwc = [
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 16, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
]
ifm_matrix = [
[1, 0, 0, 0, 0],
[0, stride_h, 0, 0, (dilated_kernel_h - stride_h)],
[0, 0, stride_w, 0, (dilated_kernel_w - stride_w)],
[0, 0, 0, 0, ifm_channels],
[0, 0, 0, 0, 1],
]
weights_matrix = [
[0, 0, 0, 1, 0],
[0, 0, 0, 0, kernel_h],
[0, 0, 0, 0, kernel_w],
[0, 0, 0, 0, ifm_channels],
[0, 0, 0, 0, 1],
]
bias_matrix = [
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 10],
[0, 0, 0, 0, 1],
]
if ofm_layout == "NHCWB16":
ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist()
weights_matrix = np.matmul(weights_matrix, nhcwb16_to_nhwc).tolist()
bias_matrix = np.matmul(bias_matrix, nhcwb16_to_nhwc).tolist()
if ifm_layout == "NHCWB16":
ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist()
ifm_propagator = Propagator(
ifm_matrix,
[0, -padding[0], -padding[1], 0]
if ifm_layout == "NHWC"
else [0, -padding[0], 0, -padding[1], 0],
)
weights_propagator = Propagator(
weights_matrix,
[0, 0, 0, 0],
)
bias_propagator = Propagator(
bias_matrix,
[0, 0],
)
propagator_attrs = {
"ifm_propagator": ifm_propagator,
"weights_propagator": weights_propagator,
"bias_propagator": bias_propagator,
}

# Compute operation for the OFM DMA pipeline
return dma_ofm_compute(conv, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels)
dma_ofm = dma_ofm_compute(
conv, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels, attrs=propagator_attrs
)
return dma_ofm


@register_matcher
def match_ethosu_conv2d(output_tensor):
"""Match a Tensor Expression corresponding to an NPU Conv2D.
If the Tensor Expression matches, an EthosuPart will be created that models the
matched Tensor Expression. Otherwise, None will be returned.
Parameters
----------
output_tensor : tvm.te.Tensor
The tensor to attempt to match with.
Returns
-------
Union[None, EthosuPart]
The created EthosuPart if there was a match, otherwise None.
"""
write = output_tensor
if write.op.name != "ethosu_write":
return None
convert_to_nhcwb16 = write.op.input_tensors[0]
if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16":
return None
conv2d = convert_to_nhcwb16.op.input_tensors[0]
if conv2d.op.name != "ethosu_conv2d":
return None
pad = conv2d.op.input_tensors[0]
if pad.op.name != "ethosu_pad":
return None
convert_to_nhwc = pad.op.input_tensors[0]
if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc":
return None
read = convert_to_nhwc.op.input_tensors[0]
if read.op.name != "ethosu_read":
return None

input_tensors = [
read.op.input_tensors[0],
conv2d.op.input_tensors[1],
conv2d.op.input_tensors[2],
]
subgraph = TESubgraph(input_tensors, output_tensor)
propagators = [
write.op.attrs["ifm_propagator"],
write.op.attrs["weights_propagator"],
write.op.attrs["bias_propagator"],
]
# TODO(@jacobbohlin) Both the output_quantum and quantum_cycles here are placeholders,
# needs true implementation.
if convert_to_nhcwb16.op.attrs["layout"] == "NHWC":
output_quantum = [1, 2, 2, 1]
else:
output_quantum = [1, 2, 1, 2, 1]
quantum_cycles = 1000
return EthosuPart(subgraph, propagators, output_quantum, quantum_cycles)
21 changes: 18 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/te/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def read_compute(


def write_compute(
tensor: te.Tensor, zero_point: int, scale: float, layout: Optional[str] = None
tensor: te.Tensor,
zero_point: int,
scale: float,
layout: Optional[str] = None,
attrs: dict = None,
) -> te.Tensor:
"""A tensor expression which represents a write.
Expand All @@ -117,6 +121,8 @@ def write_compute(
The scale of the tensor.
layout : Optional[str]
The layout of the tensor, either NHWC or NHCWB16.
attrs : dict, optional
Additional attributes to add to the compute op.
Returns
-------
Expand All @@ -125,6 +131,9 @@ def write_compute(
"""

if not attrs:
attrs = {}

write_attrs = {
"op": "ethosu_write",
"zero_point": zero_point,
Expand All @@ -135,6 +144,7 @@ def write_compute(
assert layout in {"NHWC", "NHCWB16"}
write_attrs["layout"] = layout

write_attrs = {**write_attrs, **attrs}
return te.compute(
tensor.shape,
lambda *i: tensor(*i),
Expand Down Expand Up @@ -304,7 +314,7 @@ def dma_ifm_compute(


def dma_ofm_compute(
ofm: te.Tensor, layout: str, zero_point: int, scale: float, channels: int
ofm: te.Tensor, layout: str, zero_point: int, scale: float, channels: int, attrs: dict = None
) -> te.Tensor:
"""A sequence of compute operators representing the DMA capabilities for an OFM.
Expand All @@ -320,12 +330,17 @@ def dma_ofm_compute(
The scale of the data.
channels : int
The number of valid channels for the data.
attrs : dict, optional
Additional attributes to add to the write compute op.
Returns
-------
te.Tensor
The dma-ed OFM tensor.
"""
if not attrs:
attrs = {}
convert_to_nhcwb16_ofm = convert_to_nhcwb16_compute(ofm, layout, channels)
return write_compute(convert_to_nhcwb16_ofm, zero_point, scale, layout=layout)
return write_compute(convert_to_nhcwb16_ofm, zero_point, scale, layout=layout, attrs=attrs)
Loading

0 comments on commit fcaa30c

Please sign in to comment.