From 98bebf65ad0e5a08cbb0993b3bbc57f454a95e60 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Mon, 30 May 2022 22:05:41 +0800 Subject: [PATCH] Enable layer normalization in DNNL byoc. --- python/tvm/relay/op/contrib/dnnl.py | 52 ++++++++++++++++++- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 43 +++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 2e975cf49c885..fe79512edbfa7 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -41,7 +41,7 @@ from tvm.relay.expr_functor import ExprMutator, ExprVisitor from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op +from ...dataflow_pattern import wildcard, is_op, is_constant, rewrite, DFPatternCallback from .register import register_pattern_table logger = logging.getLogger("DNNL") @@ -92,6 +92,7 @@ def _func_wrapper(expr): _register_external_op_helper("nn.softmax") _register_external_op_helper("add") _register_external_op_helper("multiply") +_register_external_op_helper("nn.layer_norm") def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): @@ -526,3 +527,52 @@ def visit_call(self, call): new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) new_mod = transform.RemoveUnusedFunctions()(new_mod) return new_mod + + +class LayerNormRewrite(DFPatternCallback): + ''' + A callback to rewrite the following operators into a single layer normalization operator. + + 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */; + 4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */; + 5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */; + 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */; + 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */; + 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */; + 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */; + ''' + + def __init__(self): + super(LayerNormRewrite, self).__init__() + self.data = wildcard() + self.eps = wildcard() + self.gamma = wildcard() + self.beta = wildcard() + mu = is_op("mean")(self.data) + diff = is_op("subtract")(self.data, mu) + cdiff = diff | is_op("cast")(diff) + p1 = is_op("power")(cdiff, is_constant()) + mp1 = is_op("mean")(p1) + added_eps = is_op("add")(mp1, self.eps) + deno = is_op("sqrt")(added_eps) + div_out = is_op("divide")(diff, deno) + weighted = is_op("multiply")(div_out, self.gamma) + added_bias = is_op("add")(weighted, self.beta) + self.pattern = added_bias + + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + gamma = node_map[self.gamma][0] + beta = node_map[self.beta][0] + return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta, epsilon=1e-5) + + +def rewrite_layer_norm(mod): + """Rewrite the input graph to replace multiple operators with a TVM native layer normalization + operator so that we can offload them to dnnl layer normalization byoc part. + """ + mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) + return mod \ No newline at end of file diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index f6a1c3b790807..0d70f7b56ce81 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -718,6 +718,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { Binary(nid, dnnl::algorithm::binary_mul); + } else if ("nn.layer_norm" == op_name) { + LayerNorm(nid); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -1084,6 +1086,47 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_VARIANCE, variance_memory}}); } + void LayerNorm(const size_t& nid) { + auto node = nodes_[nid]; + + auto data_entry = node.GetInputs()[0]; + auto gamma_entry = node.GetInputs()[1]; + auto beta_entry = node.GetInputs()[2]; + + dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + + float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + + // Memory description. + dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + + // LN description. + auto lnorm_desc = dnnl::layer_normalization_forward::desc( + dnnl::prop_kind::forward_inference, data_md, epsilon, + dnnl::normalization_flags::use_scale | dnnl::normalization_flags::use_shift); + + auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_); + auto lnorm_prim = dnnl::layer_normalization_forward(lnorm_prim_desc); + + net_.push_back(lnorm_prim); + + // Memories. + auto data_memory = BindDNNLMemory(data_entry, data_md); + JSONGraphNodeEntry out_entry(nid, 0); + auto dst_memory = BindDNNLMemory(out_entry, data_md); + auto scale_memory = BindDNNLMemory(gamma_entry, data_md); + auto shift_memory = BindDNNLMemory(beta_entry, data_md); + auto mean_memory = dnnl::memory(lnorm_prim_desc.mean_desc(), engine_); + auto variance_memory = dnnl::memory(lnorm_prim_desc.variance_desc(), engine_); + + net_args_.push_back({{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_MEAN, mean_memory}, + {DNNL_ARG_VARIANCE, variance_memory}, + {DNNL_ARG_SCALE, scale_memory}, + {DNNL_ARG_SHIFT, shift_memory}, + {DNNL_ARG_DST, dst_memory}}); + } + void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid];