From d24af7d3271dbb75810f6e0a6fc412cf9931edc7 Mon Sep 17 00:00:00 2001 From: Vaino Granat Date: Fri, 4 Oct 2024 14:06:19 +0300 Subject: [PATCH] Dnnl style codegen (#1) * Checkpoint, nothing works * DNNL based codegen almost works * Work in dnnl style * Work in dnnl style * Arg passing works * Work in dnnl style * Codegen somewhat works * Requantization not working * Codegen works * Remove headsail_old --- python/tvm/relay/op/contrib/headsail.py | 279 ++++++++++++- src/relay/backend/contrib/headsail/codegen.cc | 378 +++++++++++------ .../backend/contrib/headsail/codegen_c.h | 31 -- .../contrib/headsail/codegen_headsail.h | 386 ++++++++++++++++++ .../contrib/headsail/genereted_example.c | 13 - 5 files changed, 905 insertions(+), 182 deletions(-) delete mode 100644 src/relay/backend/contrib/headsail/codegen_c.h create mode 100644 src/relay/backend/contrib/headsail/codegen_headsail.h delete mode 100644 src/relay/backend/contrib/headsail/genereted_example.c diff --git a/python/tvm/relay/op/contrib/headsail.py b/python/tvm/relay/op/contrib/headsail.py index 6001f1d7aa42..de1a1d3081b9 100644 --- a/python/tvm/relay/op/contrib/headsail.py +++ b/python/tvm/relay/op/contrib/headsail.py @@ -32,10 +32,18 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ +import logging import tvm.ir -from ...dataflow_pattern import wildcard, is_op +from tvm import relay +from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard +from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const +from tvm.relay import transform from .register import register_pattern_table +from ..strategy.generic import is_depthwise_conv2d +logger = logging.getLogger("HEADSAIL") + +conv2d_counter = True def _register_external_op_helper(op_name, supported=True): """The helper function to indicate that a given operator can be supported @@ -53,32 +61,269 @@ def _register_external_op_helper(op_name, supported=True): """ @tvm.ir.register_op_attr(op_name, "target.headsail") def _func_wrapper(expr): + args = expr.args + typ = args[0].checked_type + if typ.dtype != "int8": + return False + + global conv2d_counter + if conv2d_counter == True: + conv2d_counter = False + logger.info(expr.span) return supported return _func_wrapper -#_register_external_op_helper("nn.conv2d") -_register_external_op_helper("nn.relu") -#_register_external_op_helper("add") - +#_register_external_op_helper("qnn.add") +#_register_external_op_helper("qnn.conv2d") +#_register_external_op_helper("qnn.relu") -def make_pattern(with_bias=True): +# Special case to handle tflite models converted to relay with fused activation +def qnn_tflite_conv2d_bias_relu(): data = wildcard() weight = wildcard() bias = wildcard() - conv = is_op('nn.conv2d')(data, weight) - if with_bias: - conv_out = is_op('add')(conv, bias) - else: - conv_out = conv - return is_op('nn.relu')(conv_out) + pattern = is_op("qnn.conv2d")( + data, weight, is_constant(), is_constant(), is_constant(), is_constant() + ) + pattern = is_op("nn.bias_add")(pattern, bias) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + pattern = is_op("clip")(pattern) + return pattern + +def make_qnn_conv2d_pattern(): + """Make qnn.conv2d based pattern supported by DNNL + + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + data = wildcard() + weight = is_constant() + bias = is_constant() + o_scl = is_constant() + dst_zp = is_constant() + act_scl = is_constant() + sum_scl = is_constant() + sum_src = wildcard() + zero_zp = is_expr(const(0, dtype="int32")) + + pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant()) + pat = is_op("cast")(pat) + pat = is_op("add")(pat, bias) | pat # optional bias + pat = is_op("multiply")(pat, o_scl) + pat = is_op("clip")(pat) # TBD, not only clip + pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1 + pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum + pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0 + pat = is_op("cast")(pat) + return pat @register_pattern_table("headsail") def pattern_table(): - conv2d_bias_relu_pat = ("headsail.conv2d_bias_relu", make_pattern(with_bias=True)) - conv2d_relu_pat = ("headsail.conv2d_relu", make_pattern(with_bias=False)) - patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] - return [] - return patterns + tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", qnn_tflite_conv2d_bias_relu()) + #tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", make_qnn_conv2d_pattern()) + #tflite_conv2d_bias= ("headsail.tflite_conv2d_bias", qnn_tflite_conv2d_bias()) + return [tflite_conv2d_bias_relu] + #return [tflite_conv2d_bias_relu, tflite_conv2d_b//ias] + +class LegalizeQnnOpForHeadsail(DFPatternCallback): + """Legalize QNN based patterns to match DNNL + + original pattern: + OP = qnn.conv2d + %1 = OP(SRC, WGH) - OP(src_zp, WGH) // qnn.conv2d + %2 = %1 + orig_bias // bias + %2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize + %3 = act(%2) // activation == clip + + transform to DNNL compatible: + %1 = OP(SRC, WGH) + %2 = cast(%1, dtype="float") + %2 = (%1 + bias) * o_scl + %3 = act(%2) * act_scl + %4 = %3 + SRC2 * sum_scl + %5 = %4 + dst_zp + %6 = cast(%5, dtype="float") + + where: + o_scl = rq_in_scl / rq_out_scl + act_scl = sum_lhs_scl / sum_out_scl + sum_scl = sum_rhs_scl / sum_out_scl + bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl + dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl - + sum_rhs_zp * sum_rhs_scl / sum_out_scl + """ + + def __init__(self): + super(LegalizeQnnOpForHeadsail, self).__init__() + self.src = wildcard() + self.wgh = wildcard() + self.bias = wildcard() + self.sum_src = wildcard() + + self.src_scl = is_constant() + self.src_zp = is_constant() + self.wgh_scl = is_constant() + self.wgh_zp = is_expr(const(0)) + + self.rq_in_scl = is_constant() + self.rq_in_zp = is_constant() + self.rq_out_scl = is_constant() + self.rq_out_zp = is_constant() + + self.sum_lhs_scl = is_constant() + self.sum_lhs_zp = is_constant() + self.sum_rhs_scl = is_constant() + self.sum_rhs_zp = is_constant() + self.sum_out_scl = is_constant() + self.sum_out_zp = is_constant() + + self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))( + self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl + ) + pat = is_op("add")(self.root, self.bias) | self.root # optional bias + pat = is_op("qnn.requantize")( + pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp + ) + pat = is_op("clip")(pat) + cast = is_op("cast")(pat) + pat = is_op("qnn.add")( + cast, + self.sum_src, + self.sum_lhs_scl, + self.sum_lhs_zp, + self.sum_rhs_scl, + self.sum_rhs_zp, + self.sum_out_scl, + self.sum_out_zp, + ) + pat = is_op("clip")(pat) + self.pattern = pat | cast + + def callback(self, pre, post, node_map): + root = node_map[self.root][0] + src = node_map[self.src][0] + wgh = node_map[self.wgh][0] + bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0] + src_zp = node_map[self.src_zp][0] + rq_in_scl = node_map[self.rq_in_scl][0] + rq_in_zp = node_map[self.rq_in_zp][0] + rq_out_scl = node_map[self.rq_out_scl][0] + rq_out_zp = node_map[self.rq_out_zp][0] + + final_dtype = node_map[self.pattern][0].checked_type.dtype + + if root.op == relay.op.get("qnn.conv2d"): + dst_layout = root.attrs.out_layout + dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout + wgh_layout = root.attrs.kernel_layout + else: + # qnn.dense has no layout attributes. Assume that is plain + dst_layout = "NC" + wgh_layout = "OI" + + # TODO(@apeskov): dst_layout may ne blocked + bias_rank = len(dst_layout) - dst_layout.index("C") + + sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None + # Default values if qnn.sum is not present + sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32") + sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32") + sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32") + sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32") + sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32") + sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32") + + def cast_fp(op): + return relay.op.cast(op, dtype="float32") + + # recalculate some factors + o_scl = rq_in_scl / rq_out_scl + act_scl = sum_lhs_scl / sum_out_scl + sum_scl = sum_rhs_scl / sum_out_scl + dst_zp = ( + cast_fp(sum_out_zp) + - cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl + - cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl + ) + bias = self.squeeze_bias(bias, dst_layout) + bias = ( + cast_fp(bias) + - cast_fp(self.fake_op(src_zp, wgh, wgh_layout)) + - cast_fp(rq_in_zp) + + cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl + ) + bias = self.broadcast_to_rank(bias, bias_rank) + + zero_zp = relay.const(0, dtype="int32") + one_scl = relay.const(1.0, dtype="float32") + + # construct new graph with proper post op ordering + gr = tvm.relay.Call( + root.op, + [src, wgh, zero_zp, zero_zp, one_scl, one_scl], + root.attrs, + root.type_args, + root.span, + ) + gr = relay.op.cast(gr, dtype="float32") + gr = gr + bias + gr = gr * o_scl + gr = relay.op.clip(gr, 0, 255) * act_scl + gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr + gr = gr + dst_zp + gr = relay.op.cast(gr, dtype=final_dtype) + return gr + + @staticmethod + def fake_op(zp, wgh, layout): + """Fake operator implementation for zp broadcast input""" + # Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct + # Dense: reduce kernel {OC, IC} -> {OC} + wgh_int = relay.op.cast(wgh, dtype="int32") + reduced_kernel = relay.op.sum( + wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True + ) + return zp * reduced_kernel + + @staticmethod + def squeeze_bias(bias, layout): + shape = transform.InferTypeLocal(bias).concrete_shape + c_position = layout.index("C") - len(layout) + len(shape) + squeeze_idxs = [i for i in range(len(shape)) if i != c_position] + return relay.op.squeeze(bias, squeeze_idxs) + + @staticmethod + def broadcast_to_rank(op, rank): + """Scalar or 1D tensor are supported""" + shape = transform.InferTypeLocal(op).concrete_shape + if len(shape) == 0: + return op + if len(shape) == 1: + return relay.op.expand_dims(op, 1, rank - 1) + raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.") + + +def legalize_qnn_for_headsail(mod): + """Transform qnn primitives to DNNL compatible form. Eliminate source zero point and apply + strict sequence of post ops.""" + print("Legalizing qnn for headsail") + #mod["main"] = rewrite(LegalizeQnnOpForHeadsail(), mod["main"]) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + # transform.SimplifyInference(), # TODO: this pass decompose nn.layer_norm + # transform.FoldScaleAxis(), # TODO: fail inside TVM in case of grouped convolutions. + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod diff --git a/src/relay/backend/contrib/headsail/codegen.cc b/src/relay/backend/contrib/headsail/codegen.cc index c676be5849ec..2a9efe557398 100644 --- a/src/relay/backend/contrib/headsail/codegen.cc +++ b/src/relay/backend/contrib/headsail/codegen.cc @@ -17,19 +17,30 @@ * under the License. */ +#include #include #include #include #include +#include +#include +#include +#include +#include #include #include +#include +#include "../../../transforms/pattern_utils.h" #include #include #include +#include #include "../../utils.h" -#include "../codegen_c/codegen_c.h" +#include "./codegen_headsail.h" +//#include "../codegen_c/codegen_c.h" +#include "../../../../target/source/codegen_c_host.h" namespace tvm { namespace relay { @@ -37,88 +48,165 @@ namespace contrib { using namespace backend; +struct CompositeCallables { + std::vector passed_args; // Args used in wrapped call + std::vector static_args; // Static arguments used in function call +}; + + inline size_t GetShape1DSize(const Type& type) { const auto shape = GetShape(type); return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } -inline std::string GetShapeString(std::vector shape) { - std::string v = "std::vector{"; - for (auto s : shape) { - v += std::to_string(s) + ","; - } - v += "}"; - return v; -} - - -std::vector Conv2d(const CallNode* call) { - std::vector args; - const auto* conv2d_attr = call->attrs.as(); - ICHECK(conv2d_attr); - - auto ishape = GetShape(call->args[0]->checked_type()); - auto wshape = GetShape(call->args[1]->checked_type()); - - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - - // Args: O, G, Ph0, Pw0, Ph1, Pw1, Kh, Kw, Sh, Sw - args.push_back(std::to_string(wshape[0])); - args.push_back(std::to_string(conv2d_attr->groups)); - args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); - args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); - args.push_back(std::to_string(conv2d_attr->padding[2].as()->value)); - args.push_back(std::to_string(conv2d_attr->padding[3].as()->value)); - args.push_back(std::to_string(wshape[2])); - args.push_back(std::to_string(wshape[3])); - args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); - args.push_back(std::to_string(conv2d_attr->strides[1].as()->value)); - - return args; -} - -std::vector Relu(const CallNode* call) { - std::vector args; - auto ishape = GetShape(call->args[0]->checked_type()); - // Args: N, C, H, W - args.push_back(GetShapeString(ishape)); - return args; -} - -std::vector Add(const CallNode* call) { - std::vector args; - auto ishape = GetShape(call->args[0]->checked_type()); - // Args: N, C, H, W - args.push_back(GetShapeString(ishape)); - return args; -} - -class CodegenHeadsail : public MemoizedExprTranslator>, public CodegenCBase { +class CodegenHeadsail : public MemoizedExprTranslator>, public HeadsailCodegenCBase { public: - explicit CodegenHeadsail(const std::string& id) { this->ext_func_id_ = id; } + //CodegenHeadsail(const std::string& id) { this->ext_func_id_ = id; } + CodegenHeadsail(std::unordered_map* const_name_to_constant, + Array* const_names, std::string ext_func_id) + : const_name_to_constant_(const_name_to_constant), + const_names_(const_names), + ext_func_id_(std::move(ext_func_id)) {} + + + CompositeCallables Conv2d_bias(const FunctionNode* callee) { + + CompositeCallables callables; + + const ClipAttrs* clip_attr = nullptr; + const qnn::RequantizeAttrs* requantize_attr = nullptr; + const BiasAddAttrs* bias_attr = nullptr; + const Conv2DAttrs* conv2d_attr = nullptr; + + const auto* current_call = callee->body.as(); + + + if (backend::IsOp(current_call, "clip")) { + std::cout << "CLIP!!!!!!!!!!!!!" << std::endl; + clip_attr = current_call->attrs.as(); + current_call = current_call->args[0].as(); + ICHECK(clip_attr); + } + if (backend::IsOp(current_call, "qnn.requantize")) { + std::cout << "REQ!!!!!!!!!!!!!" << std::endl; + requantize_attr = current_call->attrs.as(); + + // Input scale + for (auto const& arg : VisitExpr(current_call->args[1])) { + callables.static_args.push_back(arg.name); // Const calls + } + + // Input zero + for (auto const& arg : VisitExpr(current_call->args[2])) { + callables.static_args.push_back(arg.name); // Const calls + } + + // Output zero + for (auto const& arg : VisitExpr(current_call->args[3])) { + callables.static_args.push_back(arg.name); // Const calls + } + + // Output scale + for (auto const& arg : VisitExpr(current_call->args[4])) { + callables.static_args.push_back(arg.name); // Const calls + } + current_call = current_call->args[0].as(); + ICHECK(requantize_attr); + } + + if (backend::IsOp(current_call, "nn.bias_add")) { + std::cout << "BIAS!!!!!!!!!!!!!" << std::endl; + bias_attr = current_call->attrs.as(); + current_call = current_call->args[0].as(); + ICHECK(bias_attr); + } + + if (backend::IsOp(current_call, "qnn.conv2d")) { + std::cout << "CONV!!!!!!!!!!!!!" << std::endl; + //auto conv2d_args = GetArgumentNames(callee->body.as()); + //callables.passed_args.insert(callables.passed_args.end(), conv2d_args.begin(), conv2d_args.end()); + conv2d_attr = current_call->attrs.as(); + ICHECK(conv2d_attr); + } + + auto ishape = GetShape(current_call->args[0]->checked_type()); // Input shape + auto wshape = GetShape(current_call->args[1]->checked_type()); // Kernel shape + + std::cout << std::endl; + callables.static_args.push_back(std::to_string(ishape[3])); // Input channels + callables.static_args.push_back(std::to_string(ishape[1])); // Input height + callables.static_args.push_back(std::to_string(ishape[2])); // Input width + + // Input layout + char data_layout[6]; + std::strcpy(data_layout, "\""); + std::strcat(data_layout, &conv2d_attr->data_layout.c_str()[1]); + std::strcat(data_layout, "\""); + + std::cout << "Data layout: " << data_layout << std::endl; + //callables.static_args.push_back(data_layout); + callables.static_args.push_back("\"HWC\""); + + callables.static_args.push_back(std::to_string(wshape[3])); // Kernels amount + callables.static_args.push_back(std::to_string(wshape[2])); // Kernels channels + callables.static_args.push_back(std::to_string(wshape[0])); // Kernels height + callables.static_args.push_back(std::to_string(wshape[1])); // Kernels width + + + // Kernel layout + char kernel_layout[7]; + std::strcpy(kernel_layout, "\""); + std::strcat(kernel_layout, conv2d_attr->kernel_layout.c_str()); + std::strcat(kernel_layout, "\""); + + // Convert TVM layout string to Headsail layout string + for (int i = 0; i < 7; ++i) { + if (kernel_layout[i] == 'I') { + kernel_layout[i] = 'C'; + } else if (kernel_layout[i] == 'O') { + kernel_layout[i] = 'K'; + } + } + //callables.static_args.push_back(kernel_layout); + callables.static_args.push_back("\"HWCK\""); + + callables.static_args.push_back(std::to_string(conv2d_attr->groups * wshape[3])); + + // Padding + callables.static_args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); // Pad top + callables.static_args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); // Pad left + callables.static_args.push_back(std::to_string(conv2d_attr->padding[3].as()->value)); // Pad right + callables.static_args.push_back(std::to_string(conv2d_attr->padding[2].as()->value)); // Pad bottom + callables.static_args.push_back(std::to_string(0)); // Pad value + + callables.static_args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); // Stride x + callables.static_args.push_back(std::to_string(conv2d_attr->strides[1].as()->value)); // Stride y + callables.static_args.push_back(std::to_string(0)); // Mac clip + callables.static_args.push_back(std::to_string(8)); // PP clip + + return callables; + } + std::vector VisitExprDefault_(const Object* op) final { LOG(FATAL) << "Headsail codegen doesn't support: " << op->GetTypeKey(); } - // Generates function parameter - std::vector VisitExpr_(const VarNode* node) override { + std::vector VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); Output output; output.name = node->name_hint(); + std::cout << "Input variable:" << output.name << std::endl; return {output}; } std::vector VisitExpr_(const TupleNode* node) final { std::vector outs; for (auto field : node->fields) { - auto res = VisitExpr(field); - ICHECK_EQ(res.size(), 1U) << "Do not support tuple nest"; - outs.push_back(res[0]); + auto res = VisitExpr(field); + ICHECK_EQ(res.size(), 1U) << "Do not support tuple nest"; + outs.push_back(res[0]); } return outs; } @@ -134,35 +222,66 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ std::vector VisitExpr_(const ConstantNode* cn) final { Output output; - // Get const: static_cast(dnnl_0_consts[0]->data) - output.name = CreateDataReference(ext_func_id_, const_idx_); - output.dtype = "float"; - - // Generate the global variable for needed ndarrays - if (const_array_name_.empty()) { - const_array_name_ = CreateNDArrayPool(ext_func_id_); - std::string checker = CreateInitChecker(ext_func_id_); - ext_func_body_.insert(ext_func_body_.begin(), checker); - } - // Give the ndarray a unique name to ease the initialization of it at - // runtime. - std::string const_symbol = "dnnl_" + ext_func_id_; - std::string const_var_name = CreateConstVar(const_symbol, const_idx_); - const_vars_.push_back(const_var_name); - const_idx_++; + size_t const_id = const_name_to_constant_->size(); - const auto* type_node = cn->checked_type().as(); - ICHECK(type_node); - ICHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; + //output.name = CreateDataReference(ext_func_id_, const_id); + const auto* type_node = cn->checked_type().as(); + ICHECK(type_node); + const auto& dtype = GetDtypeString(type_node); - return {output}; + output.dtype = dtype; + + std::string const_var_name = CreateConstVar(ext_func_id_, const_id); + output.name = const_var_name; + + std::vector constant_values; + + tvm::runtime::NDArray data = cn->data; + + int ndim = data->ndim; + int num_elements = 1; + for (int i = 0; i < ndim; i++) { + num_elements *= data->shape[i]; + } + + + // Extract constant values + if (data->dtype.code == kDLFloat && data->dtype.bits == 32) { + const float* values = static_cast(data->data); + // // Convert the constant values to string and push to vector + for (int64_t i = 0; i < num_elements; ++i) { + //std::cout << "D:" << std::to_string(values[i]) << std::endl; + constant_values.push_back(std::to_string(values[i])); + } + } + else if (data->dtype.code == kDLInt && data->dtype.bits == 32) { + const int* values = static_cast(data->data); + // // Convert the constant values to string and push to vector + for (int64_t i = 0; i < num_elements; ++i) { + //std::cout << "D:" << std::to_string(values[i]) << std::endl; + constant_values.push_back(std::to_string(values[i])); + } + } + + ExtractedConstArray extracted; + extracted.arr = constant_values; + extracted.dtype = dtype; + extracted.size = num_elements; + + + extracted_constants.insert({const_var_name, extracted}); + const_name_to_constant_->emplace(const_var_name, cn->data); + const_names_->push_back(const_var_name); + + return {output}; } - std::vector VisitExpr_(const CallNode* call) override { + std::vector VisitExpr_(const CallNode* call) final { GenerateBodyOutput ret; if (const auto* func = call->op.as()) { ret = GenerateCompositeFunctionCall(func, call); + } else { ret = GenerateOpCall(call); } @@ -172,12 +291,13 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ return ret.outputs; } - std::string JIT(const std::vector& out) override { - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); + std::string JIT(const std::vector& out) { + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, extracted_constants, out); } private: + // TODO: Fix this to parse composite std::vector GetArgumentNames(const CallNode* call) { std::vector arg_names; for (size_t i = 0; i < call->args.size(); ++i) { @@ -195,9 +315,7 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ using ArgFunType = std::function(const CallNode*)>; static const std::map> op_map = { - {"nn.conv2d", {"headsail_conv2d", Conv2d}}, - {"nn.relu", {"dnnl_relu", Relu}}, - {"add", {"dnnl_binary_op", Add}}, + //{"qnn.conv2d", {"dla_conv2d", Conv2d_bias}}, }; const auto op_name = GetRef(op_node)->name; @@ -209,27 +327,18 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ LOG(FATAL) << "Unsupported op: " << AsText(call->op, false); } + GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, const CallNode* caller) { const auto pattern_name = callee->GetAttr(attr::kComposite); ICHECK(pattern_name.defined()) << "Only functions with composite attribute supported"; - if (pattern_name == "dnnl.conv2d_bias_relu") { - const auto* conv_call = - GetRootCall(callee->body.as(), 2, {"nn.conv2d", "add", "nn.relu"}); - return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller), - Conv2d(conv_call)); - } else if (pattern_name == "dnnl.conv2d_relu") { - const auto* conv_call = GetRootCall(callee->body.as(), 1, - (const std::vector){"nn.conv2d", "nn.relu"}); - return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller), - Conv2d(conv_call)); - } else if (pattern_name == "dnnl.conv2d_bias") { - const auto* conv_call = GetRootCall(callee->body.as(), 1, - (const std::vector){"nn.conv2d", "add"}); - return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller), - Conv2d(conv_call)); - } + if (pattern_name == "headsail.tflite_conv2d_bias_relu") { + const auto* conv_call = GetRootCall(callee->body.as(), 3, + (const std::vector){"qnn.conv2d", "nn.bias_add", "qnn.requantize", "clip"}); + CompositeCallables callables = Conv2d_bias(callee); + return GenerateBody(conv_call, "dla_tvm_qnn_conv2d", GetArgumentNames(caller), callables.static_args); + } LOG(FATAL) << "Unknown composite function:" << pattern_name; } @@ -245,6 +354,8 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ // Make function call with input buffers when visiting arguments ICHECK_GT(func_args.size(), 0); std::ostringstream decl_stream; + + // Wildcard arguments i.e. input, weight, output decl_stream << "(" << func_args[0]; for (size_t i = 1; i < func_args.size(); ++i) { decl_stream << ", " << func_args[i]; @@ -265,6 +376,7 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false); } + // Generate buffers to hold results GenerateBodyOutput ret; for (const auto& out_type : out_types) { this->PrintIndents(); @@ -277,13 +389,12 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ output.size = out_size; output.dtype = GetDtypeString(out_type.as()); output.need_copy = true; - // NOTE: This needs to be extended for int16_t - ret.buffers.push_back("int8_t* " + out + " = (int8_t*)std::malloc(" + + ret.buffers.push_back("int8_t* " + out + " = (int8_t*)malloc(" + std::to_string(out_size) + ");"); ret.outputs.push_back(output); } - // Attach attribute arguments + // Attach attribute arguments, op specific defined by the codegen for (size_t i = 0; i < attribute_args.size(); ++i) { decl_stream << ", " << attribute_args[i]; } @@ -292,8 +403,15 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ return ret; } + /*! + * \brief The accumulated constant name to constant mapping. Shared between all generated + * functions. + */ + std::unordered_map* const_name_to_constant_; + std::unordered_map extracted_constants; + /*! \brief The id of the external dnnl ext_func. */ - std::string ext_func_id_{""}; + std::string ext_func_id_; /*! * \brief The index to track the output buffer. Each kernel will redirect the * output to a buffer that may be consumed by other kernels. @@ -310,7 +428,7 @@ class CodegenHeadsail : public MemoizedExprTranslator>, publ /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; /*! \brief The variable name to constant mapping. */ - Array const_vars_; + Array* const_names_; friend class HeadsailModuleCodegen; }; @@ -323,14 +441,21 @@ class HeadsailModuleCodegen : public CSourceModuleCodegenBase { // Record the external symbol for runtime lookup. auto sid = GetExtSymbol(func); + func_names_.push_back(sid); - CodegenHeadsail builder(sid); + CodegenHeadsail builder(&const_name_to_constant_, &const_names_, sid); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); - return {sid, builder.const_vars_}; + return {sid, const_names_}; } + /*! \brief Returns the accumulated constant name to constant mapping. */ + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; + } + + /*! * \brief The overridden function that will create a CSourceModule. In order * to compile the generated C source code, users need to specify the paths to @@ -344,19 +469,16 @@ class HeadsailModuleCodegen : public CSourceModuleCodegenBase { */ runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; - code_stream_ << "#include \n"; code_stream_ << "#include \n"; // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't // expose it to ordinary users. To make export_library use it, users need to // pass -I${PATH_TO_TVM}/src/runtime/contrib - code_stream_ << "#include \n"; - code_stream_ << "using namespace tvm::runtime;\n"; - code_stream_ << "using namespace tvm::runtime::contrib;\n"; + code_stream_ << "#include \n"; code_stream_ << "\n"; ICHECK(ref->IsInstance()); @@ -365,11 +487,21 @@ class HeadsailModuleCodegen : public CSourceModuleCodegenBase { String sym = std::get<0>(res); Array variables = std::get<1>(res); + std::cout << "Sym: " << sym << std::endl; + + int i = 0; + for (auto x : variables) { + std::cout << i << " | " << "Var: " << x << std::endl; + ++i; + } + // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - // TODO(@manupa-arm): pass the function names to enable system-lib creation - return (*pf)(code, "c", Array{sym}, variables); + //// TODO(@manupa-arm): pass the function names to enable system-lib creation + //return (*pf)(code, "c", Array{sym}, variables); + // Use this if things break + return codegen::CSourceModuleCreate(code, "c", func_names_); } private: @@ -378,6 +510,9 @@ class HeadsailModuleCodegen : public CSourceModuleCodegenBase { * external codegen tools. */ std::ostringstream code_stream_; + Array func_names_; + std::unordered_map const_name_to_constant_; + Array const_names_; }; @@ -388,6 +523,7 @@ runtime::Module HeadsailCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.headsail").set_body_typed(HeadsailCompiler); + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/headsail/codegen_c.h b/src/relay/backend/contrib/headsail/codegen_c.h deleted file mode 100644 index 58194a046a34..000000000000 --- a/src/relay/backend/contrib/headsail/codegen_c.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ -#define TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace contrib { - -struct Output { - std::string name; - std::string dtype; - int size; - bool need_copy; -}; - -struct GenerateBodyOutput { - std::string decl; - std::vector buffers; - std::vector outputs; - Array headers; -}; - -#endif // TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ diff --git a/src/relay/backend/contrib/headsail/codegen_headsail.h b/src/relay/backend/contrib/headsail/codegen_headsail.h new file mode 100644 index 000000000000..82f3c4793281 --- /dev/null +++ b/src/relay/backend/contrib/headsail/codegen_headsail.h @@ -0,0 +1,386 @@ + /* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/codegen_c/codegen_c.h + * \brief The base class for external codegen tools. + */ +#ifndef TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ +#define TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { + +struct Output { + std::string name; + std::string dtype; + int size; + bool need_copy; +}; + +struct ExtractedConstArray { + std::string dtype; + int size; + std::vector arr; +}; + +struct GenerateBodyOutput { + std::string decl; + std::vector buffers; + std::vector outputs; + Array headers; +}; + +class CSourceModuleCodegenBase { + public: + CSourceModuleCodegenBase() = default; + virtual ~CSourceModuleCodegenBase() = default; + + /*! + * \brief Create a runtime module for the external library. For example, it + * could be a CSourceModule that can be directly compiled and linked together + * with a DSOModule, or a json style module that emitts a json artifact that + * is able to be executed by a customized json runtime. + * + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * + * \return A runtime module. + */ + virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0; +}; + +// The base class to generate the declaration functions in C. +class HeadsailCodegenCBase { + public: + virtual ~HeadsailCodegenCBase() {} + + protected: + /*! \brief Print indents using spaces. */ + void PrintIndents() { + for (int i = 0; i < indent_; i++) { + code_stream_ << ' '; + } + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + ICHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! + * \brief Creates a runtime function header + */ + void PrintRuntimeFunctionHeader(std::string func_name) { + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "extern \"C\" {\n"; + code_stream_ << "#endif\n"; + code_stream_ << "TVM_DLL int32_t "; + code_stream_ << func_name << "("; + code_stream_ << "TVMValue* args, "; + code_stream_ << "int* type_code, "; + code_stream_ << "int num_args, "; + code_stream_ << "TVMValue* out_value, "; + code_stream_ << "int* out_type_code) {\n"; + } + + /*! + * \brief Adds a line to convert TVMValue args to DLTensors + */ + void PrintArgToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* arg" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Adds a line to convert TVMValue rets to DLTensors + */ + void PrintRetToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* ret" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Gerenate C code for the external function. + * + * \param func_name The name of the external function. + * \param arg_types Types of arguments represented as string + * + * \code + * + * Array foo_consts; + * + * // An example code for the generated C function. + * int foo_wrapper_(DLTensor* arg0, + * DLTensor* arg1, + * DLTensor* out) { + * foo_((float*)(arg0->data), + * (float*)(arg1->data), + * (float*)(out->data)); + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); + * + * int foo_init_wrapper_(Array arr) { + * foo_consts = arr; + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * + * \endcode + */ + void GenerateBackendCFunc(const std::string& func_name, const std::vector& arg_types, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + // Print signature + code_stream_ << "\n"; + + } + + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + std::vector arg_types; + for (size_t i = 0; i < args.size(); i++) { + arg_types.push_back(GetDtypeString(args[i])); + } + return GenerateBackendCFunc(func_name, arg_types, const_arr_name, outs, pass_dl_tensor); + } + + std::string GenerateConstantArray(const std::string name, const ExtractedConstArray extracted) { + std::ostringstream stream; + stream << " " << extracted.dtype << " " << name << "[" << std::to_string(extracted.size) << "]" << " = {"; + for (const auto &x : extracted.arr) { + stream << x <<", "; + } + stream << "};\n"; + return stream.str(); + } + + /*! + * \brief Emit the code for external runtime. + * + * \param out The outputs. + * + * \return The code string. + */ + virtual std::string JIT(const std::vector& out) = 0; + + /*! + * \brief A common interface that is used by various external runtime to + * generate the wrapper to invoke external kernels. + * + * \param ext_func_id The unique id of an external function. It will be used + * during runtime to pick the correct external function. + * \param args The arguments used by the external function. + * \param buf_decl The declaration of temporary buffers that used to store the + * intermeidate of each external kernel. + * \param body The statements of the external function. + * \param out The name and id pairs for output. + * + * \return The emitted code string. + */ + std::string JitImpl(const std::string& ext_func_id, const Array& args, + const std::vector& buf_decl, + const std::vector& body, const std::string& const_arr_name, + const std::unordered_map& extracted_constants, + const std::vector& outs) { + // Create a declaration for global ndarrays that contain constant data. + code_stream_ << "//This was generated with headsail codegen\n"; + + // Create the signature. For example, it could be: + // void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {} + code_stream_ << "int " << ext_func_id << "("; + + for (const auto& arg : args) { + const auto& dtype_str = GetDtypeString(arg); + code_stream_ << dtype_str << "* " << arg->name_hint() << ", "; + } + for (size_t i = 0; i < outs.size() - 1; ++i) { + code_stream_ << outs[i].dtype << "* out" << i << ", "; + } + code_stream_ << outs.back().dtype << "* out" << outs.size() - 1 << ") {\n"; + + // TODO: Constants here + for (auto const& x : extracted_constants) { + this->PrintIndents(); + code_stream_ << GenerateConstantArray(x.first, x.second); + std::cout<< std::endl; + } + + this->EnterScope(); + + // Function body + for (auto decl : buf_decl) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : body) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + // Copy output + for (size_t i = 0; i < outs.size(); ++i) { + if (!outs[i].need_copy) { + continue; + } + this->PrintIndents(); + code_stream_ << "memcpy(out" << i << ", " << outs[i].name << ", " << outs[i].size + << ");\n"; + } + + // Free buffers + for (size_t i = 0; i < buf_decl.size(); i++) { + this->PrintIndents(); + code_stream_ << "free(buf_" << i << ");\n"; + } + + // Return success + this->PrintIndents(); + code_stream_ << "return 0;\n"; + + this->ExitScope(); + code_stream_ << "}\n"; + + // Create the wrapper to call the ext_func + this->GenerateBackendCFunc(ext_func_id, args, const_arr_name, outs); + return code_stream_.str(); + } + + /*! + * \brief Returns dtype string + * + * \param var Var to get the dtype of + * + * \return The dtype string. + */ + std::string GetDtypeString(const Var& var) { + auto ttype = var->checked_type().as(); + ICHECK(ttype) << "Expect TensorTypeNode"; + return GetDtypeString(ttype); + } + + /*! + * \brief Returns dtype string + * + * \param ttype TensorTypeNode* to get the dtype of + * + * \return The dtype string. + */ + std::string GetDtypeString(const TensorTypeNode* ttype) { + std::string dtype; + if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) { + dtype = "float"; + } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { + dtype = "half"; + } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + dtype = "bfloat"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { + dtype = "int"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { + dtype = "int64_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 8)) { + dtype = "int8_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLUInt, 8)) { + dtype = "uint8_t"; + } else { + LOG(FATAL) << "Unsupported dtype " << ttype->dtype; + } + + return dtype; + } + + /*! + * \brief Generates the global ndarray pool declaration + * + * \param symobl The Symbol of the current function + * + * \return The created declaration + */ + std::string CreateNDArrayPool(const std::string& symbol) const { + return "tvm::runtime::Array " + symbol + "_consts;"; + } + + /*! + * \brief Generates the reference to the data of a constant ndarray + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created reference + */ + std::string CreateDataReference(const std::string& symbol, size_t const_id) const { + return "(int*)(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)"; + } + + /*! + * \brief Returns the variable name for a constant variable + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created variable name + */ + std::string CreateConstVar(const std::string& symbol, size_t const_id) const { + // tvmgen_default_headsail_const_0 etc... + return symbol + "_const_" + std::to_string(const_id); + } + + /*! \brief The external function source code stream. */ + std::ostringstream code_stream_; + + private: + /*! \brief Indent of the source code. */ + int indent_{0}; +}; + +/*! + * \brief A pass to translate all "Primitive" Relay functions with "Compiler=ccompiler" to + * a \p CSourceModule. + */ +transform::Pass HeadsailCompilerPass(); + +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_HEADSAIL_CODEGEN_C_H_ diff --git a/src/relay/backend/contrib/headsail/genereted_example.c b/src/relay/backend/contrib/headsail/genereted_example.c deleted file mode 100644 index ae5d734993e3..000000000000 --- a/src/relay/backend/contrib/headsail/genereted_example.c +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include -#include -#include - -extern "C" void headsail_call_conv2d_int8(int8_t* headsail_0_in0, int8_t* out0) { - int8* buf_0 = (int8_t*)std::malloc() - - //Call headsail-bsp conv2d layer - buf_0 = conv2d(input, kernels, padding, stride); - -}