Skip to content

Commit

Permalink
Improve log_softmax op performance by using DNNL support (apache#18320)
Browse files Browse the repository at this point in the history
* Improve log_softmax performance by OneDNN library

* Adapt tests for MKLDNN log_softmax

* Fix lint errors

* Fix indent and comments
  • Loading branch information
bgawrych authored and ys2843 committed Jun 2, 2020
1 parent 5d40e16 commit e5d4c18
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 6 deletions.
83 changes: 83 additions & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,79 @@
#include "../tensor/elemwise_unary_op.h"
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#include "mkldnn/mkldnn_base-inl.h"
#include "mkldnn/mkldnn_ops-inl.h"
#endif

namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLogSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>, attrs, ctx,
inputs, req, outputs);
}

static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNLogSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::left, mxnet_op::log_softmax_bwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxGradCompute<cpu, op::mshadow_op::left, mxnet_op::log_softmax_bwd>,
attrs, ctx, inputs, req, outputs);
}

inline static bool LogSoftmaxStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}

inline static bool LogSoftmaxGradStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
bool support = true;
int num_inputs = 2U;
if (softmax_has_dtype_override(attrs)) {
support = false;
num_inputs = 3U;
}

CHECK_EQ(in_attrs->size(), num_inputs);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, support, dispatch_mode, in_attrs, out_attrs);
}
#endif

NNVM_REGISTER_OP(log_softmax)
.add_alias("_npx_log_softmax")
.describe(R"code(Computes the log softmax of the input.
Expand All @@ -49,7 +118,16 @@ Examples::
)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs){
return std::vector<std::string>{"data"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LogSoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", LogSoftmaxStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
Expand All @@ -71,6 +149,11 @@ NNVM_REGISTER_OP(_backward_log_softmax)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LogSoftmaxGradComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", LogSoftmaxGradStorageType)
#endif
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);

Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
bool SupportMKLDNNConv(const ConvolutionParam &params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
} // namespace op
Expand Down
226 changes: 226 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_log_softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_log_softmax.cc
* \brief Implementation of log_softmax function with MKLDNN support
*/

#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

static mkldnn::logsoftmax_forward::primitive_desc GetLogSoftmaxFwdPd(
bool is_train,
const int axis,
const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::logsoftmax_forward::desc(prop, data_md, axis);
return mkldnn::logsoftmax_forward::primitive_desc(desc, cpu_engine);
}

static mkldnn::logsoftmax_backward::primitive_desc GetLogSoftmaxBwdPd(
const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) {
mkldnn::memory::desc diff_md = diff_mem.get_desc();
mkldnn::memory::desc data_md = data_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto desc = mkldnn::logsoftmax_backward::desc(diff_md, data_md, axis);
return mkldnn::logsoftmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd);
}


bool SupportMKLDNNLogSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();
const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their log_softmax function
// now. Need update this once they start to support it.
// Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
}

// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

class MKLDNNLogSoftmaxFwd {
public:
mkldnn::logsoftmax_forward::primitive_desc pd;

MKLDNNLogSoftmaxFwd(const bool is_train,
const int axis,
const mkldnn::memory &input) : pd(GetLogSoftmaxFwdPd(is_train, axis, input)) {
fwd_ = std::make_shared<mkldnn::logsoftmax_forward>(pd);
}

const mkldnn::logsoftmax_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::logsoftmax_forward> fwd_;
};

typedef ParamOpSign<SoftmaxParam> MKLDNNSoftmaxSignature;

static MKLDNNLogSoftmaxFwd &GetLogSoftmaxFwd(const SoftmaxParam &param,
const int real_axis,
const bool is_train,
const NDArray &data,
const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxFwd,
OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxFwd,
OpHash> fwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(is_train);
key.AddSign(data);
key.AddSign(output);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNLogSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData()));
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) return;
// same as the FCompute path, log_softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);

const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data.shape().ndim());
auto fwd = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);

auto in_mem = in_data.GetMKLDNNData();
auto out_mem = out_data.GetMKLDNNData(fwd.pd.dst_desc());
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}});
stream->Submit();
}

class MKLDNNLogSoftmaxBwd {
public:
mkldnn::logsoftmax_backward::primitive_desc pd;

MKLDNNLogSoftmaxBwd(const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) :
pd(GetLogSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) {
bwd_ = std::make_shared<mkldnn::logsoftmax_backward>(pd);
}

const mkldnn::logsoftmax_backward &GetBwd() const {
return *bwd_;
}

private:
std::shared_ptr<mkldnn::logsoftmax_backward> bwd_;
};

static MKLDNNLogSoftmaxBwd &GetLogSoftmaxBwd(const SoftmaxParam &param,
const int real_axis,
const std::vector<NDArray> &data,
const std::vector<NDArray> &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxBwd,
OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature,
MKLDNNLogSoftmaxBwd,
OpHash> bwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(data);
key.AddSign(output);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto diff_mem = data[0].GetMKLDNNData();
auto data_mem = data[1].GetMKLDNNData();
auto fwd_pd = GetLogSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNLogSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}

void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
if (req[0] == kNullOp) return;
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
auto diff_mem = in_data[0].GetMKLDNNData();
auto data_mem = in_data[1].GetMKLDNNData();
auto bwd = GetLogSoftmaxBwd(param, axis, in_data, out_data);

auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_args_map_t args = {
{ MKLDNN_ARG_DST, *data_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_mem },
{ MKLDNN_ARG_DIFF_SRC, *out_mem.second }
};

stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(out_data[0], out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
9 changes: 9 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For log_softmax */
void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_loss_ndarray():

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]))
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]))
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4)


def get_net(num_hidden, flatten=True):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def test_np_loss_ndarray():

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False)
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False, rtol=1e-3)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False)
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False, rtol=1e-3)


@with_seed()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5362,8 +5362,8 @@ def test_log_softmax():
axis = np.random.randint(0, ndim)
data = np.random.uniform(-2, 2, size=shape)
sym = mx.sym.log_softmax(axis=axis-ndim)
check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)])
check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3)
check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)], rtol=1e-3, atol=1e-4)
check_numeric_gradient(sym, [data], rtol=1e-1, atol=1e-2)

def test_softmax_with_large_inputs():
def softmax_forward(input_data, true_output):
Expand Down

0 comments on commit e5d4c18

Please sign in to comment.