diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc index 73db44805c8e..16324b51c322 100644 --- a/src/operator/nn/log_softmax.cc +++ b/src/operator/nn/log_softmax.cc @@ -26,10 +26,79 @@ #include "../tensor/elemwise_unary_op.h" #include "../tensor/elemwise_binary_op.h" #include "../operator_common.h" +#if MXNET_USE_MKLDNN == 1 +#include "mkldnn/mkldnn_base-inl.h" +#include "mkldnn/mkldnn_ops-inl.h" +#endif namespace mxnet { namespace op { +#if MXNET_USE_MKLDNN == 1 +static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNLogSoftmaxForward, attrs, ctx, inputs[0], req[0], outputs[0]); + auto fn = SoftmaxCompute; + MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(SoftmaxCompute, attrs, ctx, + inputs, req, outputs); +} + +static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNLogSoftmaxBackward, attrs, ctx, inputs, req, outputs); + auto fn = SoftmaxGradCompute; + MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(SoftmaxGradCompute, + attrs, ctx, inputs, req, outputs); +} + +inline static bool LogSoftmaxStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} + +inline static bool LogSoftmaxGradStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + bool support = true; + int num_inputs = 2U; + if (softmax_has_dtype_override(attrs)) { + support = false; + num_inputs = 3U; + } + + CHECK_EQ(in_attrs->size(), num_inputs); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, support, dispatch_mode, in_attrs, out_attrs); +} +#endif + NNVM_REGISTER_OP(log_softmax) .add_alias("_npx_log_softmax") .describe(R"code(Computes the log softmax of the input. @@ -49,7 +118,16 @@ Examples:: )code") .set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs){ + return std::vector{"data"}; +}) .set_attr("FCompute", SoftmaxCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", LogSoftmaxComputeExCPU) +.set_attr("FInferStorageType", LogSoftmaxStorageType) +#endif .set_attr("FGradient", SoftmaxFGradient{"_backward_log_softmax"}) .set_attr("FInferType", SoftmaxOpType) .set_num_inputs(1) @@ -71,6 +149,11 @@ NNVM_REGISTER_OP(_backward_log_softmax) .set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) .add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") .set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", LogSoftmaxGradComputeExCPU) +.set_attr("FInferStorageType", LogSoftmaxGradStorageType) +#endif .set_attr("FCompute", SoftmaxGradCompute); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 65a0a6918558..3e73103b2f14 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -211,6 +211,8 @@ bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m); bool SupportMKLDNNConv(const ConvolutionParam ¶ms, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); +bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input, + const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_log_softmax.cc b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc new file mode 100644 index 000000000000..0d992b252fa8 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_log_softmax.cc @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_log_softmax.cc + * \brief Implementation of log_softmax function with MKLDNN support +*/ + +#include "../softmax-inl.h" +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +static mkldnn::logsoftmax_forward::primitive_desc GetLogSoftmaxFwdPd( + bool is_train, + const int axis, + const mkldnn::memory &input_mem) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto prop = is_train ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::logsoftmax_forward::desc(prop, data_md, axis); + return mkldnn::logsoftmax_forward::primitive_desc(desc, cpu_engine); +} + +static mkldnn::logsoftmax_backward::primitive_desc GetLogSoftmaxBwdPd( + const mkldnn::memory &diff_mem, + const mkldnn::memory &data_mem, + const int axis, + const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) { + mkldnn::memory::desc diff_md = diff_mem.get_desc(); + mkldnn::memory::desc data_md = data_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto desc = mkldnn::logsoftmax_backward::desc(diff_md, data_md, axis); + return mkldnn::logsoftmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd); +} + + +bool SupportMKLDNNLogSoftmax(const SoftmaxParam ¶m, + const NDArray &data, + const NDArray &output) { + const int ndim = data.shape().ndim(); + const int in_dtype = data.dtype(); + const int out_dtype = output.dtype(); + const int axis = CheckAxis(param.axis, ndim); + // MKLDNN does not support temperature argument in their log_softmax function + // now. Need update this once they start to support it. + // Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension + if (param.temperature.has_value() || + in_dtype != mshadow::kFloat32 || + in_dtype != out_dtype || + axis != (ndim - 1)) { + return false; + } + + // only supports ndim = 1, 2, 3, 4 for now + return (ndim >= 1 && ndim <= 4); +} + +class MKLDNNLogSoftmaxFwd { + public: + mkldnn::logsoftmax_forward::primitive_desc pd; + + MKLDNNLogSoftmaxFwd(const bool is_train, + const int axis, + const mkldnn::memory &input) : pd(GetLogSoftmaxFwdPd(is_train, axis, input)) { + fwd_ = std::make_shared(pd); + } + + const mkldnn::logsoftmax_forward &GetFwd() const { + return *fwd_; + } + + private: + std::shared_ptr fwd_; +}; + +typedef ParamOpSign MKLDNNSoftmaxSignature; + +static MKLDNNLogSoftmaxFwd &GetLogSoftmaxFwd(const SoftmaxParam ¶m, + const int real_axis, + const bool is_train, + const NDArray &data, + const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + + MKLDNNSoftmaxSignature key(param); + key.AddSign(real_axis); + key.AddSign(is_train); + key.AddSign(data); + key.AddSign(output); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNLogSoftmaxFwd fwd(is_train, real_axis, *(data.GetMKLDNNData())); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &in_data, + const OpReqType &req, + const NDArray &out_data) { + if (req == kNullOp) return; + // same as the FCompute path, log_softmax only supports kWriteTo and kWriteInplace for now. + CHECK_NE(req, kAddTo); + + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, in_data.shape().ndim()); + auto fwd = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data); + + auto in_mem = in_data.GetMKLDNNData(); + auto out_mem = out_data.GetMKLDNNData(fwd.pd.dst_desc()); + MKLDNNStream *stream = MKLDNNStream::Get(); + stream->RegisterPrimArgs(fwd.GetFwd(), {{MKLDNN_ARG_SRC, *in_mem}, {MKLDNN_ARG_DST, *out_mem}}); + stream->Submit(); +} + +class MKLDNNLogSoftmaxBwd { + public: + mkldnn::logsoftmax_backward::primitive_desc pd; + + MKLDNNLogSoftmaxBwd(const mkldnn::memory &diff_mem, + const mkldnn::memory &data_mem, + const int axis, + const mkldnn::logsoftmax_forward::primitive_desc &hint_fwd_pd) : + pd(GetLogSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) { + bwd_ = std::make_shared(pd); + } + + const mkldnn::logsoftmax_backward &GetBwd() const { + return *bwd_; + } + + private: + std::shared_ptr bwd_; +}; + +static MKLDNNLogSoftmaxBwd &GetLogSoftmaxBwd(const SoftmaxParam ¶m, + const int real_axis, + const std::vector &data, + const std::vector &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map bwds; +#else + static MX_THREAD_LOCAL std::unordered_map bwds; +#endif + + MKLDNNSoftmaxSignature key(param); + key.AddSign(real_axis); + key.AddSign(data); + key.AddSign(output); + + auto it = bwds.find(key); + if (it == bwds.end()) { + auto diff_mem = data[0].GetMKLDNNData(); + auto data_mem = data[1].GetMKLDNNData(); + auto fwd_pd = GetLogSoftmaxFwdPd(true, real_axis, *data_mem); + MKLDNNLogSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd); + it = AddToCache(&bwds, key, bwd); + } + return it->second; +} + +void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + if (req[0] == kNullOp) return; + CHECK_EQ(in_data.size(), 2U); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, in_data[1].shape().ndim()); + auto diff_mem = in_data[0].GetMKLDNNData(); + auto data_mem = in_data[1].GetMKLDNNData(); + auto bwd = GetLogSoftmaxBwd(param, axis, in_data, out_data); + + auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]); + MKLDNNStream *stream = MKLDNNStream::Get(); + mkldnn_args_map_t args = { + { MKLDNN_ARG_DST, *data_mem }, + { MKLDNN_ARG_DIFF_DST, *diff_mem }, + { MKLDNN_ARG_DIFF_SRC, *out_mem.second } + }; + + stream->RegisterPrimArgs(bwd.GetBwd(), args); + CommitOutput(out_data[0], out_mem); + stream->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index c862607372a9..32f2e9f74130 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -98,6 +98,15 @@ void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &req, const std::vector &out_data); +/* For log_softmax */ +void MKLDNNLogSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + /* For softmax_output */ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index a1a49c97d7f4..0b8374b9b90a 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -49,10 +49,10 @@ def test_loss_ndarray(): loss = gluon.loss.SoftmaxCrossEntropyLoss() L = loss(output, label).asnumpy() - assert_almost_equal(L, np.array([ 2.12692809, 0.04858733])) + assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4) L = loss(output, label, weighting).asnumpy() - assert_almost_equal(L, np.array([ 1.06346405, 0.04858733])) + assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4) def get_net(num_hidden, flatten=True): diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 0d1e5fed59b3..204fde6bd2bd 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -153,10 +153,10 @@ def test_np_loss_ndarray(): loss = gluon.loss.SoftmaxCrossEntropyLoss() L = loss(output, label).asnumpy() - assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False) + assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False, rtol=1e-3) L = loss(output, label, weighting).asnumpy() - assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False) + assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False, rtol=1e-3) @with_seed() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c73b8456240b..6b1782a2c48c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5338,8 +5338,8 @@ def test_log_softmax(): axis = np.random.randint(0, ndim) data = np.random.uniform(-2, 2, size=shape) sym = mx.sym.log_softmax(axis=axis-ndim) - check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)]) - check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) + check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)], rtol=1e-3, atol=1e-4) + check_numeric_gradient(sym, [data], rtol=1e-1, atol=1e-2) def test_softmax_with_large_inputs(): def softmax_forward(input_data, true_output):