From fdfd16c951c8d1f06406d7f71e824b84c3e29697 Mon Sep 17 00:00:00 2001 From: Aleksei-grovety <113356454+Aleksei-grovety@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:23:55 +0400 Subject: [PATCH] [microNPU][ETHOSU] MatMul legalization support (#15780) NPU has a restriction that weights must be constant, so the matrix multiplication operation was expressed using split, elementwise multiplication, reduce sum, concatenations operations. --- .../relay/backend/contrib/ethosu/legalize.py | 98 ++++++++++++++- python/tvm/relay/op/contrib/ethosu.py | 43 +++++++ .../contrib/test_ethosu/test_codegen.py | 24 ++++ .../contrib/test_ethosu/test_legalize.py | 117 ++++++++++++++++++ 4 files changed, 281 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 175f97c87e02..242d3c2d0cc5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1401,6 +1401,101 @@ def callback(self, pre, post, node_map): return ethosu_fc +class MatMulRewriter(DFPatternCallback): + """Legalize matrix multiplication to an NPU operator""" + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name}) + )(wildcard(), wildcard()) + + def callback(self, pre, post, node_map): + params = ethosu_patterns.MatMulParams(post.op.body) + ifm = post.args[0] + ifm2 = post.args[1] + lut = relay.const([], dtype="int8") + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # Reshape ifm to NHWC + ifm = relay.reshape(ifm, (1, 1, *params.ifm.shape)) + # Split the second matrix to get columns + columns = list(relay.op.split(ifm2, params.ofm.shape[-1], axis=0)) + + res_columns = [] + for column in columns: + ifm2 = relay.reshape(column, (1, 1, 1, params.ifm.shape[-1])) + # Multiplying the first matrix by a column + ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm, + ifm2=ifm2, + lut=lut, + operator_type="MUL", + ifm_zero_point=int(params.ifm.q_params.zero_point), + ifm_scale=0.0, + ifm2_zero_point=int(params.weights.q_params.zero_point), + ifm2_scale=0.0, + ofm_scale=0.0, + ofm_zero_point=0, + ifm_channels=params.ifm.shape[-1], + ifm2_channels=params.ifm.shape[-1], + reversed_operands=False, + ofm_dtype="int32", + ) + + # Use reduce sum to get result column + reduce_sum = ethosu_ops.ethosu_pooling( + ifm=ethosu_binary_elementwise, + lut=lut, + pooling_type="SUM", + ifm_zero_point=0, + ifm_scale=float(params.weights.q_params.scale_f32) + * float(params.ifm.q_params.scale_f32), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=0, + pool_shape=(1, 1), + ofm_channels=1, + ofm_dtype="int32", + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + rounding_mode="NATURAL", + ) + + # Convert tensor dtype from int32 to int8 + scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32") + reduce_sum = ethosu_ops.ethosu_binary_elementwise( + ifm=reduce_sum, + ifm2=scalar_tensor, + lut=lut, + operator_type="MUL", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int8", + ) + + res_columns.append(reduce_sum) + + # Concatenate result columns + concat = relay.op.concatenate(relay.Tuple(res_columns), axis=3) + return relay.reshape(concat, params.ofm.shape) + + class PadRewriter(DFPatternCallback): """Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d operator""" @@ -1546,12 +1641,13 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: """ rewriters = [ PartitionedSplitRewriter(), + FullyConnectedRewriter(), + MatMulRewriter(), SplitRewriter(), ChannelPadRewriter(), Conv2DRewriter(), Conv2DTransposeRewriter(), DepthwiseConv2DRewriter(), - FullyConnectedRewriter(), MaxPoolingRewriter(), AvgPoolingRewriter(), PadRewriter(), diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 386ef9038e49..73cee9d0cd23 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1900,6 +1900,44 @@ def qnn_fc_pattern(): return optional_clip +class MatMulParams(FullyConnectedParams): + """ + This class will parse a call to an ethos-u.matmul composite + function and extract the parameter information. + """ + + composite_name = "ethos-u.matmul" + + @requires_vela + def __init__(self, func_body): + FullyConnectedParams.__init__(self, func_body) + + def is_valid(self) -> bool: + """ + Checks whether matrix multiplication has compatible attributes with HW + """ + + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): + return False + if not len(self.ifm.shape) == 2: + return False + if not len(self.ofm.shape) == 2: + return False + # The weights must be transposed + if self.ifm.shape[1] != self.weights.shape[1]: + return False + return True + + +def matmul_pattern(): + dense = is_op("qnn.dense")( + wildcard(), wildcard(), is_constant(), is_constant(), is_constant(), is_constant() + ) + req = is_op("qnn.requantize")(dense, is_constant(), is_constant(), is_constant(), is_constant()) + optional_clip = req.optional(is_op("clip")) + return optional_clip + + class HardSwishParams: """ This class will parse a call to a ethos-u.hard_swish composite function @@ -2185,6 +2223,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_fc_pattern(), lambda pat: FullyConnectedParams(pat).is_valid(), ), + ( + MatMulParams.composite_name, + matmul_pattern(), + lambda pat: MatMulParams(pat).is_valid(), + ), ( MaxPool2DParams.composite_name, qnn_maxpool2d_pattern(), diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index e094bb74b2e1..fde2e284347e 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1564,6 +1564,30 @@ def fully_connected(x): ) +@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"]) +@pytest.mark.parametrize("ifm_shape", [(1, 16), (4, 8)]) +@pytest.mark.parametrize("ofm_channels", [8, 32]) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +def test_tflite_matmul( + accel_type, + ifm_shape, + ofm_channels, + activation_function, +): + np.random.seed(0) + + @tf.function + def matmul(x, y): + x = tf.matmul(x, y, transpose_b=True) + if activation_function == "RELU": + x = tf.nn.relu(x) + return x + + infra.compare_tvm_with_tflite( + matmul, [ifm_shape, [ofm_channels, ifm_shape[-1]]], accel_type, enable_cascader=False + ) + + @pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"]) def test_tflite_subtract_sigmoid(accel_type): np.random.seed(0) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 6dd533c73042..35a8cc358ed5 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -3806,5 +3806,122 @@ def representative_dataset(): assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d" +def test_tflite_matmul(): + ifm_shape = [1, 4] + ifm2_shape = [2, 4] + ifm_shapes = [ifm_shape, ifm2_shape] + ofm_shape = [ifm_shape[0], ifm2_shape[0]] + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def matmul(self, x, y): + res = tf.matmul(x, y, transpose_b=True) + return res + + model = Model() + concrete_func = model.matmul.get_concrete_function( + *[tf.TensorSpec(shape, tf.float32) for shape in ifm_shapes] + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + datas = [np.random.rand(*shape) for shape in ifm_shapes] + yield [data.astype(np.float32) for data in datas] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + ofm = ext_func.body + ops = [] + + def _visit(stmt): + if isinstance(stmt, relay.expr.Call): + ops.append(stmt) + + relay.analysis.post_order_visit(ofm, _visit) + ofm_checked_type = ofm.checked_type + ofm_channels = ofm_shape[-1] + + # check IFM + ifm = ops[1].checked_type + assert list(ifm.shape) == ifm_shape + assert str(ifm.dtype) == dtype + + # check IFM2 + ifm2 = ops[3].checked_type + assert list(ifm2.shape) == ifm2_shape + assert str(ifm2.dtype) == dtype + + # check split + split = ops[4] + split_checked_types = list(split.checked_type.fields) + assert split.op.name == "split" + assert split.attrs.axis == 0 + assert int(split.attrs.indices_or_sections) == ofm_channels + for split_checked_type in split_checked_types: + assert list(split_checked_type.shape) == ifm_shape + assert str(split_checked_type.dtype) == dtype + + # check MUL + mul_ops = [ops[6], ops[10]] + for mul_op in mul_ops: + assert mul_op.op.name == "contrib.ethosu.binary_elementwise" + assert mul_op.attrs.operator_type == "MUL" + assert mul_op.attrs.ofm_dtype == "int32" + + # check reduce sum + reduce_sum_ops = [ops[7], ops[11]] + for reduce_sum_op in reduce_sum_ops: + assert reduce_sum_op.op.name == "contrib.ethosu.pooling" + assert reduce_sum_op.attrs.pooling_type == "SUM" + assert list(reduce_sum_op.checked_type.shape) == [1, 1, 1, 1] + + # check concatenation + concatenation = ofm.args[0] + concatenation_shape = concatenation.checked_type.shape + assert concatenation.op.name == "concatenate" + assert list(concatenation_shape) == [1, 1, 1, ofm_channels] + + # check OFM + assert ofm.op.name == "reshape" + assert list(ofm_checked_type.shape) == ofm_shape + assert str(ofm_checked_type.dtype) == dtype + + matmul_pattern_table = [ + ( + ethosu.MatMulParams.composite_name, + ethosu.matmul_pattern(), + lambda pat: ethosu.MatMulParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(ifm_shapes)}, + dtype_dict={("ifm" + str(i)): dtype for i, _ in enumerate(ifm_shapes)}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, matmul_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.MatMulRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": tvm.testing.main()