Skip to content

Commit

Permalink
Merge pull request #9081 from kbinias/kbinias/mkldnn-activations
Browse files Browse the repository at this point in the history
MKLDNN Relu Tanh Sqrt Abs activations added
  • Loading branch information
luotao1 authored Mar 26, 2018
2 parents 8ccc61f + 6461e80 commit cb3bbbd
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 10 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ function(op_library TARGET)

# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
endif()

# pybind USE_OP
Expand Down
193 changes: 193 additions & 0 deletions paddle/fluid/operators/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "mkldnn.hpp"
#include "mkldnn_activation_op.h"
#include "paddle/fluid/operators/activation_op.h"

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;

namespace {
template <typename T, typename ExecContext>
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");

auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();

// get buffers
const auto *src = ctx.template Input<Tensor>("X");
const auto *src_data = src->template data<T>();

auto *dst = ctx.template Output<Tensor>("Out");
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());

// get memory dim
PADDLE_ENFORCE(src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW");
std::vector<int> src_tz = framework::vectorize2int(src->dims());

// create memory description
// TODO(kbinias-intel): support more formats
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);

// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);

auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);

// save prim desc into global device context to be referred in backward path
const std::string key = ctx.op().Output("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);

auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);

// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}

template <typename T, typename ExecContext>
void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();

// get buffers
const auto *x = ctx.template Input<Tensor>("X");
const auto *src = x->template data<T>();

auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>();

auto *dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());

// get memory dim
std::vector<int> src_tz = framework::vectorize2int(x->dims());

// create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);

// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
auto diff_dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);

auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);

// retrieve eltwise primitive desc from device context
const std::string key = ctx.op().Input("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd);
PADDLE_ENFORCE(forward_pd != nullptr,
"Fail to find eltwise_pd in device context");
auto *p_forward_pd =
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());

auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *p_forward_pd);

auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
diff_dst_memory, diff_src_memory);

// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // anonymous namespace

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_forward<T>(ctx, algorithm);
}
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};

template <typename T>
using ReluMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
using TanhMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
using SqrtMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
using AbsMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
using ReluMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
using TanhMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
using SqrtMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
using AbsMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
31 changes: 22 additions & 9 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -87,6 +88,9 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Out", "Output of Relu operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Relu Activation Operator.
Expand Down Expand Up @@ -140,6 +144,9 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Out", "Output of Tanh operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Tanh Activation Operator.
Expand Down Expand Up @@ -193,6 +200,9 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Out", "Output of Sqrt operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Sqrt Activation Operator.
Expand All @@ -208,6 +218,9 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Out", "Output of Abs operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Abs Activation Operator.
Expand Down Expand Up @@ -524,23 +537,23 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);

REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad,
ops::ActivationWithMKLDNNOpGrad);

REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationWithMKLDNNOpGrad);

REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);

REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad);

REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationWithMKLDNNOpGrad);

REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad,
ops::ActivationWithMKLDNNOpGrad);

REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
ops::ActivationOpGrad);
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,10 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

namespace paddle {
namespace operators {

Expand Down
Loading

0 comments on commit cb3bbbd

Please sign in to comment.