Skip to content

Commit

Permalink
Enable layer normalization in DNNL byoc.
Browse files Browse the repository at this point in the history
  • Loading branch information
billishyahao committed May 30, 2022
1 parent 559f0c7 commit 98bebf6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
52 changes: 51 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
from ...dataflow_pattern import wildcard, is_op, is_constant, rewrite, DFPatternCallback
from .register import register_pattern_table

logger = logging.getLogger("DNNL")
Expand Down Expand Up @@ -92,6 +92,7 @@ def _func_wrapper(expr):
_register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
_register_external_op_helper("nn.layer_norm")


def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
Expand Down Expand Up @@ -526,3 +527,52 @@ def visit_call(self, call):
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod


class LayerNormRewrite(DFPatternCallback):
'''
A callback to rewrite the following operators into a single layer normalization operator.
1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */;
5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
'''

def __init__(self):
super(LayerNormRewrite, self).__init__()
self.data = wildcard()
self.eps = wildcard()
self.gamma = wildcard()
self.beta = wildcard()
mu = is_op("mean")(self.data)
diff = is_op("subtract")(self.data, mu)
cdiff = diff | is_op("cast")(diff)
p1 = is_op("power")(cdiff, is_constant())
mp1 = is_op("mean")(p1)
added_eps = is_op("add")(mp1, self.eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
weighted = is_op("multiply")(div_out, self.gamma)
added_bias = is_op("add")(weighted, self.beta)
self.pattern = added_bias

def callback(self, pre, post, node_map):
data = node_map[self.data][0]
gamma = node_map[self.gamma][0]
beta = node_map[self.beta][0]
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta, epsilon=1e-5)


def rewrite_layer_norm(mod):
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
operator so that we can offload them to dnnl layer normalization byoc part.
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod
43 changes: 43 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Binary(nid, dnnl::algorithm::binary_mul);
} else if ("nn.layer_norm" == op_name) {
LayerNorm(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down Expand Up @@ -1084,6 +1086,47 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, variance_memory}});
}

void LayerNorm(const size_t& nid) {
auto node = nodes_[nid];

auto data_entry = node.GetInputs()[0];
auto gamma_entry = node.GetInputs()[1];
auto beta_entry = node.GetInputs()[2];

dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];

float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);

// Memory description.
dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);

// LN description.
auto lnorm_desc = dnnl::layer_normalization_forward::desc(
dnnl::prop_kind::forward_inference, data_md, epsilon,
dnnl::normalization_flags::use_scale | dnnl::normalization_flags::use_shift);

auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);
auto lnorm_prim = dnnl::layer_normalization_forward(lnorm_prim_desc);

net_.push_back(lnorm_prim);

// Memories.
auto data_memory = BindDNNLMemory(data_entry, data_md);
JSONGraphNodeEntry out_entry(nid, 0);
auto dst_memory = BindDNNLMemory(out_entry, data_md);
auto scale_memory = BindDNNLMemory(gamma_entry, data_md);
auto shift_memory = BindDNNLMemory(beta_entry, data_md);
auto mean_memory = dnnl::memory(lnorm_prim_desc.mean_desc(), engine_);
auto variance_memory = dnnl::memory(lnorm_prim_desc.variance_desc(), engine_);

net_args_.push_back({{DNNL_ARG_SRC, data_memory},
{DNNL_ARG_MEAN, mean_memory},
{DNNL_ARG_VARIANCE, variance_memory},
{DNNL_ARG_SCALE, scale_memory},
{DNNL_ARG_SHIFT, shift_memory},
{DNNL_ARG_DST, dst_memory}});
}

void Pooling(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];

Expand Down

0 comments on commit 98bebf6

Please sign in to comment.