From 010ab29ec3efb0fe3fe83417f5d6c94250336c0d Mon Sep 17 00:00:00 2001 From: billishyahao Date: Tue, 31 May 2022 15:00:49 +0800 Subject: [PATCH 01/10] Enhance dnnl byoc dense operators performance by 1) introducing gelu fusion and 2) introducing alter dense weight layout. --- python/tvm/relay/op/contrib/dnnl.py | 388 +++++++++++++++++- src/relay/backend/contrib/dnnl/codegen.cc | 23 +- .../backend/contrib/dnnl/query_layout.cc | 42 ++ src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 74 +++- 4 files changed, 506 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index c87a7162b070..6631cf8fa10b 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -40,10 +40,15 @@ from tvm.relay.expr import GlobalVar from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.analysis import analysis as _analysis +from tvm.relay import expr as _expr + from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table +import re + logger = logging.getLogger("DNNL") @@ -144,7 +149,14 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): dense_out = is_op("add")(dense, bias) else: dense_out = dense - if with_eltwise: + if with_eltwise == "gelu": + div = is_op("divide")(dense_out, is_constant()) + erf_val = is_op("erf")(div) + added_erf_val = is_op("add")(erf_val, is_constant()) + mul_val = is_op("multiply")(dense_out, added_erf_val) + dense_out = is_op("multiply")(mul_val, is_constant()) + + elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) return dense_out @@ -168,22 +180,60 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): pat_name = op_name.replace("nn", "dnnl") if "_transpose" in op_name: pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] + if "contrib_dense_pack" in op_name: + pat_name = "dnnl.packeddense" pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" if "conv" in op_name: dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) elif op_name == "nn.dense": dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) + elif op_name == "nn.contrib_dense_pack": + dnnl_pattern = (pat_name, make_packed_dense_pattern(with_bias, with_eltwise)) else: logger.warning( - "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and " - "dense op are supported, but got %s.", + "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, " + "dense and packed dense op are supported, but got %s.", op_name, ) dnnl_pattern = () return dnnl_pattern +def make_packed_dense_pattern(with_bias=True, with_eltwise=None): + """Create patterns related to nn.contrib_dense_pack. + + Parameters + ---------- + with_bias : bool + Whether attach `bias_add` to `nn.dense`. + with_eltwise : str + The attached elementwise post-op name. + Returns + ------- + dense_out : CallPattern + Call node sequence. + """ + data = wildcard() + weight = wildcard() + bias = wildcard() + dense = is_op("nn.contrib_dense_pack")(data, weight) + if with_bias: + dense_out = is_op("add")(dense, bias) + else: + dense_out = dense + if with_eltwise == "gelu": + div = is_op("divide")(dense_out, is_constant()) + erf_val = is_op("erf")(div) + added_erf_val = is_op("add")(erf_val, is_constant()) + mul_val = is_op("multiply")(dense_out, added_erf_val) + dense_out = is_op("multiply")(mul_val, is_constant()) + + elif with_eltwise: + dense_out = is_op(with_eltwise)(dense_out) + return dense_out + + @register_pattern_table("dnnl") def pattern_table(): """Create dnnl patterns. @@ -193,12 +243,12 @@ def pattern_table(): dnnl_patterns : List[dnnl_pattern] Created patterns. """ - elt_list = ["nn.relu", "tanh", "sigmoid", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None] dnnl_patterns = [] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: - return dnnl_patterns + continue for conv_name in [ "nn.conv1d", "nn.conv2d", @@ -206,8 +256,10 @@ def pattern_table(): "nn.conv2d_transpose", "nn.conv3d_transpose", ]: - dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) + if elt != "gelu": + dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt)) + dnnl_patterns.append(make_dnnl_pattern("nn.contrib_dense_pack", with_bias, elt)) return dnnl_patterns @@ -280,6 +332,29 @@ def get_optimal_layout_for_conv_transpose( ) +def get_optimal_layout_for_dense( + data_layout, weight_shape, out_shape +): + """Get the optimal layout of dnnl, given shape of dense. + + Parameters + ---------- + data_layout, weight_shape, out_shape + : String + Input argument. + + Returns + ------- + layouts : string + The result. + """ + return _ffi_api.get_optimal_layout_for_dense( + data_layout, + weight_shape, + out_shape, + ) + + def get_shape(tensor): """Get tensor's shape.""" if isinstance(tensor, relay.expr.Var): @@ -310,20 +385,24 @@ def get_dtype(tensor): raise TypeError("Unsupport data type: %s" % type(tensor)) -def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): +def tag2layout(input_data, is_weight=False, op_type="Conv1D"): """Transfer layout, denoted with `a, b, c, d, e`, into valid layout (NCHW / OIHW) of TVM.""" - if "Conv1D" in conv_type: + if "Conv1D" in op_type: data_dic = {"a": "N", "b": "C", "c": "W"} weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"} - elif "Conv2D" in conv_type: + elif "Conv2D" in op_type: data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"} if "e" in input_data: weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"} - elif "Conv3D" in conv_type: + elif "Conv3D" in op_type: data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"} weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"} + elif "Dense" in op_type: + data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} + weight_dic = data_dic + dic = weight_dic if is_weight else data_dic res = "" @@ -339,6 +418,17 @@ def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): res += i else: raise ValueError("Unsupport layout format: %s" % input_data) + + if "Dense" in op_type: + # Post process for dense weight layout + # e.g. NC16c64n => NC64n16c + regexN = '\d+n' + regexC = '\d+c' + + matchN = re.findall(regexN, res) + matchC = re.findall(regexC, res) + res = "NC" + "".join(matchN) + "".join(matchC) + return res @@ -387,9 +477,9 @@ def alter_conv(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) - new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) - new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, op_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, op_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, op_type=conv_type) if conv_type == "Conv1D": return relay.nn.conv1d(data, weight, **new_attrs) @@ -426,9 +516,9 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) - new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) - new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, op_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, op_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, op_type=conv_type) if conv_type == "Conv1DTranspose": return relay.nn.conv1d_transpose(data, weight, **new_attrs) @@ -437,6 +527,34 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): return relay.nn.conv3d_transpose(data, weight, **new_attrs) +def alter_dense(attrs, inputs, tinfos, out_type): + """The packed dense's layout auto-query func for dnnl.""" + + data, weight = inputs + + weight_shape_list = [str(x) for x in get_shape(weight)] + out_shape_list = [str(x) for x in get_shape(out_type)] + + data_shape = ",".join([out_shape_list[0], weight_shape_list[1]]) + weight_shape = ",".join(weight_shape_list) + out_shape = ",".join(out_shape_list) + + res = get_optimal_layout_for_dense( + data_shape, + weight_shape, + out_shape + ) + + _, weight_df, _ = res.split(",") + + new_attrs = {} + new_attrs["weight_layout"] = tag2layout(weight_df, is_weight=True, op_type="Dense") + + weight_transform = relay.layout_transform(weight, "NC", dst_layout=new_attrs["weight_layout"]) + return relay.nn.contrib_dense_pack(data, weight_transform, weight_layout=new_attrs["weight_layout"], + units=None, out_dtype=out_type.dtype) + + class IsComputeIntensiveGraph(ExprVisitor): """ Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and @@ -594,3 +712,241 @@ def rewrite_layer_norm(mod): """ mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) return mod + + +class DenseReshapeBiasGeluRewrite(DFPatternCallback): + """ + A callback to reorder reshape operators when the patten is as below: + 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; + 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */; + 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */; + 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */; + 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + 7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */; + 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; + """ + + def __init__(self, pack_wei=False): + super(DenseReshapeBiasGeluRewrite, self).__init__() + self.data = wildcard() + self.weight = wildcard() + self.bias = wildcard() + self.const1 = wildcard() + self.const2 = wildcard() + self.const3 = wildcard() + + self.pack_wei = pack_wei + + self.attr_map = {} + + den = is_op("nn.dense")(self.data, self.weight) + re_den = is_op("reshape")(den) + added = is_op("add")(self.bias, re_den) + divisor = is_op("divide")(added, self.const1) + val_erf = is_op("erf")(divisor) + added_erf = is_op("add")(val_erf, self.const2) + mul1 = is_op("multiply")(added, added_erf) + mul2 = is_op("multiply")(mul1, self.const3) + self.pattern = mul2 + + def get_attr(self, pre): + def visit_func(expr): + if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["reshape"] = new_attrs + elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["nn.dense"] = new_attrs + + _analysis.post_order_visit(pre, visit_func) + + def callback(self, pre, post, node_map): + self.get_attr(pre) + + data = node_map[self.data][0] + weight = node_map[self.weight][0] + bias = node_map[self.bias][0] + const1 = node_map[self.const1][0] + const2 = node_map[self.const2][0] + const3 = node_map[self.const3][0] + + if self.pack_wei: + weight_shape_list = [str(x) for x in get_shape(weight)] + data_shape_list = [str(x) for x in get_shape(data)] + + data_shape = ",".join(data_shape_list) + weight_shape = ",".join(weight_shape_list) + out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) + + res = get_optimal_layout_for_dense( + data_shape, + weight_shape, + out_shape + ) + + _, weight_df, _ = res.split(",") + reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + + weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + + den = relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, + units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") + else: + den = relay.op.nn.dense(data, weight) + added = relay.op.add(bias, den) + divisor = relay.op.divide(added, const1) + val_erf = relay.op.erf(divisor) + added_erf = relay.op.add(val_erf, const2) + mul1 = relay.op.multiply(added, added_erf) + mul2 = relay.op.multiply(mul1, const3) + return relay.op.reshape(mul2, self.attr_map['reshape']['newshape']) + + +def rewrite_dense_bias_gelu_reshape_last(mod, pack_wei=False): + """Rewrite the input graph to reorder reshape operators so that + we can perform dense_bias_gelu fusion and then offload them to byoc part. + """ + mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(pack_wei), mod["main"]) + return mod + + +class DenseReshapeBiasRewrite(DFPatternCallback): + """ + A callback to reorder reshape operators when the patten is as below: + 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; + 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */; + """ + + def __init__(self, pack_wei=False): + super(DenseReshapeBiasRewrite, self).__init__() + self.data = wildcard() + self.weight = wildcard() + self.bias = wildcard() + + self.pack_wei = pack_wei + self.attr_map = {} + + den = is_op("nn.dense")(self.data, self.weight) + re_den = is_op("reshape")(den) + added = is_op("add")(self.bias, re_den) + self.pattern = added + + def get_attr(self, pre): + def visit_func(expr): + if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["reshape"] = new_attrs + elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["nn.dense"] = new_attrs + + _analysis.post_order_visit(pre, visit_func) + + def callback(self, pre, post, node_map): + self.get_attr(pre) + + data = node_map[self.data][0] + weight = node_map[self.weight][0] + bias = node_map[self.bias][0] + + if self.pack_wei: + weight_shape_list = [str(x) for x in get_shape(weight)] + data_shape_list = [str(x) for x in get_shape(data)] + + data_shape = ",".join(data_shape_list) + weight_shape = ",".join(weight_shape_list) + out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) + + res = get_optimal_layout_for_dense( + data_shape, + weight_shape, + out_shape + ) + + _, weight_df, _ = res.split(",") + reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + + den = relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, + units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") + else: + den = relay.op.nn.dense(data, weight) + added = relay.op.add(bias, den) + return relay.op.reshape(added, self.attr_map['reshape']['newshape']) + + +def rewrite_dense_bias_reshape_last(mod, pack_wei=False): + """Rewrite the input graph to reorder reshape operators so that + we can perform dense_bias fusion and then offload them to byoc part. + """ + mod["main"] = rewrite(DenseReshapeBiasRewrite(pack_wei), mod["main"]) + return mod + + +class PackDenseRewrite(DFPatternCallback): + """A callback to rewrite nn.dense to nn.contrib_dense_pack.""" + + def __init__(self): + super(PackDenseRewrite, self).__init__() + self.data = wildcard() + self.weight = wildcard() + + self.attr_map = {} + + den = is_op("nn.dense")(self.data, self.weight) + self.pattern = den + + def get_attr(self, pre): + def visit_func(expr): + if isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): + new_attrs = {} + for k in expr.attrs.keys(): + new_attrs[k] = expr.attrs[k] + self.attr_map["nn.dense"] = new_attrs + + _analysis.post_order_visit(pre, visit_func) + + def callback(self, pre, post, node_map): + self.get_attr(pre) + + data = node_map[self.data][0] + weight = node_map[self.weight][0] + + weight_shape_list = [str(x) for x in get_shape(weight)] + data_shape_list = [str(x) for x in get_shape(data)] + + data_shape = ",".join(data_shape_list) + weight_shape = ",".join(weight_shape_list) + out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) + + res = get_optimal_layout_for_dense( + data_shape, + weight_shape, + out_shape + ) + + _, weight_df, _ = res.split(",") + + reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + + weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + return relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, + units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") + + +def rewrite_dense_to_pack(mod): + """Rewrite the input graph to use packed dense operators so that + we can gain better performance boost in dnnl byoc part. + """ + mod["main"] = rewrite(PackDenseRewrite(), mod["main"]) + return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 41480ed33b0a..61e19dd342ba 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -447,6 +447,9 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"sigmoid", "sigmoid"}, {"nn.deconv2d", "nn.conv2d_transpose"}, {"nn.deconv3d", "nn.conv3d_transpose"}, + {"add", "add"}, + {"multiply", "multiply"}, + {"nn.packeddense", "nn.contrib_dense_pack"}, }; std::vector ParsingOpList(const std::string& pattern_name, @@ -454,12 +457,20 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK_NE(pattern_name, ""); std::vector op_list; size_t pos = 0, start = 0; - while ((pos = pattern_name.find(interval, start)) != std::string::npos) { - std::string op_name = pattern_name.substr(start, pos - start); + + std::string raw_name = pattern_name; + if (raw_name.find("gelu") != std::string::npos) { + //TODO(billishyahao): Remove me after introducing new gelu operator + raw_name.replace(raw_name.find("gelu"), 4, "multiply_multiply_"); + } + while ((pos = raw_name.find(interval, start)) != std::string::npos) { + std::string op_name = raw_name.substr(start, pos - start); if (op_name.find("dnnl") != std::string::npos) { op_name.replace(op_name.find("dnnl"), 4, "nn"); if (op_name.find("deconv") != std::string::npos) { op_name = op_map[op_name]; + } else if (op_name.find("packeddense") != std::string::npos) { + op_name = op_map[op_name]; } } else { op_name = op_map[op_name]; @@ -467,8 +478,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { if (pos > start) op_list.push_back(op_name); start = pos + interval.size(); } - if (pattern_name.size() > start) { - op_list.push_back(op_map[pattern_name.substr(start)]); + if (raw_name.size() > start) { + op_list.push_back(op_map[raw_name.substr(start)]); } return op_list; } @@ -511,6 +522,10 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; + } else if (name.find("dnnl.packeddense") != std::string::npos) { + std::vector op_list = ParsingOpList(name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); + ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; } diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc index 3762c1906f40..4bf5e4d27fa7 100755 --- a/src/relay/backend/contrib/dnnl/query_layout.cc +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -362,6 +362,43 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, return res; } +std::string get_optimal_layout_for_dense(std::string data_layout, std::string weight_shape, + std::string out_shape) { + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + dnnl::stream s(eng); + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + dnnl::memory::dims data_dims = str2dims(data_layout); + dnnl::memory::dims weight_dims = str2dims(weight_shape); + dnnl::memory::dims out_dims = str2dims(out_shape); + dnnl::memory::dims bias_dims = {out_dims[1]}; + + // Memory descriptions. + auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::any}); + auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::any}); + auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::any}); + auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::any}); + + // Dense description. + auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, + weight_md, bias_md, dst_md); + + dnnl::primitive_attr attr; + auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, eng); + + auto src_format = dense_prim_desc.src_desc(); + auto weights_format = dense_prim_desc.weights_desc(); + auto dst_format = dense_prim_desc.dst_desc(); + std::string src_df, weight_df, dst_df; + + src_df = md2fmt_tag_str(&src_format); + weight_df = md2fmt_tag_str(&weights_format); + dst_df = md2fmt_tag_str(&dst_format); + std::string res = src_df + "," + weight_df + "," + dst_df; + return res; +} + TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5], @@ -374,6 +411,11 @@ TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose") args[5], args[6], args[7], args[8], args[9]); }); +TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_dense") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = get_optimal_layout_for_dense(args[0], args[1], args[2]); + }); + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index db8f25e2a6ea..ef0fad0015b5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -143,6 +143,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex relu_pat(".*_relu.*"); std::regex tanh_pat(".*_tanh.*"); std::regex sigmoid_pat(".*_sigmoid.*"); + std::regex gelu_pat(".*_gelu.*"); // Parsing post-ops. dnnl::post_ops ops; @@ -155,7 +156,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, sigmoid_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); } - attr.set_post_ops(ops); + if (std::regex_match(op_name, gelu_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + } + if (ops.len() != 0){ + attr.set_post_ops(ops); + } // Parsing bias_add. return std::regex_match(op_name, bias_add_pat) ? true : false; @@ -173,6 +179,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex deconv_pat(".*deconv[1-3]d.*"); std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); std::regex dense_pat(".*dense.*"); + std::regex dense_pack_pat(".*packeddense.*"); std::regex max_pool_pat(".*max_pool[1-3]d"); std::regex avg_pool_pat(".*avg_pool[1-3]d"); @@ -187,6 +194,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Deconvolution(nid); } else if (std::regex_match(op_name, conv_pat)) { Convolution(nid); + } else if (std::regex_match(op_name, dense_pack_pat) || + "nn.contrib_dense_pack" == op_name) { + DensePack(nid); } else if (std::regex_match(op_name, dense_pat)) { Dense(nid); } else if ("nn.batch_norm" == op_name) { @@ -405,6 +415,68 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_DST, dst_tr}}); } + void DensePack(const size_t& nid) { + auto node = nodes_[nid]; + auto op_name = node.GetOpName(); + dnnl::primitive_attr attr; + bool has_bias = ParsingOpName(op_name, attr); + + // Setup attributes. + auto data_entry = node.GetInputs()[0]; + auto weight_entry = node.GetInputs()[1]; + JSONGraphNodeEntry out_entry(nid, 0); + dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; + dnnl::memory::dim OC = out_shape[1]; + + std::string weight_layout = node.GetAttr>("weight_layout")[0]; + + // Memory shapes. + dnnl::memory::dims data_dims = input_shape; + dnnl::memory::dims weight_dims = {out_shape[1], input_shape[1]}; + dnnl::memory::dims bias_dims = {OC}; + dnnl::memory::dims out_dims = out_shape; + + // Memory descriptions. + auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); + auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, layout2tag(weight_layout)}); + auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); + auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); + + // Dense description. + auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, + weight_md, bias_md, dst_md); + + // Enable elementwise post-ops. + auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); + + auto dense = dnnl::inner_product_forward(dense_prim_desc); + net_.push_back(dense); + + // Memories. + auto data_memory = BindDNNLMemory(data_entry, data_md); + auto weight_memory = BindDNNLMemory(weight_entry, weight_md); + + // Bias memory. + auto bias_memory = dnnl::memory(bias_md, engine_); + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + BindDNNLMemory(bias_entry, bias_memory); + } else { + float bias[OC] = {0}; + write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); + } + + // Output memory. + auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); + + net_args_.push_back({{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_WEIGHTS, weight_memory}, + {DNNL_ARG_BIAS, bias_memory}, + {DNNL_ARG_DST, dst_memory}}); + } + void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; From 8d9b0856eaa7f3228fa9d66b832f15425c40d54a Mon Sep 17 00:00:00 2001 From: billishyahao Date: Tue, 31 May 2022 23:07:44 +0800 Subject: [PATCH 02/10] fix lint issue --- python/tvm/relay/op/contrib/dnnl.py | 139 +++++++++++++++------------- 1 file changed, 73 insertions(+), 66 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 6631cf8fa10b..ae994243f1f5 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -332,9 +332,7 @@ def get_optimal_layout_for_conv_transpose( ) -def get_optimal_layout_for_dense( - data_layout, weight_shape, out_shape -): +def get_optimal_layout_for_dense(data_layout, weight_shape, out_shape): """Get the optimal layout of dnnl, given shape of dense. Parameters @@ -402,7 +400,6 @@ def tag2layout(input_data, is_weight=False, op_type="Conv1D"): elif "Dense" in op_type: data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} weight_dic = data_dic - dic = weight_dic if is_weight else data_dic res = "" @@ -419,15 +416,15 @@ def tag2layout(input_data, is_weight=False, op_type="Conv1D"): else: raise ValueError("Unsupport layout format: %s" % input_data) - if "Dense" in op_type: - # Post process for dense weight layout + if "Dense" in op_type: + # Post process for dense weight layout # e.g. NC16c64n => NC64n16c - regexN = '\d+n' - regexC = '\d+c' + regexN = "\d+n" + regexC = "\d+c" matchN = re.findall(regexN, res) matchC = re.findall(regexC, res) - res = "NC" + "".join(matchN) + "".join(matchC) + res = "NC" + "".join(matchN) + "".join(matchC) return res @@ -539,20 +536,21 @@ def alter_dense(attrs, inputs, tinfos, out_type): weight_shape = ",".join(weight_shape_list) out_shape = ",".join(out_shape_list) - res = get_optimal_layout_for_dense( - data_shape, - weight_shape, - out_shape - ) + res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") - + new_attrs = {} new_attrs["weight_layout"] = tag2layout(weight_df, is_weight=True, op_type="Dense") - + weight_transform = relay.layout_transform(weight, "NC", dst_layout=new_attrs["weight_layout"]) - return relay.nn.contrib_dense_pack(data, weight_transform, weight_layout=new_attrs["weight_layout"], - units=None, out_dtype=out_type.dtype) + return relay.nn.contrib_dense_pack( + data, + weight_transform, + weight_layout=new_attrs["weight_layout"], + units=None, + out_dtype=out_type.dtype, + ) class IsComputeIntensiveGraph(ExprVisitor): @@ -716,7 +714,7 @@ def rewrite_layer_norm(mod): class DenseReshapeBiasGeluRewrite(DFPatternCallback): """ - A callback to reorder reshape operators when the patten is as below: + A callback to reorder reshape operators when the patten is as below: 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */; 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */; @@ -739,7 +737,7 @@ def __init__(self, pack_wei=False): self.pack_wei = pack_wei self.attr_map = {} - + den = is_op("nn.dense")(self.data, self.weight) re_den = is_op("reshape")(den) added = is_op("add")(self.bias, re_den) @@ -762,7 +760,7 @@ def visit_func(expr): for k in expr.attrs.keys(): new_attrs[k] = expr.attrs[k] self.attr_map["nn.dense"] = new_attrs - + _analysis.post_order_visit(pre, visit_func) def callback(self, pre, post, node_map): @@ -782,20 +780,23 @@ def callback(self, pre, post, node_map): data_shape = ",".join(data_shape_list) weight_shape = ",".join(weight_shape_list) out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense( - data_shape, - weight_shape, - out_shape - ) + + res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - + weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) - - den = relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, - units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") + + den = relay.op.nn.contrib_dense_pack( + data, + weight_transform, + weight_layout=reco_weight_layout, + units=None, + out_dtype=self.attr_map["nn.dense"]["out_dtype"] + if "out_dtype" in self.attr_map["nn.dense"] + else "", + ) else: den = relay.op.nn.dense(data, weight) added = relay.op.add(bias, den) @@ -804,11 +805,11 @@ def callback(self, pre, post, node_map): added_erf = relay.op.add(val_erf, const2) mul1 = relay.op.multiply(added, added_erf) mul2 = relay.op.multiply(mul1, const3) - return relay.op.reshape(mul2, self.attr_map['reshape']['newshape']) + return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"]) def rewrite_dense_bias_gelu_reshape_last(mod, pack_wei=False): - """Rewrite the input graph to reorder reshape operators so that + """Rewrite the input graph to reorder reshape operators so that we can perform dense_bias_gelu fusion and then offload them to byoc part. """ mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(pack_wei), mod["main"]) @@ -817,7 +818,7 @@ def rewrite_dense_bias_gelu_reshape_last(mod, pack_wei=False): class DenseReshapeBiasRewrite(DFPatternCallback): """ - A callback to reorder reshape operators when the patten is as below: + A callback to reorder reshape operators when the patten is as below: 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */; @@ -828,10 +829,10 @@ def __init__(self, pack_wei=False): self.data = wildcard() self.weight = wildcard() self.bias = wildcard() - + self.pack_wei = pack_wei self.attr_map = {} - + den = is_op("nn.dense")(self.data, self.weight) re_den = is_op("reshape")(den) added = is_op("add")(self.bias, re_den) @@ -849,7 +850,7 @@ def visit_func(expr): for k in expr.attrs.keys(): new_attrs[k] = expr.attrs[k] self.attr_map["nn.dense"] = new_attrs - + _analysis.post_order_visit(pre, visit_func) def callback(self, pre, post, node_map): @@ -858,7 +859,7 @@ def callback(self, pre, post, node_map): data = node_map[self.data][0] weight = node_map[self.weight][0] bias = node_map[self.bias][0] - + if self.pack_wei: weight_shape_list = [str(x) for x in get_shape(weight)] data_shape_list = [str(x) for x in get_shape(data)] @@ -866,28 +867,31 @@ def callback(self, pre, post, node_map): data_shape = ",".join(data_shape_list) weight_shape = ",".join(weight_shape_list) out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense( - data_shape, - weight_shape, - out_shape - ) + + res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) - den = relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, - units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") + den = relay.op.nn.contrib_dense_pack( + data, + weight_transform, + weight_layout=reco_weight_layout, + units=None, + out_dtype=self.attr_map["nn.dense"]["out_dtype"] + if "out_dtype" in self.attr_map["nn.dense"] + else "", + ) else: den = relay.op.nn.dense(data, weight) added = relay.op.add(bias, den) - return relay.op.reshape(added, self.attr_map['reshape']['newshape']) + return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) def rewrite_dense_bias_reshape_last(mod, pack_wei=False): - """Rewrite the input graph to reorder reshape operators so that - we can perform dense_bias fusion and then offload them to byoc part. + """Rewrite the input graph to reorder reshape operators so that + we can perform dense_bias fusion and then offload them to byoc part. """ mod["main"] = rewrite(DenseReshapeBiasRewrite(pack_wei), mod["main"]) return mod @@ -900,9 +904,9 @@ def __init__(self): super(PackDenseRewrite, self).__init__() self.data = wildcard() self.weight = wildcard() - + self.attr_map = {} - + den = is_op("nn.dense")(self.data, self.weight) self.pattern = den @@ -913,7 +917,7 @@ def visit_func(expr): for k in expr.attrs.keys(): new_attrs[k] = expr.attrs[k] self.attr_map["nn.dense"] = new_attrs - + _analysis.post_order_visit(pre, visit_func) def callback(self, pre, post, node_map): @@ -921,32 +925,35 @@ def callback(self, pre, post, node_map): data = node_map[self.data][0] weight = node_map[self.weight][0] - + weight_shape_list = [str(x) for x in get_shape(weight)] data_shape_list = [str(x) for x in get_shape(data)] data_shape = ",".join(data_shape_list) weight_shape = ",".join(weight_shape_list) out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense( - data_shape, - weight_shape, - out_shape - ) + + res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") - + reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - + weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) - return relay.op.nn.contrib_dense_pack(data, weight_transform, weight_layout=reco_weight_layout, - units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if 'out_dtype' in self.attr_map['nn.dense'] else "") - + return relay.op.nn.contrib_dense_pack( + data, + weight_transform, + weight_layout=reco_weight_layout, + units=None, + out_dtype=self.attr_map["nn.dense"]["out_dtype"] + if "out_dtype" in self.attr_map["nn.dense"] + else "", + ) + def rewrite_dense_to_pack(mod): - """Rewrite the input graph to use packed dense operators so that - we can gain better performance boost in dnnl byoc part. + """Rewrite the input graph to use packed dense operators so that + we can gain better performance boost in dnnl byoc part. """ mod["main"] = rewrite(PackDenseRewrite(), mod["main"]) return mod From 05d53f0cd720a83d3530de14462e99a9b90a8b1c Mon Sep 17 00:00:00 2001 From: billishyahao Date: Thu, 2 Jun 2022 14:58:28 +0800 Subject: [PATCH 03/10] add unittest for dense pack --- python/tvm/relay/op/contrib/dnnl.py | 46 +++++++++---------- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 4 +- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index ae994243f1f5..d77c4f96f752 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -418,15 +418,16 @@ def tag2layout(input_data, is_weight=False, op_type="Conv1D"): if "Dense" in op_type: # Post process for dense weight layout - # e.g. NC16c64n => NC64n16c + # e.g. NC16c64n => NC64n regexN = "\d+n" regexC = "\d+c" matchN = re.findall(regexN, res) - matchC = re.findall(regexC, res) - res = "NC" + "".join(matchN) + "".join(matchC) + layout_fmt = "NC" + "".join(matchN) + full_layout_fmt = res + return layout_fmt, full_layout_fmt - return res + return res, res def legalize_group_conv(attrs, inputs, types): @@ -474,9 +475,9 @@ def alter_conv(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, op_type=conv_type) - new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, op_type=conv_type) - new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, op_type=conv_type) + new_attrs["data_layout"], _ = tag2layout(src_df, is_weight=False, op_type=conv_type) + new_attrs["kernel_layout"], _ = tag2layout(weight_df, is_weight=True, op_type=conv_type) + new_attrs["out_layout"], _ = tag2layout(dst_df, is_weight=False, op_type=conv_type) if conv_type == "Conv1D": return relay.nn.conv1d(data, weight, **new_attrs) @@ -513,9 +514,9 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, op_type=conv_type) - new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, op_type=conv_type) - new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, op_type=conv_type) + new_attrs["data_layout"], _ = tag2layout(src_df, is_weight=False, op_type=conv_type) + new_attrs["kernel_layout"], _ = tag2layout(weight_df, is_weight=True, op_type=conv_type) + new_attrs["out_layout"], _ = tag2layout(dst_df, is_weight=False, op_type=conv_type) if conv_type == "Conv1DTranspose": return relay.nn.conv1d_transpose(data, weight, **new_attrs) @@ -540,14 +541,13 @@ def alter_dense(attrs, inputs, tinfos, out_type): _, weight_df, _ = res.split(",") - new_attrs = {} - new_attrs["weight_layout"] = tag2layout(weight_df, is_weight=True, op_type="Dense") + wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - weight_transform = relay.layout_transform(weight, "NC", dst_layout=new_attrs["weight_layout"]) + weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) return relay.nn.contrib_dense_pack( data, weight_transform, - weight_layout=new_attrs["weight_layout"], + weight_layout=full_wei_layout, units=None, out_dtype=out_type.dtype, ) @@ -784,14 +784,14 @@ def callback(self, pre, post, node_map): res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") - reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) den = relay.op.nn.contrib_dense_pack( data, weight_transform, - weight_layout=reco_weight_layout, + weight_layout=full_wei_layout, units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if "out_dtype" in self.attr_map["nn.dense"] @@ -871,13 +871,13 @@ def callback(self, pre, post, node_map): res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) _, weight_df, _ = res.split(",") - reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) den = relay.op.nn.contrib_dense_pack( data, weight_transform, - weight_layout=reco_weight_layout, + weight_layout=full_wei_layout, units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if "out_dtype" in self.attr_map["nn.dense"] @@ -937,13 +937,13 @@ def callback(self, pre, post, node_map): _, weight_df, _ = res.split(",") - reco_weight_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") + wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - weight_transform = relay.layout_transform(weight, "NC", dst_layout=reco_weight_layout) + weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) return relay.op.nn.contrib_dense_pack( data, weight_transform, - weight_layout=reco_weight_layout, + weight_layout=full_wei_layout, units=None, out_dtype=self.attr_map["nn.dense"]["out_dtype"] if "out_dtype" in self.attr_map["nn.dense"] diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index ef0fad0015b5..e941fe750fe3 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -439,6 +439,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims out_dims = out_shape; // Memory descriptions. + auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto dtype = dtype_dl2dnnl(dl_dtype); auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, layout2tag(weight_layout)}); auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); @@ -465,7 +467,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { BindDNNLMemory(bias_entry, bias_memory); } else { float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); + write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8)); } // Output memory. From 60304035dc476fcd32ea4f2ee6e519760e862855 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 8 Jun 2022 00:56:53 +0800 Subject: [PATCH 04/10] Make code compatible after introducing TensorRequisite(PR-11345) --- python/tvm/relay/op/contrib/dnnl.py | 14 +++- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 65 +------------------ tests/python/contrib/test_dnnl.py | 32 +++++++++ 3 files changed, 44 insertions(+), 67 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index d77c4f96f752..357b20b4eebb 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -150,11 +150,19 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): else: dense_out = dense if with_eltwise == "gelu": - div = is_op("divide")(dense_out, is_constant()) + const1 = wildcard() + const2 = wildcard() + const3 = wildcard() + # div = is_op("divide")(dense_out, is_constant()) + # erf_val = is_op("erf")(div) + # added_erf_val = is_op("add")(erf_val, is_constant()) + # mul_val = is_op("multiply")(dense_out, added_erf_val) + # dense_out = is_op("multiply")(mul_val, is_constant()) + div = is_op("divide")(dense_out, const1) erf_val = is_op("erf")(div) - added_erf_val = is_op("add")(erf_val, is_constant()) + added_erf_val = is_op("add")(erf_val, const2) mul_val = is_op("multiply")(dense_out, added_erf_val) - dense_out = is_op("multiply")(mul_val, is_constant()) + dense_out = is_op("multiply")(mul_val, const3) elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index e941fe750fe3..9040bd1d0d3e 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -196,7 +196,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Convolution(nid); } else if (std::regex_match(op_name, dense_pack_pat) || "nn.contrib_dense_pack" == op_name) { - DensePack(nid); + // DensePack(nid); } else if (std::regex_match(op_name, dense_pat)) { Dense(nid); } else if ("nn.batch_norm" == op_name) { @@ -415,69 +415,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_DST, dst_tr}}); } - void DensePack(const size_t& nid) { - auto node = nodes_[nid]; - auto op_name = node.GetOpName(); - dnnl::primitive_attr attr; - bool has_bias = ParsingOpName(op_name, attr); - - // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim OC = out_shape[1]; - - std::string weight_layout = node.GetAttr>("weight_layout")[0]; - - // Memory shapes. - dnnl::memory::dims data_dims = input_shape; - dnnl::memory::dims weight_dims = {out_shape[1], input_shape[1]}; - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = out_shape; - - // Memory descriptions. - auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; - auto dtype = dtype_dl2dnnl(dl_dtype); - auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, layout2tag(weight_layout)}); - auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); - - // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); - - // Enable elementwise post-ops. - auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); - - auto dense = dnnl::inner_product_forward(dense_prim_desc); - net_.push_back(dense); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - auto weight_memory = BindDNNLMemory(weight_entry, weight_md); - - // Bias memory. - auto bias_memory = dnnl::memory(bias_md, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8)); - } - - // Output memory. - auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_WEIGHTS, weight_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST, dst_memory}}); - } void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 3e4e831aa594..6dca99da0a66 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -19,6 +19,7 @@ import numpy as np import sys import subprocess +import math import tvm from tvm import relay @@ -121,6 +122,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): transform.PartitionGraph(), ] ) + with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) mod = dnnl.prune_dnnl_subgraphs(mod) @@ -170,6 +172,7 @@ def check_dnnl_used(mod, subgraph_num=None): ] if test_bf16 and bf16_supported(): configs += [(True, False, True), (True, True, True)] + for use_dnnl, alter_layout, use_bf16 in configs: result_key = ( mode @@ -185,6 +188,8 @@ def check_dnnl_used(mod, subgraph_num=None): continue if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) + print("hebi-dbg: processed_mod: ", result_key) + print(processed_mod) check_dnnl_used(processed_mod) with tvm.transform.PassContext(opt_level=3): @@ -194,6 +199,10 @@ def check_dnnl_used(mod, subgraph_num=None): if run_module: if isinstance(input, dict): result_dict[result_key] = func(**input, **params) + print("input:", input) + print("params:", params) + print("result_dict[result_key]:") + print(result_dict[result_key]) else: result_dict[result_key] = func(input, **params) @@ -585,12 +594,27 @@ def get_conv3d_transpose_bias( return out, dic, param_lst +def gelu_helper(data): + const1 = relay.const(math.sqrt(2.0)) + const2 = relay.const(1.0) + const3 = relay.const(0.5) + divisor = relay.op.divide(data, const1) + val_erf = relay.op.erf(divisor) + added_erf = relay.op.add(val_erf, const2) + mul1 = relay.op.multiply(data, added_erf) + out = relay.op.multiply(mul1, const3) + return out + + def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): x = relay.var("x", shape=(x_shape), dtype=dtype) kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) out = relay.nn.dense(x, kernel, units=k_shape[0]) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] + + if activation == "gelu": + out = gelu_helper(out) return out, dic, param_lst @@ -600,6 +624,9 @@ def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="fl out = relay.nn.bias_add(dense, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] + + if activation == "gelu": + out = gelu_helper(out) return out, dic, param_lst @@ -906,6 +933,11 @@ def test_dense_pattern(run_module, dtype="float32"): config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) + dense_bias = tvm.IRModule.from_expr(dense_bias) + config = dense_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_pool2d(run_module, dtype="float32"): def get_graph( From 54ae4d01fcfa3255085685809d30b07e2033bb24 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 8 Jun 2022 10:40:35 +0800 Subject: [PATCH 05/10] Fix comments & refactor code --- python/tvm/relay/op/contrib/dnnl.py | 269 ++---------------- src/relay/backend/contrib/dnnl/codegen.cc | 57 ++-- .../backend/contrib/dnnl/query_layout.cc | 42 --- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 6 - tests/python/contrib/test_dnnl.py | 11 +- 5 files changed, 69 insertions(+), 316 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 357b20b4eebb..ce7ae1fea2f7 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,6 +33,7 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging +import math import tvm.ir from tvm import relay @@ -43,6 +44,7 @@ from tvm.relay.analysis import analysis as _analysis from tvm.relay import expr as _expr + from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table @@ -144,6 +146,7 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): data = wildcard() weight = wildcard() bias = wildcard() + dense = is_op("nn.dense")(data, weight) if with_bias: dense_out = is_op("add")(dense, bias) @@ -153,17 +156,11 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): const1 = wildcard() const2 = wildcard() const3 = wildcard() - # div = is_op("divide")(dense_out, is_constant()) - # erf_val = is_op("erf")(div) - # added_erf_val = is_op("add")(erf_val, is_constant()) - # mul_val = is_op("multiply")(dense_out, added_erf_val) - # dense_out = is_op("multiply")(mul_val, is_constant()) div = is_op("divide")(dense_out, const1) erf_val = is_op("erf")(div) added_erf_val = is_op("add")(erf_val, const2) mul_val = is_op("multiply")(dense_out, added_erf_val) dense_out = is_op("multiply")(mul_val, const3) - elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) return dense_out @@ -188,60 +185,22 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): pat_name = op_name.replace("nn", "dnnl") if "_transpose" in op_name: pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] - if "contrib_dense_pack" in op_name: - pat_name = "dnnl.packeddense" pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" if "conv" in op_name: dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) elif op_name == "nn.dense": dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) - elif op_name == "nn.contrib_dense_pack": - dnnl_pattern = (pat_name, make_packed_dense_pattern(with_bias, with_eltwise)) else: logger.warning( "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, " - "dense and packed dense op are supported, but got %s.", + "dense op are supported, but got %s.", op_name, ) dnnl_pattern = () return dnnl_pattern -def make_packed_dense_pattern(with_bias=True, with_eltwise=None): - """Create patterns related to nn.contrib_dense_pack. - - Parameters - ---------- - with_bias : bool - Whether attach `bias_add` to `nn.dense`. - with_eltwise : str - The attached elementwise post-op name. - Returns - ------- - dense_out : CallPattern - Call node sequence. - """ - data = wildcard() - weight = wildcard() - bias = wildcard() - dense = is_op("nn.contrib_dense_pack")(data, weight) - if with_bias: - dense_out = is_op("add")(dense, bias) - else: - dense_out = dense - if with_eltwise == "gelu": - div = is_op("divide")(dense_out, is_constant()) - erf_val = is_op("erf")(div) - added_erf_val = is_op("add")(erf_val, is_constant()) - mul_val = is_op("multiply")(dense_out, added_erf_val) - dense_out = is_op("multiply")(mul_val, is_constant()) - - elif with_eltwise: - dense_out = is_op(with_eltwise)(dense_out) - return dense_out - - @register_pattern_table("dnnl") def pattern_table(): """Create dnnl patterns. @@ -267,7 +226,6 @@ def pattern_table(): if elt != "gelu": dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt)) - dnnl_patterns.append(make_dnnl_pattern("nn.contrib_dense_pack", with_bias, elt)) return dnnl_patterns @@ -340,27 +298,6 @@ def get_optimal_layout_for_conv_transpose( ) -def get_optimal_layout_for_dense(data_layout, weight_shape, out_shape): - """Get the optimal layout of dnnl, given shape of dense. - - Parameters - ---------- - data_layout, weight_shape, out_shape - : String - Input argument. - - Returns - ------- - layouts : string - The result. - """ - return _ffi_api.get_optimal_layout_for_dense( - data_layout, - weight_shape, - out_shape, - ) - - def get_shape(tensor): """Get tensor's shape.""" if isinstance(tensor, relay.expr.Var): @@ -391,23 +328,20 @@ def get_dtype(tensor): raise TypeError("Unsupport data type: %s" % type(tensor)) -def tag2layout(input_data, is_weight=False, op_type="Conv1D"): +def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): """Transfer layout, denoted with `a, b, c, d, e`, into valid layout (NCHW / OIHW) of TVM.""" - if "Conv1D" in op_type: + if "Conv1D" in conv_type: data_dic = {"a": "N", "b": "C", "c": "W"} weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"} - elif "Conv2D" in op_type: + elif "Conv2D" in conv_type: data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"} if "e" in input_data: weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"} - elif "Conv3D" in op_type: + elif "Conv3D" in conv_type: data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"} weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"} - elif "Dense" in op_type: - data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} - weight_dic = data_dic dic = weight_dic if is_weight else data_dic res = "" @@ -424,18 +358,7 @@ def tag2layout(input_data, is_weight=False, op_type="Conv1D"): else: raise ValueError("Unsupport layout format: %s" % input_data) - if "Dense" in op_type: - # Post process for dense weight layout - # e.g. NC16c64n => NC64n - regexN = "\d+n" - regexC = "\d+c" - - matchN = re.findall(regexN, res) - layout_fmt = "NC" + "".join(matchN) - full_layout_fmt = res - return layout_fmt, full_layout_fmt - - return res, res + return res def legalize_group_conv(attrs, inputs, types): @@ -483,9 +406,9 @@ def alter_conv(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"], _ = tag2layout(src_df, is_weight=False, op_type=conv_type) - new_attrs["kernel_layout"], _ = tag2layout(weight_df, is_weight=True, op_type=conv_type) - new_attrs["out_layout"], _ = tag2layout(dst_df, is_weight=False, op_type=conv_type) + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) if conv_type == "Conv1D": return relay.nn.conv1d(data, weight, **new_attrs) @@ -522,9 +445,9 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): dtype, ) src_df, weight_df, dst_df = res.split(",") - new_attrs["data_layout"], _ = tag2layout(src_df, is_weight=False, op_type=conv_type) - new_attrs["kernel_layout"], _ = tag2layout(weight_df, is_weight=True, op_type=conv_type) - new_attrs["out_layout"], _ = tag2layout(dst_df, is_weight=False, op_type=conv_type) + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) if conv_type == "Conv1DTranspose": return relay.nn.conv1d_transpose(data, weight, **new_attrs) @@ -533,34 +456,6 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): return relay.nn.conv3d_transpose(data, weight, **new_attrs) -def alter_dense(attrs, inputs, tinfos, out_type): - """The packed dense's layout auto-query func for dnnl.""" - - data, weight = inputs - - weight_shape_list = [str(x) for x in get_shape(weight)] - out_shape_list = [str(x) for x in get_shape(out_type)] - - data_shape = ",".join([out_shape_list[0], weight_shape_list[1]]) - weight_shape = ",".join(weight_shape_list) - out_shape = ",".join(out_shape_list) - - res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) - - _, weight_df, _ = res.split(",") - - wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - - weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) - return relay.nn.contrib_dense_pack( - data, - weight_transform, - weight_layout=full_wei_layout, - units=None, - out_dtype=out_type.dtype, - ) - - class IsComputeIntensiveGraph(ExprVisitor): """ Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and @@ -733,7 +628,7 @@ class DenseReshapeBiasGeluRewrite(DFPatternCallback): 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; """ - def __init__(self, pack_wei=False): + def __init__(self): super(DenseReshapeBiasGeluRewrite, self).__init__() self.data = wildcard() self.weight = wildcard() @@ -742,8 +637,6 @@ def __init__(self, pack_wei=False): self.const2 = wildcard() self.const3 = wildcard() - self.pack_wei = pack_wei - self.attr_map = {} den = is_op("nn.dense")(self.data, self.weight) @@ -780,33 +673,8 @@ def callback(self, pre, post, node_map): const1 = node_map[self.const1][0] const2 = node_map[self.const2][0] const3 = node_map[self.const3][0] - - if self.pack_wei: - weight_shape_list = [str(x) for x in get_shape(weight)] - data_shape_list = [str(x) for x in get_shape(data)] - - data_shape = ",".join(data_shape_list) - weight_shape = ",".join(weight_shape_list) - out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) - - _, weight_df, _ = res.split(",") - wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - - weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) - - den = relay.op.nn.contrib_dense_pack( - data, - weight_transform, - weight_layout=full_wei_layout, - units=None, - out_dtype=self.attr_map["nn.dense"]["out_dtype"] - if "out_dtype" in self.attr_map["nn.dense"] - else "", - ) - else: - den = relay.op.nn.dense(data, weight) + + den = relay.op.nn.dense(data, weight) added = relay.op.add(bias, den) divisor = relay.op.divide(added, const1) val_erf = relay.op.erf(divisor) @@ -816,11 +684,11 @@ def callback(self, pre, post, node_map): return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"]) -def rewrite_dense_bias_gelu_reshape_last(mod, pack_wei=False): +def rewrite_dense_bias_gelu_reshape_last(mod): """Rewrite the input graph to reorder reshape operators so that we can perform dense_bias_gelu fusion and then offload them to byoc part. """ - mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(pack_wei), mod["main"]) + mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(), mod["main"]) return mod @@ -832,13 +700,12 @@ class DenseReshapeBiasRewrite(DFPatternCallback): 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */; """ - def __init__(self, pack_wei=False): + def __init__(self): super(DenseReshapeBiasRewrite, self).__init__() self.data = wildcard() self.weight = wildcard() self.bias = wildcard() - self.pack_wei = pack_wei self.attr_map = {} den = is_op("nn.dense")(self.data, self.weight) @@ -867,101 +734,15 @@ def callback(self, pre, post, node_map): data = node_map[self.data][0] weight = node_map[self.weight][0] bias = node_map[self.bias][0] - - if self.pack_wei: - weight_shape_list = [str(x) for x in get_shape(weight)] - data_shape_list = [str(x) for x in get_shape(data)] - - data_shape = ",".join(data_shape_list) - weight_shape = ",".join(weight_shape_list) - out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) - - _, weight_df, _ = res.split(",") - wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) - - den = relay.op.nn.contrib_dense_pack( - data, - weight_transform, - weight_layout=full_wei_layout, - units=None, - out_dtype=self.attr_map["nn.dense"]["out_dtype"] - if "out_dtype" in self.attr_map["nn.dense"] - else "", - ) - else: - den = relay.op.nn.dense(data, weight) + + den = relay.op.nn.dense(data, weight) added = relay.op.add(bias, den) return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) -def rewrite_dense_bias_reshape_last(mod, pack_wei=False): +def rewrite_dense_bias_reshape_last(mod): """Rewrite the input graph to reorder reshape operators so that we can perform dense_bias fusion and then offload them to byoc part. """ - mod["main"] = rewrite(DenseReshapeBiasRewrite(pack_wei), mod["main"]) - return mod - - -class PackDenseRewrite(DFPatternCallback): - """A callback to rewrite nn.dense to nn.contrib_dense_pack.""" - - def __init__(self): - super(PackDenseRewrite, self).__init__() - self.data = wildcard() - self.weight = wildcard() - - self.attr_map = {} - - den = is_op("nn.dense")(self.data, self.weight) - self.pattern = den - - def get_attr(self, pre): - def visit_func(expr): - if isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["nn.dense"] = new_attrs - - _analysis.post_order_visit(pre, visit_func) - - def callback(self, pre, post, node_map): - self.get_attr(pre) - - data = node_map[self.data][0] - weight = node_map[self.weight][0] - - weight_shape_list = [str(x) for x in get_shape(weight)] - data_shape_list = [str(x) for x in get_shape(data)] - - data_shape = ",".join(data_shape_list) - weight_shape = ",".join(weight_shape_list) - out_shape = ",".join([data_shape_list[0], weight_shape_list[0]]) - - res = get_optimal_layout_for_dense(data_shape, weight_shape, out_shape) - - _, weight_df, _ = res.split(",") - - wei_layout, full_wei_layout = tag2layout(weight_df, is_weight=True, op_type="Dense") - - weight_transform = relay.layout_transform(weight, "NC", dst_layout=wei_layout) - return relay.op.nn.contrib_dense_pack( - data, - weight_transform, - weight_layout=full_wei_layout, - units=None, - out_dtype=self.attr_map["nn.dense"]["out_dtype"] - if "out_dtype" in self.attr_map["nn.dense"] - else "", - ) - - -def rewrite_dense_to_pack(mod): - """Rewrite the input graph to use packed dense operators so that - we can gain better performance boost in dnnl byoc part. - """ - mod["main"] = rewrite(PackDenseRewrite(), mod["main"]) + mod["main"] = rewrite(DenseReshapeBiasRewrite(), mod["main"]) return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 61e19dd342ba..ff1776715cc2 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -436,6 +436,38 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { #else // DNNL JSON runtime +/*! + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in + * relu(add(conv2d)) + * \param call A Relay call node. Typically nn.relu when called the first time. + * \param max_depth The maximum number of calls before the root op, counting from current_call. + * \param root_name The name of expected "root" op in this fused call. + * \return A CallNode corresponding to the root op + */ +inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth, + const std::string& root_name) { + ICHECK(current_call && max_depth >= 0); + + if (max_depth == 0) { + ICHECK(current_call && IsOp(current_call, root_name)); + return current_call; + } + if (IsOp(current_call, root_name)) { + return current_call; + } + + ICHECK_GT(current_call->args.size(), 0); + + size_t valid_node_idx = 0; + while (valid_node_idx < current_call->args.size() && + current_call->args[valid_node_idx].as()) { + valid_node_idx++; + } + + const auto* next_call = current_call->args[valid_node_idx].as(); + return FindCallWithName(next_call, max_depth - 1, root_name); +} + class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; @@ -447,9 +479,6 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"sigmoid", "sigmoid"}, {"nn.deconv2d", "nn.conv2d_transpose"}, {"nn.deconv3d", "nn.conv3d_transpose"}, - {"add", "add"}, - {"multiply", "multiply"}, - {"nn.packeddense", "nn.contrib_dense_pack"}, }; std::vector ParsingOpList(const std::string& pattern_name, @@ -458,19 +487,12 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { std::vector op_list; size_t pos = 0, start = 0; - std::string raw_name = pattern_name; - if (raw_name.find("gelu") != std::string::npos) { - //TODO(billishyahao): Remove me after introducing new gelu operator - raw_name.replace(raw_name.find("gelu"), 4, "multiply_multiply_"); - } - while ((pos = raw_name.find(interval, start)) != std::string::npos) { - std::string op_name = raw_name.substr(start, pos - start); + while ((pos = pattern_name.find(interval, start)) != std::string::npos) { + std::string op_name = pattern_name.substr(start, pos - start); if (op_name.find("dnnl") != std::string::npos) { op_name.replace(op_name.find("dnnl"), 4, "nn"); if (op_name.find("deconv") != std::string::npos) { op_name = op_map[op_name]; - } else if (op_name.find("packeddense") != std::string::npos) { - op_name = op_map[op_name]; } } else { op_name = op_map[op_name]; @@ -478,8 +500,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { if (pos > start) op_list.push_back(op_name); start = pos + interval.size(); } - if (raw_name.size() > start) { - op_list.push_back(op_map[raw_name.substr(start)]); + if (pattern_name.size() > start) { + op_list.push_back(op_map[pattern_name.substr(start)]); } return op_list; } @@ -518,11 +540,10 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name.find("dnnl.dense") != std::string::npos) { - std::vector op_list = ParsingOpList(name); - call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); + } else if (name.find("gelu") != std::string::npos) { + call = FindCallWithName(fn->body.as(), 10, "nn.dense"); ICHECK(call->op.as()) << "Not op node"; - } else if (name.find("dnnl.packeddense") != std::string::npos) { + } else if (name.find("dnnl.dense") != std::string::npos) { std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc index 4bf5e4d27fa7..3762c1906f40 100755 --- a/src/relay/backend/contrib/dnnl/query_layout.cc +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -362,43 +362,6 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, return res; } -std::string get_optimal_layout_for_dense(std::string data_layout, std::string weight_shape, - std::string out_shape) { - dnnl::engine eng(dnnl::engine::kind::cpu, 0); - dnnl::stream s(eng); - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - - dnnl::memory::dims data_dims = str2dims(data_layout); - dnnl::memory::dims weight_dims = str2dims(weight_shape); - dnnl::memory::dims out_dims = str2dims(out_shape); - dnnl::memory::dims bias_dims = {out_dims[1]}; - - // Memory descriptions. - auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::any}); - auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::any}); - auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::any}); - auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::any}); - - // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); - - dnnl::primitive_attr attr; - auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, eng); - - auto src_format = dense_prim_desc.src_desc(); - auto weights_format = dense_prim_desc.weights_desc(); - auto dst_format = dense_prim_desc.dst_desc(); - std::string src_df, weight_df, dst_df; - - src_df = md2fmt_tag_str(&src_format); - weight_df = md2fmt_tag_str(&weights_format); - dst_df = md2fmt_tag_str(&dst_format); - std::string res = src_df + "," + weight_df + "," + dst_df; - return res; -} - TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5], @@ -411,11 +374,6 @@ TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose") args[5], args[6], args[7], args[8], args[9]); }); -TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_dense") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = get_optimal_layout_for_dense(args[0], args[1], args[2]); - }); - } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 9040bd1d0d3e..dbf63d46e116 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -83,7 +83,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Find proper dnnl::memory buffers std::unordered_map mem_args; for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); - prim.execute(stream_, mem_args); } } @@ -179,7 +178,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex deconv_pat(".*deconv[1-3]d.*"); std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); std::regex dense_pat(".*dense.*"); - std::regex dense_pack_pat(".*packeddense.*"); std::regex max_pool_pat(".*max_pool[1-3]d"); std::regex avg_pool_pat(".*avg_pool[1-3]d"); @@ -194,9 +192,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Deconvolution(nid); } else if (std::regex_match(op_name, conv_pat)) { Convolution(nid); - } else if (std::regex_match(op_name, dense_pack_pat) || - "nn.contrib_dense_pack" == op_name) { - // DensePack(nid); } else if (std::regex_match(op_name, dense_pat)) { Dense(nid); } else if ("nn.batch_norm" == op_name) { @@ -415,7 +410,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_DST, dst_tr}}); } - void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 6dca99da0a66..33be57f014fb 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -188,8 +188,6 @@ def check_dnnl_used(mod, subgraph_num=None): continue if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) - print("hebi-dbg: processed_mod: ", result_key) - print(processed_mod) check_dnnl_used(processed_mod) with tvm.transform.PassContext(opt_level=3): @@ -199,10 +197,6 @@ def check_dnnl_used(mod, subgraph_num=None): if run_module: if isinstance(input, dict): result_dict[result_key] = func(**input, **params) - print("input:", input) - print("params:", params) - print("result_dict[result_key]:") - print(result_dict[result_key]) else: result_dict[result_key] = func(input, **params) @@ -918,6 +912,11 @@ def test_dense(run_module, dtype="float32"): config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_dense_pattern(run_module, dtype="float32"): x_shape = (1, 16) From e71778e64342d3ee883d6216b97baa237ecb78c4 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 8 Jun 2022 15:17:53 +0800 Subject: [PATCH 06/10] Fix lint --- python/tvm/relay/op/contrib/dnnl.py | 104 ++++++------------ src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 2 +- tests/python/contrib/test_dnnl.py | 4 +- 3 files changed, 38 insertions(+), 72 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index ce7ae1fea2f7..53f3ec15febe 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,7 +33,6 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging -import math import tvm.ir from tvm import relay @@ -49,7 +48,6 @@ from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table -import re logger = logging.getLogger("DNNL") @@ -617,10 +615,21 @@ def rewrite_layer_norm(mod): class DenseReshapeBiasGeluRewrite(DFPatternCallback): """ - A callback to reorder reshape operators when the patten is as below: - 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; + A callback to reorder reshape operators when the patterns are as below: + + Pattern #1: + 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; + 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) + /* ty=Tensor[(1, 3136, 64), float32] */; + + Pattern #2: + 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, + units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */; 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */; - 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */; + 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) + /* ty=Tensor[(1, 3136, 512), float32] */; 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */; 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; @@ -628,7 +637,7 @@ class DenseReshapeBiasGeluRewrite(DFPatternCallback): 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */; """ - def __init__(self): + def __init__(self, has_gelu=True): super(DenseReshapeBiasGeluRewrite, self).__init__() self.data = wildcard() self.weight = wildcard() @@ -638,29 +647,30 @@ def __init__(self): self.const3 = wildcard() self.attr_map = {} + self.has_gelu = has_gelu den = is_op("nn.dense")(self.data, self.weight) re_den = is_op("reshape")(den) added = is_op("add")(self.bias, re_den) - divisor = is_op("divide")(added, self.const1) - val_erf = is_op("erf")(divisor) - added_erf = is_op("add")(val_erf, self.const2) - mul1 = is_op("multiply")(added, added_erf) - mul2 = is_op("multiply")(mul1, self.const3) - self.pattern = mul2 + if self.has_gelu: + divisor = is_op("divide")(added, self.const1) + val_erf = is_op("erf")(divisor) + added_erf = is_op("add")(val_erf, self.const2) + mul1 = is_op("multiply")(added, added_erf) + mul2 = is_op("multiply")(mul1, self.const3) + self.pattern = mul2 + else: + self.pattern = added def get_attr(self, pre): + """Recursively retrieve attributes from reshape operator.""" + def visit_func(expr): if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): new_attrs = {} for k in expr.attrs.keys(): new_attrs[k] = expr.attrs[k] self.attr_map["reshape"] = new_attrs - elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["nn.dense"] = new_attrs _analysis.post_order_visit(pre, visit_func) @@ -670,12 +680,16 @@ def callback(self, pre, post, node_map): data = node_map[self.data][0] weight = node_map[self.weight][0] bias = node_map[self.bias][0] + + den = relay.op.nn.dense(data, weight) + added = relay.op.add(bias, den) + if not self.has_gelu: + return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) + const1 = node_map[self.const1][0] const2 = node_map[self.const2][0] const3 = node_map[self.const3][0] - - den = relay.op.nn.dense(data, weight) - added = relay.op.add(bias, den) + divisor = relay.op.divide(added, const1) val_erf = relay.op.erf(divisor) added_erf = relay.op.add(val_erf, const2) @@ -692,57 +706,9 @@ def rewrite_dense_bias_gelu_reshape_last(mod): return mod -class DenseReshapeBiasRewrite(DFPatternCallback): - """ - A callback to reorder reshape operators when the patten is as below: - 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */; - 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */; - 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */; - """ - - def __init__(self): - super(DenseReshapeBiasRewrite, self).__init__() - self.data = wildcard() - self.weight = wildcard() - self.bias = wildcard() - - self.attr_map = {} - - den = is_op("nn.dense")(self.data, self.weight) - re_den = is_op("reshape")(den) - added = is_op("add")(self.bias, re_den) - self.pattern = added - - def get_attr(self, pre): - def visit_func(expr): - if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["reshape"] = new_attrs - elif isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.dense"): - new_attrs = {} - for k in expr.attrs.keys(): - new_attrs[k] = expr.attrs[k] - self.attr_map["nn.dense"] = new_attrs - - _analysis.post_order_visit(pre, visit_func) - - def callback(self, pre, post, node_map): - self.get_attr(pre) - - data = node_map[self.data][0] - weight = node_map[self.weight][0] - bias = node_map[self.bias][0] - - den = relay.op.nn.dense(data, weight) - added = relay.op.add(bias, den) - return relay.op.reshape(added, self.attr_map["reshape"]["newshape"]) - - def rewrite_dense_bias_reshape_last(mod): """Rewrite the input graph to reorder reshape operators so that we can perform dense_bias fusion and then offload them to byoc part. """ - mod["main"] = rewrite(DenseReshapeBiasRewrite(), mod["main"]) + mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"]) return mod diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index dbf63d46e116..5045f3323af7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -158,7 +158,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } - if (ops.len() != 0){ + if (ops.len() != 0) { attr.set_post_ops(ops); } diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 33be57f014fb..0daacffc7933 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -912,7 +912,7 @@ def test_dense(run_module, dtype="float32"): config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) + dense, dic, param_lst = get_dense(x_shape, k_shape, activation="gelu", dtype=dtype) dense = tvm.IRModule.from_expr(dense) config = dense, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) @@ -932,7 +932,7 @@ def test_dense_pattern(run_module, dtype="float32"): config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, activation="gelu", dtype=dtype) dense_bias = tvm.IRModule.from_expr(dense_bias) config = dense_bias, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) From 2bd1dd3d93b852108c37de3fa305d253f53fcb67 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Wed, 8 Jun 2022 19:34:33 +0800 Subject: [PATCH 07/10] Fix partition graph unittest case --- tests/python/relay/test_pass_partition_graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 761a430997b0..dedeae56e9da 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -928,9 +928,9 @@ def test_dnnl_fuse(): ) = ( dnnl_patterns[1], dnnl_patterns[13], - dnnl_patterns[19], - dnnl_patterns[25], - dnnl_patterns[37], + dnnl_patterns[20], + dnnl_patterns[26], + dnnl_patterns[38], ) def get_blocks( From ef2e7a1a1c684774d877406e417cd44cebba7de1 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Thu, 9 Jun 2022 20:23:35 +0800 Subject: [PATCH 08/10] Fix comments --- python/tvm/relay/op/contrib/dnnl.py | 15 ++--- src/relay/backend/contrib/dnnl/codegen.cc | 6 +- tests/python/contrib/test_dnnl.py | 72 ++++++++++++++++++----- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 53f3ec15febe..6581f10a2f56 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -700,15 +700,10 @@ def callback(self, pre, post, node_map): def rewrite_dense_bias_gelu_reshape_last(mod): """Rewrite the input graph to reorder reshape operators so that - we can perform dense_bias_gelu fusion and then offload them to byoc part. + we can perform dense_bias_gelu/dense_bias fusion and then offload + them to byoc part. """ - mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(), mod["main"]) - return mod - - -def rewrite_dense_bias_reshape_last(mod): - """Rewrite the input graph to reorder reshape operators so that - we can perform dense_bias fusion and then offload them to byoc part. - """ - mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(has_gelu=False), mod["main"]) + mod["main"] = rewrite( + [DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"] + ) return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index ff1776715cc2..f9436e490ff7 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -540,12 +540,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name.find("gelu") != std::string::npos) { - call = FindCallWithName(fn->body.as(), 10, "nn.dense"); - ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.dense") != std::string::npos) { - std::vector op_list = ParsingOpList(name); - call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); + call = FindCallWithName(fn->body.as(), 10, "nn.dense"); ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 0daacffc7933..c884665421cb 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -57,7 +57,7 @@ def bf16_supported(): return _bf16_supported -def partition_for_dnnl(mod, params=None, alter_layout=True): +def partition_for_dnnl(mod, params=None, alter_layout=True, prune_subgraphs=True): """Partition the graph greedily offloading supported operators to DNNL. Parameters @@ -113,6 +113,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): mod = alter_layout_seq(mod) mod = dnnl.rewrite_layer_norm(mod) + mod = dnnl.rewrite_dense_bias_gelu_reshape_last(mod) byoc_seq = tvm.transform.Sequential( [ @@ -125,7 +126,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) - mod = dnnl.prune_dnnl_subgraphs(mod) + if prune_subgraphs: + mod = dnnl.prune_dnnl_subgraphs(mod) return mod @@ -152,16 +154,15 @@ def assert_result_dict_holds(result_dict): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) -def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True): - def check_dnnl_used(mod, subgraph_num=None): - num_dnnl_subgraphs = sum( - [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()] - ) - if subgraph_num: - assert num_dnnl_subgraphs == subgraph_num - else: - assert num_dnnl_subgraphs >= 1 +def check_dnnl_used(mod, subgraph_num=None): + num_dnnl_subgraphs = sum([1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]) + if subgraph_num: + assert num_dnnl_subgraphs == subgraph_num + else: + assert num_dnnl_subgraphs >= 1 + +def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True): dev = tvm.cpu() result_dict = dict() for mode in ["graph", "vm"]: @@ -600,10 +601,15 @@ def gelu_helper(data): return out -def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): +def get_dense( + x_shape=(1, 16), k_shape=(32, 16), activation=None, has_reshape=False, dtype="float32" +): x = relay.var("x", shape=(x_shape), dtype=dtype) kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) out = relay.nn.dense(x, kernel, units=k_shape[0]) + # out = relay.nn.dense(x, kernel, units=None) + if has_reshape: + out = relay.reshape(out, newshape=(1, x_shape[0], k_shape[0])) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] @@ -612,10 +618,22 @@ def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32 return out, dic, param_lst -def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): - dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype) +def get_dense_bias( + x_shape=(1, 16), + k_shape=(32, 16), + activation=None, + has_reshape=False, + use_add=False, + dtype="float32", +): + dense, dic, param_lst = get_dense( + x_shape=x_shape, k_shape=k_shape, has_reshape=has_reshape, dtype=dtype + ) bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(dense, bias) + if use_add: + out = relay.add(dense, bias) + else: + out = relay.nn.bias_add(dense, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] @@ -1084,5 +1102,29 @@ def test_layer_norm(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_rewrite_dense_bias_gelu_reshape_last(run_module, dtype="float32"): + def get_graph(act=None): + x_shape = (1, 16) + k_shape = (32, 16) + + dense_bias, dic, param_lst = get_dense_bias( + x_shape, k_shape, activation=act, has_reshape=True, use_add=True, dtype=dtype + ) + dense_bias = tvm.IRModule.from_expr(dense_bias) + processed_dense_bias = partition_for_dnnl( + dense_bias, params=None, alter_layout=False, prune_subgraphs=False + ) + check_dnnl_used(processed_dense_bias, 1) + + return dense_bias, dic, param_lst + + run_and_verify_func( + get_graph("gelu"), subgraph_num=1, run_module=run_module, dtype=dtype, test_bf16=False + ) + run_and_verify_func( + get_graph(), subgraph_num=1, run_module=run_module, dtype=dtype, test_bf16=False + ) + + if __name__ == "__main__": tvm.testing.main() From fc8dbb24609437425f7cbd4210c4e73f6c305c24 Mon Sep 17 00:00:00 2001 From: billishyahao Date: Fri, 10 Jun 2022 07:17:28 +0800 Subject: [PATCH 09/10] Fix comments --- src/relay/backend/contrib/dnnl/codegen.cc | 34 +---------------------- src/relay/backend/utils.h | 32 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f9436e490ff7..927cd12ae0fb 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -436,38 +436,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { #else // DNNL JSON runtime -/*! - * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in - * relu(add(conv2d)) - * \param call A Relay call node. Typically nn.relu when called the first time. - * \param max_depth The maximum number of calls before the root op, counting from current_call. - * \param root_name The name of expected "root" op in this fused call. - * \return A CallNode corresponding to the root op - */ -inline const CallNode* FindCallWithName(const CallNode* current_call, int max_depth, - const std::string& root_name) { - ICHECK(current_call && max_depth >= 0); - - if (max_depth == 0) { - ICHECK(current_call && IsOp(current_call, root_name)); - return current_call; - } - if (IsOp(current_call, root_name)) { - return current_call; - } - - ICHECK_GT(current_call->args.size(), 0); - - size_t valid_node_idx = 0; - while (valid_node_idx < current_call->args.size() && - current_call->args[valid_node_idx].as()) { - valid_node_idx++; - } - - const auto* next_call = current_call->args[valid_node_idx].as(); - return FindCallWithName(next_call, max_depth - 1, root_name); -} - class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; @@ -541,7 +509,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.dense") != std::string::npos) { - call = FindCallWithName(fn->body.as(), 10, "nn.dense"); + call = GetRootCall(fn->body.as(), 10, "nn.dense"); ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 360f366a162e..573e6ec629e6 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -480,6 +480,38 @@ inline const CallNode* GetRootCall(const CallNode* current_call, const std::stri return GetRootCall(next_call, op_name); } +/*! + * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in + * relu(add(conv2d)) + * \param call A Relay call node. Typically nn.relu when called the first time. + * \param max_depth The maximum number of calls before the root op, counting from current_call. + * \param op_name The name of expected "root" op in this fused call. + * \return A CallNode corresponding to the root op + */ +inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth, + const std::string& op_name) { + ICHECK(current_call && max_depth >= 0); + + if (max_depth == 0) { + ICHECK(current_call && IsOp(current_call, op_name)); + return current_call; + } + if (IsOp(current_call, op_name)) { + return current_call; + } + + ICHECK_GT(current_call->args.size(), 0); + + size_t valid_node_idx = 0; + while (valid_node_idx < current_call->args.size() && + current_call->args[valid_node_idx].as()) { + valid_node_idx++; + } + + const auto* next_call = current_call->args[valid_node_idx].as(); + return GetRootCall(next_call, max_depth - 1, op_name); +} + /*! * \brief Get the external symbol of the Relay function name. * From 6b811c662e90b98c7f7ec373a7e8e452303f1c3c Mon Sep 17 00:00:00 2001 From: billishyahao Date: Fri, 10 Jun 2022 09:31:19 +0800 Subject: [PATCH 10/10] Fix lint --- src/relay/backend/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 573e6ec629e6..70080254c414 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -489,7 +489,7 @@ inline const CallNode* GetRootCall(const CallNode* current_call, const std::stri * \return A CallNode corresponding to the root op */ inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth, - const std::string& op_name) { + const std::string& op_name) { ICHECK(current_call && max_depth >= 0); if (max_depth == 0) {