diff --git a/python/tvm/contrib/hexagon/transform.py b/python/tvm/contrib/hexagon/transform.py index 89db2a4a717ca..2844f10452ad7 100644 --- a/python/tvm/contrib/hexagon/transform.py +++ b/python/tvm/contrib/hexagon/transform.py @@ -20,6 +20,9 @@ import functools as ft import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard +from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple from ..._ffi.registry import register_func ### VTCM @@ -148,3 +151,163 @@ def transform(func, mod, ctx): def ir_lower_vtcm_pass(): return [(3, ir_lower_vtcm())] + + +class qdistilbert_rewrite(DFPatternCallback): + """ + A callback to replace the below pattern: + Pattern: + %35 = strided_slice(%34, begin=[0, 0, 0], end=[1, 128, 64], strides=[1, 1, 1], axes=None); + %44 = reshape(%35, newshape=[-1, 64]); + + %42 = strided_slice(%41, begin=[0, 0, 0], end=[1, 64, 128], strides=[1, 1, 1], axes=None); + %43 = reshape(%42, newshape=[64, 128]); + %45 = transpose(%43, axes=[1, 0]); + + %46 = qnn.dense(%44, %45, 13, 1, 0.0541715f, 0.0489368f, units=None, out_dtype="int32"); + %47 = qnn.requantize(%46, 0.00265098f, 0, 0.728874f, -14, axis=1, out_dtype="int8"); + + %125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */; + < The above pattern repeats 12 times, which is the batch size > + + %137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136); + %138 = concatenate(%137); + + """ + + def __init__(self): + super(qdistilbert_rewrite, self).__init__() + self.A = wildcard() # Tensor A + self.B = wildcard() # Tensor B + self.batch = 12 # Number of time pattern repeats or Batch size + + self.d = [] # List of dense quantization parameters + self.q = [] # List of requantize parameters + L = [] # List of patterns + + z = tvm.tir.IntImm("int64", 0) + s1 = tvm.tir.IntImm("int64", 1) + + for i in range(self.batch): + x = tvm.tir.IntImm("int64", i) + + self.d.append([is_constant(), is_constant(), is_constant(), is_constant()]) + self.q.append([is_constant(), is_constant(), is_constant(), is_constant()]) + + pat_a = is_op("strided_slice")(self.A).has_attr( + {"begin": [x, z, z], "strides": [s1, s1, s1]} + ) + pat_a = is_op("reshape")(pat_a) + + pat_b = is_op("strided_slice")(self.B).has_attr( + {"begin": [x, z, z], "strides": [s1, s1, s1]} + ) + pat_b = is_op("reshape")(pat_b) + pat_b = is_op("transpose")(pat_b) + + pat = is_op("qnn.dense")( + pat_a, pat_b, self.d[i][0], self.d[i][1], self.d[i][2], self.d[i][3] + ) + pat = is_op("qnn.requantize")( + pat, self.q[i][0], self.q[i][1], self.q[i][2], self.q[i][3] + ) + pat = is_op("expand_dims")(pat) + L.append(pat) + + T = is_tuple(L) + self.pattern = is_op("concatenate")(T) + + def check_quant_params(self, node_map): + """checking if dense and requant params are the same across patterns""" + r = self.batch + x1 = [node_map[self.d[0][i]][0].data.numpy().item() for i in range(4)] + x2 = [node_map[self.q[0][i]][0].data.numpy().item() for i in range(4)] + for i in range(1, r): + for j in range(4): + y1 = node_map[self.d[i][j]][0].data.numpy().item() + y2 = node_map[self.q[i][j]][0].data.numpy().item() + if x1[j] != y1 or x2[j] != y2: + return False + return True + + def callback(self, pre, post, node_map): + A = node_map[self.A][0] + B = node_map[self.B][0] + + if not self.check_quant_params(node_map): + return post + + [a0, a1, a2] = [0, 0, 0] # Tensor A shape + [b0, b1, b2] = [0, 0, 0] # Tensor B shape + + if isinstance(A, relay.expr.Call) and isinstance(B, relay.expr.Call): + if A.checked_type is None or B.checked_type is None: + # Need infer pass to be run before this pass + return post + if len(A.checked_type.shape) == 3 and len(B.checked_type.shape) == 3: + [a0, a1, a2] = A.checked_type.shape + [b0, b1, b2] = B.checked_type.shape + + if isinstance(A, relay.Var) and isinstance(B, relay.Var): + if len(A.type_annotation.shape) == 3 and len(B.type_annotation.shape) == 3: + [a0, a1, a2] = A.type_annotation.shape + [b0, b1, b2] = B.type_annotation.shape + + # Check if the batch size is same as expected tensor size + if (a0 != self.batch) or (b0 != self.batch): + return post + + for i in range(self.batch): + # end=(x, pa1, pa2) attribute of strided_slice for Tensor A + pa1 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[1].value + pa2 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[2].value + + # end=(x, pb1, pb2) attribute of strided_slice for Tensor B + pb1 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[1].value + pb2 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[2].value + + if a1 != pa1 or a2 != pa2 or b1 != pb1 or b2 != pb2: + return post + + d = [node_map[self.d[0][i]][0] for i in range(4)] + q = [node_map[self.q[0][i]][0] for i in range(4)] + + out = relay.op.transpose(B, axes=[0, 2, 1]) + out = relay.qnn.op.batch_matmul(A, out, d[0], d[1], d[2], d[3], out_dtype="int32") + out = relay.qnn.op.requantize(out, q[0], q[1], q[2], q[3], out_dtype="int8") + return out + + +def rewrite_qdistilbert(mod): + """Rewrite the Quantized Distilbert to reduce computational complexity.""" + mod["main"] = rewrite(qdistilbert_rewrite(), mod["main"]) + return mod + + +class remove_empty_pad_callback(DFPatternCallback): + """ + A callback to remove empty pad op from the below pattern: + Pattern: + %0 = cast(0f, dtype="float16"); + %1 = nn.pad(%inp, %0, pad_width=[[0i64, 0i64], [0i64, 0i64]]); + nn.matmul(%1, %inp2, units=None) + + """ + + def __init__(self): + super(remove_empty_pad_callback, self).__init__() + self.A = wildcard() + self.B = wildcard() + self.a = is_op("nn.pad")(self.A, wildcard()).has_attr({"pad_width": ((0, 0), (0, 0))}) + self.pattern = is_op("nn.matmul")(self.a, self.B) + + def callback(self, pre, post, node_map): + A = node_map[self.A][0] + B = node_map[self.B][0] + return relay.nn.matmul(A, B) + + +def remove_emptyPad(mod): + """Remove the empty pad operator.""" + mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"]) + return mod diff --git a/tests/python/contrib/test_hexagon/test_relay_transforms.py b/tests/python/contrib/test_hexagon/test_relay_transforms.py new file mode 100644 index 0000000000000..c2e702e957b9e --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_relay_transforms.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import, invalid-name + +""" +Test hexagon relay transforms +""" +import tvm +from tvm import relay +from tvm.contrib.hexagon.transform import rewrite_qdistilbert, remove_emptyPad +from tvm import testing + + +def test_rewrite_qdistilbert(): + """Test case for rewrite_qdistilbert""" + A = relay.var("A", shape=(12, 128, 64), dtype="int8") + B = relay.var("B", shape=(12, 64, 128), dtype="int8") + + z = tvm.tir.IntImm("int64", 0) + s1 = tvm.tir.IntImm("int64", 1) + tx = tvm.tir.IntImm("int64", 128) + ty = tvm.tir.IntImm("int64", 64) + expand_dims = [] + for i in range(12): + d1 = relay.const(13, dtype="int32") + d2 = relay.const(1, dtype="int32") + d3 = relay.const(0.0541715, dtype="float32") + d4 = relay.const(0.0489368, dtype="float32") + + q1 = relay.const(0.00265098, dtype="float32") + q2 = relay.const(0, dtype="int32") + q3 = relay.const(0.728874, dtype="float32") + q4 = relay.const(-14, dtype="int32") + + x = tvm.tir.IntImm("int64", i) + y = tvm.tir.IntImm("int64", i + 1) + + SA = relay.op.strided_slice( + A, begin=[x, z, z], end=[y, tx, ty], strides=[s1, s1, s1], axes=None + ) + RA = relay.op.reshape(SA, [128, 64]) + SB = relay.op.strided_slice( + B, begin=[x, z, z], end=[y, ty, tx], strides=[s1, s1, s1], axes=None + ) + RB = relay.op.reshape(SB, [64, 128]) + TB = relay.op.transpose(RB, [1, 0]) + dense = relay.qnn.op.dense(RA, TB, d1, d2, d3, d4, units=None, out_dtype="int32") + requantize = relay.qnn.op.requantize(dense, q1, q2, q3, q4) + expand_dims.append(relay.op.expand_dims(requantize, axis=0)) + + t = relay.expr.Tuple(expand_dims) + graph = relay.op.concatenate(t, axis=0) + + func = relay.Function(relay.analysis.free_vars(graph), graph) + mod = tvm.IRModule.from_expr(func) + mod = rewrite_qdistilbert(mod) + + d1 = relay.const(13, dtype="int32") + d2 = relay.const(1, dtype="int32") + d3 = relay.const(0.0541715, dtype="float32") + d4 = relay.const(0.0489368, dtype="float32") + + q1 = relay.const(0.00265098, dtype="float32") + q2 = relay.const(0, dtype="int32") + q3 = relay.const(0.728874, dtype="float32") + q4 = relay.const(-14, dtype="int32") + + ref = relay.op.transpose(B, [0, 2, 1]) + ref = relay.qnn.op.batch_matmul(A, ref, d1, d2, d3, d4, out_dtype="int32") + ref = relay.qnn.op.requantize(ref, q1, q2, q3, q4, out_dtype="int8") + ref_func = relay.Function(relay.analysis.free_vars(ref), ref) + ref_mod = tvm.IRModule.from_expr(ref_func) + + assert tvm.ir.structural_equal(mod["main"], ref_mod["main"]) + + # If the pattern does not match, should return the original. + func = relay.expr.Tuple(expand_dims) # omitting concatenate + mod = tvm.IRModule.from_expr(func) + out_mod = rewrite_qdistilbert(mod) # out does not return ref_mod but the original mod + + assert tvm.ir.structural_equal(mod["main"], out_mod["main"]) + + +def test_remove_emptyPad(): + """Test case for remove_emptyPad""" + A = relay.var("A", shape=(32, 32), dtype="float16") + B = relay.var("B", shape=(32, 32), dtype="float16") + + p0 = relay.cast(relay.const(0, dtype="float32"), dtype="float16") + p1 = relay.nn.pad(A, pad_value=p0, pad_width=((0, 0), (0, 0))) + graph = relay.nn.matmul(p1, B) + + func = relay.Function(relay.analysis.free_vars(graph), graph) + mod = tvm.IRModule.from_expr(func) + + mod = remove_emptyPad(mod) + + ref = relay.nn.matmul(A, B) + ref_func = relay.Function(relay.analysis.free_vars(ref), ref) + ref_mod = tvm.IRModule.from_expr(ref_func) + + assert tvm.ir.structural_equal(mod["main"], ref_mod["main"]) + + +if __name__ == "__main__": + testing.main()