Skip to content

Commit

Permalink
Relay transform for rolling strided_slice, dense and other ops into b…
Browse files Browse the repository at this point in the history
…atch_matmul
  • Loading branch information
rasagna-quic committed Mar 9, 2023
1 parent befdc4e commit 3b7056b
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 0 deletions.
163 changes: 163 additions & 0 deletions python/tvm/contrib/hexagon/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import functools as ft

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
from ..._ffi.registry import register_func

### VTCM
Expand Down Expand Up @@ -148,3 +151,163 @@ def transform(func, mod, ctx):

def ir_lower_vtcm_pass():
return [(3, ir_lower_vtcm())]


class qdistilbert_rewrite(DFPatternCallback):
"""
A callback to replace the below pattern:
Pattern:
%35 = strided_slice(%34, begin=[0, 0, 0], end=[1, 128, 64], strides=[1, 1, 1], axes=None);
%44 = reshape(%35, newshape=[-1, 64]);
<snip>
%42 = strided_slice(%41, begin=[0, 0, 0], end=[1, 64, 128], strides=[1, 1, 1], axes=None);
%43 = reshape(%42, newshape=[64, 128]);
%45 = transpose(%43, axes=[1, 0]);
<snip>
%46 = qnn.dense(%44, %45, 13, 1, 0.0541715f, 0.0489368f, units=None, out_dtype="int32");
%47 = qnn.requantize(%46, 0.00265098f, 0, 0.728874f, -14, axis=1, out_dtype="int8");
<snip>
%125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */;
< The above pattern repeats 12 times, which is the batch size >
%137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136);
%138 = concatenate(%137);
"""

def __init__(self):
super(qdistilbert_rewrite, self).__init__()
self.A = wildcard() # Tensor A
self.B = wildcard() # Tensor B
self.batch = 12 # Number of time pattern repeats or Batch size

self.d = [] # List of dense quantization parameters
self.q = [] # List of requantize parameters
L = [] # List of patterns

z = tvm.tir.IntImm("int64", 0)
s1 = tvm.tir.IntImm("int64", 1)

for i in range(self.batch):
x = tvm.tir.IntImm("int64", i)

self.d.append([is_constant(), is_constant(), is_constant(), is_constant()])
self.q.append([is_constant(), is_constant(), is_constant(), is_constant()])

pat_a = is_op("strided_slice")(self.A).has_attr(
{"begin": [x, z, z], "strides": [s1, s1, s1]}
)
pat_a = is_op("reshape")(pat_a)

pat_b = is_op("strided_slice")(self.B).has_attr(
{"begin": [x, z, z], "strides": [s1, s1, s1]}
)
pat_b = is_op("reshape")(pat_b)
pat_b = is_op("transpose")(pat_b)

pat = is_op("qnn.dense")(
pat_a, pat_b, self.d[i][0], self.d[i][1], self.d[i][2], self.d[i][3]
)
pat = is_op("qnn.requantize")(
pat, self.q[i][0], self.q[i][1], self.q[i][2], self.q[i][3]
)
pat = is_op("expand_dims")(pat)
L.append(pat)

T = is_tuple(L)
self.pattern = is_op("concatenate")(T)

def check_quant_params(self, node_map):
"""checking if dense and requant params are the same across patterns"""
r = self.batch
x1 = [node_map[self.d[0][i]][0].data.numpy().item() for i in range(4)]
x2 = [node_map[self.q[0][i]][0].data.numpy().item() for i in range(4)]
for i in range(1, r):
for j in range(4):
y1 = node_map[self.d[i][j]][0].data.numpy().item()
y2 = node_map[self.q[i][j]][0].data.numpy().item()
if x1[j] != y1 or x2[j] != y2:
return False
return True

def callback(self, pre, post, node_map):
A = node_map[self.A][0]
B = node_map[self.B][0]

if not self.check_quant_params(node_map):
return post

[a0, a1, a2] = [0, 0, 0] # Tensor A shape
[b0, b1, b2] = [0, 0, 0] # Tensor B shape

if isinstance(A, relay.expr.Call) and isinstance(B, relay.expr.Call):
if A.checked_type is None or B.checked_type is None:
# Need infer pass to be run before this pass
return post
if len(A.checked_type.shape) == 3 and len(B.checked_type.shape) == 3:
[a0, a1, a2] = A.checked_type.shape
[b0, b1, b2] = B.checked_type.shape

if isinstance(A, relay.Var) and isinstance(B, relay.Var):
if len(A.type_annotation.shape) == 3 and len(B.type_annotation.shape) == 3:
[a0, a1, a2] = A.type_annotation.shape
[b0, b1, b2] = B.type_annotation.shape

# Check if the batch size is same as expected tensor size
if (a0 != self.batch) or (b0 != self.batch):
return post

for i in range(self.batch):
# end=(x, pa1, pa2) attribute of strided_slice for Tensor A
pa1 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[1].value
pa2 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[2].value

# end=(x, pb1, pb2) attribute of strided_slice for Tensor B
pb1 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[1].value
pb2 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[2].value

if a1 != pa1 or a2 != pa2 or b1 != pb1 or b2 != pb2:
return post

d = [node_map[self.d[0][i]][0] for i in range(4)]
q = [node_map[self.q[0][i]][0] for i in range(4)]

out = relay.op.transpose(B, axes=[0, 2, 1])
out = relay.qnn.op.batch_matmul(A, out, d[0], d[1], d[2], d[3], out_dtype="int32")
out = relay.qnn.op.requantize(out, q[0], q[1], q[2], q[3], out_dtype="int8")
return out


def rewrite_qdistilbert(mod):
"""Rewrite the Quantized Distilbert to reduce computational complexity."""
mod["main"] = rewrite(qdistilbert_rewrite(), mod["main"])
return mod


class remove_empty_pad_callback(DFPatternCallback):
"""
A callback to remove empty pad op from the below pattern:
Pattern:
%0 = cast(0f, dtype="float16");
%1 = nn.pad(%inp, %0, pad_width=[[0i64, 0i64], [0i64, 0i64]]);
nn.matmul(%1, %inp2, units=None)
"""

def __init__(self):
super(remove_empty_pad_callback, self).__init__()
self.A = wildcard()
self.B = wildcard()
self.a = is_op("nn.pad")(self.A, wildcard()).has_attr({"pad_width": ((0, 0), (0, 0))})
self.pattern = is_op("nn.matmul")(self.a, self.B)

def callback(self, pre, post, node_map):
A = node_map[self.A][0]
B = node_map[self.B][0]
return relay.nn.matmul(A, B)


def remove_empty_pad(mod):
"""Remove the empty pad operator."""
mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
return mod
120 changes: 120 additions & 0 deletions tests/python/contrib/test_hexagon/test_relay_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-wildcard-import, invalid-name

"""
Test hexagon relay transforms
"""
import tvm
from tvm import relay
from tvm.contrib.hexagon.transform import rewrite_qdistilbert, remove_empty_pad
from tvm import testing


def test_rewrite_qdistilbert():
"""Test case for rewrite_qdistilbert"""
A = relay.var("A", shape=(12, 128, 64), dtype="int8")
B = relay.var("B", shape=(12, 64, 128), dtype="int8")

z = tvm.tir.IntImm("int64", 0)
s1 = tvm.tir.IntImm("int64", 1)
tx = tvm.tir.IntImm("int64", 128)
ty = tvm.tir.IntImm("int64", 64)
expand_dims = []
for i in range(12):
d1 = relay.const(13, dtype="int32")
d2 = relay.const(1, dtype="int32")
d3 = relay.const(0.0541715, dtype="float32")
d4 = relay.const(0.0489368, dtype="float32")

q1 = relay.const(0.00265098, dtype="float32")
q2 = relay.const(0, dtype="int32")
q3 = relay.const(0.728874, dtype="float32")
q4 = relay.const(-14, dtype="int32")

x = tvm.tir.IntImm("int64", i)
y = tvm.tir.IntImm("int64", i + 1)

SA = relay.op.strided_slice(
A, begin=[x, z, z], end=[y, tx, ty], strides=[s1, s1, s1], axes=None
)
RA = relay.op.reshape(SA, [128, 64])
SB = relay.op.strided_slice(
B, begin=[x, z, z], end=[y, ty, tx], strides=[s1, s1, s1], axes=None
)
RB = relay.op.reshape(SB, [64, 128])
TB = relay.op.transpose(RB, [1, 0])
dense = relay.qnn.op.dense(RA, TB, d1, d2, d3, d4, units=None, out_dtype="int32")
requantize = relay.qnn.op.requantize(dense, q1, q2, q3, q4)
expand_dims.append(relay.op.expand_dims(requantize, axis=0))

t = relay.expr.Tuple(expand_dims)
graph = relay.op.concatenate(t, axis=0)

func = relay.Function(relay.analysis.free_vars(graph), graph)
mod = tvm.IRModule.from_expr(func)
mod = rewrite_qdistilbert(mod)

d1 = relay.const(13, dtype="int32")
d2 = relay.const(1, dtype="int32")
d3 = relay.const(0.0541715, dtype="float32")
d4 = relay.const(0.0489368, dtype="float32")

q1 = relay.const(0.00265098, dtype="float32")
q2 = relay.const(0, dtype="int32")
q3 = relay.const(0.728874, dtype="float32")
q4 = relay.const(-14, dtype="int32")

ref = relay.op.transpose(B, [0, 2, 1])
ref = relay.qnn.op.batch_matmul(A, ref, d1, d2, d3, d4, out_dtype="int32")
ref = relay.qnn.op.requantize(ref, q1, q2, q3, q4, out_dtype="int8")
ref_func = relay.Function(relay.analysis.free_vars(ref), ref)
ref_mod = tvm.IRModule.from_expr(ref_func)

assert tvm.ir.structural_equal(mod["main"], ref_mod["main"])

# If the pattern does not match, should return the original.
func = relay.expr.Tuple(expand_dims) # omitting concatenate
mod = tvm.IRModule.from_expr(func)
out_mod = rewrite_qdistilbert(mod) # out does not return ref_mod but the original mod

assert tvm.ir.structural_equal(mod["main"], out_mod["main"])


def test_remove_empty_pad():
"""Test case for remove_empty_pad"""
A = relay.var("A", shape=(32, 32), dtype="float16")
B = relay.var("B", shape=(32, 32), dtype="float16")

p0 = relay.cast(relay.const(0, dtype="float32"), dtype="float16")
p1 = relay.nn.pad(A, pad_value=p0, pad_width=((0, 0), (0, 0)))
graph = relay.nn.matmul(p1, B)

func = relay.Function(relay.analysis.free_vars(graph), graph)
mod = tvm.IRModule.from_expr(func)

mod = remove_empty_pad(mod)

ref = relay.nn.matmul(A, B)
ref_func = relay.Function(relay.analysis.free_vars(ref), ref)
ref_mod = tvm.IRModule.from_expr(ref_func)

assert tvm.ir.structural_equal(mod["main"], ref_mod["main"])


if __name__ == "__main__":
testing.main()

0 comments on commit 3b7056b

Please sign in to comment.