diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py index 72f1667c6151..03753b4049bb 100644 --- a/python/tvm/contrib/ethosu/cascader/__init__.py +++ b/python/tvm/contrib/ethosu/cascader/__init__.py @@ -20,6 +20,7 @@ for both performance and memory usage on Arm(R) Ethos(TM)-U NPUs. """ from .stripe_config import StripeConfig +from .block_config import BlockConfig from .propagator import Propagator from .graph import ( PerformanceInfo, @@ -27,7 +28,9 @@ Part, TESubgraph, CascaderGraph, + BufferMode, register_matcher, create_cascader_graph, ) from .parts import InlinePart, EthosuPart +from .device_config import EthosuDeviceConfig diff --git a/python/tvm/contrib/ethosu/cascader/block_config.py b/python/tvm/contrib/ethosu/cascader/block_config.py new file mode 100644 index 000000000000..3281b8a3606f --- /dev/null +++ b/python/tvm/contrib/ethosu/cascader/block_config.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Block config to hold an output block shape and a corresponding input block shape""" +from typing import List +import tvm._ffi + +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("contrib.ethosu.cascader.BlockConfig") +class BlockConfig(Object): + """BlockConfig class""" + + def __init__(self, output_shape: List[int], compute_cycles: int, output_cycles: int): + self.__init_handle_by_constructor__( + _ffi_api.BlockConfig, output_shape, compute_cycles, output_cycles + ) + + @property + def output_shape(self) -> List[int]: + return list(self._output_shape) + + @property + def compute_cycles(self) -> int: + return int(self._compute_cycles) + + @property + def output_cycles(self) -> int: + return int(self._output_cycles) + + def __repr__(self) -> str: + return f"BlockConfig(output_shape={self.output_shape})" diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py new file mode 100644 index 000000000000..5ad7fde1ed52 --- /dev/null +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -0,0 +1,661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Device config class to hold information about the target hardware""" +from typing import Tuple, List, Dict +from functools import reduce + +import math + +from . import BlockConfig +from . import StripeConfig +from . import Propagator + + +def _round_up(a: int, b: int) -> int: + """Round up to a multiple of b""" + return ((a + b - 1) // b) * b + + +def _round_up_div(a: int, b: int) -> int: + """Divide by b and round up to a multiple of b""" + return (a + b - 1) // b + + +class _Shape: + """Helper class for dealing with Tensor shapes of different layouts""" + + def __init__(self, shape: List[int], layout="NHWC"): + if layout == "NHCWB16": + self.height = int(shape[1]) + self.width = int(shape[3]) + self.depth = int(shape[2]) * int(shape[4]) + else: + self.height = int(shape[1]) + self.width = int(shape[2]) + self.depth = int(shape[3]) + + def round_up(self, other: "_Shape"): + self.height = _round_up(self.height, other.height) + self.width = _round_up(self.width, other.width) + self.depth = _round_up(self.depth, other.depth) + + def area(self) -> int: + return self.height * self.width + + def as_list(self): + return [1, self.height, self.width, self.depth] + + +class EthosuDeviceConfig: + """Arm(R) Ethos(TM)-U NPU config class""" + + def __init__(self, device: str): + self._device = device + self._subkernel_limits = (8, 8) + self._output_cycles = (1, 2, 3, 4, 6) + self._split_depth = 16 + self._max_block_shape = _Shape([1, 32, 64, 128]) + self._bank_size_bytes = 1024 + if self._device == "ethos-u55-256": + self._micro_block = _Shape([1, 2, 2, 8]) + self._input_micro_block = _Shape([1, 2, 2, 8]) + self._delay_cycles = (2, 2) + self._activation_cycles = (0.25, 1) + self._output_units = 8 + + self._total_banks = 48 + self._reserved_banks = 4 + self._input_granularity = 8 + self._accumulator_granularity = {4: 16, 5: 20} + self._lut_reserved = True + elif self._device == "ethos-u55-128": + self._micro_block = _Shape([1, 1, 2, 8]) + self._input_micro_block = _Shape([1, 1, 2, 8]) + self._delay_cycles = (2, 3) + self._activation_cycles = (0.5, 1) + self._output_units = 4 + + self._total_banks = 24 + self._reserved_banks = 4 + self._input_granularity = 4 + self._accumulator_granularity = {4: 8, 5: 12} + self._lut_reserved = True + elif self._device == "ethos-u55-64": + self._micro_block = _Shape([1, 1, 1, 8]) + self._input_micro_block = _Shape([1, 1, 1, 8]) + self._delay_cycles = (2, 3) + self._activation_cycles = (1, 1) + self._output_units = 2 + + self._total_banks = 16 + self._reserved_banks = 2 + self._input_granularity = 2 + self._accumulator_granularity = {4: 4, 5: 8} + self._lut_reserved = False + elif self._device == "ethos-u55-32": + self._micro_block = _Shape([1, 1, 1, 4]) + self._input_micro_block = _Shape([1, 1, 1, 8]) + self._delay_cycles = (3, 7) + self._activation_cycles = (1, 2) + self._output_units = 1 + + self._total_banks = 16 + self._reserved_banks = 2 + self._input_granularity = 2 + self._accumulator_granularity = {4: 4, 5: 8} + self._lut_reserved = False + + def _get_output_cycles( + self, op_type: str, op_str: str, ifm_dtype: str, ofm_dtype: str, activation: str + ) -> float: + """Estimate cycles per output element for an NPU operator + + Parameters + ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" + op_str : str + The type of NPU operator. + "MAX" + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str + Datatype of the Ouput Feature Map tensor (OFM) + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + + Returns + ------- + float + The cycles per output element + """ + cycles = 0 + bw_limit = 0 + if op_type == "ethosu_pooling" and op_str == "MAX": + cycles = self._output_cycles[0] + elif op_type in ("ethosu_pooling", "ethosu_conv2d", "ethosu_depthwise_conv2d"): + cycles = self._output_cycles[1] if ifm_dtype == "int8" else self._output_cycles[2] + elif op_type == "ethosu_binary_elementwise": + # Binary Bandwidth Limitations + if ifm_dtype == "int8": + bw_limit = 0.125 if ofm_dtype == "int8" else 0.75 + elif ifm_dtype == "int16": + bw_limit = 0.75 if ofm_dtype == "int16" else 1 + else: + bw_limit = 1.5 + + if op_str in ("MIN", "MAX"): + cycles = self._output_cycles[1] + elif op_str == "MUL": + cycles = self._output_cycles[2] + if op_str in ("ADD", "SUB"): + if ofm_dtype == "int32": + cycles = ( + self._output_cycles[2] if ifm_dtype == "int32" else self._output_cycles[3] + ) + else: + cycles = self._output_cycles[4] + + elif op_type == "ethosu_unary_elementwise": + # Unary Bandwidth Limitations + if ifm_dtype == "int16": + bw_limit = 0.25 + elif ifm_dtype == "int32": + bw_limit = 1 + + if op_str == "CLZ": + cycles = self._output_cycles[1] + elif op_str in ("SHL", "SHR"): + cycles = self._output_cycles[2] + elif op_str in ("LRELU", "ABS"): + cycles = self._output_cycles[1] + if ifm_dtype == "int16": + bw_limit = 0.5 + + act_cycles = 0 + if activation == "CLIP": + act_cycles = self._activation_cycles[0] + elif activation in ("LUT", "TANH", "SIGMOID"): + act_cycles = self._activation_cycles[1] + + return max((cycles / self._output_units), act_cycles, bw_limit) + + def _get_delay_cycles(self, op_type: str, ifm_dtype: str) -> int: + """Get the number of delay cycles during a bubble + + Parameters + ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" + op_str : str + The type of NPU operator. + "MAX" + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + + Returns + ---------- + int + The amount of delay cycles + """ + if op_type in ("ethosu_conv2d", "ethosu_depthwise2d", "ethosu_pooling"): + if ifm_dtype == "int16": + return self._delay_cycles[1] + + return self._delay_cycles[0] + + return 0 + + def _get_weight_decoder_cycles(self, op_type: str) -> int: + """Get cycle estimate for weight decoding + + Parameters + ---------- + op_type: str + The NPU primitive operator + "ethosu_pooling" + + Returns + ---------- + int + Estimated cycles for weight decoding + """ + if op_type in ("ethosu_conv2d", "ethosu_depthwise2d"): + return 32 * self._micro_block.depth // 8 + + return 0 + + def get_output_quantum(self, ofm_layout: str) -> Tuple[int]: + """Get the atomic output volume + + Parameters + ---------- + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ---------- + Tuple[int] + The atomic output volume formatted to the ofm_layout parameter + """ + if ofm_layout == "NHCWB16": + return [ + 1, + self._micro_block.height, + 1, + self._micro_block.width, + self._micro_block.depth, + ] + + return self._micro_block.as_list() + + def _align(self, x: int, n: int) -> int: + return int(math.ceil(x / n) * n) + + def _get_input_size( + self, output_size: int, kernel_stride: int, border: int, upscaling_factor: int + ) -> int: + return int(math.ceil(((output_size - 1) * kernel_stride + border)) / upscaling_factor) + + def _get_dilated_kernel_size(self, kernel_size: int, dilation: int) -> int: + return (kernel_size - 1) * dilation + 1 + + def _get_input_block( + self, + output_block: _Shape, + input_shape: _Shape, + dtype: str, + op_type: str, + is_partkernel: bool, + stride_h: int, + stride_w: int, + dilated_kernel_h: int, + dilated_kernel_w: int, + upscaling_factor: int, + ) -> _Shape: + height = self._get_input_size( + output_block.height, + stride_h, + min(dilated_kernel_h, self._subkernel_limits[0]), + upscaling_factor, + ) + width = self._get_input_size( + output_block.width, + stride_w, + min(dilated_kernel_w, self._subkernel_limits[1]), + upscaling_factor, + ) + + if op_type == "ethosu_conv2d": + if dtype == "int8": + if is_partkernel: + depth = self._align(min(32, input_shape.depth), 8) + else: + depth = self._align(min(16, input_shape.depth), 8) + elif dtype == "int16": + depth = self._align(min(16, input_shape.depth), 4) + else: + depth = self._align(min(8, input_shape.depth), 2) + else: + depth = output_block.depth + + return _Shape( + [ + 1, + self._align(height, self._micro_block.height), + self._align(width, self._micro_block.width), + depth, + ] + ) + + def get_kernel_steps( + self, + dilated_kernel_h: int, + dilated_kernel_w: int, + ifm_dtype: str, + is_partkernel: bool = False, + ) -> List[int]: + """Calculate the total number of subkernels and their sizes + + Parameters + ---------- + dilated_kernel_h: int + Height of dilated kernel + dilated_kernel_w: int + Width of dilated kernel + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + is_partkernel: bool + Flag showing whether part-kernel first traversal is used + + Returns + ---------- + List[int] + List where each entry contains the amount of elements in one of the subkernels + """ + subkernels = self._get_subkernels(dilated_kernel_h, dilated_kernel_w) + + # Determine the number of kernel steps per subkernel + kernel_steps = [] + for y, x in subkernels: + subkernel_elements = x * y + if is_partkernel: + # Part-kernel-first traversal + divisor = 4 if ifm_dtype == "int8" else 2 + kernel_steps.append(int(_round_up_div(subkernel_elements, divisor))) + else: + # Depth-first traversal + kernel_steps.append(int(subkernel_elements)) + + return kernel_steps + + def _get_subkernels(self, dilated_kernel_h: int, dilated_kernel_w: int): + num_subkernels_y = _round_up_div(dilated_kernel_h, self._subkernel_limits[0]) + num_subkernels_x = _round_up_div(dilated_kernel_w, self._subkernel_limits[1]) + subkernels_y = [ + min((dilated_kernel_h - i * self._subkernel_limits[0]), self._subkernel_limits[0]) + for i in range(num_subkernels_y) + ] + subkernels_x = [ + min((dilated_kernel_w - i * self._subkernel_limits[1]), self._subkernel_limits[1]) + for i in range(num_subkernels_x) + ] + + subkernels = [] + for y in subkernels_y: + for x in subkernels_x: + subkernels.append((y, x)) + + return subkernels + + def _get_accumulator_width(self, op_type: str, ifm_dtype: str): + if ifm_dtype == "int16" and op_type != "ethosu_pooling": + return 5 + + return 4 + + def is_partkernel( + self, op_type: str, ifm_channels: int, ifm_dtype: str, kernel_elements: int + ) -> bool: + """Determine which block traversal strategy has better DPU utilization + + Parameters + ---------- + op_type: str + The NPU primitive operator + "ethosu_pooling" + ifm_channels: int + Number of input channels + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + kernel_elements: int + Total number of elements in the kernel + + Returns + ---------- + bool + True if partkernel first has best DPU utilization + """ + if op_type != "ethosu_conv2d": + return False + + depth_first_utilization = ifm_channels / _round_up( + ifm_channels, 32 if ifm_dtype == "int8" else 16 + ) + part_kernel_first_utilization = (ifm_channels / _round_up(ifm_channels, 8)) * ( + kernel_elements / _round_up(kernel_elements, 4 if ifm_dtype == "int8" else 2) + ) + + return part_kernel_first_utilization > depth_first_utilization or ifm_channels <= 8 + + def get_valid_block_configs( + self, + ifm_propagator: Propagator, + op_attrs: Dict, + output_shape: List[int], + ofm_channels: int, + ifm_channels: int, + output_layout: str, + input_layout: str, + ifm_dtype: str, + ofm_dtype: str, + kernel_h: int = 1, + kernel_w: int = 1, + ) -> List[BlockConfig]: + """Get all of the valid block configs + + Parameters + ---------- + ifm_propagator: Propagator, + The propagator containing the data dependencies between input and output + op_attrs: Dict, + Dictionary containing operator attributes + output_shape: List[int], + Shape of the output tensor + ofm_channels: int, + Number of output channels + ifm_channels: int, + Number of input channels + output_layout: str, + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + input_layout: str, + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm_dtype: str, + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str, + Datatype of the Output Feature Map tensor (OFM) + kernel_h: int, + Height of kernel + kernel_h: int + Width of kernel + + Returns + ---------- + List[BlockConfig] + List containing all of the valid block configs + """ + valid_block_configs = [] + + op_type = op_attrs.get("op") + op_str = op_attrs.get("op_str") + activation = op_attrs.get("activation", "NONE") + stride_h = int(op_attrs.get("stride_h", 1)) + stride_w = int(op_attrs.get("stride_w", 1)) + upscaling_factor = 1 if op_attrs.get("upscale", "NONE") == "NONE" else 2 + + subkernel_transform = ifm_propagator.transform + if output_layout == "NHCWB16": + output_shape = _Shape([1, output_shape[1], output_shape[3], ofm_channels]) + else: + output_shape = _Shape(output_shape) + + if input_layout == "NHCWB16": + subkernel_transform[1][-1] = min( + subkernel_transform[1][-1], self._subkernel_limits[0] - stride_h + ) + subkernel_transform[3][-1] = min( + subkernel_transform[3][-1], self._subkernel_limits[1] - stride_w + ) + else: + subkernel_transform[1][-1] = min( + subkernel_transform[1][-1], self._subkernel_limits[0] - stride_h + ) + subkernel_transform[2][-1] = min( + subkernel_transform[2][-1], self._subkernel_limits[1] - stride_w + ) + + subkernel_propagator = Propagator(subkernel_transform, ifm_propagator.offset) + + # Define search space + max_height = min(output_shape.height, self._max_block_shape.height) + min_height = max(self._micro_block.height, upscaling_factor) + + max_width = min(output_shape.width, self._max_block_shape.width) + min_width = max(self._micro_block.width, upscaling_factor) + + max_depth = min(ofm_channels, self._max_block_shape.depth) + min_depth = max(self._micro_block.depth, upscaling_factor) + + input_bytewidth = 1 if ifm_dtype == "int8" else 2 + acc_bytewidth = self._get_accumulator_width(op_type, ifm_dtype) + banks_available = self._total_banks - self._reserved_banks + if activation == "LUT" and not self._lut_reserved: + banks_available -= 2 + + # Input block depth has additional limitations for Operators that require full input depth + input_block_depth = 0 + is_partkernel = self.is_partkernel(op_type, ifm_channels, ifm_dtype, kernel_h * kernel_w) + if op_type == "ethosu_conv2d": + if is_partkernel: + input_block_depth = min(ifm_channels, 16) + else: + input_block_depth = min(ifm_channels, 32) + + for depth in range(min_depth, max_depth + min_depth, min_depth): + if (depth < output_shape.depth) and (depth % self._split_depth != 0): + # Block depth has to be less than full depth or a multiple of the split depth + continue + + for width in range(min_width, max_width + min_width, min_width): + for height in range(min_height, max_height + min_height, min_height): + if output_layout == "NHCWB16": + output_block = ( + 1, + height, + 1 + ((depth - 1) // 16), + width, + _round_up( + min(16, max(ofm_channels, min_depth)), self._micro_block.depth + ), + ) + order = [1, 2, 4, 3, 0] + else: + output_block = (1, height, width, depth) + order = [1, 2, 3, 4] + + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + block_stripe_config = StripeConfig( + output_block, + output_block, + output_block, + order, + stripes, + offset, + ) + + # Propagate output block + input_block = subkernel_propagator.propagate(block_stripe_config) + + input_block_shape = _Shape(input_block.shape, input_layout) + input_block_shape.round_up(self._input_micro_block) + output_block_shape = _Shape(output_block, output_layout) + + if op_type == "ethosu_conv2d": + input_block_shape.depth = input_block_depth + + # Banks required for input block + input_bytes = input_block_shape.area() * self._align( + input_block_shape.depth * input_bytewidth, 8 + ) + input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2 + input_banks = _round_up(input_banks, self._input_granularity) + + # Banks required for accumulation + acc_depth = _round_up(min(output_block_shape.depth, ofm_channels), 8) + acc_bytes = ( + output_block_shape.area() * self._align(acc_depth, 8) * acc_bytewidth + ) + acc_banks = _round_up_div(acc_bytes, self._bank_size_bytes) * 2 + acc_banks = _round_up(acc_banks, self._accumulator_granularity[acc_bytewidth]) + + if (input_banks + acc_banks) <= banks_available: + + output_cycles = self._get_output_cycles( + op_type, op_str, ifm_dtype, ofm_dtype, activation + ) + output_cycles *= reduce(lambda a, b: a * b, output_block, 1) + output_cycles = int(_round_up(output_cycles, 1)) + compute_cycles = self._estimate_compute_cycles_per_block( + op_type, + output_block_shape, + input_block_shape, + kernel_h, + kernel_w, + ifm_channels, + is_partkernel, + ) + valid_block_configs.append( + BlockConfig(output_block, compute_cycles, output_cycles) + ) + else: + # Block config does not fit into SHRAM + # Any Block config that is strictly larger than this one will also fail + break + + return valid_block_configs + + def _estimate_compute_cycles_per_block( + self, + op_type: str, + block_shape: _Shape, + input_block_shape: _Shape, + kernel_h: int, + kernel_w: int, + input_channels: int, + ifm_dtype: str, + is_partkernel: bool = False, + ) -> Tuple[int, int]: + # Calculate the amount of micro blocks per block, per axis + num_quantum_x = _round_up_div(block_shape.width, self._micro_block.width) + num_quantum_y = _round_up_div(block_shape.height, self._micro_block.height) + num_quantum_z = _round_up_div(block_shape.depth, self._micro_block.depth) + num_quantum_xy = num_quantum_x * num_quantum_y + + kernel_steps = self.get_kernel_steps(kernel_h, kernel_w, ifm_dtype, is_partkernel) + + wd_cycles = self._get_weight_decoder_cycles(op_type) + delay_cycles = self._get_delay_cycles(op_type, ifm_dtype) + cycle_quantum = 4 + + compute_cycles = 0 + for subkernel_steps in kernel_steps: + compute_cycles += ( + max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_steps * num_quantum_z + ) + + if num_quantum_xy == 1: + if num_quantum_z == 1: + compute_cycles += delay_cycles * subkernel_steps + elif subkernel_steps > 1: + compute_cycles += delay_cycles * (subkernel_steps - 1) * num_quantum_z + + if is_partkernel: + compute_cycles *= _round_up_div(input_block_shape.depth, 8) + + if op_type == "ethosu_conv2d": + compute_cycles *= _round_up_div(input_channels, input_block_shape.depth) + + return compute_cycles diff --git a/python/tvm/contrib/ethosu/cascader/graph.py b/python/tvm/contrib/ethosu/cascader/graph.py index 9b22e632ff89..7aa4a26513cd 100644 --- a/python/tvm/contrib/ethosu/cascader/graph.py +++ b/python/tvm/contrib/ethosu/cascader/graph.py @@ -16,6 +16,7 @@ # under the License. """Graph objects to define compute graphs for the NPU cascader.""" from typing import List, Dict +from enum import IntEnum from collections import namedtuple import numpy as np @@ -24,6 +25,7 @@ from tvm.runtime import Object from .stripe_config import StripeConfig +from .device_config import EthosuDeviceConfig from . import _ffi_api @@ -34,6 +36,11 @@ TESubgraph = namedtuple("TESubgraph", ["input_tensors", "output_tensor"]) +class BufferMode(IntEnum): + RECOMPUTE = 0 + ROLLING = 1 + + @tvm._ffi.register_object("contrib.ethosu.cascader.PerformanceInfo") class PerformanceInfo(Object): """PerformanceInfo class""" @@ -113,9 +120,9 @@ def get_stripe_align_hint(self) -> List[int]: return list(_ffi_api.PartGetStripeAlignHint(self)) def get_performance_info( - self, stripe_config: StripeConfig, is_rolling: bool + self, stripe_config: StripeConfig, buffer_mode: BufferMode ) -> PerformanceInfo: - return _ffi_api.PartGetPerformanceInfo(self, stripe_config, is_rolling) + return _ffi_api.PartGetPerformanceInfo(self, stripe_config, buffer_mode) @property def input_tensors(self): @@ -188,7 +195,9 @@ def register_matcher(matcher): return matcher -def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray]) -> CascaderGraph: +def create_cascader_graph( + te_graph: TESubgraph, const_dict: Dict[int, np.ndarray], device_config: EthosuDeviceConfig +) -> CascaderGraph: """Create a CascaderGraph from a Tensor Expression graph and constant dictionary. Parameters @@ -197,6 +206,8 @@ def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray The Tensor Expression graph. const_dict : Dict[int, np.ndarray] The constant dictionary. + device_config : EthosuDeviceConfig + Target device configuration. Returns ------- @@ -227,7 +238,7 @@ def _visit_tensor(tensor): input_tensors = [] # Check whether any of the registered matchers match the current tensor for matcher in REGISTERED_MATCHERS: - part = matcher(tensor) + part = matcher(tensor, device_config) if part: input_tensors = part.subgraph.input_tensors break diff --git a/python/tvm/contrib/ethosu/cascader/parts.py b/python/tvm/contrib/ethosu/cascader/parts.py index 9cc67d5760dd..12588799a66a 100644 --- a/python/tvm/contrib/ethosu/cascader/parts.py +++ b/python/tvm/contrib/ethosu/cascader/parts.py @@ -20,6 +20,8 @@ from .propagator import Propagator from .graph import Part, TESubgraph +from .block_config import BlockConfig +from .stripe_config import StripeConfig from . import _ffi_api @@ -52,7 +54,9 @@ def __init__( te_subgraph: TESubgraph, propagators: List[Propagator], output_quantum: List[int], - quantum_cycles: int, + subkernels: int, + valid_block_configs: List[BlockConfig], + weight_tensor_idx: int = -1, ): self.__init_handle_by_constructor__( _ffi_api.EthosuPart, @@ -60,5 +64,10 @@ def __init__( te_subgraph.output_tensor, propagators, output_quantum, - quantum_cycles, + subkernels, + valid_block_configs, + weight_tensor_idx, ) + + def get_block_config(self, stripe_config: StripeConfig) -> BlockConfig: + return _ffi_api.EthosuPartGetBlockConfig(self, stripe_config) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 766af0dbbeef..c61082beb737 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -180,7 +180,7 @@ def conv2d_compute( [1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0], - [0, 0, 16, 0, 0, 0], + [0, 0, 16, 0, 1, -16], [0, 0, 0, 0, 0, 1], ] ifm_matrix = [ @@ -236,7 +236,7 @@ def conv2d_compute( @register_matcher -def match_ethosu_conv2d(output_tensor): +def match_ethosu_conv2d(output_tensor, device_config): """Match a Tensor Expression corresponding to an NPU Conv2D. If the Tensor Expression matches, an EthosuPart will be created that models the @@ -246,6 +246,8 @@ def match_ethosu_conv2d(output_tensor): ---------- output_tensor : tvm.te.Tensor The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration Returns ------- @@ -277,17 +279,50 @@ def match_ethosu_conv2d(output_tensor): conv2d.op.input_tensors[1], conv2d.op.input_tensors[2], ] + subgraph = TESubgraph(input_tensors, output_tensor) propagators = [ write.op.attrs["ifm_propagator"], write.op.attrs["weights_propagator"], write.op.attrs["bias_propagator"], ] - # TODO(@jacobbohlin) Both the output_quantum and quantum_cycles here are placeholders, - # needs true implementation. - if convert_to_nhcwb16.op.attrs["layout"] == "NHWC": - output_quantum = [1, 2, 2, 1] - else: - output_quantum = [1, 2, 1, 2, 1] - quantum_cycles = 1000 - return EthosuPart(subgraph, propagators, output_quantum, quantum_cycles) + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels, kernel_height, kernel_width = (int(axis) for axis in input_tensors[1].shape[0:3]) + kernel_elements = kernel_height * kernel_width + + is_part_kernel = device_config.is_partkernel( + conv2d.op.name, ifm_channels, ifm_dtype, kernel_elements + ) + subkernels = len( + device_config.get_kernel_steps(kernel_height, kernel_width, ifm_dtype, is_part_kernel) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + conv2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + kernel_height, + kernel_width, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + 1, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/inline.py b/python/tvm/relay/backend/contrib/ethosu/te/inline.py index 95e7342d5e82..79631f4b8c1c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/inline.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/inline.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """Tensor Expressions for operations that will be inlined""" import numpy as np # type: ignore @@ -24,7 +25,7 @@ @register_matcher -def match_ethosu_inline(output_tensor): +def match_ethosu_inline(output_tensor, device_config): """Match a Tensor Expression corresponding to an operator that will be inlined. If the Tensor Expression matches, an InlinePart will be created that models the @@ -37,6 +38,8 @@ def match_ethosu_inline(output_tensor): ---------- output_tensor : tvm.te.Tensor The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration Returns ------- diff --git a/src/contrib/ethosu/cascader/block_config.cc b/src/contrib/ethosu/cascader/block_config.cc new file mode 100644 index 000000000000..fe698aa17aac --- /dev/null +++ b/src/contrib/ethosu/cascader/block_config.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "block_config.h" + +#include +#include +#include + +#include +#include + +#include "common.h" + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +void BlockConfigNode::VisitAttrs(AttrVisitor* v) { + Array tmp_arr = make_array(output_shape_); + v->Visit("_output_shape", &tmp_arr); +} + +BlockConfig::BlockConfig(const std::vector& output_shape, int compute_cycles, + int output_cycles) { + auto n = make_object(); + n->output_shape_ = std::move(output_shape); + n->compute_cycles_ = compute_cycles; + n->output_cycles_ = output_cycles; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.BlockConfig") + .set_body_typed([](Array output_shape, int compute_cycles, int output_cycles) { + std::vector voutput_shape = make_vector(output_shape); + return BlockConfig(voutput_shape, compute_cycles, output_cycles); + }); + +TVM_REGISTER_NODE_TYPE(BlockConfigNode); + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/ethosu/cascader/block_config.h b/src/contrib/ethosu/cascader/block_config.h new file mode 100644 index 000000000000..d7da1d90e82e --- /dev/null +++ b/src/contrib/ethosu/cascader/block_config.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/ethosu/cascader/block_config.h + * \brief BlockConfig object for the NPU cascader + */ +#ifndef TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ +#define TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +class BlockConfig; + +/*! \brief Node to represent a BlockConfig */ +class BlockConfigNode : public Object { + public: + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Get the shape of output block. + * \return The output shape of the block config. + */ + inline std::vector GetOutputBlockShape() const { return output_shape_; } + + /*! + * \brief Get the number of cycles required to output this block + * \return The output cycles + */ + inline int GetOutputCycles() const { return output_cycles_; } + + /*! + * \brief Get the number of cycles required to compute this block + * \return The compute cycles + */ + inline int GetComputeCycles() const { return compute_cycles_; } + + static constexpr const char* _type_key = "contrib.ethosu.cascader.BlockConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockConfigNode, Object); + + protected: + friend class BlockConfig; + + /*! \brief The shape of the output block */ + std::vector output_shape_; + /*! \brief Cycles required to compute this block */ + int compute_cycles_; + /*! \brief Cycles required to output this block */ + int output_cycles_; +}; + +/*! + * \brief An object that contains a an output block shape as well as the output and compute cycles + * required to compute this block + */ +class BlockConfig : public ObjectRef { + public: + BlockConfig(const std::vector& output_shape, int compute_cycles, int output_cycles); + + TVM_DEFINE_OBJECT_REF_METHODS(BlockConfig, ObjectRef, BlockConfigNode); +}; + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ diff --git a/src/contrib/ethosu/cascader/common.h b/src/contrib/ethosu/cascader/common.h index ec62861049a3..b4b5664e04b9 100644 --- a/src/contrib/ethosu/cascader/common.h +++ b/src/contrib/ethosu/cascader/common.h @@ -68,6 +68,22 @@ inline Array make_array(const std::vector& vec) { return arr; } +/*! + * \brief Make a tvm::Array from an int64_t vector. + * \param vec The int64_t vector. + * \return The IntImm Array. + * \note Array(std::vector) doesn't work as this implicit + * type conversion fails. This is why this helper is required. + */ +inline Array make_array(const std::vector& vec) { + Array arr; + arr.resize(vec.size()); + for (unsigned int i = 0; i < vec.size(); ++i) { + arr.Set(i, IntImm(DataType::Int(64), vec[i])); + } + return arr; +} + /*! * \brief Make a tvm::Array from an float vector. * \param vec The float vector. @@ -82,6 +98,16 @@ inline Array make_array(const std::vector& vec) { return arr; } +/*! + * \brief Calculate the ceil of an Integer division + * \param dividend The dividend of the division + * \param divisor The divisor of the division + * \return The quotient + */ +inline int round_up_divide(int dividend, int divisor) { + return dividend / divisor + (dividend % divisor != 0); +} + /*! * \brief Make a vector from a tvm::Array. * \param arr The Array. diff --git a/src/contrib/ethosu/cascader/graph.cc b/src/contrib/ethosu/cascader/graph.cc index a930c2606e18..ce28f728d838 100644 --- a/src/contrib/ethosu/cascader/graph.cc +++ b/src/contrib/ethosu/cascader/graph.cc @@ -38,12 +38,10 @@ namespace ethosu { namespace cascader { void PerformanceInfoNode::VisitAttrs(AttrVisitor* v) { - int compute_cycles_int = static_cast(compute_cycles); - v->Visit("_compute_cycles", &compute_cycles_int); - Array tmp_reads = make_array(read_bytes); + v->Visit("_compute_cycles", &compute_cycles); + Array tmp_reads = make_array(read_bytes); v->Visit("_read_bytes", &tmp_reads); - int write_bytes_int = static_cast(write_bytes); - v->Visit("_write_bytes", &write_bytes_int); + v->Visit("_write_bytes", &write_bytes); } TVM_REGISTER_NODE_TYPE(PerformanceInfoNode); @@ -147,8 +145,9 @@ TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetStripeAlignHint").set_body_t return make_array(align_hint); }); TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetPerformanceInfo") - .set_body_typed([](Part part, StripeConfig stripe_config, bool is_rolling) { - return part->GetPerformanceInfo(stripe_config, is_rolling); + .set_body_typed([](Part part, StripeConfig stripe_config, int buffer_mode) { + BufferMode ebuffer_mode = static_cast(buffer_mode); + return part->GetPerformanceInfo(stripe_config, ebuffer_mode); }); CascaderGraphNode::CascaderGraphNode(std::vector input_tensors, diff --git a/src/contrib/ethosu/cascader/graph.h b/src/contrib/ethosu/cascader/graph.h index 2bea890c722b..81cbd1c9da5f 100644 --- a/src/contrib/ethosu/cascader/graph.h +++ b/src/contrib/ethosu/cascader/graph.h @@ -44,6 +44,14 @@ class Tensor; class Part; class StripeConfig; +/*! + * \brief The buffering mode to use when realizing a tensor. + * RECOMPUTE - The 'default' behaviour of TVM. Overlapping stripes will be recomputed. + * ROLLING - Apply both the sliding window and storage folding optimizations to the tensor + * realization. + */ +enum BufferMode { RECOMPUTE, ROLLING }; + /*! \brief A struct to hold a Tensor Expression subgraph */ struct TESubgraph { /*! \brief The input te::Tensors to the subgraph */ @@ -58,11 +66,11 @@ class PerformanceInfoNode : public Object { void VisitAttrs(AttrVisitor* v); /*! \brief The cycles to compute a block */ - size_t compute_cycles; + int64_t compute_cycles; /*! \brief The number of bytes read per input tensor */ - std::vector read_bytes; + std::vector read_bytes; /*! \brief The number of bytes written to the output tensor */ - size_t write_bytes; + int64_t write_bytes; static constexpr const char* _type_key = "contrib.ethosu.cascader.PerformanceInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PerformanceInfoNode, Object); @@ -77,7 +85,7 @@ class PerformanceInfoNode : public Object { */ class PerformanceInfo : public ObjectRef { public: - PerformanceInfo(size_t compute_cycles, std::vector read_bytes, size_t write_bytes) { + PerformanceInfo(int64_t compute_cycles, std::vector read_bytes, int64_t write_bytes) { auto n = make_object(); n->compute_cycles = compute_cycles; n->read_bytes = std::move(read_bytes); @@ -190,7 +198,7 @@ class PartNode : public Object { * \return The performance information containing the compute cycles and read/write bytes. */ virtual const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) = 0; + BufferMode buffer_mode) = 0; static constexpr const char* _type_key = "contrib.ethosu.cascader.Part"; TVM_DECLARE_BASE_OBJECT_INFO(PartNode, Object); diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc index 29b43269c7b6..c5f236761ba0 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.cc +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -21,6 +21,9 @@ #include #include +#include +#include +#include #include #include @@ -32,62 +35,114 @@ namespace contrib { namespace ethosu { namespace cascader { -const std::vector EthosuPartNode::GetBlockShape(const StripeConfig& output_stripe_config, - bool is_rollling) { - std::vector block_shape; - for (int axis : output_stripe_config->GetShape()) { - block_shape.push_back(std::min(axis, 4)); - } - return block_shape; -} +const std::vector EthosuPartNode::GetBytesRead(const std::vector& block_shape, + const std::vector& full_shape) { + std::vector bytes_per_input(propagators_.size(), 0); -const std::vector EthosuPartNode::GetBlockInputBytes_(const std::vector& block_shape) { - std::vector bytes_per_input; - std::vector strides; std::vector order; std::vector stripes; std::vector offset; + std::vector strides; for (size_t i = 0; i < block_shape.size(); i++) { - strides.push_back(1.0); order.push_back(1); - stripes.push_back(1); + stripes.push_back(round_up_divide(full_shape[i], block_shape[i])); offset.push_back(0); + strides.push_back(static_cast(block_shape[i])); } - StripeConfig output_block_config(block_shape, block_shape, strides, order, stripes, offset); + + StripeConfig output_block_config(block_shape, full_shape, strides, order, stripes, offset); auto input_block_configs = CalculateInputStripeConfigs(output_block_config); + + int i = 0; for (const auto& input_block_config : input_block_configs) { - bytes_per_input.push_back(mul_reduce(input_block_config->GetShape())); + std::map, int> input_blocks = CountStripes(input_block_config, false); + + for (const auto& block : input_blocks) { + bytes_per_input[i] += mul_reduce(block.first) * block.second; + } + i++; } + + if (weight_tensor_idx_ != -1) { + bytes_per_input[weight_tensor_idx_] *= (stripes[height_idx_] * stripes[width_idx_]); + } + return bytes_per_input; } +const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stripe_config) { + BlockConfig best_block_config; + float best_cost = std::numeric_limits::infinity(); + std::vector output_stripe_shape = output_stripe_config->GetShape(); + + for (const auto& block_config : valid_block_configs_) { + std::vector output_block = block_config->GetOutputBlockShape(); + + std::vector bytes_per_input = GetBytesRead(output_block, output_stripe_shape); + bytes_per_input[0] *= subkernels_; + + // Calculate bytes read per output element + float relative_cost = + (bytes_per_input[0] + bytes_per_input[1]) / mul_reduce(output_stripe_shape); + + // Single buffering hardware optimization + if (mul_reduce(output_stripe_shape) <= 2 * mul_reduce(output_block)) { + relative_cost /= 2; + } + + if (relative_cost < best_cost) { + best_block_config = block_config; + best_cost = relative_cost; + } + } + + return best_block_config; +} + const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) { - std::vector block_shape = GetBlockShape(output_stripe_config, is_rolling); - std::vector bytes_per_input = GetBlockInputBytes_(block_shape); - int bytes_per_output = mul_reduce(block_shape); - int num_blocks = 1; + BufferMode buffer_mode) { + BlockConfig block_config = GetBlockConfig(output_stripe_config); + std::vector block_shape = block_config->GetOutputBlockShape(); + + std::vector bytes_per_input = + GetBytesRead(block_shape, output_stripe_config->GetShape()); + + int elements_per_block = mul_reduce(block_shape); + int bytes_per_output = elements_per_block; + float num_blocks = 1.0f; for (size_t i = 0; i < block_shape.size(); i++) { - if (!is_rolling) { - num_blocks *= output_stripe_config->GetShape()[i] * output_stripe_config->GetStripes()[i] / + if (buffer_mode == BufferMode::RECOMPUTE) { + num_blocks *= static_cast(output_stripe_config->GetShape()[i] * + output_stripe_config->GetStripes()[i]) / block_shape[i]; } else { - num_blocks *= output_stripe_config->GetExtent()[i] / block_shape[i]; + num_blocks *= static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i]; } } - int num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1; - std::vector read_bytes; + float num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1.0f; + std::vector read_bytes; for (int block_bytes : bytes_per_input) { read_bytes.push_back((num_blocks + num_stripes) * block_bytes); } - int write_bytes = (num_blocks + num_stripes) * bytes_per_output; - auto shape = output_stripe_config->GetShape(); - PerformanceInfo info(0, read_bytes, write_bytes); + int64_t write_bytes = (num_blocks + num_stripes) * bytes_per_output; + + int block_output_cycles = block_config->GetOutputCycles(); + int block_compute_cycles = block_config->GetComputeCycles(); + + int64_t total_cycles = 0; + if (block_output_cycles > block_compute_cycles) { + total_cycles = (block_output_cycles * num_blocks) + block_compute_cycles; + } else { + total_cycles = (block_compute_cycles * num_blocks) + block_output_cycles; + } + + PerformanceInfo info(total_cycles, read_bytes, write_bytes); return info; } EthosuPart::EthosuPart(const TESubgraph& subgraph, const std::vector propagators, - const std::vector output_quantum, int quantum_cycles) { + const std::vector& output_quantum, int subkernels, + const std::vector& valid_block_configs, int weight_tensor_idx) { auto n = make_object(); ICHECK_GT(propagators.size(), 0) << "The Part must include at least one Propagator."; n->subgraph_ = subgraph; @@ -95,21 +150,40 @@ EthosuPart::EthosuPart(const TESubgraph& subgraph, const std::vector n->in_line_ = false; n->input_tensors_.resize(propagators.size()); n->output_quantum_ = output_quantum; - n->quantum_cycles_ = quantum_cycles; + n->valid_block_configs_ = valid_block_configs; + n->subkernels_ = subkernels; + n->weight_tensor_idx_ = weight_tensor_idx; + if (output_quantum.size() == 5) { + // NHCWB16 Format + n->height_idx_ = 1; + n->width_idx_ = 3; + } else { + // NHWC Format + n->height_idx_ = 1; + n->width_idx_ = 2; + } data_ = std::move(n); } TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPart") .set_body_typed([](Array subgraph_inputs, te::Tensor subgraph_output, - Array propagators, Array output_quantum, - int quantum_cycles) { + Array propagators, Array output_quantum, int subkernels, + Array valid_block_configs, int weight_tensor_idx) { std::vector vsubgraph_inputs(subgraph_inputs.begin(), subgraph_inputs.end()); std::vector vpropagators(propagators.begin(), propagators.end()); + std::vector voutput_quantum(output_quantum.begin(), output_quantum.end()); TESubgraph subgraph; subgraph.input_tensors = vsubgraph_inputs; subgraph.output_tensor = subgraph_output; - std::vector voutput_quantum = make_vector(output_quantum); - return EthosuPart(subgraph, vpropagators, voutput_quantum, quantum_cycles); + std::vector vvalid_block_configs(valid_block_configs.begin(), + valid_block_configs.end()); + return EthosuPart(subgraph, vpropagators, voutput_quantum, subkernels, vvalid_block_configs, + weight_tensor_idx); + }); + +TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPartGetBlockConfig") + .set_body_typed([](EthosuPart part, StripeConfig stripe_config) { + return part->GetBlockConfig(stripe_config); }); TVM_REGISTER_NODE_TYPE(EthosuPartNode); diff --git a/src/contrib/ethosu/cascader/parts/ethosu.h b/src/contrib/ethosu/cascader/parts/ethosu.h index ab3ca69d2717..cd8fa84eca2b 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.h +++ b/src/contrib/ethosu/cascader/parts/ethosu.h @@ -28,6 +28,7 @@ #include +#include "../block_config.h" #include "../graph.h" namespace tvm { @@ -39,11 +40,10 @@ namespace cascader { class EthosuPartNode : public PartNode { public: /*! - * \brief Get the optimal block shape to use. + * \brief Get the optimal BlockConfig to use given a StripeConfig * \param output_stripe_config The output StripeConfig. - * \param is_rolling Whether the output config should be computed as a rolling buffer. */ - const std::vector GetBlockShape(const StripeConfig& output_stripe_config, bool is_rolling); + const BlockConfig GetBlockConfig(const StripeConfig& output_stripe_config); /*! * \brief Get the preferred alignment in each axis for a stripe of the Part. * \note This is used to bias the selection of StripeConfigs towards those that are integer @@ -53,11 +53,11 @@ class EthosuPartNode : public PartNode { /*! * \brief Get the performance information for a given output stripe config. * \param output_stripe_config The output stripe config to compute the performance for. - * \param is_rolling Whether the output config should be computed as a rolling buffer. + * \param buffer_mode The mode of buffering, rolling or recompute. * \return The performance information containing the compute cycles and read/write bytes. */ const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) final; + BufferMode buffer_mode) final; static constexpr const char* _type_key = "contrib.ethosu.cascader.EthosuPart"; TVM_DECLARE_FINAL_OBJECT_INFO(EthosuPartNode, PartNode); @@ -66,16 +66,27 @@ class EthosuPartNode : public PartNode { friend class EthosuPart; /*! - * \brief Get the size of input required (per input tensor) to compute a block. - * \param block_shape The shape of the block to compute. + * \brief Get the size of input required (per input tensor) to compute a stripe given a + * block_shape + * \param block_shape The shape of the block(s) the stripe is split into + * \param stripe_shape The shape of the full stripe to compute. * \return The bytes required per input tensor. */ - const std::vector GetBlockInputBytes_(const std::vector& block_shape); + const std::vector GetBytesRead(const std::vector& block_shape, + const std::vector& full_shape); + /*! \brief List of block configs that are valid for this part */ + std::vector valid_block_configs_; /*! \brief The output volume that is atomically computed */ std::vector output_quantum_; - /*! \brief The cycles taken to compute a single output quantum */ - int quantum_cycles_; + /*! \brief Index for output height dimension */ + int height_idx_; + /*! \brief Index for output width dimension */ + int width_idx_; + /*! \brief Index of weight tensor, -1 if the Part has no weights */ + int weight_tensor_idx_; + /*! \brief Number of sub-kernels the kernel has been split into */ + int subkernels_; }; /*! @@ -86,7 +97,8 @@ class EthosuPartNode : public PartNode { class EthosuPart : public Part { public: EthosuPart(const TESubgraph& subgraph, const std::vector propagators, - const std::vector output_quantum, int quantum_cycles); + const std::vector& output_quantum, int subkernels, + const std::vector& valid_block_configs, int weight_tensor_idx); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EthosuPart, Part, EthosuPartNode); }; diff --git a/src/contrib/ethosu/cascader/parts/inline.cc b/src/contrib/ethosu/cascader/parts/inline.cc index ff5e055084e6..cb216e7d1454 100644 --- a/src/contrib/ethosu/cascader/parts/inline.cc +++ b/src/contrib/ethosu/cascader/parts/inline.cc @@ -31,8 +31,8 @@ namespace ethosu { namespace cascader { const PerformanceInfo InlinePartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) { - std::vector read_bytes(input_tensors_.size()); + BufferMode buffer_mode) { + std::vector read_bytes(input_tensors_.size()); PerformanceInfo info(0, read_bytes, 0); return info; } diff --git a/src/contrib/ethosu/cascader/parts/inline.h b/src/contrib/ethosu/cascader/parts/inline.h index 44f2762319fb..11d94f17397d 100644 --- a/src/contrib/ethosu/cascader/parts/inline.h +++ b/src/contrib/ethosu/cascader/parts/inline.h @@ -45,7 +45,7 @@ class InlinePartNode : public PartNode { * \return The performance information containing the compute cycles and read/write bytes. */ const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) final; + BufferMode buffer_mode) final; static constexpr const char* _type_key = "contrib.ethosu.cascader.InlinePart"; TVM_DECLARE_FINAL_OBJECT_INFO(InlinePartNode, PartNode); diff --git a/tests/python/contrib/test_ethosu/cascader/infra.py b/tests/python/contrib/test_ethosu/cascader/infra.py index baf398dc3602..c2b6073fb62e 100644 --- a/tests/python/contrib/test_ethosu/cascader/infra.py +++ b/tests/python/contrib/test_ethosu/cascader/infra.py @@ -18,6 +18,8 @@ from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import extract_constants, lower_to_te +import numpy as np + def create_te_graph(func): func, consts = extract_constants(func) @@ -25,3 +27,67 @@ def create_te_graph(func): func = relay.transform.InferType()(mod)["main"] te_graph = lower_to_te(func) return te_graph, consts + + +def make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + scale_bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weight_matrix = np.matmul(weight_matrix, nhcwb16_to_nhwc).tolist() + scale_bias_matrix = np.matmul(scale_bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_offset = ( + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0] + ) + weight_offset = [0, 0, 0, 0] + scale_bias_offset = [0, 0] + return ( + ifm_matrix, + ifm_offset, + weight_matrix, + weight_offset, + scale_bias_matrix, + scale_bias_offset, + ) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py new file mode 100644 index 000000000000..3418bb58351e --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +import tvm.contrib.ethosu.cascader as cs + +from .infra import make_matrices + + +@pytest.mark.parametrize( + "id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", + [ + ( + 0, + "ethosu_conv2d", + "NONE", + (34, 19), + (2, 2), + (1, 1), + (0, 0, 0, 0), + (1, 266, 111, 15), + (1, 117, 47, 15), + ), + ( + 1, + "ethosu_conv2d", + "NONE", + (14, 14), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 125, 63, 64), + (1, 112, 50, 128), + ), + ( + 2, + "ethosu_conv2d", + "NONE", + (7, 1), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 13, 4, 12), + (1, 4, 4, 511), + ), + ( + 3, + "ethosu_conv2d", + "NONE", + (5, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 96, 16, 276), + (1, 92, 12, 16), + ), + ( + 4, + "ethosu_conv2d", + "NONE", + (5, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 96, 16, 276), + (1, 92, 12, 1), + ), + ( + 5, + "ethosu_conv2d", + "NONE", + (3, 3), + (1, 1), + (2, 2), + (0, 0, 0, 0), + (1, 62, 94, 32), + (1, 58, 90, 16), + ), + ], +) +@pytest.mark.parametrize( + "layouts", + [ + ("NHWC", "NHWC"), + ("NHCWB16", "NHCWB16"), + ("NHWC", "NHCWB16"), + ("NHCWB16", "NHWC"), + ], +) +@pytest.mark.parametrize( + "acc_config, expected_block_configs", + [ + ( + "ethos-u55-32", + [ + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 10, 6, 4), (1, 16, 1, 4, 4)), + ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ], + ), + ( + "ethos-u55-64", + [ + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 10, 6, 8), (1, 16, 1, 4, 8)), + ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ], + ), + ( + "ethos-u55-128", + [ + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), + ((1, 5, 8, 16), (1, 5, 1, 8, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 16, 4, 16), (1, 16, 1, 4, 16)), + ((1, 8, 12, 8), (1, 8, 1, 12, 8)), + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), + ], + ), + ( + "ethos-u55-256", + [ + ((1, 14, 8, 16), (1, 14, 1, 8, 16)), + ((1, 16, 8, 16), (1, 16, 1, 8, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 32, 4, 16), (1, 32, 1, 4, 16)), + ((1, 20, 12, 8), (1, 20, 1, 12, 8)), + ((1, 20, 6, 16), (1, 20, 1, 6, 16)), + ], + ), + ], +) +def test_best_block_config( + id, + op_type, + activation, + kernel, + stride, + dilation, + padding, + in_shape, + out_shape, + layouts, + acc_config, + expected_block_configs, +): + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + kernel, stride, dilation, padding, in_shape[3], layouts[0], layouts[1] + ) + + ofm_channels = out_shape[3] + ifm_channels = in_shape[3] + + if layouts[0] == "NHCWB16": + in_shape = [ + int(math.ceil(n)) for n in np.matmul(nhwc_to_nhcwb16, in_shape + (1,)).tolist()[:-1] + ] + if layouts[1] == "NHCWB16": + out_shape = [ + int(math.ceil(n)) for n in np.matmul(nhwc_to_nhcwb16, out_shape + (1,)).tolist()[:-1] + ] + + propagator = cs.Propagator(ifm_matrix, ifm_offset) + weight_propagator = cs.Propagator(weight_matrix, weight_offset) + + subkernels = ((kernel[0] + 7) // 8) * ((kernel[1] + 7) // 8) + + op_attrs = { + "op": op_type, + "activation": activation, + "stride_h": stride[0], + "stride_w": stride[1], + "dilation_h": dilation[0], + "dilation_w": dilation[1], + } + + device_config = cs.EthosuDeviceConfig(acc_config) + block_configs = device_config.get_valid_block_configs( + propagator, + op_attrs, + out_shape, + ofm_channels, + ifm_channels, + layouts[1], + layouts[0], + "int8", + "int8", + kernel[0], + kernel[1], + ) + + output_quantum = [1, 1, 2, 8] + if layouts[1] == "NHCWB16": + output_quantum = [1, 1, 1, 2, 8] + + # Create EthosUPart + te_subgraph = cs.TESubgraph([], None) + part = cs.EthosuPart( + te_subgraph, + [propagator, weight_propagator], + output_quantum, + subkernels, + block_configs, + 1, + ) + + order = [1, 2, 3, 4] if layouts[1] == "NHCWB16" else [1, 2, 4, 3, 0] + stripes = [1] * len(output_quantum) + offset = [0] * len(output_quantum) + + stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) + + block = part.get_block_config(stripe_config) + block_shape = tuple(int(a) for a in block.output_shape) + if layouts[1] == "NHCWB16": + assert block_shape == expected_block_configs[id][1] + else: + assert block_shape == expected_block_configs[id][0] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py index 79a139594b3e..8ff5ef09fdc3 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py @@ -22,71 +22,9 @@ import tvm.contrib.ethosu.cascader as cs from tvm.relay.backend.contrib.ethosu.te.convolution import match_ethosu_conv2d, conv2d_compute -import numpy as np +from .infra import make_matrices - -def _make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): - kernel_h, kernel_w = kernel - stride_h, stride_w = stride - dilation_h, dilation_w = dilation - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - nhwc_to_nhcwb16 = [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 1 / 16, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 16], - [0, 0, 0, 0, 1], - ] - nhcwb16_to_nhwc = [ - [1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 16, 0, 0, 0], - [0, 0, 0, 0, 0, 1], - ] - ifm_matrix = [ - [1, 0, 0, 0, 0], - [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], - [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - weight_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, kernel_h], - [0, 0, 0, 0, kernel_w], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - scale_bias_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 10], - [0, 0, 0, 0, 1], - ] - if ofm_layout == "NHCWB16": - ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() - weight_matrix = np.matmul(weight_matrix, nhcwb16_to_nhwc).tolist() - scale_bias_matrix = np.matmul(scale_bias_matrix, nhcwb16_to_nhwc).tolist() - if ifm_layout == "NHCWB16": - ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() - - ifm_offset = ( - [0, -padding[0], -padding[1], 0] - if ifm_layout == "NHWC" - else [0, -padding[0], 0, -padding[1], 0] - ) - weight_offset = [0, 0, 0, 0] - scale_bias_offset = [0, 0] - return ( - ifm_matrix, - ifm_offset, - weight_matrix, - weight_offset, - scale_bias_matrix, - scale_bias_offset, - ) +import pytest @pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) @@ -137,7 +75,7 @@ def test_ethosu_conv2d_matcher( weight_offset, scale_bias_transform, scale_bias_offset, - ) = _make_matrices( + ) = make_matrices( kernel, stride, dilation, @@ -147,7 +85,8 @@ def test_ethosu_conv2d_matcher( ofm_layout, ) - part = match_ethosu_conv2d(out) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_conv2d(out, device_config) assert isinstance(part, cs.EthosuPart) assert len(part.propagators) == 3 diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py index a3639ba03077..1eebbe40c1b3 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py @@ -37,7 +37,8 @@ def test_ethosu_inline_matcher(): ] ifm_offset = [0, 0, 0] - part = match_ethosu_inline(out) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_inline(out, device_config) assert isinstance(part, cs.InlinePart) assert len(part.propagators) == 1 diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py index ef449a49976c..fca136cf4ab4 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py @@ -18,28 +18,40 @@ pytest.importorskip("ethosu.vela") -import tvm.contrib.ethosu.cascader as pl +import tvm.contrib.ethosu.cascader as cs +from tvm.contrib.ethosu.cascader.graph import BufferMode from tvm.contrib.ethosu.cascader.parts import EthosuPart def test_ethosu_part(): - te_subgraph = pl.TESubgraph([], None) + te_subgraph = cs.TESubgraph([], None) output_quantum = [1, 2, 2, 8] - quantum_cycles = 32 - propagator = pl.Propagator( + propagator = cs.Propagator( [[1, 0, 0, 0, 2], [0, 1, 0, 0, 2], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [0, 0, 0, 0], ) - stripe_config = pl.StripeConfig( + stripe_config = cs.StripeConfig( [1, 4, 4, 16], [1, 64, 72, 96], [1, 4, 4, 16], [1, 2, 3, 4], [1, 16, 13, 6], [0, 0, 0, 0] ) + subkernels = 3 - part = EthosuPart(te_subgraph, [propagator], output_quantum, quantum_cycles) + valid_block_configs = [cs.BlockConfig([1, 2, 4, 16], 15000, 7500)] + + part = EthosuPart( + te_subgraph, + [propagator], + output_quantum, + subkernels, + valid_block_configs, + 1, + ) + input_tensor = cs.Tensor(shape=[1, 66, 74, 16], dtype="int8") + part.set_input(0, input_tensor) assert part.get_stripe_align_hint() == output_quantum # Check that the performance model runs, don't verify output - part.get_performance_info(stripe_config, False) - part.get_performance_info(stripe_config, True) + part.get_performance_info(stripe_config, BufferMode.ROLLING) + part.get_performance_info(stripe_config, BufferMode.RECOMPUTE) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py new file mode 100644 index 000000000000..297fbaa89059 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from functools import reduce +import numpy as np +import math + +import tvm.contrib.ethosu.cascader as cs +from tvm.contrib.ethosu.cascader.device_config import _Shape + +from .infra import make_matrices + + +@pytest.mark.parametrize( + "acc_config, expected", + [ + ("ethos-u55-256", (1, 0.125, 0.75, 0.375, 0.75)), + ("ethos-u55-128", (1, 0.25, 1.5, 0.75, 0.75)), + ("ethos-u55-64", (1, 0.5, 3, 1.5, 1.5)), + ("ethos-u55-32", (2, 1, 6, 3, 3)), + ], +) +def test_device_config_cycles(acc_config, expected): + device_config = cs.EthosuDeviceConfig(acc_config) + + conv_type = "ethosu_conv2d" + conv_str = None + conv_ifm_dtype = "int8" + conv_ofm_dtype = "int8" + conv_activation = "LUT" + conv_cycles = device_config._get_output_cycles( + conv_type, conv_str, conv_ifm_dtype, conv_ofm_dtype, conv_activation + ) + assert conv_cycles == expected[0] + + pool_type = "ethosu_pooling" + pool_str = "MAX" + pool_ifm_dtype = "int8" + pool_ofm_dtype = "int8" + pool_activation = "NONE" + pool_cycles = device_config._get_output_cycles( + pool_type, pool_str, pool_ifm_dtype, pool_ofm_dtype, pool_activation + ) + assert pool_cycles == expected[1] + + add_type = "ethosu_binary_elementwise" + add_str = "ADD" + add_ifm_dtype = "int8" + add_ofm_dtype = "int8" + add_activation = "NONE" + add_cycles = device_config._get_output_cycles( + add_type, add_str, add_ifm_dtype, add_ofm_dtype, add_activation + ) + assert add_cycles == expected[2] + + mul_type = "ethosu_binary_elementwise" + mul_str = "MUL" + mul_ifm_dtype = "int8" + mul_ofm_dtype = "int8" + mul_activation = "NONE" + mul_cycles = device_config._get_output_cycles( + mul_type, mul_str, mul_ifm_dtype, mul_ofm_dtype, mul_activation + ) + assert mul_cycles == expected[3] + + mul_32_type = "ethosu_binary_elementwise" + mul_32_str = "MUL" + mul_32_ifm_dtype = "int8" + mul_32_ofm_dtype = "int32" + mul_32_activation = "NONE" + mul_32_cycles = device_config._get_output_cycles( + mul_32_type, mul_32_str, mul_32_ifm_dtype, mul_32_ofm_dtype, mul_32_activation + ) + assert mul_32_cycles == expected[4] + + +@pytest.mark.parametrize( + "accelerator, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape, block_shape, input_block_shape, expected", + [ + ( + "ethos-u55-128", + "ethosu_conv2d", + "NONE", + (3, 3), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 16, 16, 96), + (1, 16, 16, 96), + (1, 8, 8, 16), + (1, 10, 10, 32), + 167733, + ), + ( + "ethos-u55-128", + "ethosu_conv2d", + "NONE", + (10, 4), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 58, 13, 1), + (1, 25, 10, 276), + (1, 6, 10, 32), + (1, 18, 14, 8), + 174105, + ), + ], +) +def test_conv_performance( + accelerator, + op_type, + activation, + kernel, + stride, + dilation, + padding, + in_shape, + out_shape, + block_shape, + input_block_shape, + expected, +): + ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + kernel, + stride, + dilation, + padding, + in_shape[3], + "NHWC", + "NHWC", + ) + ifm_channels = in_shape[3] + + propagator = cs.Propagator(ifm_matrix, ifm_offset) + weight_propagator = cs.Propagator(weight_matrix, weight_offset) + + subkernels = ((kernel[0] + 7) // 8) * ((kernel[1] + 7) // 8) + + device_config = cs.EthosuDeviceConfig(accelerator) + + output_cycles = device_config._get_output_cycles(op_type, "", "int8", "int8", activation) + output_cycles *= reduce(lambda a, b: a * b, block_shape, 1) + is_partkernel = device_config.is_partkernel( + op_type, ifm_channels, "int8", kernel[0] * kernel[1] + ) + compute_cycles = device_config._estimate_compute_cycles_per_block( + op_type, + _Shape(block_shape), + _Shape(input_block_shape), + kernel[0], + kernel[1], + ifm_channels, + "int8", + is_partkernel, + ) + block_configs = [cs.BlockConfig(block_shape, compute_cycles, int(output_cycles))] + + output_quantum = [1, 1, 2, 8] + te_subgraph = cs.TESubgraph([], None) + part = cs.EthosuPart( + te_subgraph, + [propagator, weight_propagator], + output_quantum, + subkernels, + block_configs, + 1, + ) + + stripes = [1] * len(output_quantum) + offset = [0] * len(output_quantum) + order = [1, 2, 3, 4] + + stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) + + compute_cycles = part.get_performance_info(stripe_config, cs.BufferMode.ROLLING).compute_cycles + tolerance = expected * 0.05 + + assert expected - tolerance <= compute_cycles <= expected + tolerance + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py index 3bab83f24143..da31ad346b4f 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_graph.py +++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py @@ -54,7 +54,7 @@ def test_inline_part(): assert len(part.propagators) == 1 assert part.in_line == True assert part.get_stripe_align_hint() == [1, 1] - performance_info = part.get_performance_info(output_stripe_config, is_rolling=False) + performance_info = part.get_performance_info(output_stripe_config, cs.BufferMode.RECOMPUTE) assert performance_info.compute_cycles == 0 assert performance_info.read_bytes == [0] assert performance_info.write_bytes == 0 @@ -127,7 +127,8 @@ def test_small_graph(): def test_create_cascader_graph(TwoConv2DWithSliceTE): _, te_graph, const_dict = TwoConv2DWithSliceTE - graph = cs.create_cascader_graph(te_graph, const_dict) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + graph = cs.create_cascader_graph(te_graph, const_dict, device_config) output_tensor = graph.output_tensors[0] assert output_tensor.shape == [1, 6, 1, 6, 16]