From fcaa30ca541f7f5fe8901f5c2aaf3198c980c408 Mon Sep 17 00:00:00 2001 From: Matthew Barrett <55580676+mbaret@users.noreply.github.com> Date: Fri, 7 Jan 2022 16:06:04 +0000 Subject: [PATCH] [microNPU][2b] Create CascaderGraphs from TE graphs (#9471) The first step in the cascader is to create a CascaderGraph from a TE graph. To do this, every operator in the TE graph must get 'matched' by a Part matcher. This converts TE operations into cascader Parts by augmenting them with Propagators. In this initial commit, we include basic Part matchers for ethosu_conv2d and some transform operators so that the graph creation can be tested. --- .../tvm/contrib/ethosu/cascader/__init__.py | 12 +- python/tvm/contrib/ethosu/cascader/graph.py | 87 +++++++++- python/tvm/contrib/ethosu/cascader/parts.py | 24 +++ .../backend/contrib/ethosu/te/__init__.py | 1 + .../backend/contrib/ethosu/te/convolution.py | 133 +++++++++++++- .../relay/backend/contrib/ethosu/te/dma.py | 21 ++- .../relay/backend/contrib/ethosu/te/inline.py | 71 ++++++++ .../backend/contrib/ethosu/tir/scheduler.py | 3 +- src/contrib/ethosu/cascader/parts/ethosu.cc | 120 +++++++++++++ src/contrib/ethosu/cascader/parts/ethosu.h | 99 +++++++++++ .../contrib/test_ethosu/cascader/conftest.py | 73 ++++++++ .../contrib/test_ethosu/cascader/infra.py | 27 +++ .../cascader/test_ethosu_conv2d_matcher.py | 163 ++++++++++++++++++ .../cascader/test_ethosu_inline_matcher.py | 49 ++++++ .../test_ethosu/cascader/test_ethosu_part.py | 46 +++++ .../test_ethosu/cascader/test_graph.py | 95 +++++++--- 16 files changed, 987 insertions(+), 37 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/inline.py create mode 100644 src/contrib/ethosu/cascader/parts/ethosu.cc create mode 100644 src/contrib/ethosu/cascader/parts/ethosu.h create mode 100644 tests/python/contrib/test_ethosu/cascader/conftest.py create mode 100644 tests/python/contrib/test_ethosu/cascader/infra.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py index bf06d00566ba..72f1667c6151 100644 --- a/python/tvm/contrib/ethosu/cascader/__init__.py +++ b/python/tvm/contrib/ethosu/cascader/__init__.py @@ -21,5 +21,13 @@ """ from .stripe_config import StripeConfig from .propagator import Propagator -from .graph import PerformanceInfo, Tensor, Part, TESubgraph, CascaderGraph -from .parts import InlinePart +from .graph import ( + PerformanceInfo, + Tensor, + Part, + TESubgraph, + CascaderGraph, + register_matcher, + create_cascader_graph, +) +from .parts import InlinePart, EthosuPart diff --git a/python/tvm/contrib/ethosu/cascader/graph.py b/python/tvm/contrib/ethosu/cascader/graph.py index 001bbbf907b7..9b22e632ff89 100644 --- a/python/tvm/contrib/ethosu/cascader/graph.py +++ b/python/tvm/contrib/ethosu/cascader/graph.py @@ -15,16 +15,22 @@ # specific language governing permissions and limitations # under the License. """Graph objects to define compute graphs for the NPU cascader.""" -from typing import List +from typing import List, Dict from collections import namedtuple -import tvm._ffi +import numpy as np +import tvm._ffi +from tvm import te from tvm.runtime import Object from .stripe_config import StripeConfig from . import _ffi_api +# A global store to register matching functions +REGISTERED_MATCHERS = [] + + TESubgraph = namedtuple("TESubgraph", ["input_tensors", "output_tensor"]) @@ -168,3 +174,80 @@ def tensor_order(self): @property def part_order(self): return list(self._part_order) + + +def register_matcher(matcher): + """Register a match function to the frontend. + + A match function takes a te.Tensor and checks whether it matches + a known operator/operator sequence. If it does, it returns a Part + which models the behaviour of that operator sequence. Otherwise, + it returns None. + """ + REGISTERED_MATCHERS.append(matcher) + return matcher + + +def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray]) -> CascaderGraph: + """Create a CascaderGraph from a Tensor Expression graph and constant dictionary. + + Parameters + ---------- + te_graph : TESubgraph + The Tensor Expression graph. + const_dict : Dict[int, np.ndarray] + The constant dictionary. + + Returns + ------- + CascaderGraph + The CascaderGraph. + """ + tensor_map = {} + + def _visit_tensor(tensor): + if tensor not in tensor_map: + is_const = False + # Logic to determine if the tensor is constant + if tensor in list(te_graph.inputs): + i = list(te_graph.inputs).index(tensor) + if i in const_dict: + is_const = True + + # TODO(@mbaret): Calculate the compression ratio + plan_tensor = Tensor( + tensor.shape, + tensor.dtype, + is_constant=is_const, + ) + tensor_map[tensor] = plan_tensor + if isinstance(tensor.op, te.PlaceholderOp) or tensor in te_graph.inputs: + return + + input_tensors = [] + # Check whether any of the registered matchers match the current tensor + for matcher in REGISTERED_MATCHERS: + part = matcher(tensor) + if part: + input_tensors = part.subgraph.input_tensors + break + + assert part is not None, f"The tensor {tensor} doesn't match any part." + part.set_output(plan_tensor) + plan_tensor.add_producer(part) + for i, input_tensor in enumerate(input_tensors): + _visit_tensor(input_tensor) + part.set_input(i, tensor_map[input_tensor]) + tensor_map[input_tensor].add_consumer(part) + + for output in te_graph.outputs: + _visit_tensor(output) + + input_tensors = [] + for t in te_graph.inputs: + # This is needed because sometimes there are orphaned constants + if t in tensor_map: + input_tensors.append(tensor_map[t]) + + output_tensors = [tensor_map[t] for t in te_graph.outputs] + return CascaderGraph(input_tensors, output_tensors) diff --git a/python/tvm/contrib/ethosu/cascader/parts.py b/python/tvm/contrib/ethosu/cascader/parts.py index 48d2d77deb1f..9cc67d5760dd 100644 --- a/python/tvm/contrib/ethosu/cascader/parts.py +++ b/python/tvm/contrib/ethosu/cascader/parts.py @@ -38,3 +38,27 @@ def __init__( te_subgraph.output_tensor, propagators, ) + + +@tvm._ffi.register_object("contrib.ethosu.cascader.EthosuPart") +class EthosuPart(Part): + """A class to describe a Part to be executed on an Arm(R) Ethos(TM)-U NPU. + + EthosuParts must be provided with an output quantum and the cycles taken to + compute an output quantum which depend on the operator the NPU is computing.""" + + def __init__( + self, + te_subgraph: TESubgraph, + propagators: List[Propagator], + output_quantum: List[int], + quantum_cycles: int, + ): + self.__init_handle_by_constructor__( + _ffi_api.EthosuPart, + te_subgraph.input_tensors, + te_subgraph.output_tensor, + propagators, + output_quantum, + quantum_cycles, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index 21261521ac5f..2ede967a036c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -22,3 +22,4 @@ from .binary_elementwise import * from .identity import * from .unary_elementwise import * +from .inline import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 6e50c6ff3b0b..766af0dbbeef 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -17,8 +17,11 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for convolutions for the NPU""" from typing import Tuple, Union, List +import numpy as np # type: ignore from tvm import te # type: ignore +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -108,9 +111,10 @@ def conv2d_compute( assert ifm_layout in {"NHWC", "NHCWB16"} assert ofm_layout in {"NHWC", "NHCWB16"} - stride_h, stride_w = strides - dilation_h, dilation_w = dilation - ofm_channels, kernel_h, kernel_w, ifm_channels = weight.shape + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + dilation_h, dilation_w = [int(v) for v in dilation] + ofm_channels, kernel_h, kernel_w, ifm_channels = [int(v) for v in weight.shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute( @@ -164,5 +168,126 @@ def conv2d_compute( attrs=conv2d_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 0, 0], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weights_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weights_matrix = np.matmul(weights_matrix, nhcwb16_to_nhwc).tolist() + bias_matrix = np.matmul(bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + weights_propagator = Propagator( + weights_matrix, + [0, 0, 0, 0], + ) + bias_propagator = Propagator( + bias_matrix, + [0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "weights_propagator": weights_propagator, + "bias_propagator": bias_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(conv, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + dma_ofm = dma_ofm_compute( + conv, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels, attrs=propagator_attrs + ) + return dma_ofm + + +@register_matcher +def match_ethosu_conv2d(output_tensor): + """Match a Tensor Expression corresponding to an NPU Conv2D. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + conv2d = convert_to_nhcwb16.op.input_tensors[0] + if conv2d.op.name != "ethosu_conv2d": + return None + pad = conv2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + conv2d.op.input_tensors[1], + conv2d.op.input_tensors[2], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["weights_propagator"], + write.op.attrs["bias_propagator"], + ] + # TODO(@jacobbohlin) Both the output_quantum and quantum_cycles here are placeholders, + # needs true implementation. + if convert_to_nhcwb16.op.attrs["layout"] == "NHWC": + output_quantum = [1, 2, 2, 1] + else: + output_quantum = [1, 2, 1, 2, 1] + quantum_cycles = 1000 + return EthosuPart(subgraph, propagators, output_quantum, quantum_cycles) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index 5d51c7bfae20..14aa67bb37d3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -103,7 +103,11 @@ def read_compute( def write_compute( - tensor: te.Tensor, zero_point: int, scale: float, layout: Optional[str] = None + tensor: te.Tensor, + zero_point: int, + scale: float, + layout: Optional[str] = None, + attrs: dict = None, ) -> te.Tensor: """A tensor expression which represents a write. @@ -117,6 +121,8 @@ def write_compute( The scale of the tensor. layout : Optional[str] The layout of the tensor, either NHWC or NHCWB16. + attrs : dict, optional + Additional attributes to add to the compute op. Returns ------- @@ -125,6 +131,9 @@ def write_compute( """ + if not attrs: + attrs = {} + write_attrs = { "op": "ethosu_write", "zero_point": zero_point, @@ -135,6 +144,7 @@ def write_compute( assert layout in {"NHWC", "NHCWB16"} write_attrs["layout"] = layout + write_attrs = {**write_attrs, **attrs} return te.compute( tensor.shape, lambda *i: tensor(*i), @@ -304,7 +314,7 @@ def dma_ifm_compute( def dma_ofm_compute( - ofm: te.Tensor, layout: str, zero_point: int, scale: float, channels: int + ofm: te.Tensor, layout: str, zero_point: int, scale: float, channels: int, attrs: dict = None ) -> te.Tensor: """A sequence of compute operators representing the DMA capabilities for an OFM. @@ -320,6 +330,9 @@ def dma_ofm_compute( The scale of the data. channels : int The number of valid channels for the data. + attrs : dict, optional + Additional attributes to add to the write compute op. + Returns ------- @@ -327,5 +340,7 @@ def dma_ofm_compute( The dma-ed OFM tensor. """ + if not attrs: + attrs = {} convert_to_nhcwb16_ofm = convert_to_nhcwb16_compute(ofm, layout, channels) - return write_compute(convert_to_nhcwb16_ofm, zero_point, scale, layout=layout) + return write_compute(convert_to_nhcwb16_ofm, zero_point, scale, layout=layout, attrs=attrs) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/inline.py b/python/tvm/relay/backend/contrib/ethosu/te/inline.py new file mode 100644 index 000000000000..95e7342d5e82 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/inline.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tensor Expressions for operations that will be inlined""" +import numpy as np # type: ignore + +from tvm.contrib.ethosu.cascader import TESubgraph, InlinePart, Propagator, register_matcher + + +INLINE_OPS = {"T_reshape", "T_strided_slice"} + + +@register_matcher +def match_ethosu_inline(output_tensor): + """Match a Tensor Expression corresponding to an operator that will be inlined. + + If the Tensor Expression matches, an InlinePart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. This matcher is + naive and assumes nothing about the compute of the Tensor Expression. Therefore, + the resulting InlinePart will have full-tensor dependencies (i.e. each output + element depends on every input element). + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + + Returns + ------- + Union[None, InlinePart] + The created InlinePart if there was a match, otherwise None. + + """ + if output_tensor.op.name not in INLINE_OPS: + return None + + input_tensors = output_tensor.op.input_tensors + propagators = [] + output_dims = len(output_tensor.shape) + for input_tensor in input_tensors: + input_dims = len(input_tensor.shape) + transform_matrix = np.zeros((input_dims + 1, output_dims + 1)) + for i, axis in enumerate(input_tensor.shape): + transform_matrix[i, output_dims] = int(axis) + transform_matrix[input_dims, output_dims] = 1 + offset_vector = np.zeros(input_dims, dtype="int64") + propagators.append( + Propagator( + transform_matrix.tolist(), + offset_vector.tolist(), + ) + ) + + subgraph = TESubgraph(input_tensors, output_tensor) + return InlinePart( + subgraph, + propagators, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 3b20e783eb1a..572057452602 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Scheduling for Arm(R) Ethos(TM)-U NPU.""" import tvm +from tvm.contrib.ethosu.cascader import Propagator def schedule(cached_func, const_dict, cascader=None): @@ -203,7 +204,7 @@ def _add_pragmas(stage, ax): if "op" in [attr for attr, val in stage.op.attrs.items()]: stage.pragma(ax, "op", stage.op.attrs["op"]) for attr, val in stage.op.attrs.items(): - if attr not in ("op", "lut"): + if attr not in ("op", "lut") and not isinstance(val, Propagator): stage.pragma(ax, str(attr), val) for stage in sch.stages: diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc new file mode 100644 index 000000000000..29b43269c7b6 --- /dev/null +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "ethosu.h" + +#include + +#include +#include +#include + +#include "../common.h" +#include "../stripe_config.h" + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +const std::vector EthosuPartNode::GetBlockShape(const StripeConfig& output_stripe_config, + bool is_rollling) { + std::vector block_shape; + for (int axis : output_stripe_config->GetShape()) { + block_shape.push_back(std::min(axis, 4)); + } + return block_shape; +} + +const std::vector EthosuPartNode::GetBlockInputBytes_(const std::vector& block_shape) { + std::vector bytes_per_input; + std::vector strides; + std::vector order; + std::vector stripes; + std::vector offset; + for (size_t i = 0; i < block_shape.size(); i++) { + strides.push_back(1.0); + order.push_back(1); + stripes.push_back(1); + offset.push_back(0); + } + StripeConfig output_block_config(block_shape, block_shape, strides, order, stripes, offset); + auto input_block_configs = CalculateInputStripeConfigs(output_block_config); + for (const auto& input_block_config : input_block_configs) { + bytes_per_input.push_back(mul_reduce(input_block_config->GetShape())); + } + return bytes_per_input; +} + +const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config, + bool is_rolling) { + std::vector block_shape = GetBlockShape(output_stripe_config, is_rolling); + std::vector bytes_per_input = GetBlockInputBytes_(block_shape); + int bytes_per_output = mul_reduce(block_shape); + int num_blocks = 1; + for (size_t i = 0; i < block_shape.size(); i++) { + if (!is_rolling) { + num_blocks *= output_stripe_config->GetShape()[i] * output_stripe_config->GetStripes()[i] / + block_shape[i]; + } else { + num_blocks *= output_stripe_config->GetExtent()[i] / block_shape[i]; + } + } + int num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1; + std::vector read_bytes; + for (int block_bytes : bytes_per_input) { + read_bytes.push_back((num_blocks + num_stripes) * block_bytes); + } + int write_bytes = (num_blocks + num_stripes) * bytes_per_output; + auto shape = output_stripe_config->GetShape(); + PerformanceInfo info(0, read_bytes, write_bytes); + return info; +} + +EthosuPart::EthosuPart(const TESubgraph& subgraph, const std::vector propagators, + const std::vector output_quantum, int quantum_cycles) { + auto n = make_object(); + ICHECK_GT(propagators.size(), 0) << "The Part must include at least one Propagator."; + n->subgraph_ = subgraph; + n->propagators_ = std::move(propagators); + n->in_line_ = false; + n->input_tensors_.resize(propagators.size()); + n->output_quantum_ = output_quantum; + n->quantum_cycles_ = quantum_cycles; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPart") + .set_body_typed([](Array subgraph_inputs, te::Tensor subgraph_output, + Array propagators, Array output_quantum, + int quantum_cycles) { + std::vector vsubgraph_inputs(subgraph_inputs.begin(), subgraph_inputs.end()); + std::vector vpropagators(propagators.begin(), propagators.end()); + TESubgraph subgraph; + subgraph.input_tensors = vsubgraph_inputs; + subgraph.output_tensor = subgraph_output; + std::vector voutput_quantum = make_vector(output_quantum); + return EthosuPart(subgraph, vpropagators, voutput_quantum, quantum_cycles); + }); + +TVM_REGISTER_NODE_TYPE(EthosuPartNode); + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/ethosu/cascader/parts/ethosu.h b/src/contrib/ethosu/cascader/parts/ethosu.h new file mode 100644 index 000000000000..ab3ca69d2717 --- /dev/null +++ b/src/contrib/ethosu/cascader/parts/ethosu.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/ethosu/cascader/parts/ethosu.h + * \brief Arm(R) Ethos(TM)-U NPU Part object + */ +#ifndef TVM_CONTRIB_ETHOSU_CASCADER_PARTS_ETHOSU_H_ +#define TVM_CONTRIB_ETHOSU_CASCADER_PARTS_ETHOSU_H_ + +#include + +#include + +#include "../graph.h" + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +/*! \brief Node to represent an EthosuPart */ +class EthosuPartNode : public PartNode { + public: + /*! + * \brief Get the optimal block shape to use. + * \param output_stripe_config The output StripeConfig. + * \param is_rolling Whether the output config should be computed as a rolling buffer. + */ + const std::vector GetBlockShape(const StripeConfig& output_stripe_config, bool is_rolling); + /*! + * \brief Get the preferred alignment in each axis for a stripe of the Part. + * \note This is used to bias the selection of StripeConfigs towards those that are integer + * multiples of a tensor intrinsic used to compute the Part. + */ + const std::vector GetStripeAlignHint() const final { return output_quantum_; } + /*! + * \brief Get the performance information for a given output stripe config. + * \param output_stripe_config The output stripe config to compute the performance for. + * \param is_rolling Whether the output config should be computed as a rolling buffer. + * \return The performance information containing the compute cycles and read/write bytes. + */ + const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, + bool is_rolling) final; + + static constexpr const char* _type_key = "contrib.ethosu.cascader.EthosuPart"; + TVM_DECLARE_FINAL_OBJECT_INFO(EthosuPartNode, PartNode); + + protected: + friend class EthosuPart; + + /*! + * \brief Get the size of input required (per input tensor) to compute a block. + * \param block_shape The shape of the block to compute. + * \return The bytes required per input tensor. + */ + const std::vector GetBlockInputBytes_(const std::vector& block_shape); + + /*! \brief The output volume that is atomically computed */ + std::vector output_quantum_; + /*! \brief The cycles taken to compute a single output quantum */ + int quantum_cycles_; +}; + +/*! + * \brief A class to describe a Part to be executed on an Arm(R) Ethos(TM)-U NPU. + * \note EthosuParts must be provided with an output quantum and the cycles taken to + * compute an output quantum which depend on the operator the NPU is computing. + */ +class EthosuPart : public Part { + public: + EthosuPart(const TESubgraph& subgraph, const std::vector propagators, + const std::vector output_quantum, int quantum_cycles); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EthosuPart, Part, EthosuPartNode); +}; + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_ETHOSU_CASCADER_PARTS_ETHOSU_H_ diff --git a/tests/python/contrib/test_ethosu/cascader/conftest.py b/tests/python/contrib/test_ethosu/cascader/conftest.py new file mode 100644 index 000000000000..58ffb51a5967 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/conftest.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +ethosu_enabled = True +try: + import ethosu.vela +except ImportError: + ethosu_enabled = False + + +if ethosu_enabled: + import tvm + from tvm import relay + from tvm.relay.testing import run_opt_pass + + from .infra import create_te_graph + from ..infra import make_ethosu_conv2d + + def make_TwoConv2DWithSliceTE(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 12, 12, 8), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm=ifm, + ifm_channels=8, + ofm_channels=64, + kernel_shape=(1, 1), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + ) + strided_slice = relay.strided_slice(conv1, [0, 0, 0, 0], [1, 6, 6, 128]) + conv2 = make_ethosu_conv2d( + ifm=strided_slice, + ifm_channels=64, + ofm_channels=16, + kernel_shape=(3, 3), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHCWB16", + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + te_graph, const_dict = create_te_graph(func) + sch = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + return sch, te_graph, const_dict + + @pytest.fixture + def TwoConv2DWithSliceTE(): + return make_TwoConv2DWithSliceTE() diff --git a/tests/python/contrib/test_ethosu/cascader/infra.py b/tests/python/contrib/test_ethosu/cascader/infra.py new file mode 100644 index 000000000000..baf398dc3602 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/infra.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import extract_constants, lower_to_te + + +def create_te_graph(func): + func, consts = extract_constants(func) + mod = tvm.IRModule.from_expr(func) + func = relay.transform.InferType()(mod)["main"] + te_graph = lower_to_te(func) + return te_graph, consts diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py new file mode 100644 index 000000000000..79a139594b3e --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.convolution import match_ethosu_conv2d, conv2d_compute + +import numpy as np + + +def _make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 0, 0], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + scale_bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weight_matrix = np.matmul(weight_matrix, nhcwb16_to_nhwc).tolist() + scale_bias_matrix = np.matmul(scale_bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_offset = ( + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0] + ) + weight_offset = [0, 0, 0, 0] + scale_bias_offset = [0, 0] + return ( + ifm_matrix, + ifm_offset, + weight_matrix, + weight_offset, + scale_bias_matrix, + scale_bias_offset, + ) + + +@pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("dilation", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_channels", [8, 57]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_conv2d_matcher( + kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout +): + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ifm_channels) + else: + ifm_shape = (1, 12, 1 + ((ifm_channels - 1) // 16), 15, 16) + ofm_channels = 8 + kernel_h, kernel_w = kernel + ifm = te.placeholder(ifm_shape, dtype="int8") + weight = te.placeholder((ofm_channels, kernel_h, kernel_w, ifm_channels), dtype="int8") + scale_bias = te.placeholder((ofm_channels, 10), dtype="uint8") + lut = te.placeholder((), dtype="uint8") + out = conv2d_compute( + ifm=ifm, + weight=weight, + scale_bias=scale_bias, + lut=lut, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + weight_zero_point=0, + strides=stride, + padding=padding, + dilation=dilation, + activation="NONE", + clip_min=0, + clip_max=0, + upscale="NONE", + rounding_mode="TFL", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + ( + ifm_transform, + ifm_offset, + weight_transform, + weight_offset, + scale_bias_transform, + scale_bias_offset, + ) = _make_matrices( + kernel, + stride, + dilation, + padding, + ifm_channels, + ifm_layout, + ofm_layout, + ) + + part = match_ethosu_conv2d(out) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 3 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + assert part.propagators[1].transform == weight_transform + assert part.propagators[1].offset == weight_offset + assert part.propagators[2].transform == scale_bias_transform + assert part.propagators[2].offset == scale_bias_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py new file mode 100644 index 000000000000..a3639ba03077 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from tvm import te +from tvm.topi.transform import reshape +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.inline import match_ethosu_inline + + +def test_ethosu_inline_matcher(): + ifm_shape = (2, 5, 6) + new_shape = (2, 30) + ifm = te.placeholder(ifm_shape, dtype="int8") + out = reshape(ifm, new_shape) + ifm_transform = [ + [0, 0, ifm_shape[0]], + [0, 0, ifm_shape[1]], + [0, 0, ifm_shape[2]], + [0, 0, 1], + ] + ifm_offset = [0, 0, 0] + + part = match_ethosu_inline(out) + + assert isinstance(part, cs.InlinePart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py new file mode 100644 index 000000000000..ef449a49976c --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm.contrib.ethosu.cascader as pl +from tvm.contrib.ethosu.cascader.parts import EthosuPart + + +def test_ethosu_part(): + te_subgraph = pl.TESubgraph([], None) + output_quantum = [1, 2, 2, 8] + quantum_cycles = 32 + propagator = pl.Propagator( + [[1, 0, 0, 0, 2], [0, 1, 0, 0, 2], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], + [0, 0, 0, 0], + ) + stripe_config = pl.StripeConfig( + [1, 4, 4, 16], [1, 64, 72, 96], [1, 4, 4, 16], [1, 2, 3, 4], [1, 16, 13, 6], [0, 0, 0, 0] + ) + + part = EthosuPart(te_subgraph, [propagator], output_quantum, quantum_cycles) + + assert part.get_stripe_align_hint() == output_quantum + # Check that the performance model runs, don't verify output + part.get_performance_info(stripe_config, False) + part.get_performance_info(stripe_config, True) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py index f00eb96251d5..3bab83f24143 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_graph.py +++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py @@ -16,14 +16,9 @@ # under the License. import pytest -from tvm.contrib.ethosu.cascader import ( - StripeConfig, - Propagator, - Tensor, - InlinePart, - TESubgraph, - CascaderGraph, -) +pytest.importorskip("ethosu.vela") + +import tvm.contrib.ethosu.cascader as cs def test_tensor(): @@ -32,7 +27,7 @@ def test_tensor(): is_constant = True compression_ratio = 0.5 size = 6 - tensor = Tensor(shape, dtype, is_constant, compression_ratio) + tensor = cs.Tensor(shape, dtype, is_constant, compression_ratio) assert tensor.shape == shape assert tensor.dtype == dtype assert tensor.is_constant == is_constant @@ -41,18 +36,18 @@ def test_tensor(): def test_inline_part(): - subgraph = TESubgraph([], None) - part = InlinePart( + subgraph = cs.TESubgraph([], None) + part = cs.InlinePart( subgraph, [ - Propagator( + cs.Propagator( [[0, 1, 0], [1, 0, 0], [0, 0, 1]], [0, 0], ), ], ) - output_stripe_config = StripeConfig([2, 4], [8, 8], [2, 4], [1, 2], [4, 2], [0, 0]) - input_stripe_config = StripeConfig([4, 2], [8, 8], [4, 2], [2, 1], [2, 4], [0, 0]) + output_stripe_config = cs.StripeConfig([2, 4], [8, 8], [2, 4], [1, 2], [4, 2], [0, 0]) + input_stripe_config = cs.StripeConfig([4, 2], [8, 8], [4, 2], [2, 1], [2, 4], [0, 0]) assert part.input_tensors == [None] assert part.output_tensor == None @@ -69,33 +64,33 @@ def test_inline_part(): def test_small_graph(): - subgraph = TESubgraph([], None) - part_a = InlinePart( + subgraph = cs.TESubgraph([], None) + part_a = cs.InlinePart( subgraph, [ - Propagator( + cs.Propagator( [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0, 0], ), - Propagator( + cs.Propagator( [[0, 1, 0], [1, 0, 0], [0, 0, 1]], [-1, -1], ), ], ) - part_b = InlinePart( + part_b = cs.InlinePart( subgraph, [ - Propagator( + cs.Propagator( [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0, 0], ), ], ) - tensor_1 = Tensor([10, 10], "uint8") - tensor_2 = Tensor([9, 9], "uint8") - tensor_3 = Tensor([10, 10], "uint8") - tensor_4 = Tensor([10, 10], "uint8") + tensor_1 = cs.Tensor([10, 10], "uint8") + tensor_2 = cs.Tensor([9, 9], "uint8") + tensor_3 = cs.Tensor([10, 10], "uint8") + tensor_4 = cs.Tensor([10, 10], "uint8") part_a.set_input(0, tensor_1) part_a.set_input(1, tensor_2) @@ -122,7 +117,7 @@ def test_small_graph(): assert tensor_4.producers == [part_b] assert tensor_4.consumers == [] - graph = CascaderGraph([tensor_1, tensor_2], [tensor_4]) + graph = cs.CascaderGraph([tensor_1, tensor_2], [tensor_4]) assert graph.input_tensors == [tensor_1, tensor_2] assert graph.output_tensors == [tensor_4] assert graph.part_order == [part_b, part_a] @@ -130,5 +125,55 @@ def test_small_graph(): assert graph.get_part_id(part) == i +def test_create_cascader_graph(TwoConv2DWithSliceTE): + _, te_graph, const_dict = TwoConv2DWithSliceTE + graph = cs.create_cascader_graph(te_graph, const_dict) + + output_tensor = graph.output_tensors[0] + assert output_tensor.shape == [1, 6, 1, 6, 16] + assert len(output_tensor.producers) == 1 + assert not output_tensor.is_constant + + conv2_part = output_tensor.producers[0] + assert isinstance(conv2_part, cs.EthosuPart) + assert len(conv2_part.input_tensors) == 3 + + assert conv2_part.input_tensors[0].shape == [1, 6, 6, 64] + assert len(conv2_part.input_tensors[0].producers) == 1 + assert not conv2_part.input_tensors[0].is_constant + + assert conv2_part.input_tensors[1].shape == [16, 3, 3, 64] + assert len(conv2_part.input_tensors[1].producers) == 0 + assert conv2_part.input_tensors[1].is_constant + + assert conv2_part.input_tensors[2].shape == [16, 10] + assert len(conv2_part.input_tensors[2].producers) == 0 + assert conv2_part.input_tensors[2].is_constant + + slice_part = conv2_part.input_tensors[0].producers[0] + assert isinstance(slice_part, cs.InlinePart) + assert len(slice_part.input_tensors) == 1 + + assert slice_part.input_tensors[0].shape == [1, 12, 12, 64] + assert len(slice_part.input_tensors[0].producers) == 1 + assert not slice_part.input_tensors[0].is_constant + + conv1_part = slice_part.input_tensors[0].producers[0] + assert isinstance(conv1_part, cs.EthosuPart) + assert len(conv1_part.input_tensors) == 3 + + assert conv1_part.input_tensors[0].shape == [1, 12, 12, 8] + assert len(conv1_part.input_tensors[0].producers) == 0 + assert not conv1_part.input_tensors[0].is_constant + + assert conv1_part.input_tensors[1].shape == [64, 1, 1, 8] + assert len(conv1_part.input_tensors[1].producers) == 0 + assert conv1_part.input_tensors[1].is_constant + + assert conv1_part.input_tensors[2].shape == [64, 10] + assert len(conv1_part.input_tensors[2].producers) == 0 + assert conv1_part.input_tensors[2].is_constant + + if __name__ == "__main__": pytest.main([__file__])