Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve stack operator performance by oneDNN #20621

Merged
merged 4 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,6 @@
'_npi_not_equal',
'_npi_dstack',
'_npi_hstack',
'_npi_stack',
'_npi_tensordot',
'_npi_tensordot_int_axes',
'_npi_vstack',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3191,7 +3191,6 @@ def convert_embedding(node, **kwargs):


@mx_op.register("stack")
@mx_op.register("_npi_stack")
def convert_stack(node, **kwargs):
"""Map MXNet's stack operator to onnx operators.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ def convert_expand_dims(node, **kwargs):


@mx_op.register("stack", OPSET_VERSION)
@mx_op.register("_npi_stack", OPSET_VERSION)
def convert_stack(node, **kwargs):
"""Map MXNet's stack operator to onnx operators.
"""
Expand Down
7 changes: 4 additions & 3 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ bool SupportDNNLTranspose(const NDArray& data);
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down Expand Up @@ -607,9 +608,9 @@ class DNNLMemory {
dnnl::memory::data_type data_type = dnnl::memory::data_type::undef) const {
dnnl::memory::dims dims(desc.data.dims, desc.data.dims + desc.data.ndims);
dnnl::memory::data_type cpp_type =
(data_type == dnnl::memory::data_type::undef)
? static_cast<dnnl::memory::data_type>(desc.data.data_type)
: data_type;
(data_type == dnnl::memory::data_type::undef) ?
static_cast<dnnl::memory::data_type>(desc.data.data_type) :
data_type;
dnnl::memory::desc data_md(dims, cpp_type, static_cast<dnnl::memory::format_tag>(format));
return data_md;
}
Expand Down
5 changes: 4 additions & 1 deletion src/operator/nn/dnnl/dnnl_concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ class DNNLConcatFwd {

static DNNLConcatFwd& GetConcatForward(int concat_dim,
const std::vector<NDArray>& in_data,
const std::vector<dnnl::memory::desc>& data_md) {
const std::vector<dnnl::memory::desc>& data_md,
int stack_axis = -1 /*used only by stack op*/) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLConcatFwd, OpHash> fwds;
#endif

OpSignature key;
key.AddSign(concat_dim);
key.AddSign(stack_axis);
key.AddSign(in_data);

auto it = fwds.find(key);
Expand Down
6 changes: 6 additions & 0 deletions src/operator/nn/dnnl/dnnl_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ void DNNLLayerNormBackward(const nnvm::NodeAttrs& attrs,

void DNNLSum(const dnnl::memory& arr1, const dnnl::memory& arr2, const dnnl::memory& out);

void DNNLStackForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data);

template <class ParamType>
void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
123 changes: 123 additions & 0 deletions src/operator/nn/dnnl/dnnl_stack.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file dnnl_stack.cc
*/

#include "./dnnl_base-inl.h"
#include "./dnnl_concat-inl.h"
#include "./dnnl_ops-inl.h"

#include "../../tensor/matrix_op-inl.h"

#if MXNET_USE_ONEDNN == 1
namespace mxnet {
namespace op {

bool SupportDNNLStack(const std::vector<NDArray>& inputs) {
if (inputs[0].dtype() != mshadow::kFloat32 && inputs[0].dtype() != mshadow::kBfloat16) {
return false;
}

int src_dtype = inputs[0].dtype();
for (const auto& arr : inputs) {
if (arr.dtype() != src_dtype) {
return false;
}
// DO not support zero-size tensors.
if (arr.shape().Size() == 0) {
return false;
}

int ndim = arr.shape().ndim();
if (ndim <= 0) {
return false;
}
}
return true;
}

void DNNLStackForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);

// const value of artificial new dimension to
// stack tensors on using oneDNN concat primitive
constexpr int stacking_dim = 1;

const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
const int axis = CheckAxis(param.axis, out_data[0].shape().ndim());
const TShape oshape = out_data[0].shape();
const int src_dtype = in_data[0].dtype();
const int dst_dtype = out_data[0].dtype();
const int mid_dim = oshape[axis];
int leading_dim = 1;
int trailing_dim = 1;

for (int i = 0; i < axis; ++i) {
leading_dim *= oshape[i];
}
for (int i = axis + 1; i < oshape.ndim(); ++i) {
trailing_dim *= oshape[i];
}

std::vector<dnnl::memory::desc> data_md;
std::vector<dnnl::memory> data_mem;
dnnl::memory::desc in_md({leading_dim, stacking_dim, trailing_dim},
get_dnnl_type(src_dtype),
dnnl::memory::format_tag::abc);
dnnl::memory::desc out_md({leading_dim, mid_dim, trailing_dim},
get_dnnl_type(dst_dtype),
dnnl::memory::format_tag::any);

const int num_in_data = in_data.size();
data_md.reserve(num_in_data);
data_mem.reserve(num_in_data);

MSHADOW_TYPE_SWITCH(src_dtype, DType, {
for (int i = 0; i < num_in_data; i++) {
NDArray tmp = in_data[i].Reorder2Default();
dnnl::memory tmp_mem(in_md, CpuEngine::Get()->get_engine(), tmp.data().dptr<DType>());
data_mem.emplace_back(tmp_mem);
data_md.emplace_back(in_md);
}
});

auto& fwd = GetConcatForward(stacking_dim, in_data, data_md, axis);
mxnet::dnnl_output_t out_mem =
CreateDNNLMem(out_data[concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);

std::unordered_map<int, dnnl::memory> net_args;
net_args.insert({DNNL_ARG_DST, *out_mem.second});
for (int i = 0; i < num_in_data; i++) {
net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, data_mem[i]});
}

DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data[concat_enum::kOut], out_mem);
DNNLStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif
40 changes: 0 additions & 40 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,46 +638,6 @@ struct NumpyConcatGrad {
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
NNVM_REGISTER_OP(_npi_stack)
.describe(R"code(Join a sequence of arrays along a new axis.

The axis parameter specifies the index of the new axis in the dimensions of the
result. For example, if axis=0 it will be the first dimension and if axis=-1 it
will be the last dimension.

Examples::

x = [1, 2]
y = [3, 4]

stack(x, y) = [[1, 2],
[3, 4]]
stack(x, y, axis=1) = [[1, 3],
[2, 4]]
)code")
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
const StackParam& param = dmlc::get<StackParam>(attrs.parsed);
return static_cast<uint32_t>(param.num_args);
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<StackParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args =
dmlc::get<StackParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<mxnet::FInferShape>("FInferShape", StackOpShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FCompute>("FCompute<cpu>", StackOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_stack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());

bool NumpyColumnStackType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
Expand Down
2 changes: 0 additions & 2 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ NNVM_REGISTER_OP(_np_reshape).set_attr<FCompute>("FCompute<gpu>", UnaryOp::Ident

NNVM_REGISTER_OP(_npi_squeeze).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

NNVM_REGISTER_OP(_npi_stack).set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);

NNVM_REGISTER_OP(_npi_vstack).set_attr<FCompute>("FCompute<gpu>", NumpyVstackForward<gpu>);

NNVM_REGISTER_OP(_backward_np_vstack).set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);
Expand Down
44 changes: 43 additions & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
return DNNLStorageType(
attrs, dev_mask, /*support_dnnl*/ true, dispatch_mode, in_attrs, out_attrs);
}
#endif

Expand Down Expand Up @@ -930,7 +931,39 @@ NNVM_REGISTER_OP(_backward_reverse)
})
.set_attr<FCompute>("FCompute<cpu>", ReverseOpForward<cpu>);

#if MXNET_USE_ONEDNN == 1
static void StackForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(!inputs.empty());
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
if (req[0] == kNullOp) {
return;
}

if (SupportDNNLStack(inputs)) {
DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs);
DNNLRun(DNNLStackForward, attrs, op_ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(StackOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
} else {
FallBackCompute(StackOpForward<cpu>, attrs, op_ctx, inputs, req, outputs);
}
}

inline static bool StackInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
}
#endif // MXNET_USE_ONEDNN == 1

NNVM_REGISTER_OP(stack)
.add_alias("_npi_stack")
.describe(R"code(Join a sequence of arrays along a new axis.
The axis parameter specifies the index of the new axis in the dimensions of the
result. For example, if axis=0 it will be the first dimension and if axis=-1 it
Expand Down Expand Up @@ -965,6 +998,15 @@ Examples::
.set_attr<mxnet::FInferShape>("FInferShape", StackOpShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FCompute>("FCompute<cpu>", StackOpForward<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", StackForwardEx)
.set_attr<bool>("TIsDNNL", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FInferStorageType>("FInferStorageType", StackInferStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_stack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());
Expand Down