diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index ce7ae1fea2f7..53f3ec15febe 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,7 +33,6 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging -import math import tvm.ir from tvm import relay @@ -49,7 +48,6 @@ from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table -import re logger = logging.getLogger("DNNL") @@ -617,10 +615,21 @@ def rewrite_layer_norm(mod): class DenseReshapeBiasGeluRewrite(DFPatternCallback): """ - A callback to reorder reshape operators when the patten is as below: - 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; + A callback to reorder reshape operators when the patterns are as below: + + Pattern #1: + 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; + 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) + /* ty=Tensor[(1, 3136, 64), float32] */; + + Pattern #2: + 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */; - 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */; + 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) + /* ty=Tensor[(1, 3136, 512), float32] */; 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */; 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; @@ -628,7 +637,7 @@ class DenseReshapeBiasGeluRewrite(DFPatternCallback): 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; """ - def __init__(self): + def __init__(self, has_gelu=True): super(DenseReshapeBiasGeluRewrite, self).__init__() self.data = wildcard() self.weight = wildcard() @@ -638,29 +647,30 @@ def __init__(self): self.const3 = wildcard() self.attr_map = {} + self.has_gelu = has_gelu den = is_op("nn.dense")(self.data, self.weight) re_den = is_op("reshape")(den) added = is_op("add")(self.bias, re_den) - divisor = is_op("divide")(added, self.const1) - val_erf = is_op("erf")(divisor) - added_erf = is_op("add")(val_erf, self.const2) - mul1 = is_op("multiply")(added, added_erf) - mul2 = is_op("multiply")(mul1, self.const3) - self.pattern = mul2 + if self.has_gelu: + divisor = is_op("divide")(added, self.const1) + val_erf = is_op("erf")(divisor) + added_erf = is_op("add")(val_erf, self.const2) + mul1 = is_op("multiply")(added, added_erf) + mul2 = is_op("multiply")(mul1, self.const3) + self.pattern = mul2 + else: + self.pattern = added def get_attr(self, pre): + """Recursively retrieve attributes from reshape operator.""" + def visit_func(expr): if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): new_attrs = {} for k in expr.attrs.keys(): new_attrs[k] = expr.attrs[k] self.attr_map["reshape"] = new_attrs - elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["nn.dense"] = new_attrs _analysis.post_order_visit(pre, visit_func) @@ -670,12 +680,16 @@ def callback(self, pre, post, node_map): data = node_map[self.data][0] weight = node_map[self.weight][0] bias = node_map[self.bias][0] + + den = relay.op.nn.dense(data, weight) + added = relay.op.add(bias, den) + if not self.has_gelu: + return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) + const1 = node_map[self.const1][0] const2 = node_map[self.const2][0] const3 = node_map[self.const3][0] - - den = relay.op.nn.dense(data, weight) - added = relay.op.add(bias, den) + divisor = relay.op.divide(added, const1) val_erf = relay.op.erf(divisor) added_erf = relay.op.add(val_erf, const2) @@ -692,57 +706,9 @@ def rewrite_dense_bias_gelu_reshape_last(mod): return mod -class DenseReshapeBiasRewrite(DFPatternCallback): - """ - A callback to reorder reshape operators when the patten is as below: - 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; - 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; - 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */; - """ - - def __init__(self): - super(DenseReshapeBiasRewrite, self).__init__() - self.data = wildcard() - self.weight = wildcard() - self.bias = wildcard() - - self.attr_map = {} - - den = is_op("nn.dense")(self.data, self.weight) - re_den = is_op("reshape")(den) - added = is_op("add")(self.bias, re_den) - self.pattern = added - - def get_attr(self, pre): - def visit_func(expr): - if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["reshape"] = new_attrs - elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["nn.dense"] = new_attrs - - _analysis.post_order_visit(pre, visit_func) - - def callback(self, pre, post, node_map): - self.get_attr(pre) - - data = node_map[self.data][0] - weight = node_map[self.weight][0] - bias = node_map[self.bias][0] - - den = relay.op.nn.dense(data, weight) - added = relay.op.add(bias, den) - return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) - - def rewrite_dense_bias_reshape_last(mod): """Rewrite the input graph to reorder reshape operators so that we can perform dense_bias fusion and then offload them to byoc part. """ - mod["main"] = rewrite(DenseReshapeBiasRewrite(), mod["main"]) + mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"]) return mod diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index dbf63d46e116..5045f3323af7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -158,7 +158,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } - if (ops.len() != 0){ + if (ops.len() != 0) { attr.set_post_ops(ops); } diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 33be57f014fb..0daacffc7933 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -912,7 +912,7 @@ def test_dense(run_module, dtype="float32"): config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) + dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) dense = tvm.IRModule.from_expr(dense) config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) @@ -932,7 +932,7 @@ def test_dense_pattern(run_module, dtype="float32"): config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) dense_bias = tvm.IRModule.from_expr(dense_bias) config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype)