From bcdc345fbf9e0ad1482c2f8479ed8113015dbdb6 Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Thu, 20 Jan 2022 10:41:09 +0000 Subject: [PATCH] [microNPU][2d] Add more Part matchers to cascader (#9785) * [microNPU][2d] Add more Part matchers for the cascader Adds Part matchers for ethosu_depthwise_conv2d, ethosu_pooling and ethosu_binary_elementwise. Also adds additional testing for the CascaderGraph creation. Co-authored-by: Jacob Bohlin * Extended testing for block config * Add test guards Co-authored-by: Matthew Barrett --- .../contrib/ethosu/cascader/device_config.py | 156 ++++++++++++- .../contrib/ethosu/te/binary_elementwise.py | 153 ++++++++++++- .../backend/contrib/ethosu/te/convolution.py | 4 +- .../backend/contrib/ethosu/te/depthwise.py | 162 +++++++++++++- .../backend/contrib/ethosu/te/pooling.py | 139 +++++++++++- .../contrib/ethosu/te/unary_elementwise.py | 118 +++++++++- src/contrib/ethosu/cascader/parts/ethosu.cc | 8 +- .../contrib/test_ethosu/cascader/conftest.py | 65 +++++- .../contrib/test_ethosu/cascader/infra.py | 63 ++++-- .../test_ethosu_binary_elementwise_matcher.py | 209 ++++++++++++++++++ .../cascader/test_ethosu_block_config.py | 97 +++++++- .../cascader/test_ethosu_conv2d_matcher.py | 7 +- .../test_ethosu_depthwise2d_matcher.py | 102 +++++++++ .../cascader/test_ethosu_part_performance.py | 37 +++- .../cascader/test_ethosu_pooling_matcher.py | 81 +++++++ .../test_ethosu_unary_elementwise_matcher.py | 158 +++++++++++++ .../test_ethosu/cascader/test_graph.py | 24 ++ 17 files changed, 1521 insertions(+), 62 deletions(-) create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py create mode 100644 tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py index 5ad7fde1ed52..68a218da2616 100644 --- a/python/tvm/contrib/ethosu/cascader/device_config.py +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name """Device config class to hold information about the target hardware""" -from typing import Tuple, List, Dict +from typing import Tuple, List, Dict, Optional from functools import reduce import math @@ -332,6 +332,7 @@ def _get_input_block( def get_kernel_steps( self, + op_type: str, dilated_kernel_h: int, dilated_kernel_w: int, ifm_dtype: str, @@ -341,6 +342,9 @@ def get_kernel_steps( Parameters ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" dilated_kernel_h: int Height of dilated kernel dilated_kernel_w: int @@ -355,18 +359,23 @@ def get_kernel_steps( List[int] List where each entry contains the amount of elements in one of the subkernels """ + if op_type == "ethosu_binary_elementwise": + return [1] + subkernels = self._get_subkernels(dilated_kernel_h, dilated_kernel_w) # Determine the number of kernel steps per subkernel kernel_steps = [] for y, x in subkernels: subkernel_elements = x * y - if is_partkernel: - # Part-kernel-first traversal + if op_type == "ethosu_conv2d" and is_partkernel: + # Part-kernel-first traversal conv2d divisor = 4 if ifm_dtype == "int8" else 2 kernel_steps.append(int(_round_up_div(subkernel_elements, divisor))) + elif op_type == "ethosu_depthwise_conv2d": + kernel_steps.append(int(_round_up_div(subkernel_elements, 4))) else: - # Depth-first traversal + # Depth-first traversal conv2d or pooling kernel_steps.append(int(subkernel_elements)) return kernel_steps @@ -430,11 +439,133 @@ def is_partkernel( return part_kernel_first_utilization > depth_first_utilization or ifm_channels <= 8 + def get_elementwise_block_config( + self, + ifm_propagator: Propagator, + ifm2_propagator: Optional[Propagator], + op_attrs: Dict, + ofm_shape: List[int], + output_layout: str, + input_layout: str, + input2_layout: Optional[str], + ifm_dtype: str, + ofm_dtype: str, + ) -> List[BlockConfig]: + """Get a suitable block config for an elementwise operator + + Parameters + ---------- + ifm_propagator: Propagator, + The propagator containing the data dependencies between input and output + ifm2_propagator: Propagator, + The propagator containing the data dependencies between input2 and output + op_attrs: Dict, + Dictionary containing operator attributes + ofm_shape: List[int], + Shape of the output tensor + output_layout: str, + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + input_layout: str, + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + input2_layout: str, + The layout of the Input2 Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm_dtype: str, + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str, + Datatype of the Output Feature Map tensor (OFM) + + Returns + ---------- + List[BlockConfig] + List containing a single suitable block config + """ + block_config = [] + output_shape = [int(a) for a in ofm_shape] + + op_type = op_attrs.get("op") + op_str = op_attrs.get("op_str") + activation = op_attrs.get("activation", "NONE") + + input_bytewidth = 1 if ifm_dtype == "int8" else 2 if ifm_dtype == "int16" else 4 + banks_available = self._total_banks - self._reserved_banks + if activation == "LUT" and not self._lut_reserved: + banks_available -= 2 + + # Split the block in half until it fits into SHRAM + if output_layout == "NHCWB16": + split_order = (a for a in [1, 3, 2]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2] * output_shape[4], self._max_block_shape.depth), + min(output_shape[3], self._max_block_shape.width), + 16, + ] + else: + split_order = (a for a in [1, 2, 3]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2], self._max_block_shape.width), + min(output_shape[3], self._max_block_shape.depth), + ] + split_axis = next(split_order) + while True: + # Create stripe config for output block + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + order = [1, 2, 4, 3, 0] if output_layout == "NHCWB16" else [1, 2, 3, 4] + output_stripe_config = StripeConfig( + output_block, output_block, output_block, order, stripes, offset + ) + + # Propagate the output to obtain the two input blocks + input_block = _Shape(ifm_propagator.propagate(output_stripe_config).shape, input_layout) + if ifm2_propagator: + input2_block = _Shape( + ifm2_propagator.propagate(output_stripe_config).shape, input2_layout + ) + else: + # Unary elementwise + input2_block = _Shape([0, 0, 0, 0]) + + input_block.round_up(self._input_micro_block) + input2_block.round_up(self._input_micro_block) + + # Banks required for input block + input_bytes = input_block.area() * self._align(input_block.depth * input_bytewidth, 8) + input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2 + input_banks = _round_up(input_banks, self._input_granularity) + + # Banks required for input2 block + input2_bytes = input2_block.area() * self._align( + input2_block.depth * input_bytewidth, 8 + ) + input2_banks = _round_up_div(input2_bytes, self._bank_size_bytes) * 2 + input2_banks = _round_up(input2_banks, self._input_granularity) + + # Check whether or not both IFMs fit into SHRAM + if (input_banks + input2_banks) <= banks_available: + output_cycles = self._get_output_cycles( + op_type, op_str, ifm_dtype, ofm_dtype, activation + ) + output_cycles *= reduce(lambda a, b: a * b, output_block, 1) + output_cycles = int(math.ceil(output_cycles)) + block_config.append(BlockConfig(output_block, 0, output_cycles)) + break + + if output_block[split_axis] == 1: + split_axis = next(split_order) + + output_block[split_axis] = _round_up_div(output_block[split_axis], 2) + + return block_config + def get_valid_block_configs( self, ifm_propagator: Propagator, op_attrs: Dict, - output_shape: List[int], + ofm_shape: List[int], ofm_channels: int, ifm_channels: int, output_layout: str, @@ -452,7 +583,7 @@ def get_valid_block_configs( The propagator containing the data dependencies between input and output op_attrs: Dict, Dictionary containing operator attributes - output_shape: List[int], + ofm_shape: List[int], Shape of the output tensor ofm_channels: int, Number of output channels @@ -487,9 +618,9 @@ def get_valid_block_configs( subkernel_transform = ifm_propagator.transform if output_layout == "NHCWB16": - output_shape = _Shape([1, output_shape[1], output_shape[3], ofm_channels]) + output_shape = _Shape([1, ofm_shape[1], ofm_shape[3], ofm_channels]) else: - output_shape = _Shape(output_shape) + output_shape = _Shape(ofm_shape) if input_layout == "NHCWB16": subkernel_transform[1][-1] = min( @@ -571,6 +702,7 @@ def get_valid_block_configs( input_block_shape = _Shape(input_block.shape, input_layout) input_block_shape.round_up(self._input_micro_block) + output_block_shape = _Shape(output_block, output_layout) if op_type == "ethosu_conv2d": @@ -592,12 +724,11 @@ def get_valid_block_configs( acc_banks = _round_up(acc_banks, self._accumulator_granularity[acc_bytewidth]) if (input_banks + acc_banks) <= banks_available: - output_cycles = self._get_output_cycles( op_type, op_str, ifm_dtype, ofm_dtype, activation ) output_cycles *= reduce(lambda a, b: a * b, output_block, 1) - output_cycles = int(_round_up(output_cycles, 1)) + output_cycles = int(math.ceil(output_cycles)) compute_cycles = self._estimate_compute_cycles_per_block( op_type, output_block_shape, @@ -634,7 +765,7 @@ def _estimate_compute_cycles_per_block( num_quantum_z = _round_up_div(block_shape.depth, self._micro_block.depth) num_quantum_xy = num_quantum_x * num_quantum_y - kernel_steps = self.get_kernel_steps(kernel_h, kernel_w, ifm_dtype, is_partkernel) + kernel_steps = self.get_kernel_steps(op_type, kernel_h, kernel_w, ifm_dtype, is_partkernel) wd_cycles = self._get_weight_decoder_cycles(op_type) delay_cycles = self._get_delay_cycles(op_type, ifm_dtype) @@ -642,8 +773,9 @@ def _estimate_compute_cycles_per_block( compute_cycles = 0 for subkernel_steps in kernel_steps: + subkernel_cycles = 1 if op_type == "ethosu_pooling" else subkernel_steps compute_cycles += ( - max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_steps * num_quantum_z + max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_cycles * num_quantum_z ) if num_quantum_xy == 1: diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py index c1d39556d11d..8446b0c2e4ad 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -17,7 +17,10 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for binary_elementwise""" import operator +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -123,6 +126,12 @@ def binary_elementwise_compute( te.Tensor The Output Feature Map tensor. """ + assert ifm.shape[0] == 1 + assert ifm2.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ifm2_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute( ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0) @@ -187,5 +196,147 @@ def binary_elementwise_compute( attrs=binary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - int(broadcast[1])), 0, 0, int(broadcast[1])], + [0, 0, (1 - int(broadcast[2])), 0, int(broadcast[2])], + [0, 0, 0, (1 - int(broadcast[3])), int(broadcast[3])], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + ifm2_propagator = Propagator( + ifm2_matrix, + [0, 0, 0, 0] if ifm2_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "ifm2_propagator": ifm2_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ifm_channels) + return dma_ofm_compute( + binary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ifm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_binary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Binary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + binary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if binary_elementwise.op.name != "ethosu_binary_elementwise": + return None + pad = binary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + pad2 = binary_elementwise.op.input_tensors[1] + if pad2.op.name != "ethosu_pad": + return None + convert_to_nhwc2 = pad2.op.input_tensors[0] + if convert_to_nhwc2.op.name != "ethosu_convert_to_nhwc": + return None + read2 = convert_to_nhwc2.op.input_tensors[0] + if read2.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + read2.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["ifm2_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + input2_layout = convert_to_nhwc2.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + propagators[1], + binary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + input2_layout, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index c61082beb737..ea2290ef1e5f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -297,7 +297,9 @@ def match_ethosu_conv2d(output_tensor, device_config): conv2d.op.name, ifm_channels, ifm_dtype, kernel_elements ) subkernels = len( - device_config.get_kernel_steps(kernel_height, kernel_width, ifm_dtype, is_part_kernel) + device_config.get_kernel_steps( + conv2d.op.name, kernel_height, kernel_width, ifm_dtype, is_part_kernel + ) ) output_layout = convert_to_nhcwb16.op.attrs["layout"] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index f54f2f3654e2..ff09662cc14a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -17,8 +17,11 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for depthwise convolutions""" from typing import Tuple, Union, List +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -110,9 +113,10 @@ def depthwise_conv2d_compute( assert ifm_layout in {"NHWC", "NHCWB16"} assert ofm_layout in {"NHWC", "NHCWB16"} - stride_h, stride_w = strides - dilation_h, dilation_w = dilation - channels, kernel_h, kernel_w, _ = weight.shape + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + dilation_h, dilation_w = [int(v) for v in dilation] + channels, kernel_h, kernel_w, _ = [int(v) for v in weight.shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding) @@ -165,5 +169,155 @@ def depthwise_conv2d_compute( attrs=depthwise_conv2d_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weights_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weights_matrix = np.matmul(weights_matrix, nhcwb16_to_nhwc).tolist() + bias_matrix = np.matmul(bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + weights_propagator = Propagator( + weights_matrix, + [0, 0, 0, 0], + ) + bias_propagator = Propagator( + bias_matrix, + [0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "weights_propagator": weights_propagator, + "bias_propagator": bias_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels) + return dma_ofm_compute( + depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_depthwise_conv2d(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Depthwise Conv2D. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration. + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + depthwise2d = convert_to_nhcwb16.op.input_tensors[0] + if depthwise2d.op.name != "ethosu_depthwise_conv2d": + return None + pad = depthwise2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + depthwise2d.op.input_tensors[1], + depthwise2d.op.input_tensors[2], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["weights_propagator"], + write.op.attrs["bias_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels, kernel_height, kernel_width = (int(axis) for axis in input_tensors[1].shape[0:3]) + + subkernels = len( + device_config.get_kernel_steps(depthwise2d.op.name, kernel_height, kernel_width, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + depthwise2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + kernel_height, + kernel_width, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + 1, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index e98a72db7f02..aaf79e8a8c8d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -18,7 +18,10 @@ """Tensor Expressions for poolings""" from typing import Tuple +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -99,8 +102,13 @@ def pooling_compute( te.Tensor The OFM tensor. """ - stride_h, stride_w = strides - pool_shape_h, pool_shape_w = pool_shape + assert ifm.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + pool_shape_h, pool_shape_w = [int(v) for v in pool_shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) @@ -114,6 +122,8 @@ def pooling_compute( pooling_attrs = { "op": "ethosu_pooling", "pooling_type": pooling_type, + "pool_shape_h": pool_shape_h, + "pool_shape_w": pool_shape_w, "stride_h": stride_h, "stride_w": stride_w, "activation": activation, @@ -144,5 +154,128 @@ def pooling_compute( attrs=pooling_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (pool_shape_h - stride_h)], + [0, 0, stride_w, 0, (pool_shape_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_pooling(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Pooling. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + pool2d = convert_to_nhcwb16.op.input_tensors[0] + if pool2d.op.name != "ethosu_pooling": + return None + pad = pool2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels = ifm_channels + pool_shape_h = int(pool2d.op.attrs["pool_shape_h"]) + pool_shape_w = int(pool2d.op.attrs["pool_shape_w"]) + + subkernels = len( + device_config.get_kernel_steps(pool2d.op.name, pool_shape_h, pool_shape_w, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + pool2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + pool_shape_h, + pool_shape_w, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index 0aefc1c35d4c..68d1c603ad98 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for unary_elementwise for the NPU""" +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher from .dma import dma_ofm_compute, dma_ifm_compute @@ -127,5 +129,119 @@ def clz_imp(inp): attrs=unary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = {"ifm_propagator": ifm_propagator} + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(unary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + unary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ofm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_unary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Unary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + unary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if unary_elementwise.op.name != "ethosu_unary_elementwise": + return None + pad = unary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + None, + unary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + None, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc index c5f236761ba0..cdbbda18c142 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.cc +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -56,7 +56,6 @@ const std::vector EthosuPartNode::GetBytesRead(const std::vector& int i = 0; for (const auto& input_block_config : input_block_configs) { std::map, int> input_blocks = CountStripes(input_block_config, false); - for (const auto& block : input_blocks) { bytes_per_input[i] += mul_reduce(block.first) * block.second; } @@ -82,8 +81,8 @@ const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stri bytes_per_input[0] *= subkernels_; // Calculate bytes read per output element - float relative_cost = - (bytes_per_input[0] + bytes_per_input[1]) / mul_reduce(output_stripe_shape); + float relative_cost = static_cast(bytes_per_input[0] + bytes_per_input[1]) / + mul_reduce(output_stripe_shape); // Single buffering hardware optimization if (mul_reduce(output_stripe_shape) <= 2 * mul_reduce(output_block)) { @@ -116,7 +115,8 @@ const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& out output_stripe_config->GetStripes()[i]) / block_shape[i]; } else { - num_blocks *= static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i]; + num_blocks *= + std::max(static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i], 1.0f); } } float num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1.0f; diff --git a/tests/python/contrib/test_ethosu/cascader/conftest.py b/tests/python/contrib/test_ethosu/cascader/conftest.py index 58ffb51a5967..eacf57c251a8 100644 --- a/tests/python/contrib/test_ethosu/cascader/conftest.py +++ b/tests/python/contrib/test_ethosu/cascader/conftest.py @@ -29,7 +29,11 @@ from tvm.relay.testing import run_opt_pass from .infra import create_te_graph - from ..infra import make_ethosu_conv2d + from ..infra import ( + make_ethosu_conv2d, + make_ethosu_depthwise_conv2d, + make_ethosu_binary_elementwise, + ) def make_TwoConv2DWithSliceTE(): def _get_func(): @@ -71,3 +75,62 @@ def _get_func(): @pytest.fixture def TwoConv2DWithSliceTE(): return make_TwoConv2DWithSliceTE() + + def make_MobileNetv2DiamondTE(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 56, 56, 96), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm=ifm, + ifm_channels=96, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + conv2 = make_ethosu_conv2d( + ifm=conv1, + ifm_channels=24, + ofm_channels=144, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + depth1 = make_ethosu_depthwise_conv2d( + ifm=conv2, + channels=144, + kernel_shape=(3, 3), + padding=(1, 1, 1, 1), + strides=(1, 1), + dilation=(1, 1), + ) + conv3 = make_ethosu_conv2d( + ifm=depth1, + ifm_channels=144, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + add1 = make_ethosu_binary_elementwise( + ifm=conv1, + ifm2=conv3, + ifm_channels=24, + ifm2_channels=24, + operator_type="ADD", + ofm_dtype="int8", + ) + func = relay.Function(relay.analysis.free_vars(add1), add1) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + te_graph, const_dict = create_te_graph(func) + sch = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + return sch, te_graph, const_dict + + @pytest.fixture + def MobileNetv2DiamondTE(): + return make_MobileNetv2DiamondTE() diff --git a/tests/python/contrib/test_ethosu/cascader/infra.py b/tests/python/contrib/test_ethosu/cascader/infra.py index c2b6073fb62e..5f41dce30147 100644 --- a/tests/python/contrib/test_ethosu/cascader/infra.py +++ b/tests/python/contrib/test_ethosu/cascader/infra.py @@ -29,7 +29,9 @@ def create_te_graph(func): return te_graph, consts -def make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): +def make_matrices( + op_type, kernel, stride, padding, ifm_layout, ofm_layout, dilation=(1, 1), ifm_channels=1 +): kernel_h, kernel_w = kernel stride_h, stride_w = stride dilation_h, dilation_w = dilation @@ -50,20 +52,51 @@ def make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, o [0, 0, 16, 0, 1, -16], [0, 0, 0, 0, 0, 1], ] - ifm_matrix = [ - [1, 0, 0, 0, 0], - [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], - [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - weight_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, kernel_h], - [0, 0, 0, 0, kernel_w], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] + if op_type == "ethosu_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_depthwise_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_pooling": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] scale_bias_matrix = [ [0, 0, 0, 1, 0], [0, 0, 0, 0, 10], diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py new file mode 100644 index 000000000000..bb1be7b8e251 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.binary_elementwise import ( + match_ethosu_binary_elementwise, + binary_elementwise_compute, +) + + +def _make_matrices(broadcast, ifm_layout, ifm2_layout, ofm_layout): + broadcast_h, broadcast_w, broadcast_c = broadcast + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - broadcast_h), 0, 0, broadcast_h], + [0, 0, (1 - broadcast_w), 0, broadcast_w], + [0, 0, 0, (1 - broadcast_c), broadcast_c], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + + return (ifm_matrix, ifm2_matrix) + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 73, 51, 20], + [1, 124, 172, 5], + ], +) +@pytest.mark.parametrize("ifm2_broadcast", [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ifm2_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["MUL", "ADD", "MIN"]) +def test_ethosu_binary_elementwise_matcher( + ofm_shape, ifm2_broadcast, ifm_layout, ifm2_layout, ofm_layout, op_type +): + ifm_shape = ofm_shape.copy() + ifm2_shape = [1] + [1 if (b == 1) else a for a, b in zip(ofm_shape[1:], ifm2_broadcast)] + ifm_channels = ifm_shape[3] + ifm2_channels = ifm2_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + broadcast = [1 if a == 1 else 0 for a in ifm2_shape[1:]] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ifm2_layout == "NHCWB16": + ifm2_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm2_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + ifm2 = te.placeholder(ifm2_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = binary_elementwise_compute( + ifm=ifm, + ifm2=ifm2, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ifm2_scale=1, + ifm2_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ifm_channels=ifm_channels, + ifm2_channels=ifm2_channels, + reversed_operands=False, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ifm2_layout=ifm2_layout, + ofm_layout=ofm_layout, + ofm_dtype="int8", + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + ifm2_propagator = out.op.attrs["ifm2_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + (ifm_transform, ifm2_transform) = _make_matrices( + broadcast, + ifm_layout, + ifm2_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_binary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 2 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[1].transform == ifm2_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + propagated_ifm2 = ifm2_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm(2)_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + if ifm2_layout != ofm_layout: + assert ifm2_shape[:-1] == propagated_ifm2[:-1] + assert ((ifm2_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm2[-1] + else: + assert ifm2_shape == propagated_ifm2 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py index 3418bb58351e..3f3935fff1f9 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py @@ -27,8 +27,9 @@ @pytest.mark.parametrize( - "id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", + "test_id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", [ + # Conv2D ( 0, "ethosu_conv2d", @@ -95,6 +96,52 @@ (1, 62, 94, 32), (1, 58, 90, 16), ), + # Depthwise Conv2D + ( + 6, + "ethosu_depthwise_conv2d", + "NONE", + (3, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 77, 23, 18), + (1, 75, 19, 18), + ), + ( + 7, + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + ), + # Pooling + ( + 8, + "ethosu_pooling", + "NONE", + (13, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 13, 5, 276), + (1, 1, 1, 276), + ), + ( + 9, + "ethosu_pooling", + "NONE", + (7, 3), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 317, 14, 21), + (1, 156, 12, 21), + ), ], ) @pytest.mark.parametrize( @@ -112,51 +159,79 @@ ( "ethos-u55-32", [ + # Conv2D ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 6, 5, 16), (1, 6, 1, 5, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 8, 4, 16), (1, 8, 1, 4, 16)), - ((1, 10, 6, 4), (1, 16, 1, 4, 4)), - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 16, 1, 4, 4)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), ], ), ( "ethos-u55-64", [ + # Conv2D ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 6, 5, 16), (1, 6, 1, 5, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 10, 6, 8), (1, 16, 1, 4, 8)), - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), ], ), ( "ethos-u55-128", [ + # Conv2D ((1, 7, 6, 16), (1, 7, 1, 6, 16)), ((1, 5, 8, 16), (1, 5, 1, 8, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 16, 4, 16), (1, 16, 1, 4, 16)), ((1, 8, 12, 8), (1, 8, 1, 12, 8)), ((1, 10, 6, 16), (1, 10, 1, 6, 16)), + # Depthwise Conv2D + ((1, 7, 10, 16), (1, 7, 1, 10, 16)), + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), + # Pooling + ((1, 1, 2, 80), (1, 1, 5, 2, 16)), + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), ], ), ( "ethos-u55-256", [ + # Conv2D ((1, 14, 8, 16), (1, 14, 1, 8, 16)), ((1, 16, 8, 16), (1, 16, 1, 8, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), - ((1, 32, 4, 16), (1, 32, 1, 4, 16)), + ((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)), ((1, 20, 12, 8), (1, 20, 1, 12, 8)), - ((1, 20, 6, 16), (1, 20, 1, 6, 16)), + ((1, 12, 10, 16), (1, 12, 1, 10, 16)), + # Depthwise Conv2D + ((1, 8, 20, 16), (1, 8, 1, 20, 16)), + ((1, 14, 6, 16), (1, 14, 1, 6, 16)), + # Pooling + ((1, 2, 2, 48), (1, 2, 3, 2, 16)), + ((1, 10, 12, 16), (1, 10, 1, 12, 16)), ], ), ], ) def test_best_block_config( - id, + test_id, op_type, activation, kernel, @@ -185,7 +260,7 @@ def test_best_block_config( [0, 0, 0, 0, 0, 1], ] ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( - kernel, stride, dilation, padding, in_shape[3], layouts[0], layouts[1] + op_type, kernel, stride, padding, layouts[0], layouts[1], dilation, in_shape[3] ) ofm_channels = out_shape[3] @@ -252,10 +327,8 @@ def test_best_block_config( block = part.get_block_config(stripe_config) block_shape = tuple(int(a) for a in block.output_shape) - if layouts[1] == "NHCWB16": - assert block_shape == expected_block_configs[id][1] - else: - assert block_shape == expected_block_configs[id][0] + + assert block_shape in expected_block_configs[test_id] if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py index 8ff5ef09fdc3..5bd2be49f620 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py @@ -24,8 +24,6 @@ from .infra import make_matrices -import pytest - @pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) @pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) @@ -76,13 +74,14 @@ def test_ethosu_conv2d_matcher( scale_bias_transform, scale_bias_offset, ) = make_matrices( + "ethosu_conv2d", kernel, stride, - dilation, padding, - ifm_channels, ifm_layout, ofm_layout, + dilation, + ifm_channels, ) device_config = cs.EthosuDeviceConfig("ethos-u55-256") diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py new file mode 100644 index 000000000000..c2c45b6524f1 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.depthwise import ( + match_ethosu_depthwise_conv2d, + depthwise_conv2d_compute, +) +from .infra import make_matrices + + +@pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("dilation", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_depthwise2d_matcher(kernel, stride, dilation, padding, ifm_layout, ofm_layout): + ofm_channels = 57 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + kernel_h, kernel_w = kernel + ifm = te.placeholder(ifm_shape, dtype="int8") + weight = te.placeholder((ofm_channels, kernel_h, kernel_w, 1), dtype="int8") + scale_bias = te.placeholder((ofm_channels, 10), dtype="uint8") + lut = te.placeholder((), dtype="uint8") + out = depthwise_conv2d_compute( + ifm=ifm, + weight=weight, + scale_bias=scale_bias, + lut=lut, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + weight_zero_point=0, + strides=stride, + padding=padding, + dilation=dilation, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ofm_dtype=ifm.dtype, + ) + ( + ifm_transform, + ifm_offset, + weight_transform, + weight_offset, + scale_bias_transform, + scale_bias_offset, + ) = make_matrices( + "ethosu_depthwise_conv2d", + kernel, + stride, + padding, + ifm_layout, + ofm_layout, + dilation, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_depthwise_conv2d(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 3 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + assert part.propagators[1].transform == weight_transform + assert part.propagators[1].offset == weight_offset + assert part.propagators[2].transform == scale_bias_transform + assert part.propagators[2].offset == scale_bias_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py index 297fbaa89059..ba6346afa5d5 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py @@ -122,6 +122,34 @@ def test_device_config_cycles(acc_config, expected): (1, 18, 14, 8), 174105, ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + (1, 7, 6, 16), + (1, 15, 14, 16), + 17590, + ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (4, 9), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 28, 81, 42), + (1, 25, 73, 41), + (1, 4, 16, 16), + (1, 7, 24, 16), + 173414, + ), ], ) def test_conv_performance( @@ -138,16 +166,17 @@ def test_conv_performance( input_block_shape, expected, ): + ifm_channels = in_shape[3] ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + op_type, kernel, stride, - dilation, padding, - in_shape[3], "NHWC", "NHWC", + dilation, + ifm_channels, ) - ifm_channels = in_shape[3] propagator = cs.Propagator(ifm_matrix, ifm_offset) weight_propagator = cs.Propagator(weight_matrix, weight_offset) @@ -191,7 +220,7 @@ def test_conv_performance( stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) compute_cycles = part.get_performance_info(stripe_config, cs.BufferMode.ROLLING).compute_cycles - tolerance = expected * 0.05 + tolerance = expected * 0.1 assert expected - tolerance <= compute_cycles <= expected + tolerance diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py new file mode 100644 index 000000000000..6ce8ee9a2986 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.pooling import match_ethosu_pooling, pooling_compute +from .infra import make_matrices + + +@pytest.mark.parametrize("pool_shape", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_pooling_matcher(pool_shape, stride, padding, ifm_layout, ofm_layout): + ofm_channels = 21 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = pooling_compute( + ifm=ifm, + lut=lut, + pooling_type="MAX", + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + pool_shape=pool_shape, + ofm_channels=ofm_channels, + strides=stride, + padding=padding, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + (ifm_transform, ifm_offset, _, _, _, _) = make_matrices( + "ethosu_pooling", + pool_shape, + stride, + padding, + ifm_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_pooling(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py new file mode 100644 index 000000000000..0570524e0907 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.unary_elementwise import ( + match_ethosu_unary_elementwise, + unary_elementwise_compute, +) + + +def _make_matrices(ifm_layout, ofm_layout): + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + return ifm_matrix + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 53, 91, 7], + [1, 182, 12, 72], + ], +) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["ABS", "CLZ"]) +def test_ethosu_unary_elementwise_matcher(ofm_shape, ifm_layout, ofm_layout, op_type): + ifm_shape = ofm_shape.copy() + ofm_channels = ofm_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = unary_elementwise_compute( + ifm=ifm, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ofm_channels=ofm_channels, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + ifm_transform = _make_matrices(ifm_layout, ofm_layout) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_unary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py index da31ad346b4f..616800f69d7e 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_graph.py +++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py @@ -176,5 +176,29 @@ def test_create_cascader_graph(TwoConv2DWithSliceTE): assert conv1_part.input_tensors[2].is_constant +def test_create_diamond_graph(MobileNetv2DiamondTE): + _, te_graph, const_dict = MobileNetv2DiamondTE + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + graph = cs.create_cascader_graph(te_graph, const_dict, device_config) + + output_tensor = graph.output_tensors[0] + assert output_tensor.shape == [1, 56, 56, 24] + assert len(output_tensor.producers) == 1 + assert not output_tensor.is_constant + + add1_part = output_tensor.producers[0] + assert isinstance(add1_part, cs.EthosuPart) + assert len(add1_part.input_tensors) == 2 + assert graph.get_part_id(add1_part) == 0 + + assert add1_part.input_tensors[0].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + assert add1_part.input_tensors[1].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + if __name__ == "__main__": pytest.main([__file__])