Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][DNNL] Improve performance of DNNL BYOC dense operator #11513

Merged
merged 10 commits into from
Jun 10, 2022
128 changes: 123 additions & 5 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@
from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table


logger = logging.getLogger("DNNL")


Expand Down Expand Up @@ -139,12 +144,22 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
data = wildcard()
weight = wildcard()
bias = wildcard()

dense = is_op("nn.dense")(data, weight)
if with_bias:
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
if with_eltwise:
if with_eltwise == "gelu":
const1 = wildcard()
const2 = wildcard()
const3 = wildcard()
div = is_op("divide")(dense_out, const1)
erf_val = is_op("erf")(div)
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out

Expand Down Expand Up @@ -176,7 +191,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning(
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, "
"dense op are supported, but got %s.",
op_name,
)
Expand All @@ -193,20 +208,21 @@ def pattern_table():
dnnl_patterns : List[dnnl_pattern]
Created patterns.
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
dnnl_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
return dnnl_patterns
continue
for conv_name in [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
]:
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
if elt != "gelu":
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
return dnnl_patterns

Expand Down Expand Up @@ -339,6 +355,7 @@ def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
res += i
else:
raise ValueError("Unsupport layout format: %s" % input_data)

return res


Expand Down Expand Up @@ -594,3 +611,104 @@ def rewrite_layer_norm(mod):
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod


class DenseReshapeBiasGeluRewrite(DFPatternCallback):
"""
A callback to reorder reshape operators when the patterns are as below:

Pattern #1:
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
/* ty=Tensor[(1, 3136, 64), float32] */;

Pattern #2:
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
/* ty=Tensor[(1, 3136, 512), float32] */;
4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
"""

def __init__(self, has_gelu=True):
super(DenseReshapeBiasGeluRewrite, self).__init__()
self.data = wildcard()
self.weight = wildcard()
self.bias = wildcard()
self.const1 = wildcard()
self.const2 = wildcard()
self.const3 = wildcard()

self.attr_map = {}
self.has_gelu = has_gelu

den = is_op("nn.dense")(self.data, self.weight)
re_den = is_op("reshape")(den)
added = is_op("add")(self.bias, re_den)
if self.has_gelu:
divisor = is_op("divide")(added, self.const1)
val_erf = is_op("erf")(divisor)
added_erf = is_op("add")(val_erf, self.const2)
mul1 = is_op("multiply")(added, added_erf)
mul2 = is_op("multiply")(mul1, self.const3)
self.pattern = mul2
else:
self.pattern = added

def get_attr(self, pre):
"""Recursively retrieve attributes from reshape operator."""

def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["reshape"] = new_attrs

_analysis.post_order_visit(pre, visit_func)

def callback(self, pre, post, node_map):
self.get_attr(pre)

data = node_map[self.data][0]
weight = node_map[self.weight][0]
bias = node_map[self.bias][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)
if not self.has_gelu:
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])

const1 = node_map[self.const1][0]
const2 = node_map[self.const2][0]
const3 = node_map[self.const3][0]

divisor = relay.op.divide(added, const1)
val_erf = relay.op.erf(divisor)
added_erf = relay.op.add(val_erf, const2)
mul1 = relay.op.multiply(added, added_erf)
mul2 = relay.op.multiply(mul1, const3)
return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"])


def rewrite_dense_bias_gelu_reshape_last(mod):
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
"""Rewrite the input graph to reorder reshape operators so that
we can perform dense_bias_gelu fusion and then offload them to byoc part.
"""
mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(), mod["main"])
return mod


def rewrite_dense_bias_reshape_last(mod):
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
"""Rewrite the input graph to reorder reshape operators so that
we can perform dense_bias fusion and then offload them to byoc part.
"""
mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"])
return mod
36 changes: 36 additions & 0 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,38 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

#else // DNNL JSON runtime

/*!
* \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in
* relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param max_depth The maximum number of calls before the root op, counting from current_call.
* \param root_name The name of expected "root" op in this fused call.
* \return A CallNode corresponding to the root op
*/
inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth,
const std::string& root_name) {
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(current_call && max_depth >= 0);

if (max_depth == 0) {
ICHECK(current_call && IsOp(current_call, root_name));
return current_call;
}
if (IsOp(current_call, root_name)) {
return current_call;
}

ICHECK_GT(current_call->args.size(), 0);

size_t valid_node_idx = 0;
while (valid_node_idx < current_call->args.size() &&
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}

const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return FindCallWithName(next_call, max_depth - 1, root_name);
}

class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
Expand All @@ -454,6 +486,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK_NE(pattern_name, "");
std::vector<std::string> op_list;
size_t pos = 0, start = 0;

while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
std::string op_name = pattern_name.substr(start, pos - start);
if (op_name.find("dnnl") != std::string::npos) {
Expand Down Expand Up @@ -507,6 +540,9 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("gelu") != std::string::npos) {
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
call = FindCallWithName(fn->body.as<CallNode>(), 10, "nn.dense");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.dense") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Find proper dnnl::memory buffers
std::unordered_map<int, dnnl::memory> mem_args;
for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second);

prim.execute(stream_, mem_args);
}
}
Expand Down Expand Up @@ -143,6 +142,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex gelu_pat(".*_gelu.*");

// Parsing post-ops.
dnnl::post_ops ops;
Expand All @@ -155,7 +155,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
attr.set_post_ops(ops);
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
return std::regex_match(op_name, bias_add_pat) ? true : false;
Expand Down
31 changes: 31 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import sys
import subprocess
import math

import tvm
from tvm import relay
Expand Down Expand Up @@ -121,6 +122,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
transform.PartitionGraph(),
]
)

with tvm.transform.PassContext(opt_level=3):
mod = byoc_seq(mod)
mod = dnnl.prune_dnnl_subgraphs(mod)
Expand Down Expand Up @@ -170,6 +172,7 @@ def check_dnnl_used(mod, subgraph_num=None):
]
if test_bf16 and bf16_supported():
configs += [(True, False, True), (True, True, True)]

for use_dnnl, alter_layout, use_bf16 in configs:
result_key = (
mode
Expand Down Expand Up @@ -585,12 +588,27 @@ def get_conv3d_transpose_bias(
return out, dic, param_lst


def gelu_helper(data):
const1 = relay.const(math.sqrt(2.0))
const2 = relay.const(1.0)
const3 = relay.const(0.5)
divisor = relay.op.divide(data, const1)
val_erf = relay.op.erf(divisor)
added_erf = relay.op.add(val_erf, const2)
mul1 = relay.op.multiply(data, added_erf)
out = relay.op.multiply(mul1, const3)
return out


def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"):
x = relay.var("x", shape=(x_shape), dtype=dtype)
kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
out = relay.nn.dense(x, kernel, units=k_shape[0])
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "gelu":
out = gelu_helper(out)
return out, dic, param_lst


Expand All @@ -600,6 +618,9 @@ def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="fl
out = relay.nn.bias_add(dense, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "gelu":
out = gelu_helper(out)
return out, dic, param_lst


Expand Down Expand Up @@ -891,6 +912,11 @@ def test_dense(run_module, dtype="float32"):
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype)
dense = tvm.IRModule.from_expr(dense)
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_dense_pattern(run_module, dtype="float32"):
x_shape = (1, 16)
Expand All @@ -906,6 +932,11 @@ def test_dense_pattern(run_module, dtype="float32"):
config = dense_bias, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype)
dense_bias = tvm.IRModule.from_expr(dense_bias)
config = dense_bias, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_pool2d(run_module, dtype="float32"):
def get_graph(
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,9 @@ def test_dnnl_fuse():
) = (
dnnl_patterns[1],
dnnl_patterns[13],
dnnl_patterns[19],
dnnl_patterns[25],
dnnl_patterns[37],
dnnl_patterns[20],
dnnl_patterns[26],
dnnl_patterns[38],
)

def get_blocks(
Expand Down