From 03726da3052630d2a67f1c572570a799460d641d Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Tue, 17 May 2022 14:09:03 +0300 Subject: [PATCH] [DNNL] Add TensorRequisite concept Allow to use DNNL runtime in multi instance mode. Thread safe execution of Run() method. Signed-off-by: Alexander Peskov --- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 953 +++++++----------- .../contrib/dnnl/dnnl_tensor_requisite.h | 664 ++++++++++++ src/runtime/contrib/dnnl/dnnl_utils.h | 194 ++++ 3 files changed, 1220 insertions(+), 591 deletions(-) create mode 100644 src/runtime/contrib/dnnl/dnnl_tensor_requisite.h create mode 100644 src/runtime/contrib/dnnl/dnnl_utils.h diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index dc2afecbaf91..b80bcb43bf46 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -32,7 +32,12 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" -#include "dnnl.hpp" + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_tensor_requisite.h" namespace tvm { namespace runtime { @@ -40,59 +45,85 @@ namespace contrib { using namespace tvm::runtime; using namespace tvm::runtime::json; +using namespace utils; class DNNLJSONRuntime : public JSONRuntimeBase { - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + : JSONRuntimeBase(symbol_name, graph_json, const_names), + next_unique_eid_offset_(data_entry_.size()), + run_arg_eid_(input_var_eid_) { + for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e)); + } - const char* type_key() const { return "dnnl_json"; } + const char* type_key() const override { return "dnnl_json"; } void Init(const Array& consts) override { - BuildEngine(); - ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; // Setup constants entries for weights. SetupConstants(consts); + BuildEngine(); } - void Run() override { - // Fill in the input buffers. - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto eid = EntryID(input_nodes_[i], 0); - // TODO(@comaniac): Support other data lengths. - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; - size_t buffer_size = GetDataSize(*data_entry_[eid]); - write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); + /* Unused stub implementation */ + void Run() override { LOG(FATAL) << "Unreachable code"; } + + /* Thread safe implementation of Run. Keep runtime instance immutable */ + void Run(const TVMArgs& args) const { + auto arg_data_provider = makeIODataProvider(args); + auto mem_solver = tensor_registry_.makeSolver(arg_data_provider); + // Execute primitives one by one + for (const auto& act : net_) { + auto prim = std::get<0>(act); + auto arg_reqs = std::get<1>(act); + + // Find proper dnnl::memory buffers + std::unordered_map mem_args; + for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); + + prim.execute(stream_, mem_args); } + } + + /* Override GetFunction to reimplement Run method */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; - // Invoke the engine through intepreting the stream. - for (size_t i = 0; i < net_.size(); ++i) { - net_.at(i).execute(stream_, net_args_.at(i)); + ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + << "Found mismatch in the number of provided data entries and required."; + + Run(args); + }); + } else { + return JSONRuntimeBase::GetFunction(name, sptr_to_self); } - stream_.wait(); - - // Read output buffers. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto eid = EntryID(outputs_[i]); - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; - size_t buffer_size = GetDataSize(*data_entry_[eid]); - read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); + } + + /* Same as makeInitDataProvider but in case of InputOutput return real DLTensor */ + TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args) const { + auto extract_dl_tensor = [](const TVMArgValue& val) -> const DLTensor* { + ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor"; + return val.IsObjectRef() ? val.operator NDArray().operator->() + : val.operator DLTensor*(); + }; + + std::map io_map; // eid to dl tensor map + for (size_t i = 0; i < run_arg_eid_.size(); i++) { + io_map[run_arg_eid_[i]] = extract_dl_tensor(args[i]); } + + // lambda with captured IO data handlers + return [io_map](uint32_t eid) -> const DLTensor* { return io_map.at(eid); }; } private: - // Build up the engine based on the input graph. - - std::map elt_name2algo{ + const std::map elt_name2algo{ {"abs", dnnl::algorithm::eltwise_abs}, {"exp", dnnl::algorithm::eltwise_exp}, {"log", dnnl::algorithm::eltwise_log}, @@ -106,62 +137,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"clip", dnnl::algorithm::eltwise_clip}, }; - std::map layout_dict{ - {"", tag::any}, - {"NCW", tag::ncw}, - {"NWC", tag::nwc}, - {"OIW", tag::oiw}, - {"GOIW", tag::goiw}, - {"NCHW", tag::nchw}, - {"NHWC", tag::nhwc}, - {"OIHW", tag::oihw}, - {"GOIHW", tag::goihw}, - {"NCDHW", tag::ncdhw}, - {"NDHWC", tag::ndhwc}, - {"OIDHW", tag::oidhw}, - {"GOIDHW", tag::goidhw}, - {"IOHW", tag::iohw}, - {"GIOHW", tag::giohw}, - {"IODHW", tag::iodhw}, - {"GIODHW", tag::giodhw}, - - // Blocking layout. - {"NCW8c", tag::nCw8c}, - {"NCW16c", tag::nCw16c}, - {"OIW16i16o", tag::OIw8i8o}, - {"OIW16i16o", tag::OIw16i16o}, - {"OWI8o", tag::Owi8o}, - {"OWI16o", tag::Owi16o}, - {"NCHW4c", tag::nChw4c}, - {"NCHW8c", tag::nChw8c}, - {"NCHW16c", tag::nChw16c}, - {"OIHW8i8o", tag::OIhw8i8o}, - {"IOHW8i8o", tag::any}, - {"OIHW16i16o", tag::OIhw16i16o}, - {"IOHW16i16o", tag::IOhw16i16o}, - {"GOIHW4i4o", tag::gOIhw4i4o}, - {"GOIHW8i8o", tag::gOIhw8i8o}, - {"GOIHW16i16o", tag::gOIhw16i16o}, - {"OHWI8o", tag::Ohwi8o}, - {"OHWI16o", tag::Ohwi16o}, - {"OHWI32o", tag::Ohwi32o}, - {"OHWI48o", tag::Ohwi48o}, - {"OHWI64o", tag::Ohwi64o}, - {"GOIHW8g", tag::Goihw8g}, - {"GOIHW16g", tag::Goihw16g}, - {"NCDHW8c", tag::nCdhw8c}, - {"NCDHW16c", tag::nCdhw16c}, - {"OIDHW16i16o", tag::OIdhw16i16o}, - {"IODHW16i16o", tag::IOdhw16i16o}, - {"OIDHW8i8o", tag::OIdhw8i8o}, - {"IODHW8i8o", tag::any}, - {"ODHWI8o", tag::Odhwi8o}, - {"ODHWI16o", tag::Odhwi16o}, - {"ODHWI32o", tag::Odhwi32o}, - {"ODHWI48o", tag::Odhwi48o}, - {"ODHWI64o", tag::Odhwi64o}, - }; - bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) { // Define RegExp. std::regex bias_add_pat(".*_bias.*"); @@ -186,62 +161,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return std::regex_match(op_name, bias_add_pat) ? true : false; } - dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims, std::string layout) { - std::vector axis = { - 'N', 'C', 'O', 'I', 'D', 'H', 'W', - }; - dnnl::memory::dims out_dims; - std::string::iterator t = layout.begin(); - // Remove numbers in layout string to match the size of input_dims - while (t != layout.end()) { - if (*t >= '0' && *t <= '9') { - layout.erase(t); - } else { - t++; - } - } - // Push the correct shapes of each axis into the output_dims - for (auto a : axis) { - dnnl::memory::dim shape = 1; - if (layout.find(a) != std::string::npos) { - shape *= input_dims[layout.find(a)]; - char lower_a = std::tolower(a); - if (layout.find(lower_a) != std::string::npos) { - shape *= input_dims[layout.find(lower_a)]; - } - out_dims.push_back(shape); - } - } - // Multiply O and I with G, respectively - if (layout.find("G") != std::string::npos) { - dnnl::memory::dim G = 1; - if (layout.find("g") != std::string::npos) { - G = input_dims[layout.find("g")] * input_dims[layout.find("G")]; - } else { - G = input_dims[layout.find("G")]; - } - out_dims[0] *= G; - out_dims[1] *= G; - } - return out_dims; - } - - dnnl::memory::dims TransformStr2Dims(std::vector strs, bool dilates = false) { - dnnl::memory::dims out_dims; - if (dilates) { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str) - 1; }); - } else { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str); }); - } - return out_dims; - } - + // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); + std::set io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end()); + tensor_registry_ = TensorRegistry(engine_, io_eid_set); + std::regex conv_pat(".*conv[1-3]d.*"); std::regex deconv_pat(".*deconv[1-3]d.*"); std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); @@ -283,582 +210,426 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } } - // Bind a JSON graph node entry to a DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc, - size_t offset = 0) { - auto eid = EntryID(entry); - if (entry_out_mem_.count(eid) == 0) { - return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset); - } - return entry_out_mem_[eid].first; - } - - // Bind a JSON graph node entry to a given DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem, - size_t offset = 0) { - auto eid = EntryID(entry); - // Since the DNNL memory has been created before calling this function, we assume the entry - // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. - ICHECK_EQ(entry_out_mem_.count(eid), 0); - - // TODO(@comanic): Support other data types (i.e., int8). - auto data_node = nodes_[entry.id_]; - auto dltype = data_node.GetOpDataType()[entry.index_]; - ICHECK_EQ(dltype.bits, 32); - - entry_out_mem_[eid] = {mem, offset}; - return entry_out_mem_[eid].first; - } - void Convolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported data layout for conv: " << data_layout; - } - - if (layout_dict.find(kernel_layout) == layout_dict.end()) { - layout_dict.insert({kernel_layout, tag::any}); - LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout - << ", transfer to tag::any"; - } - - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - weights_dims_[0] = channels; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1; - } - - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - if (kernel_layout == "OIHW") { - kernel_layout.insert(0, "G"); - } + auto src_tr = getInput(nid, 0); + auto wgh_tr = getInput(nid, 1); + auto dst_tr = getOutput(nid, 0); + auto bias_tr = has_bias ? getInput(nid, 2) : getInput(nid, -1); + auto strides = getAttr>(node, "strides"); + auto dilates = getAttr>(node, "dilation"); + auto padding = getAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = getAttr(node, "groups"); + auto src_layout = getAttr(node, "data_layout"); + auto dst_layout = getAttr(node, "out_layout"); + auto wgh_layout = getAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.treatAs(src_layout); + dst_tr = dst_tr.treatAs(dst_layout); + wgh_tr = wgh_tr.treatAs(wgh_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.reshape(w_dims); } - // Memory descriptions. - auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); - auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.reshape({dst_tr.dims()[1]}); // Conv description. - auto conv_desc = - has_bias ? dnnl::convolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, - conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, - dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r); + auto conv_desc = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, + src_tr.layoutAny().desc(), wgh_tr.layoutAny().desc(), bias_tr.layoutAny().desc(), + dst_tr.layoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); - // Push to the network. - auto conv = dnnl::convolution_forward(conv_prim_desc); - net_.push_back(conv); - - // Data memory. - auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md); - - // Weight memory. - auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_prim_desc.weights_desc()); - - // Output memory. - auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); + src_tr = src_tr.requestLayout(conv_prim_desc.src_desc()); + wgh_tr = wgh_tr.requestLayout(conv_prim_desc.weights_desc()); + dst_tr = dst_tr.requestLayout(conv_prim_desc.dst_desc()); + bias_tr = bias_tr.requestLayout(conv_prim_desc.bias_desc()); - // Bias memory. - auto conv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, conv_bias_memory); + auto scratchpad_tr = TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } + submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Deconvolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_out_padding = - node.GetAttr>("output_padding"); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported data layout for deconv: " << data_layout; + auto src_tr = getInput(nid, 0); + auto wgh_tr = getInput(nid, 1); + auto dst_tr = getOutput(nid, 0); + auto bias_tr = has_bias ? getInput(nid, 2) : getInput(nid, -1); + + auto strides = getAttr>(node, "strides"); + auto dilates = getAttr>(node, "dilation"); + auto padding = getAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = getAttr(node, "groups"); + auto src_layout = getAttr(node, "data_layout"); + auto dst_layout = getAttr(node, "out_layout"); + auto wgh_layout = getAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // TODO(@apeskov): WA. conv3dTranspose uses wrong layout specifier. IO instead of OI. + auto wgh_logic_layout = utils::DefaultLogicLayoutFor(wgh_layout); + if (wgh_logic_layout == "OIDHW") wgh_logic_layout = "IODHW"; + if (wgh_logic_layout == "GOIDHW") wgh_logic_layout = "GIODHW"; + + // Take into account provided layout strings + src_tr = src_tr.treatAs(src_layout); + dst_tr = dst_tr.treatAs(dst_layout); + wgh_tr = wgh_tr.treatAs(wgh_layout, wgh_logic_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.reshape(w_dims); } - if (layout_dict.find(kernel_layout) == layout_dict.end()) { - layout_dict.insert({kernel_layout, tag::any}); - LOG(WARNING) << "Unregistered kernel layout for deconv: " << data_layout - << ", transfer to tag::any"; - } + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.reshape({dst_tr.dims()[1]}); - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - // legalize shape IOHW with layout OIHW - if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) { - std::swap(weights_dims_[0], weights_dims_[1]); - if (kernel_layout.find("OI") == 0) { - kernel_layout.replace(kernel_layout.find("OI"), 2, "IO"); - } - } - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim OP = out_padding[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP; - } - - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - } - - // Memory descriptions. - auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); - auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); - auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); - - // Transposed covn2d description. - auto deconv_desc = - has_bias ? dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_bias_md, deconv_dst_md, - strides_dims, dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_dst_md, strides_dims, dilates_dims, - padding_dims_l, padding_dims_r); + // Conv description. + auto deconv_desc = dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + src_tr.layoutAny().desc(), wgh_tr.layoutAny().desc(), bias_tr.layoutAny().desc(), + dst_tr.layoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); - // Push to the network. - auto deconv = dnnl::deconvolution_forward(deconv_prim_desc); - net_.push_back(deconv); + src_tr = src_tr.requestLayout(deconv_prim_desc.src_desc()); + wgh_tr = wgh_tr.requestLayout(deconv_prim_desc.weights_desc()); + dst_tr = dst_tr.requestLayout(deconv_prim_desc.dst_desc()); + bias_tr = bias_tr.requestLayout(deconv_prim_desc.bias_desc()); - // Data memory. - auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md); + auto scratchpad_tr = TensorRequisite::AsIs(deconv_prim_desc.scratchpad_desc()); - // Weight memory. - auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_prim_desc.weights_desc()); - - // Output memory. - auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc()); - - // Bias memory. - auto deconv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, deconv_bias_memory); - - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_BIAS, deconv_bias_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } + submit(dnnl::deconvolution_forward(deconv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Dense(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim OC = out_shape[1]; - - // Memory shapes. - dnnl::memory::dims data_dims = input_shape; - dnnl::memory::dims weight_dims = weight_shape; - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = out_shape; - - // Memory descriptions. - auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc}); - auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); + auto src_tr = getInput(nid, 0); + auto wgh_tr = getInput(nid, 1); + auto dst_tr = getOutput(nid, 0); + auto bias_tr = has_bias ? getInput(nid, 2) : getInput(nid, -1); + + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.reshape({dst_tr.dims()[1]}); // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); + auto dense_desc = dnnl::inner_product_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.layoutAny().desc(), wgh_tr.layoutAny().desc(), + bias_tr.layoutAny().desc(), dst_tr.layoutAny().desc()); // Enable elementwise post-ops. auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); - auto dense = dnnl::inner_product_forward(dense_prim_desc); - net_.push_back(dense); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - auto weight_memory = BindDNNLMemory(weight_entry, weight_md); + src_tr = src_tr.requestLayout(dense_prim_desc.src_desc()); + wgh_tr = wgh_tr.requestLayout(dense_prim_desc.weights_desc()); + dst_tr = dst_tr.requestLayout(dense_prim_desc.dst_desc()); + bias_tr = bias_tr.requestLayout(dense_prim_desc.bias_desc()); - // Bias memory. - auto bias_memory = dnnl::memory(bias_md, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); - } + auto scratchpad_tr = TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc()); - // Output memory. - auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_WEIGHTS, weight_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST, dst_memory}}); + submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - auto gamma_entry = node.GetInputs()[1]; - auto beta_entry = node.GetInputs()[2]; - auto mean_entry = node.GetInputs()[3]; - auto variance_entry = node.GetInputs()[4]; - dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dim IC = data_shape[1]; - float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + auto src_tr = getInput(nid, 0); + auto gamma_tr = getInput(nid, 1); + auto beta_tr = getInput(nid, 2); + auto mean_tr = getInput(nid, 3); + auto var_tr = getInput(nid, 4); + auto dst_tr = getOutput(nid, 0); + + auto axis = getAttr(node, "axis"); + auto epsilon = getAttr(node, "epsilon"); + auto center = getAttr(node, "center"); + auto scale = getAttr(node, "scale"); - // Memory description. - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case"; - // BN description. auto bn_desc = dnnl::batch_normalization_forward::desc( - dnnl::prop_kind::forward_inference, data_md, epsilon, + dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_); - auto bn = dnnl::batch_normalization_forward(bn_prim_desc); - net_.push_back(bn); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc()); - auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc()); - - // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but - // assign an offset to beta data for runtime serialization. - auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0); - BindDNNLMemory(beta_entry, weight_memory, IC); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_DST, out_memory}, - {DNNL_ARG_SCALE_SHIFT, weight_memory}, - {DNNL_ARG_MEAN, mean_memory}, - {DNNL_ARG_VARIANCE, variance_memory}}); + + // Concatenate scale and shift tensors + auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(), genUniqueEid()); + auto sc_sh_dims = scale_shift_tr.dims(); + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.crop(sc_sh_dims, {0, 0}).squeeze(); + auto shift_tr = scale_shift_tr.crop(sc_sh_dims, {1, 0}).squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + submit(dnnl::batch_normalization_forward(bn_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_DST, dst_tr}, + {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}, + {DNNL_ARG_MEAN, mean_tr}, + {DNNL_ARG_VARIANCE, var_tr}}); } void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + auto src_tr = getInput(nid, 0); + auto dst_tr = getOutput(nid, 0); + // Setup attributes. - auto data_entry = node.GetInputs()[0]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - std::vector str_kernel = node.GetAttr>("pool_size"); - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_dilates = node.GetAttr>("dilation"); - std::string layout = node.GetAttr>("layout")[0]; - - // Check layout. - if (layout_dict.find(layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported layout for pooling: " << layout; - } + auto strides = getAttr>(node, "strides"); + auto dilates = getAttr>(node, "dilation"); + auto padding = getAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto kernel = getAttr>(node, "pool_size"); + auto src_layout = getAttr(node, "layout"); + auto dst_layout = getAttr(node, "out_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.treatAs(src_layout); + dst_tr = dst_tr.treatAs(dst_layout); // Attributes related to AvgPool if (algo == dnnl::algorithm::pooling_avg) { - int int_countpad = std::stoi(node.GetAttr>("count_include_pad")[0]); - bool count_include_pad = int_countpad != 0 ? true : false; - algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding - : dnnl::algorithm::pooling_avg_exclude_padding; + auto include_pad = getAttr(node, "count_include_pad"); + algo = include_pad ? dnnl::algorithm::pooling_avg_include_padding + : dnnl::algorithm::pooling_avg_exclude_padding; } - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout); - dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout); - dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel); - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - - // Memory descriptions. - auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]); - auto pool_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); - // Pooling description. - auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo, - pool_src_md, pool_dst_md, strides_dims, - kernel_dims, padding_dims_l, padding_dims_r); - - auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc, engine_, true); - auto pool = dnnl::pooling_forward(pool_prim_desc); - net_.push_back(pool); + auto pool_desc = dnnl::pooling_v2_forward::desc( + dnnl::prop_kind::forward_inference, algo, src_tr.desc(), //<= Do not use any for src tensor + dst_tr.layoutAny().desc(), strides, kernel, dilates, padding_l, padding_r); + auto pool_prim_desc = dnnl::pooling_v2_forward::primitive_desc(pool_desc, engine_); - // Memories. - auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md); + src_tr = src_tr.requestLayout(pool_prim_desc.src_desc()); + dst_tr = dst_tr.requestLayout(pool_prim_desc.dst_desc()); - auto pool2d_dst_memory = BindDNNLMemory(out_entry, pool_prim_desc.dst_desc()); + auto scratchpad_tr = TensorRequisite::AsIs(pool_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST, pool2d_dst_memory}}); + submit(dnnl::pooling_v2_forward(pool_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCRATCHPAD, scratchpad_tr}}); } void Eltwise(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); - auto algo = elt_name2algo[op_name]; + auto algo = elt_name2algo.at(op_name); + + auto src_tr = getInput(nid, 0); + auto dst_tr = getOutput(nid, 0); - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); float alpha = 0., beta = 0.; if (op_name == "clip") { - alpha = std::stof(node.GetAttr>("a_min")[0]); - beta = std::stof(node.GetAttr>("a_max")[0]); + alpha = getAttr(node, "a_min"); + beta = getAttr(node, "a_max"); } else if (op_name == "nn.leaky_relu") { - alpha = std::stof(node.GetAttr>("alpha")[0]); + alpha = getAttr(node, "alpha"); } - auto elt_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, alpha, beta); + auto elt_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, + src_tr.desc(), alpha, beta); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); - ICHECK(data_md == elt_prim_desc.dst_desc()); - - auto elt = dnnl::eltwise_forward(elt_prim_desc); - net_.push_back(elt); + ICHECK(src_tr.desc() == elt_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Softmax(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - int axis = std::stoi(node.GetAttr>("axis")[0]); + auto src_tr = getInput(nid, 0); + auto dst_tr = getOutput(nid, 0); + + auto axis = getAttr(node, "axis"); if (axis < 0) { - axis = shape.size() + axis; + axis = src_tr.dims().size() + axis; } - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); auto softmax_desc = - dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis); + dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, src_tr.desc(), axis); auto softmax_prim_desc = dnnl::softmax_forward::primitive_desc(softmax_desc, engine_); - ICHECK(data_md == softmax_prim_desc.dst_desc()); - - auto softmax = dnnl::softmax_forward(softmax_prim_desc); - net_.push_back(softmax); + ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + submit(dnnl::softmax_forward(softmax_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + ICHECK_EQ(node.GetInputs().size(), 2U); // Memory and compute description. - std::vector data_dims; - std::vector data_mds; - std::vector data_memories; + auto lhs_tr = getInput(nid, 0); + auto rhs_tr = getInput(nid, 1); + auto dst_tr = getOutput(nid, 0); - ICHECK_EQ(node.GetInputs().size(), 2U); - for (auto entry : node.GetInputs()) { - auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + lhs_tr = lhs_tr.broadcast(dst_tr.dims()); + rhs_tr = rhs_tr.broadcast(dst_tr.dims()); - data_dims.push_back(data_shape); - data_mds.push_back(data_md); - data_memories.push_back(BindDNNLMemory(entry, data_md)); - } - ICHECK(data_dims[0] == data_dims[1]); - auto out_md = data_mds[0]; - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); - - auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); + auto binary_desc = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(), dst_tr.desc()); auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); - auto binary = dnnl::binary(binary_prim_desc); - net_.push_back(binary); - net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, - {DNNL_ARG_SRC_1, data_memories[1]}, - {DNNL_ARG_DST, out_memory}}); + submit(dnnl::binary(binary_prim_desc), + {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST, dst_tr}}); } - // Read from DNNL memory (+offset) and write to the handle. - inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* src = static_cast(mem.get_data_handle()); - std::copy(src + offset, src + offset + size, static_cast(handle)); + TensorRequisite getInput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetInputs().size()); + auto data_entry = node.GetInputs()[idx]; + + auto shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + auto dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto eid = node_row_ptr_[data_entry.id_] + data_entry.index_; + auto const_dl_tensor = data_entry_[eid]; + + DLTensor dl_tensor = {}; + if (const_dl_tensor) { + // Const data are present + dl_tensor = *const_dl_tensor; + eid = TensorRequisite::UNDEFINED_EID; + } else { + dl_tensor.dtype = dtype; + dl_tensor.ndim = shape.size(); + dl_tensor.shape = shape.data(); + } + + return TensorRequisite::AsIs(dl_tensor, eid); } - // Read from the handle and write to DNNL memory (+offset). - inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* dst = static_cast(mem.get_data_handle()); - std::copy(reinterpret_cast(handle), reinterpret_cast(handle) + size, - dst + offset); + TensorRequisite getOutput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetNumOutput()); + + auto shape = node.GetOpShape()[idx]; + auto dtype = node.GetOpDataType()[idx]; + auto eid = node_row_ptr_[nid] + static_cast(idx); + + ICHECK(data_entry_[eid] == nullptr); + + DLTensor dl_tensor = {}; + dl_tensor.dtype = dtype; + dl_tensor.ndim = shape.size(); + dl_tensor.shape = shape.data(); + + auto tr = TensorRequisite::AsIs(dl_tensor, eid); + return tr.backward(); } - // Generate DNNL memory description and infer the data layout by the given shape. - inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) { - dnnl::memory::desc data_md; - switch (shape.size()) { - case 2: - data_md = dnnl::memory::desc({shape, dtype, tag::ab}); - break; - case 3: - data_md = dnnl::memory::desc({shape, dtype, tag::abc}); - break; - case 4: - data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); - break; - case 5: - data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); - break; - default: - LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); - break; + // Helper function to register primitive into execution queue + void submit(const dnnl::primitive& prim, + const std::unordered_map& tr_args) { + // Register all provided TR arguments + std::unordered_map prim_arg_id; + TensorRegistry::ActionQue post_prim_actions; + for (const auto& kvp : tr_args) { + const auto& key = kvp.first; + const auto& tr = kvp.second; + + if (!tr.defined()) continue; // empty arg is admitted. Just skip it + auto arg_id = tensor_registry_.registerTR(tr, tr.isReversed() ? &post_prim_actions : &net_); + prim_arg_id[key] = arg_id; } - return data_md; + + // Register main primitive + net_.push_back({prim, prim_arg_id}); + + // Register post actions + net_.insert(net_.end(), post_prim_actions.begin(), post_prim_actions.end()); } + uint32_t genUniqueEid() { return next_unique_eid_offset_++; } + /* The dnnl engine. */ dnnl::engine engine_; /* The dnnl stream. */ dnnl::stream stream_; /* The network layers that are represented in dnnl primitives. */ - std::vector net_; - /* The memory that is consumed by arguments. */ - std::vector> net_args_; - /* The entry ID to its corresponding output memory. */ - std::unordered_map> entry_out_mem_; + TensorRegistry::ActionQue net_; + /* Storage for all memory objects */ + TensorRegistry tensor_registry_; + /* Generator of new unique eid which doesn't match with existing data entry */ + uint32_t next_unique_eid_offset_; + /* Map of Run arg idx to corresponding eid */ + std::vector run_arg_eid_; }; runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h new file mode 100644 index 000000000000..3e4b41e29353 --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -0,0 +1,664 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file + * \brief Helper TR wrapper to handle tensors processing + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace utils; + +/** + * Helper object to simplify handling of tensor + * + * Allow to specify original source tensor and future actions which should be applied to it. + * Can be treated as sequence of reordering or reinterpretation. Finally TR can be solved as + * proper interpretation of source memory buffer, or sequence of dnnl::reorder operators which + * will provide desired data. + * + * Example: + * ``` + * DLTensor src_mem = {5, 2, 128, 128, 8} // 5D tensor + * + * // describe sequence of layout transformation + * auto tr = TensorRequisite.AsIs(src_mem, eid); // 5D + * tr = tr.treatAs("ABCD8b"); // 4D + * tr = tr.permute({0, 2, 3, 1}); // permute axes NCHW -> NHWC + * tr = tr.crop({1, 128, 128, 16}, {0, 0, 0}); // extract first batch + * tr = tr.squeeze(); + * + * // register TR + * TensorRegistry t_reg; + * auto t_id = t_reg.register(tr); + * t_reg.finalize(); + * + * // Get final dnnl::memory object + * auto mem = t_id.solve(t_id); + * ``` + * + * @note: Empty TR object allow any manipulation. Empty TR will be returned. + */ +class TensorRequisite { + public: + using TID = uint32_t; + static constexpr TID UNDEFINED_EID = std::numeric_limits::max() - 1; + + TensorRequisite() {} + + static TensorRequisite AsIs(const DLTensor& tensor, TID id = UNDEFINED_EID) { + auto res = AsIs(utils::convert2dnnl(tensor), id); + + if (tensor.data != nullptr) { + // TODO(@apeskov): should avoid creation of temp dnnl::engine. + auto eng = dnnl::engine(dnnl::engine::kind::cpu, 0); + res.mem_ = dnnl::memory(res.t_desc_, eng, tensor.data); + } + + return res; + } + + static TensorRequisite AsIs(const dnnl::memory::desc& desc, TID id = UNDEFINED_EID) { + TensorRequisite res; + res.t_desc_ = desc; + res.orig_ = {}; + res.reinterpret_ = false; + res.mem_ = {}; + res.eid_ = id; + res.reverse_data_flow_ = false; + + return res; + } + + /** return logical shape of tensor */ + dnnl::memory::dims dims() const { return t_desc_.dims(); } + + /** return data type of tensor */ + dnnl::memory::data_type data_type() const { return t_desc_.data_type(); } + + /** return tensor desc */ + dnnl::memory::desc desc() const { return t_desc_; } + + /** Make TR with backward dataflow */ + TensorRequisite backward() const { + if (!defined()) return *this; + ICHECK(orig_ == nullptr); + return {t_desc_, orig_, reinterpret_, mem_, eid_, true}; + } + + /** Produce tensor with permuted axes */ + TensorRequisite permute(const std::vector& permutation) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.permute_axes(permutation); + return {desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Produce tensor with reinterpret data of original tr */ + TensorRequisite reshape(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(shape); + return {desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Produce tensor with broadcasted values */ + TensorRequisite broadcast(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + ICHECK(!reverse_data_flow_); + + auto orig = std::make_shared(*this); + + // numpy like broadcast + auto extended_dims = t_desc_.dims(); + auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(), 1); + extended_dims.insert(extended_dims.begin(), one_filled.begin(), one_filled.end()); + auto desc = t_desc_.reshape(extended_dims); + for (size_t i = 0; i < extended_dims.size(); i++) { + if (extended_dims[i] == shape[i]) continue; + ICHECK(extended_dims[i] == 1); + ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); + + desc.data.dims[i] = shape[i]; + desc.data.padded_dims[i] = shape[i]; + desc.data.format_desc.blocking.strides[i] = 0; + } + + // reinterpret memory buffer with new strides + return {desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Produce tensor with sub memory view (ROI) */ + TensorRequisite crop(const dnnl::memory::dims& shape, const dnnl::memory::dims& offset) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.submemory_desc(shape, offset); + return {desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Produce tensor with squeeze shape */ + TensorRequisite squeeze(const dnnl::memory::dims& dims_to_squeeze = {}) const { + if (!defined()) return *this; // nothing for empty TR + + dnnl::memory::dims squeezed_dims; + if (dims_to_squeeze.empty()) { + for (auto d : t_desc_.dims()) + if (d != 1) squeezed_dims.push_back(d); + } else { + for (size_t i = 0; i < t_desc_.dims().size(); i++) + if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) == dims_to_squeeze.end()) + squeezed_dims.push_back(t_desc_.dims()[i]); + } + + if (squeezed_dims.empty()) squeezed_dims = {1}; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(squeezed_dims); + return {desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Produce tensor with specified layout descriptor */ + TensorRequisite requestLayout(dnnl::memory::desc desc) const { + if (!defined()) return *this; // nothing for empty TR + + // If it's the same desc just return self + if (desc == t_desc_) return *this; + + ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " + "presented shape"; + + auto orig = std::make_shared(*this); + return {desc, orig, false, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** + * Treat tensor as described in layout string + * + * Limitation: blocking dims are always dense + * blocking dims are innermost + * blocking dims are in natural order + * + * NC8cHW4h4cD is not valid tensor in terms of DNNL. + * proper description is only like this NCHWD_8c4h4c. First part is outer dims, second part is + * innermost with digits. + */ + TensorRequisite treatAs(const std::string& layout, std::string layout_logic = "") const { + if (layout_logic.empty()) layout_logic = utils::DefaultLogicLayoutFor(layout); + + const auto origin_dims = dims(); + + // split layout string to tokens {size, tag} like {16, 'C'}, {4, 'O'} + std::vector> layout_tokens; + for (auto it = layout.begin(); it != layout.end();) { + auto start = it; + while (std::isdigit(*it)) it++; + int blk_size = start == it ? -1 : std::stoi(std::string{start, it}); + layout_tokens.push_back({blk_size, std::toupper(*it)}); + it++; + } + + // check applicability of layout + auto it = layout_tokens.begin(); + while (it != layout_tokens.end() && it->first == -1) it++; + int rank = std::distance(layout_tokens.begin(), it); + while (it != layout_tokens.end()) { + ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be innermost. " + << "But received layout is " << layout; + it++; + } + + ICHECK_EQ(layout_tokens.size(), origin_dims.size()); + ICHECK_EQ(rank, layout_logic.size()) << layout; + + std::vector> outermost_tokens(layout_tokens.begin(), + layout_tokens.begin() + rank); + std::vector> innermost_tokens(layout_tokens.begin() + rank, + layout_tokens.end()); + // define dim resulting dim positions + std::map dim_position_by_tag; + for (size_t i = 0; i < layout_logic.size(); i++) + dim_position_by_tag[std::toupper(layout_logic[i])] = i; + + // Construct resulting desc by modifying original one + dnnl::memory::desc res_desc = t_desc_; + + memset(&res_desc.data.format_desc.blocking, 0, sizeof(res_desc.data.format_desc.blocking)); + std::fill(res_desc.data.dims, res_desc.data.dims + DNNL_MAX_NDIMS, 0); + std::fill(res_desc.data.padded_dims, res_desc.data.padded_dims + DNNL_MAX_NDIMS, 0); + + res_desc.data.ndims = rank; + res_desc.data.format_desc.blocking.inner_nblks = innermost_tokens.size(); + + auto res_dims = res_desc.data.dims; + auto res_strides = res_desc.data.format_desc.blocking.strides; + auto res_inner_blks = res_desc.data.format_desc.blocking.inner_blks; + auto res_inner_idxs = res_desc.data.format_desc.blocking.inner_idxs; + + std::fill(res_dims, res_dims + rank, 1); + + int orig_dim_idx = 0; + for (const auto& p : outermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + + auto result_dim_position = dim_position_by_tag[tag]; + res_dims[result_dim_position] *= dim_size; + res_strides[result_dim_position] = t_desc_.data.format_desc.blocking.strides[orig_dim_idx]; + orig_dim_idx++; + } + for (const auto& p : innermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + auto result_dim_position = dim_position_by_tag[tag]; + ICHECK_EQ(p.first, dim_size) + << "Blocking layout is not applicable to tensor with shape: " << origin_dims + << ". Requested layout is " << layout; + + res_dims[result_dim_position] *= dim_size; + *res_inner_blks++ = dim_size; + *res_inner_idxs++ = result_dim_position; + orig_dim_idx++; + } + + // Assume tensor is dense. There is no additional padding. + std::copy(res_desc.data.dims, res_desc.data.dims + rank, res_desc.data.padded_dims); + + auto orig = std::make_shared(*this); + return {res_desc, orig, true, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** + * Produce tensor with unspecified layout + * Cannot be registered in TensorRegistry. Only for querying DNNL for preferred layouts. + */ + TensorRequisite layoutAny() const { + auto orig = std::make_shared(*this); + // Recreate tensor desc with layout 'any' + dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(), dnnl::memory::format_tag::any}; + return {any_desc, orig, false, {}, UNDEFINED_EID, reverse_data_flow_}; + } + + /** Check is tensor is constant */ + bool isConstant() const { + if (orig_) return orig_->isConstant(); + return mem_.operator bool(); + } + + /** Check is tensor is scalar */ + bool isScalar() const { return t_desc_.dims().size() == 1 && t_desc_.dims()[0] == 1; } + + /** Produce const data memory object with proper content */ + dnnl::memory getConstData() const { + if (mem_) return mem_; + if (!orig_) return {}; + + if (auto orig_const_data = orig_->getConstData()) { + if (reinterpret_) { + return {t_desc_, orig_const_data.get_engine(), orig_const_data.get_data_handle()}; + } else { + auto eng = orig_const_data.get_engine(); + auto res = dnnl::memory{t_desc_, eng}; + dnnl::reorder(orig_const_data, res).execute(dnnl::stream(eng), orig_const_data, res); + return res; + } + } + return {}; + } + + /** + * Same as getConstData but in form of std::vector in case of 1D tensor + * Useful for 1D specifying DNNL attributes like zero_point or per_channel_scale + */ + template + std::vector getConstDataLikeVec() const { + auto const_data = getConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::dnnlDType()); + ICHECK(desc.dims().size() == 1); + + auto size = desc.get_size() / sizeof(T); + auto ptr = static_cast(const_data.get_data_handle()); + + return std::vector(ptr, ptr + size); + } + + /** Get value of constant scalar tensor if possible */ + template + T getConstScalarData() const { + ICHECK(isConstant()); + ICHECK(isScalar()); + auto const_data = getConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::dnnlDType()); + + auto ptr = static_cast(const_data.get_data_handle()); + return *ptr; + } + + /** Check if tensor is not empty */ + bool defined() const { return !t_desc_.is_zero(); } + + /** Same as defined */ + operator bool() const { return defined(); } + + /** Check if tensor represent a reversed data flow. Useful for describing output processing */ + bool isReversed() const { return reverse_data_flow_; } + + private: + TensorRequisite(const dnnl::memory::desc& t_desc, const std::shared_ptr& orig, + bool reinterpret, const dnnl::memory& const_mem, uint32_t eid, + bool reverse_data_flow) + : t_desc_(t_desc), + orig_(orig), + reinterpret_(reinterpret), + mem_(const_mem), + eid_(eid), + reverse_data_flow_(reverse_data_flow) { + if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == UNDEFINED_EID); + if (eid_ != UNDEFINED_EID) ICHECK(!orig_); + } + + /** Descriptor of particular tensor */ + dnnl::memory::desc t_desc_ = {}; + /** Parent TR object which is referred from this TR */ + std::shared_ptr orig_ = {}; + /** Flag to specify which action should be done with orig TR, reordering or reinterpretation */ + bool reinterpret_ = false; + /** Const memory object if available */ + dnnl::memory mem_ = {}; + /** Entry ID of tensor if it available */ + uint32_t eid_ = UNDEFINED_EID; + + /** + * Flag to describe reverse data flow case + * All operation on queue will be executed in reverse order. Actual for dst tensor description + */ + bool reverse_data_flow_ = false; + + friend class TensorRegistry; +}; + +/** + * The registry of real dnnl::memory object which respond submitted TR. + */ +class TensorRegistry { + private: + enum ArgReqFlag { + CONST, /// < Constant tensor. ExecutionCTX independent + TMP_STORAGE, /// < Intermediate tensors. Stored inside TensorRegistry. Inaccessible outside + EXT_EID, /// < External data. Input or Output. + }; + + public: + struct ArgReq { + TensorRegistry::ArgReqFlag flag_; + uint32_t idx_; + }; + + using Action = std::tuple>; + using ActionQue = std::vector; + using DLTensorProvider = std::function; + using MemSolver = std::function; + + TensorRegistry() = default; + TensorRegistry(const dnnl::engine& eng, const std::set& ext_io_eid) + : tmp_mem_collection_(1), ext_io_eid_(ext_io_eid), eng_(eng), stream_(eng) {} + + /** + * Register a TR + * + * As result corresponding ArgReq and related action which should be executed before + * (or after in case of reverse data flow) usage of this tensor. + * @param tr TensorRequisite to register + * @return Associated ArgReq + */ + ArgReq registerTR(const TensorRequisite& tr, ActionQue* action) { + // 1) Constant tensor. Direct reference + if (auto const_data = tr.getConstData()) { + auto idx = const_mem_collection_.size(); + const_mem_collection_.push_back(const_data); + return makeArgReq(ArgReqFlag::CONST, static_cast(idx)); + } + + // 2) EID mapped tensor. Direct reference + if (tr.eid_ != TensorRequisite::UNDEFINED_EID) { + if (ext_io_eid_.count(tr.eid_) == 0) { // Not IO tensor, means it's intermediate + if (eid2idx_tmp_.count(tr.eid_)) { + auto idx = eid2idx_tmp_.at(tr.eid_); + return makeArgReq(ArgReqFlag::TMP_STORAGE, idx); + } else { + // register himself + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + eid2idx_tmp_[tr.eid_] = idx; + return makeArgReq(ArgReqFlag::TMP_STORAGE, static_cast(idx)); + } + } else { + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({tr.eid_, tr.t_desc_}); + return makeArgReq(ArgReqFlag::EXT_EID, static_cast(idx)); + } + } + + // 3) Tensors with transform actions + if (tr.orig_) { + // recursive register of orig TR + auto orig_arg_req = registerTR(*tr.orig_, action); + if (tr.reinterpret_) { + return register_reinterp(orig_arg_req, tr.t_desc_); + } else { + return register_reorder(orig_arg_req, tr.t_desc_, tr.reverse_data_flow_, action); + } + } + + // 4) Scratchpad + ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::UNDEFINED_EID); + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + tmp_mem_mapping_[idx] = 0; // zero position tmp mem object is reserved for scratchpads + + auto scratchpad_size = tr.t_desc_.get_size(); + auto glob_scratchpad_size = tmp_mem_collection_[0].get_size(); + if (scratchpad_size > glob_scratchpad_size) { + tmp_mem_collection_[0] = + dnnl::memory::desc({static_cast(scratchpad_size)}, + dnnl::memory::data_type::u8, dnnl::memory::format_tag::a); + } + return makeArgReq(TMP_STORAGE, static_cast(idx)); + } + + /** + * Construct mem solver for all registered TRs. + * ext_provider is used to ask for external IO data buffers + */ + MemSolver makeSolver(const DLTensorProvider& ext_provider) const { + return MemSolverImpl(eng_, ext_provider, const_mem_collection_, ext_mem_collection_, + tmp_mem_collection_, tmp_mem_mapping_); + } + + private: + ArgReq register_reinterp(ArgReq src_ar, const dnnl::memory::desc& desc) { + switch (src_ar.flag_) { + case TMP_STORAGE: { + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + tmp_mem_mapping_[idx] = src_ar.idx_; + return makeArgReq(TMP_STORAGE, idx); + } + case EXT_EID: { + auto ext_req = ext_mem_collection_[src_ar.idx_]; + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({ext_req.first, desc}); + return makeArgReq(EXT_EID, idx); + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + ArgReq register_reorder(ArgReq src_ar, const dnnl::memory::desc& desc, bool reverse_data_flow, + ActionQue* action) { + ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID); + + auto src_desc = src_ar.flag_ == TMP_STORAGE ? tmp_mem_collection_[src_ar.idx_] + : ext_mem_collection_[src_ar.idx_].second; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + auto dst_ar = makeArgReq(TMP_STORAGE, idx); + + // reorder action submit + if (reverse_data_flow) { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_, src_desc); + action->insert(action->begin(), + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, dst_ar}, {DNNL_ARG_TO, src_ar}}}); + } else { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_, desc); + action->push_back( + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, dst_ar}}}); + } + return dst_ar; + } + + class MemSolverImpl { + public: + MemSolverImpl(const dnnl::engine& eng, const DLTensorProvider& ext_data_provider, + const std::vector& const_mems, + const std::vector>& ext_mems, + const std::vector& tmp_mem_descs, + const std::map& tmp_mem_mapping) + : eng_(eng), + ext_data_provider_(ext_data_provider), + const_mems_(const_mems), + ext_mems_(ext_mems) { + // Construct temp memory objects on the fly. While we have no scratchpads + // support on VM/GraphExecutor level. + tmp_mems_.resize(tmp_mem_descs.size()); + for (size_t i = 0; i < tmp_mem_descs.size(); i++) { + auto found = tmp_mem_mapping.find(i); + + if (found != tmp_mem_mapping.end()) { + auto reuse_hdl = tmp_mems_[found->second].get_data_handle(); + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_, reuse_hdl); + } else { + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_); + } + } + } + + /** Find memory object associated with provided ArgReq */ + dnnl::memory operator()(const ArgReq& ar) const { + switch (ar.flag_) { + case CONST: + return const_mems_.at(ar.idx_); + case TMP_STORAGE: + return tmp_mems_.at(ar.idx_); + case EXT_EID: { + auto eid_and_desc = ext_mems_.at(ar.idx_); + auto eid = eid_and_desc.first; + auto desc = eid_and_desc.second; + + auto ext_dl_tensor = ext_data_provider_(eid); + ICHECK(ext_dl_tensor->data); + return dnnl::memory{desc, eng_, ext_dl_tensor->data}; + } + } + return {}; + } + + private: + const dnnl::engine& eng_; + const DLTensorProvider& ext_data_provider_; + const std::vector& const_mems_; + const std::vector>& ext_mems_; + std::vector tmp_mems_; + }; + + ArgReq makeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; } + + /** Collection of const memory objects. */ + std::vector const_mem_collection_; + + /** + * Collection of intermediate memory descriptors. + * Zero position is reserved for scratchpads. + */ + std::vector tmp_mem_collection_; + + /** Mapping of some temp buffer on previously registered. */ + std::map tmp_mem_mapping_; + + /** Collection of external_intermediate memory objects. + * first - eid of external buffer to ask + * second - t_desc describes how to treat external buffer */ + std::vector> ext_mem_collection_; + + /** Map of eid to index of temp buffer in tmp_mem_collection_ */ + std::unordered_map eid2idx_tmp_; + + /** List of external eid */ + std::set ext_io_eid_; + + /** Engine of all tensors existing in this registry */ + dnnl::engine eng_; + + /** Execution stream use to reorder const data */ + dnnl::stream stream_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h new file mode 100644 index 000000000000..09d5d3281e79 --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_utils.h @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file + * \brief Some DNNL specific utility functions + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "tvm/runtime/logging.h" + +namespace tvm { +namespace runtime { +namespace contrib { +namespace utils { + +/** Printer util */ +std::ostream& operator<<(std::ostream& o, const dnnl::memory::dims& dims) { + o << "["; + auto d = dims.begin(); + if (d != dims.end()) o << *d++; + while (d != dims.end()) o << "," << *d++; + o << "]"; + return o; +} + +/** + * Extract rank from layout string descriptor (num of capital letters). + * Example: NCWHT4n8c -> rank:5 + */ +inline int GetLayoutRank(const std::string& layout) { + int rank = 0; + for (auto it = layout.begin(); it != layout.end();) { + auto start = it; + while (std::isdigit(*it)) it++; + if (start == it) rank++; // no digits only letter + it++; + } + return rank; +} + +/** Define which logical dims ordering is default for particular layout string */ +inline std::string DefaultLogicLayoutFor(const std::string& layout) { + int rank = GetLayoutRank(layout); + static const std::vector sparse_dims = {"W", "HW", "DHW"}; + if (layout.find("N") != std::string::npos) return "NC" + sparse_dims[rank - 3]; + if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank - 4]; + if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank - 3]; + + LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme to handle it"; + return {}; +} + +/** Generator of dnnl format_tag for plain version of tensor */ +inline static dnnl::memory::format_tag plainLayout(uint32_t rank) { + switch (rank) { + case 0: + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return dnnl::memory::format_tag::abc; + case 4: + return dnnl::memory::format_tag::abcd; + case 5: + return dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + case 7: + return dnnl::memory::format_tag::abcdefg; + default: + LOG(FATAL) << "Unsupported data tensor rank: " << rank; + break; + } + return dnnl::memory::format_tag::undef; +} + +/** Converter helper for data type objects */ +inline static dnnl::memory::data_type convert2dnnl(DLDataType dtype) { + if (dtype.code == DLDataTypeCode::kDLInt) { + if (dtype.bits == 8) return dnnl::memory::data_type::s8; + if (dtype.bits == 32) return dnnl::memory::data_type::s32; + } else if (dtype.code == DLDataTypeCode::kDLUInt) { + if (dtype.bits == 8) return dnnl::memory::data_type::u8; + } else if (dtype.code == DLDataTypeCode::kDLFloat) { + if (dtype.bits == 16) return dnnl::memory::data_type::f16; + if (dtype.bits == 32) return dnnl::memory::data_type::f32; + } else if (dtype.code == DLDataTypeCode::kDLBfloat) { + if (dtype.bits == 16) return dnnl::memory::data_type::bf16; + } + LOG(FATAL) << "Data type is not supported"; + return dnnl::memory::data_type::undef; +} + +/** Converter helper for shape objects */ +inline static dnnl::memory::dims convert2dnnl(std::vector shape) { + if (shape.empty()) return {1}; // DNNL scalar representation + return shape; +} + +/** Builder of dnnl memory on top of provided DLTensor */ +dnnl::memory::desc convert2dnnl(const DLTensor& dl_tensor) { + // Assumption! TVM uses only plain tensors (no strides) + ICHECK(dl_tensor.strides == nullptr); + + std::vector dl_dims(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); + dnnl::memory::desc desc{convert2dnnl(dl_dims), convert2dnnl(dl_tensor.dtype), + plainLayout(dl_dims.size())}; + + desc.data.offset0 = dl_tensor.byte_offset; + return desc; +} + +/** Converter template arg to runtime object */ +template +dnnl::memory::data_type dnnlDType(); + +template <> +dnnl::memory::data_type dnnlDType() { + return dnnl::memory::data_type::s32; +} + +template <> +dnnl::memory::data_type dnnlDType() { + return dnnl::memory::data_type::f32; +} + +/** Attribute extractor helper */ +template ::value, int> = 0> +T convert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stol(val[0]); +} + +template ::value, int> = 0> +T convert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stof(val[0]); +} + +template ::value, int> = 0> +T convert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return val[0]; +} + +template >::value, int> = 0> +T convert(std::vector val) { + T res; + for (const auto& el : val) res.push_back(convert({el})); + return res; +} + +template +const T getAttr(const json::JSONGraphNode& node, std::string name, + std::vector def = {}) { + auto attr = node.HasAttr(name) ? node.GetAttr>(name) : def; + return convert(attr); +} + +} // namespace utils +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_