diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 274f148e9134f..8f2dddbf88a6e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -426,6 +426,60 @@ def __init__( self.params_class = params_class self.pattern = pattern + @staticmethod + def reshape_input( + inputs: List["TensorParams"], + ) -> List[tvm.relay.Expr]: + """Reshape the inputs so that the following binary elementwise + operator receives 4-dimensional inputs. + + Parameters + ---------- + inputs: List[TensorParams] + The inputs to reshape. + + Returns + ------- + reshaped_inputs: List[tvm.relay.Expr] + The new reshaped inputs. + """ + reshaped_inputs = [] + for i in inputs: + in_shape = i.shape + if len(in_shape) < 4: + pad_size = 4 - len(in_shape) + new_shape = ([1] * pad_size) + in_shape + new_call = relay.reshape(i.tensor, new_shape) + reshaped_inputs.append(new_call) + else: + reshaped_inputs.append(i.tensor) + return reshaped_inputs + + @staticmethod + def reshape_output(output: tvm.relay.Expr, ifm_input_shape: List[int]) -> tvm.relay.Expr: + """Reshape the output back to the original dimensionality. + Since the NPU must have the brodcastable tensor as the + second operand, the original shape of the first ifm must + be the output shape. + + Parameters + ---------- + output: tvm.relay.Expr + The output to reshape. + + ifm_input_shape: List[int] + The shape of the non-reshaped ifm tensor. + + Returns + ------- + reshaped_output: tvm.relay.Expr + The reshaped output expression. + """ + if len(ifm_input_shape) == 4: + return output + reshaped_output = relay.reshape(output, ifm_input_shape) + return reshaped_output + def callback( self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map ) -> tvm.relay.Expr: @@ -451,9 +505,12 @@ def callback( # We don't yet support activation functions that need to get legalized to LUTs. lut = relay.const([], dtype="int8") - return ethosu_ops.ethosu_binary_elementwise( - ifm=params.ifm.tensor, - ifm2=params.ifm2.tensor, + inputs = [params.ifm, params.ifm2] + inputs = self.reshape_input(inputs) + + ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( + ifm=inputs[0], + ifm2=inputs[1], lut=lut, operator_type=params.operator_type, ifm_scale=float(params.ifm.q_params.scale_f32), @@ -462,8 +519,8 @@ def callback( ifm2_zero_point=int(params.ifm2.q_params.zero_point), ofm_scale=float(params.ofm.q_params.scale_f32), ofm_zero_point=int(params.ofm.q_params.zero_point), - ifm_channels=params.ifm.shape[3], - ifm2_channels=params.ifm2.shape[3], + ifm_channels=params.ifm.shape[-1], + ifm2_channels=params.ifm2.shape[-1], reversed_operands=params.reversed_operands, ofm_dtype=params.ofm.dtype, activation=activation, @@ -473,6 +530,8 @@ def callback( ifm2_layout=str(params.ifm2.layout), ofm_layout=str(params.ofm.layout), ) + output = self.reshape_output(ethosu_binary_elementwise, params.ifm.shape) + return output class AddRewriter(BinaryElementwiseRewriter): diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index f37fcf6f97f45..73de3329c45f8 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -516,11 +516,12 @@ def __init__(self, func_body: Call, operator_type: str, has_quantization_paramet self.activation = clip self.operator_type = operator_type - def can_broadcast(x, y): - for i in range(1, 4): - if x.shape[i] == y.shape[i] or y.shape[i] == 1: - continue + def can_broadcast(ifm, ifm2): + if len(ifm.shape) < len(ifm2.shape): return False + for m, n in zip(ifm.shape[::-1], ifm2.shape[::-1]): + if m != n and m == 1: + return False return True if can_broadcast(self.ifm, self.ifm2): @@ -539,9 +540,14 @@ def is_valid(self): """ if np.dtype(self.ofm) == np.int32 and self.activation is not None: return False - if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4: + # Due to identity operator requiring ofm != int32 for now + if np.dtype(self.ofm) == np.int32 and len(self.ofm.shape) < 4: return False - if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1: + if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4: + return False + if len(self.ifm.shape) == 4 and self.ifm.shape[0] != 1: + return False + if len(self.ifm2.shape) == 4 and self.ifm2.shape[0] != 1: return False if not self.valid_broadcast: return False diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 806feae60e8a9..92a1ad71deda1 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -349,6 +349,7 @@ def representative_dataset(): ([1, 2, 3, 4], [1, 2, 3, 4]), ([1, 2, 3, 4], [1, 1, 1, 1]), ([1, 1, 1, 1], [1, 2, 3, 4]), + ([1, 4, 4], [4, 1]), ], ) @pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) @@ -435,6 +436,84 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([4], [4]), + ([4], [1, 2, 3, 4]), + ([1, 4, 4], [4, 1]), + ], +) +def test_binary_add_with_non_4d_shapes( + accel_type, + ifm_shape, + ifm2_shape, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, lhs, rhs): + return tf.math.add(lhs, rhs) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape}, + dtype_dict={"ifm": dtype, "ifm2": dtype}, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_binary_add_from_constant_scalar(accel_type): dtype = "uint8" diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 12bdddc978e30..dbe11cd2d7ad7 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -565,6 +565,10 @@ def verify(ext_func): ([1, 2, 3, 4], [1, 2, 3, 4], False), ([1, 2, 3, 4], [1, 1, 3, 1], False), ([1, 1, 3, 1], [1, 2, 3, 4], True), + ([1, 4, 4], [4, 1], False), + ([4], [4], False), + ([4], [1, 2, 3, 4], True), + ([1, 4, 4], [4, 1], False), ], ) @pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) @@ -621,16 +625,27 @@ def verify(ext_func): shapes = [ifm_shape, ifm2_shape] ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) op = ext_func.body - assert list(op.args[0].checked_type.shape) == shapes[ifm_index] - assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + + has_reshaped_output = False + shapes_padded = [[1] * (4 - len(s)) + s for s in shapes] + out_padded = [1] * (4 - len(out_shape)) + out_shape + if op.op.name != "contrib.ethosu.binary_elementwise": + has_reshaped_output = True + op = op.args[0] + + assert list(op.args[0].checked_type.shape) == shapes_padded[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes_padded[ifm2_index] assert op.args[0].checked_type.dtype == dtype - assert list(op.checked_type.shape) == out_shape + assert list(op.checked_type.shape) == out_padded assert op.checked_type.dtype == dtype assert op.attrs.operator_type == operator_type assert op.attrs.reversed_operands == reversed_operands if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" + if has_reshaped_output: + assert list(ext_func.body.checked_type.shape) == out_shape + if operator_type == "ADD": rewriter = legalize.AddRewriter() pattern_table = [