Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
billishyahao committed Jun 9, 2022
1 parent 54ae4d0 commit e71778e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 72 deletions.
104 changes: 35 additions & 69 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging
import math

import tvm.ir
from tvm import relay
Expand All @@ -49,7 +48,6 @@
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table

import re

logger = logging.getLogger("DNNL")

Expand Down Expand Up @@ -617,18 +615,29 @@ def rewrite_layer_norm(mod):

class DenseReshapeBiasGeluRewrite(DFPatternCallback):
"""
A callback to reorder reshape operators when the patten is as below:
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
A callback to reorder reshape operators when the patterns are as below:
Pattern #1:
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
/* ty=Tensor[(1, 3136, 64), float32] */;
Pattern #2:
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */;
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
/* ty=Tensor[(1, 3136, 512), float32] */;
4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
"""

def __init__(self):
def __init__(self, has_gelu=True):
super(DenseReshapeBiasGeluRewrite, self).__init__()
self.data = wildcard()
self.weight = wildcard()
Expand All @@ -638,29 +647,30 @@ def __init__(self):
self.const3 = wildcard()

self.attr_map = {}
self.has_gelu = has_gelu

den = is_op("nn.dense")(self.data, self.weight)
re_den = is_op("reshape")(den)
added = is_op("add")(self.bias, re_den)
divisor = is_op("divide")(added, self.const1)
val_erf = is_op("erf")(divisor)
added_erf = is_op("add")(val_erf, self.const2)
mul1 = is_op("multiply")(added, added_erf)
mul2 = is_op("multiply")(mul1, self.const3)
self.pattern = mul2
if self.has_gelu:
divisor = is_op("divide")(added, self.const1)
val_erf = is_op("erf")(divisor)
added_erf = is_op("add")(val_erf, self.const2)
mul1 = is_op("multiply")(added, added_erf)
mul2 = is_op("multiply")(mul1, self.const3)
self.pattern = mul2
else:
self.pattern = added

def get_attr(self, pre):
"""Recursively retrieve attributes from reshape operator."""

def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["reshape"] = new_attrs
elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["nn.dense"] = new_attrs

_analysis.post_order_visit(pre, visit_func)

Expand All @@ -670,12 +680,16 @@ def callback(self, pre, post, node_map):
data = node_map[self.data][0]
weight = node_map[self.weight][0]
bias = node_map[self.bias][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)
if not self.has_gelu:
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])

const1 = node_map[self.const1][0]
const2 = node_map[self.const2][0]
const3 = node_map[self.const3][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)

divisor = relay.op.divide(added, const1)
val_erf = relay.op.erf(divisor)
added_erf = relay.op.add(val_erf, const2)
Expand All @@ -692,57 +706,9 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
return mod


class DenseReshapeBiasRewrite(DFPatternCallback):
"""
A callback to reorder reshape operators when the patten is as below:
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */;
"""

def __init__(self):
super(DenseReshapeBiasRewrite, self).__init__()
self.data = wildcard()
self.weight = wildcard()
self.bias = wildcard()

self.attr_map = {}

den = is_op("nn.dense")(self.data, self.weight)
re_den = is_op("reshape")(den)
added = is_op("add")(self.bias, re_den)
self.pattern = added

def get_attr(self, pre):
def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["reshape"] = new_attrs
elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["nn.dense"] = new_attrs

_analysis.post_order_visit(pre, visit_func)

def callback(self, pre, post, node_map):
self.get_attr(pre)

data = node_map[self.data][0]
weight = node_map[self.weight][0]
bias = node_map[self.bias][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])


def rewrite_dense_bias_reshape_last(mod):
"""Rewrite the input graph to reorder reshape operators so that
we can perform dense_bias fusion and then offload them to byoc part.
"""
mod["main"] = rewrite(DenseReshapeBiasRewrite(), mod["main"])
mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"])
return mod
2 changes: 1 addition & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (ops.len() != 0){
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def test_dense(run_module, dtype="float32"):
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype)
dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype)
dense = tvm.IRModule.from_expr(dense)
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)
Expand All @@ -932,7 +932,7 @@ def test_dense_pattern(run_module, dtype="float32"):
config = dense_bias, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype)
dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype)
dense_bias = tvm.IRModule.from_expr(dense_bias)
config = dense_bias, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)
Expand Down

0 comments on commit e71778e

Please sign in to comment.