Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][DNNL] Improve performance of DNNL BYOC dense operator #11513

Merged
merged 10 commits into from
Jun 10, 2022
123 changes: 118 additions & 5 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@
from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table


logger = logging.getLogger("DNNL")


Expand Down Expand Up @@ -139,12 +144,22 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
data = wildcard()
weight = wildcard()
bias = wildcard()

dense = is_op("nn.dense")(data, weight)
if with_bias:
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
if with_eltwise:
if with_eltwise == "gelu":
const1 = wildcard()
const2 = wildcard()
const3 = wildcard()
div = is_op("divide")(dense_out, const1)
erf_val = is_op("erf")(div)
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out

Expand Down Expand Up @@ -176,7 +191,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning(
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, "
"dense op are supported, but got %s.",
op_name,
)
Expand All @@ -193,20 +208,21 @@ def pattern_table():
dnnl_patterns : List[dnnl_pattern]
Created patterns.
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
dnnl_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
return dnnl_patterns
continue
for conv_name in [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
]:
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
if elt != "gelu":
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
return dnnl_patterns

Expand Down Expand Up @@ -339,6 +355,7 @@ def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
res += i
else:
raise ValueError("Unsupport layout format: %s" % input_data)

return res


Expand Down Expand Up @@ -594,3 +611,99 @@ def rewrite_layer_norm(mod):
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod


class DenseReshapeBiasGeluRewrite(DFPatternCallback):
"""
A callback to reorder reshape operators when the patterns are as below:

Pattern #1:
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
/* ty=Tensor[(1, 3136, 64), float32] */;

Pattern #2:
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
/* ty=Tensor[(1, 3136, 512), float32] */;
4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
"""

def __init__(self, has_gelu=True):
super(DenseReshapeBiasGeluRewrite, self).__init__()
self.data = wildcard()
self.weight = wildcard()
self.bias = wildcard()
self.const1 = wildcard()
self.const2 = wildcard()
self.const3 = wildcard()

self.attr_map = {}
self.has_gelu = has_gelu

den = is_op("nn.dense")(self.data, self.weight)
re_den = is_op("reshape")(den)
added = is_op("add")(self.bias, re_den)
if self.has_gelu:
divisor = is_op("divide")(added, self.const1)
val_erf = is_op("erf")(divisor)
added_erf = is_op("add")(val_erf, self.const2)
mul1 = is_op("multiply")(added, added_erf)
mul2 = is_op("multiply")(mul1, self.const3)
self.pattern = mul2
else:
self.pattern = added

def get_attr(self, pre):
"""Recursively retrieve attributes from reshape operator."""

def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["reshape"] = new_attrs

_analysis.post_order_visit(pre, visit_func)

def callback(self, pre, post, node_map):
self.get_attr(pre)

data = node_map[self.data][0]
weight = node_map[self.weight][0]
bias = node_map[self.bias][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)
if not self.has_gelu:
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])

const1 = node_map[self.const1][0]
const2 = node_map[self.const2][0]
const3 = node_map[self.const3][0]

divisor = relay.op.divide(added, const1)
val_erf = relay.op.erf(divisor)
added_erf = relay.op.add(val_erf, const2)
mul1 = relay.op.multiply(added, added_erf)
mul2 = relay.op.multiply(mul1, const3)
return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"])


def rewrite_dense_bias_gelu_reshape_last(mod):
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
"""Rewrite the input graph to reorder reshape operators so that
we can perform dense_bias_gelu/dense_bias fusion and then offload
them to byoc part.
"""
mod["main"] = rewrite(
[DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"]
)
return mod
36 changes: 34 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,38 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

#else // DNNL JSON runtime

/*!
* \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in
* relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param max_depth The maximum number of calls before the root op, counting from current_call.
* \param root_name The name of expected "root" op in this fused call.
* \return A CallNode corresponding to the root op
*/
inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth,
const std::string& root_name) {
billishyahao marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(current_call && max_depth >= 0);

if (max_depth == 0) {
ICHECK(current_call && IsOp(current_call, root_name));
return current_call;
}
if (IsOp(current_call, root_name)) {
return current_call;
}

ICHECK_GT(current_call->args.size(), 0);

size_t valid_node_idx = 0;
while (valid_node_idx < current_call->args.size() &&
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}

const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return FindCallWithName(next_call, max_depth - 1, root_name);
}

class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
Expand All @@ -454,6 +486,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK_NE(pattern_name, "");
std::vector<std::string> op_list;
size_t pos = 0, start = 0;

while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
std::string op_name = pattern_name.substr(start, pos - start);
if (op_name.find("dnnl") != std::string::npos) {
Expand Down Expand Up @@ -508,8 +541,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.dense") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
call = FindCallWithName(fn->body.as<CallNode>(), 10, "nn.dense");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Find proper dnnl::memory buffers
std::unordered_map<int, dnnl::memory> mem_args;
for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second);

prim.execute(stream_, mem_args);
}
}
Expand Down Expand Up @@ -143,6 +142,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex gelu_pat(".*_gelu.*");

// Parsing post-ops.
dnnl::post_ops ops;
Expand All @@ -155,7 +155,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
attr.set_post_ops(ops);
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
return std::regex_match(op_name, bias_add_pat) ? true : false;
Expand Down
Loading