From 8abbcc109e851a05f96329b73039a7a46c7951b2 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Fri, 19 Jan 2018 09:25:23 +0800 Subject: [PATCH] Fix coding style in MKLDNN Pooling (#22) --- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 358 +++----------------- src/operator/nn/mkldnn/mkldnn_pooling.cc | 321 ++++++++++++++++++ src/operator/nn/pooling.cc | 166 ++++----- 3 files changed, 446 insertions(+), 399 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_pooling.cc diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 6947f66ee424..61895b4d4423 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -38,18 +38,19 @@ class MKLDNNPoolingFwd { public: MKLDNNPoolingFwd(const mxnet::NDArray &input, const mxnet::NDArray &output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int padding_t, int padding_b, int padding_l, int padding_r, - mkldnn::algorithm alg_kind, - bool with_workspace, bool is_train) : - _is_train(is_train), - _with_workspace(with_workspace), - _alg_kind(alg_kind), - fwd(nullptr), data(nullptr), out(nullptr), workspace(nullptr) { - _Init(input, output, - kernel_h, kernel_w, stride_h, stride_w, - padding_t, padding_b, padding_l, padding_r); + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int padding_t, const int padding_b, + const int padding_l, const int padding_r, + const mkldnn::algorithm alg_kind, + const bool with_workspace, const bool is_train) : + is_train_(is_train), + with_workspace_(with_workspace), + alg_kind_(alg_kind), + fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) { + Init(input, output, + kernel_h, kernel_w, stride_h, stride_w, + padding_t, padding_b, padding_l, padding_r); } ~MKLDNNPoolingFwd() {} @@ -59,334 +60,59 @@ class MKLDNNPoolingFwd { void Execute(); private: - bool _is_train; - bool _with_workspace; - mkldnn::algorithm _alg_kind; - std::shared_ptr fwd_pd; - std::shared_ptr fwd; - std::shared_ptr data; - std::shared_ptr out; - std::shared_ptr workspace; + bool is_train_; + bool with_workspace_; + mkldnn::algorithm alg_kind_; + std::shared_ptr fwd_pd_; + std::shared_ptr fwd_; + std::shared_ptr data_; + std::shared_ptr out_; + std::shared_ptr workspace_; private: - void _Init(const mxnet::NDArray &input, - const mxnet::NDArray &output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int padding_t, int padding_b, int padding_l, int padding_r); + void Init(const mxnet::NDArray &input, + const mxnet::NDArray &output, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int padding_t, const int padding_b, + const int padding_l, const int padding_r); }; -void MKLDNNPoolingFwd::_Init(const mxnet::NDArray &input, const mxnet::NDArray &output, - int kernel_h, int kernel_w, int stride_h, int stride_w, - int padding_t, int padding_b, int padding_l, int padding_r) { - auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); - mkldnn::memory::dims dims = {src_md.data.dims[0], - src_md.data.dims[1], - static_cast(output.shape()[2]), - static_cast(output.shape()[3])}; - auto dst_md = mkldnn::memory::desc({dims}, - static_cast(src_md.data.data_type), - static_cast(src_md.data.format)); - auto engine = CpuEngine::Get()->get_engine(); - auto alg_kind = this->_alg_kind; - if (alg_kind != pooling_max && - alg_kind != pooling_avg && - alg_kind != pooling_avg_include_padding && - alg_kind != pooling_avg_exclude_padding) { - LOG(FATAL) << "MKLDNN Pooling: algorithm is not supported"; - } - - auto prop = mkldnn::prop_kind::forward_scoring; - if (this->_is_train && alg_kind != mkldnn::algorithm::pooling_avg) { - prop = mkldnn::prop_kind::forward_training; - } - - if (this->_is_train && prop == mkldnn::prop_kind::forward_scoring) { - LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring"; - } - - mkldnn::memory::dims strides = {stride_h, stride_w }; - mkldnn::memory::dims pad_l = {padding_t, padding_l }; - mkldnn::memory::dims pad_r = {padding_b, padding_r }; - mkldnn::memory::dims kernel = {kernel_h, kernel_w }; - - auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, - strides, kernel, pad_l, pad_r, - mkldnn::padding_kind::zero); - this->fwd_pd.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); - this->data.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); - this->out.reset(new mkldnn::memory(this->fwd_pd->dst_primitive_desc())); - if (this->_with_workspace) { - this->workspace.reset(new mkldnn::memory(this->fwd_pd->workspace_primitive_desc())); - this->fwd.reset(new mkldnn::pooling_forward(*(this->fwd_pd), - mkldnn::primitive::at(*(this->data)), - *(this->out), - *(this->workspace))); - } else { - this->fwd.reset(new mkldnn::pooling_forward(*(fwd_pd), - mkldnn::primitive::at(*(this->data)), - *(this->out))); - } - return; -} - -void MKLDNNPoolingFwd::SetDataHandle(const mxnet::NDArray &data, - const mxnet::NDArray &output, - const mxnet::NDArray *workspace) { - auto data_mem = data.GetMKLDNNData(); - auto out_mem = const_cast(output).CreateMKLDNNData( - this->fwd_pd->dst_primitive_desc()); - this->data->set_data_handle(data_mem->get_data_handle()); - this->out->set_data_handle(out_mem->get_data_handle()); - if (this->_with_workspace && workspace == nullptr) { - LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; - } - - if (this->_with_workspace) { - // auto ws_mem = const_cast(workspace)->CreateMKLDNNData( - // this->fwd_pd->workspace_primitive_desc()); - auto ws_mem = workspace->GetMKLDNNData(); - this->workspace->set_data_handle(ws_mem->get_data_handle()); - } +inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { + return param.kernel.ndim() == 2 && + (param.pool_type == pool_enum::kMaxPooling || + param.pool_type == pool_enum::kAvgPooling); } -void MKLDNNPoolingFwd::Execute() { - if (this->fwd) { - MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); - MKLDNNStream::Get()->Submit(); - } else { - LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; - } -} - -static inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { - return param.kernel.ndim() == 2 - && (param.pool_type == pool_enum::kMaxPooling - || param.pool_type == pool_enum::kAvgPooling); -} - -static inline bool SupportMKLDNNPooling(const PoolingParam ¶m, - const TShape &dshape) { - auto ret = SupportMKLDNNPooling(param); +inline bool SupportMKLDNNPooling(const PoolingParam ¶m, + const TShape &dshape) { + bool ret = SupportMKLDNNPooling(param); if (!ret) return false; + if (param.pooling_convention == pool_enum::kValid) return true; - if ((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0 - && (dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0) + + if (((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0) && + ((dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0)) return true; else return false; } -static inline mkldnn::algorithm -GetMKLDNNPoolAlgo(const PoolingParam ¶m) { - switch (param.pool_type) { - case pool_enum::kMaxPooling: - return mkldnn::algorithm::pooling_max; - break; - case pool_enum::kAvgPooling: - return mkldnn::algorithm::pooling_avg_include_padding; - break; - default: - LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method."; - return mkldnn::algorithm::pooling_max; - } -} - -inline static mkldnn::pooling_forward::primitive_desc -GetPoolingFwd(const PoolingParam ¶m, - bool is_train, - const memory::desc &data_md, - const memory::desc &out_md) { - CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; - int kernel_h_, kernel_w_; - if (param.global_pool) { - kernel_h_ = data_md.data.dims[2]; - kernel_w_ = data_md.data.dims[3]; - } else { - kernel_h_ = param.kernel[0]; - kernel_w_ = param.kernel[1]; - } - - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - - auto pad_t_ = param.pad[0], pad_b_ = param.pad[0]; - auto pad_l_ = param.pad[1], pad_r_ = param.pad[1]; - auto stride_h_ = param.stride[0], stride_w_ = param.stride[1]; - - auto engine = CpuEngine::Get()->get_engine(); - if (param.global_pool) { - CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) - << "With Global_pooling: true; only pad = 0 and stride = 1"; - } - if (pad_t_ != 0 || pad_l_ != 0) { - CHECK(param.pool_type == pool_enum::kAvgPooling || - param.pool_type == pool_enum::kMaxPooling) - << "Padding implemented only for average and max pooling."; - CHECK_LT(pad_l_, kernel_w_); - CHECK_LT(pad_t_, kernel_h_); - } - - auto alg = GetMKLDNNPoolAlgo(param); - auto kind = prop_kind::forward_scoring; - if (is_train && alg != algorithm::pooling_avg) { - kind = prop_kind::forward_training; - } - - pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, - {static_cast(stride_h_), - static_cast(stride_w_)}, - {kernel_h_, kernel_w_}, - {static_cast(pad_t_), - static_cast(pad_l_)}, - {static_cast(pad_b_), - static_cast(pad_r_)}, - padding_kind::zero); - return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); -} - inline bool MKLDNNRequireWorkspace(const PoolingParam ¶m) { return param.pool_type != pool_enum::kAvgPooling; } typedef MKLDNNParamOpSign MKLDNNPoolingSignature; - -static inline MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, - bool is_train, - const NDArray &data, - const NDArray &output) { - static thread_local std::unordered_map pooling_fwds; - - bool with_workspace = is_train && MKLDNNRequireWorkspace(param); - MKLDNNPoolingSignature key(param); - key.AddSign(is_train); - key.AddSign(with_workspace); - key.AddSign(data); - key.AddSign(output); - - auto it = pooling_fwds.find(key); - if (it == pooling_fwds.end()) { - CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; - auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); - int kernel_h_, kernel_w_; - if (param.global_pool) { - kernel_h_ = data_md.data.dims[2]; - kernel_w_ = data_md.data.dims[3]; - } else { - kernel_h_ = param.kernel[0]; - kernel_w_ = param.kernel[1]; - } - - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - - auto pad_t_ = param.pad[0], pad_b_ = param.pad[0]; - auto pad_l_ = param.pad[1], pad_r_ = param.pad[1]; - auto stride_h_ = param.stride[0], stride_w_ = param.stride[1]; - - if (param.global_pool) { - CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) - << "With Global_pooling: true; only pad = 0 and stride = 1"; - } - - if (pad_t_ != 0 || pad_l_ != 0) { - CHECK(param.pool_type == pool_enum::kAvgPooling || - param.pool_type == pool_enum::kMaxPooling) - << "Padding implemented only for average and max pooling."; - CHECK_LT(pad_l_, kernel_w_); - CHECK_LT(pad_t_, kernel_h_); - } - - auto alg = GetMKLDNNPoolAlgo(param); - MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_, - pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train); - auto ins_ret = pooling_fwds.insert( - std::pair(key, fwd)); - CHECK(ins_ret.second); - it = ins_ret.first; - } - return it->second; -} - void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data, const NDArray *workspace) { - auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); - fwd.SetDataHandle(in_data, out_data, workspace); - fwd.Execute(); -} + const NDArray &in_data, const OpReqType req, + const NDArray &out_data, const NDArray *workspace); void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &out_grad, const NDArray &in_data, - const NDArray *workspace, const OpReqType &req, - const NDArray &in_grad) { - if (req == kNullOp) { - return; - } - - TmpMemMgr::Get()->Init(ctx.requested[0]); - auto diff_dst_mem = out_grad.GetMKLDNNData(); - auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], - static_cast(out_grad.shape()[2]), - static_cast(out_grad.shape()[3])}; - memory::desc out_md({dims}, - static_cast(data_md.data.data_type), - static_cast(data_md.data.format)); - auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md); - - mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc(); - memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], - static_cast(in_grad.shape()[2]), - static_cast(in_grad.shape()[3])}; - memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - static_cast(diff_md.data.format)); - auto cpu_engine = data_mpd.get_engine(); - - auto alg = GetMKLDNNPoolAlgo(param); - - int kernel_h_, kernel_w_; - if (param.global_pool) { - kernel_h_ = data_md.data.dims[2]; - kernel_w_ = data_md.data.dims[3]; - } else { - kernel_h_ = param.kernel[0]; - kernel_w_ = param.kernel[1]; - } - pooling_backward::desc desc(alg, diff_in_md, diff_md, - {static_cast(param.stride[0]), - static_cast(param.stride[1])}, - {kernel_h_, kernel_w_}, - {static_cast(param.pad[0]), - static_cast(param.pad[1])}, - {static_cast(param.pad[0]), - static_cast(param.pad[1])}, - padding_kind::zero); - pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd); - - auto diff_src_mem = - CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req); - - if (MKLDNNRequireWorkspace(param)) { - CHECK(workspace != nullptr); - auto workspace_mem = workspace->GetMKLDNNData(); - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem), - *diff_src_mem.second)); - } else { - MKLDNNStream::Get()->RegisterPrim( - pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second)); - } - CommitOutput(in_grad, diff_src_mem); - MKLDNNStream::Get()->Submit(); -} + const NDArray *workspace, const OpReqType req, + const NDArray &in_grad); } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc new file mode 100644 index 000000000000..0c068fbacdaa --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_pooling.cc + * \brief + * \author Tao Lv +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include "./mkldnn_pooling-inl.h" + +namespace mxnet { +namespace op { + +void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &output, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int padding_t, const int padding_b, + const int padding_l, const int padding_r) { + // mkldnn::memory::desc + auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); + mkldnn::memory::dims dims = {src_md.data.dims[0], + src_md.data.dims[1], + static_cast(output.shape()[2]), + static_cast(output.shape()[3])}; + auto dst_md = mkldnn::memory::desc({dims}, + static_cast(src_md.data.data_type), + static_cast(src_md.data.format)); + const mkldnn::engine engine = CpuEngine::Get()->get_engine(); + const mkldnn::algorithm alg_kind = this->alg_kind_; + if (alg_kind != mkldnn::algorithm::pooling_max && + alg_kind != mkldnn::algorithm::pooling_avg && + alg_kind != mkldnn::algorithm::pooling_avg_include_padding && + alg_kind != mkldnn::algorithm::pooling_avg_exclude_padding) { + LOG(FATAL) << "MKLDNN Pooling: algorithm is not supported"; + } + + mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring; + // if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) { + if (this->is_train_) { + prop = mkldnn::prop_kind::forward_training; + } + + const mkldnn::memory::dims strides = {stride_h, stride_w }; + const mkldnn::memory::dims pad_l = {padding_t, padding_l }; + const mkldnn::memory::dims pad_r = {padding_b, padding_r }; + const mkldnn::memory::dims kernel = {kernel_h, kernel_w }; + // mkldnn::pooling_forward::desc + const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, + strides, kernel, pad_l, pad_r, + mkldnn::padding_kind::zero); + this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); + this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); + this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc())); + if (this->with_workspace_) { + this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc())); + this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), + mkldnn::primitive::at(*(this->data_)), + *(this->out_), + *(this->workspace_))); + } else { + this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), + mkldnn::primitive::at(*(this->data_)), + *(this->out_))); + } + return; +} + +void MKLDNNPoolingFwd::SetDataHandle(const mxnet::NDArray &data, + const mxnet::NDArray &output, + const mxnet::NDArray *workspace) { + // mkldnn::memory + auto data_mem = data.GetMKLDNNData(); + auto out_mem = const_cast(output).CreateMKLDNNData( + this->fwd_pd_->dst_primitive_desc()); + this->data_->set_data_handle(data_mem->get_data_handle()); + this->out_->set_data_handle(out_mem->get_data_handle()); + if (this->with_workspace_ && workspace == nullptr) { + LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; + } + + if (this->with_workspace_) { + // mkldnn::memory + auto ws_mem = workspace->GetMKLDNNData(); + this->workspace_->set_data_handle(ws_mem->get_data_handle()); + } +} + +void MKLDNNPoolingFwd::Execute() { + if (this->fwd_) { + MKLDNNStream::Get()->RegisterPrim(*(this->fwd_)); + MKLDNNStream::Get()->Submit(); + } else { + LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; + } +} + +mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { + switch (param.pool_type) { + case pool_enum::kMaxPooling: + return mkldnn::algorithm::pooling_max; + break; + case pool_enum::kAvgPooling: + return mkldnn::algorithm::pooling_avg_include_padding; + break; + default: + LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method."; + return mkldnn::algorithm::pooling_max; + } +} + +mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, + const bool is_train, + const memory::desc &data_md, + const memory::desc &out_md) { + CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + + const int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + const int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + const int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + + const mkldnn::engine engine = CpuEngine::Get()->get_engine(); + if (param.global_pool) { + CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) + << "With Global_pooling: true; only pad = 0 and stride = 1"; + } + if (pad_t_ != 0 || pad_l_ != 0) { + CHECK(param.pool_type == pool_enum::kAvgPooling || + param.pool_type == pool_enum::kMaxPooling) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_l_, kernel_w_); + CHECK_LT(pad_t_, kernel_h_); + } + + + const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; + // if (is_train && alg != algorithm::pooling_avg) { + if (is_train) { + kind = mkldnn::prop_kind::forward_training; + } + + const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, + {static_cast(stride_h_), + static_cast(stride_w_)}, + {kernel_h_, kernel_w_}, + {static_cast(pad_t_), + static_cast(pad_l_)}, + {static_cast(pad_b_), + static_cast(pad_r_)}, + padding_kind::zero); + return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); +} + +MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, + const bool is_train, + const NDArray &data, + const NDArray &output) { + static thread_local std::unordered_map pooling_fwds; + + bool with_workspace = is_train && MKLDNNRequireWorkspace(param); + MKLDNNPoolingSignature key(param); + key.AddSign(is_train); + key.AddSign(with_workspace); + key.AddSign(data); + key.AddSign(output); + + auto it = pooling_fwds.find(key); + if (it == pooling_fwds.end()) { + CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; + auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + + const int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; + const int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; + const int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; + + if (param.global_pool) { + CHECK(pad_t_ == 0 && pad_l_ == 0 && stride_h_ == 1 && stride_w_ == 1) + << "With Global_pooling: true; only pad = 0 and stride = 1"; + } + + if (pad_t_ != 0 || pad_l_ != 0) { + CHECK(param.pool_type == pool_enum::kAvgPooling || + param.pool_type == pool_enum::kMaxPooling) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_l_, kernel_w_); + CHECK_LT(pad_t_, kernel_h_); + } + + const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train); + auto ins_ret = pooling_fwds.insert( + std::pair(key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, + const NDArray &in_data, const OpReqType req, + const NDArray &out_data, const NDArray *workspace) { + auto fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); + fwd.SetDataHandle(in_data, out_data, workspace); + fwd.Execute(); +} + +void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, + const NDArray &out_grad, const NDArray &in_data, + const NDArray *workspace, const OpReqType req, + const NDArray &in_grad) { + if (req == kNullOp) { + return; + } + + TmpMemMgr::Get()->Init(ctx.requested[0]); + // mkldnn::memory + auto diff_dst_mem = out_grad.GetMKLDNNData(); + auto input_mem = in_data.GetMKLDNNData(); + mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); + const mkldnn::memory::desc data_md = data_mpd.desc(); + const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + static_cast(out_grad.shape()[2]), + static_cast(out_grad.shape()[3])}; + const memory::desc out_md({dims}, + static_cast(data_md.data.data_type), + static_cast(data_md.data.format)); + auto pdesc_fwd = GetPoolingFwd(param, ctx.is_train, data_md, out_md); + + const mkldnn::memory::desc diff_md = diff_dst_mem->get_primitive_desc().desc(); + const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + static_cast(in_grad.shape()[2]), + static_cast(in_grad.shape()[3])}; + const memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + static_cast(diff_md.data.format)); + const mkldnn::engine cpu_engine = data_mpd.get_engine(); + const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + + int kernel_h_, kernel_w_; + if (param.global_pool) { + kernel_h_ = data_md.data.dims[2]; + kernel_w_ = data_md.data.dims[3]; + } else { + kernel_h_ = param.kernel[0]; + kernel_w_ = param.kernel[1]; + } + const pooling_backward::desc desc(alg, diff_in_md, diff_md, + {static_cast(param.stride[0]), + static_cast(param.stride[1])}, + {kernel_h_, kernel_w_}, + {static_cast(param.pad[0]), + static_cast(param.pad[1])}, + {static_cast(param.pad[0]), + static_cast(param.pad[1])}, + mkldnn::padding_kind::zero); + const pooling_backward::primitive_desc pdesc(desc, cpu_engine, pdesc_fwd); + + auto diff_src_mem = + CreateMKLDNNMem(in_grad, pdesc.diff_src_primitive_desc(), req); + + if (MKLDNNRequireWorkspace(param)) { + CHECK(workspace != nullptr); + auto workspace_mem = workspace->GetMKLDNNData(); + MKLDNNStream::Get()->RegisterPrim( + pooling_backward(pdesc, *diff_dst_mem, primitive::at(*workspace_mem), + *diff_src_mem.second)); + } else { + MKLDNNStream::Get()->RegisterPrim( + pooling_backward(pdesc, *diff_dst_mem, *diff_src_mem.second)); + } + CommitOutput(in_grad, diff_src_mem); + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 9c0bfa9d72ef..7e9ab2e7f80b 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -37,25 +37,25 @@ namespace op { static void PoolingParamParser(nnvm::NodeAttrs *attrs) { using namespace mshadow; - PoolingParam param_; - param_.Init(attrs->dict); - if (param_.kernel.ndim() == 1) { - if (param_.stride.ndim() == 0) param_.stride = Shape1(1); - if (param_.pad.ndim() == 0) param_.pad = Shape1(0); - } else if (param_.kernel.ndim() == 2) { - if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); - if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + PoolingParam param; + param.Init(attrs->dict); + if (param.kernel.ndim() == 1) { + if (param.stride.ndim() == 0) param.stride = Shape1(1); + if (param.pad.ndim() == 0) param.pad = Shape1(0); + } else if (param.kernel.ndim() == 2) { + if (param.stride.ndim() == 0) param.stride = Shape2(1, 1); + if (param.pad.ndim() == 0) param.pad = Shape2(0, 0); } else { - CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D pooling not supported"; - if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); - if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); + if (param.stride.ndim() == 0) param.stride = Shape3(1, 1, 1); + if (param.pad.ndim() == 0) param.pad = Shape3(0, 0, 0); } - CHECK_EQ(param_.stride.ndim(), param_.kernel.ndim()) + CHECK_EQ(param.stride.ndim(), param.kernel.ndim()) << "stride and kernel should have the same length"; - CHECK_EQ(param_.pad.ndim(), param_.kernel.ndim()) + CHECK_EQ(param.pad.ndim(), param.kernel.ndim()) << "pad and kernel should have the same length"; - attrs->parsed = std::move(param_); + attrs->parsed = std::move(param); } int GetNumOutputs(const PoolingParam ¶m) { @@ -91,7 +91,7 @@ static bool PoolingType(const nnvm::NodeAttrs& attrs, static bool PoolingShape(const nnvm::NodeAttrs &attrs, std::vector *in_shape, std::vector *out_shape) { - const PoolingParam ¶m_ = nnvm::get(attrs.parsed); + const PoolingParam ¶m = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); const TShape &dshape = (*in_shape)[0]; CHECK_GE(dshape.ndim(), 3U) @@ -100,116 +100,116 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, << " Or 5D in (batch, channel, d, y, x)"; TShape oshape = dshape; if (dshape.ndim() == 0) return false; - if (param_.kernel.ndim() == 1) { + if (param.kernel.ndim() == 1) { CHECK_EQ(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)"; - if (param_.global_pool) { + if (param.global_pool) { oshape[2] = 1; } else { - CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0]) - << "kernel size (" << param_.kernel[0] << ") exceeds input (" - << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0]) + CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0]) + << "kernel size (" << param.kernel[0] << ") exceeds input (" + << dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0]) << ")"; - if (param_.pooling_convention == pool_enum::kValid) { + if (param.pooling_convention == pool_enum::kValid) { oshape[2] = 1 + - (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; + (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; } else { oshape[2] = 1 + static_cast(ceil( - static_cast(dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / - param_.stride[0])); + static_cast(dshape[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 - if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif - } else if (param_.kernel.ndim() == 2) { + } else if (param.kernel.ndim() == 2) { CHECK_EQ(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x)"; - if (param_.global_pool) { + if (param.global_pool) { oshape[2] = 1; oshape[3] = 1; } else { - CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0]) - << "kernel size (" << param_.kernel[0] << ") exceeds input (" - << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0]) + CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0]) + << "kernel size (" << param.kernel[0] << ") exceeds input (" + << dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0]) << ")"; - CHECK(param_.kernel[1] <= dshape[3] + 2 * param_.pad[1]) - << "kernel size (" << param_.kernel[1] << ") exceeds input (" - << dshape[3] << " padded to " << (dshape[3] + 2 * param_.pad[1]) + CHECK(param.kernel[1] <= dshape[3] + 2 * param.pad[1]) + << "kernel size (" << param.kernel[1] << ") exceeds input (" + << dshape[3] << " padded to " << (dshape[3] + 2 * param.pad[1]) << ")"; - if (param_.pooling_convention == pool_enum::kValid) { + if (param.pooling_convention == pool_enum::kValid) { oshape[2] = 1 + - (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; + (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; oshape[3] = 1 + - (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / - param_.stride[1]; + (dshape[3] + 2 * param.pad[1] - param.kernel[1]) / + param.stride[1]; } else { oshape[2] = 1 + static_cast(ceil( - static_cast(dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / - param_.stride[0])); + static_cast(dshape[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); oshape[3] = 1 + static_cast(ceil( - static_cast(dshape[3] + 2 * param_.pad[1] - - param_.kernel[1]) / - param_.stride[1])); + static_cast(dshape[3] + 2 * param.pad[1] - + param.kernel[1]) / + param.stride[1])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 - if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif - } else if (param_.kernel.ndim() == 3) { + } else if (param.kernel.ndim() == 3) { CHECK_EQ(dshape.ndim(), 5U) << "Pooling: Input data should be 5D in (batch, channel, d, y, x)"; - CHECK_LE(param_.kernel[0], dshape[2] + 2 * param_.pad[0]) + CHECK_LE(param.kernel[0], dshape[2] + 2 * param.pad[0]) << "kernel size exceeds input"; - CHECK_LE(param_.kernel[1], dshape[3] + 2 * param_.pad[1]) + CHECK_LE(param.kernel[1], dshape[3] + 2 * param.pad[1]) << "kernel size exceeds input"; - CHECK_LE(param_.kernel[2], dshape[4] + 2 * param_.pad[2]) + CHECK_LE(param.kernel[2], dshape[4] + 2 * param.pad[2]) << "kernel size exceeds input"; - if (param_.global_pool) { + if (param.global_pool) { oshape[2] = 1; oshape[3] = 1; oshape[4] = 1; } else { - if (param_.pooling_convention == pool_enum::kValid) { + if (param.pooling_convention == pool_enum::kValid) { oshape[2] = 1 + - (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) / - param_.stride[0]; + (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / + param.stride[0]; oshape[3] = 1 + - (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) / - param_.stride[1]; + (dshape[3] + 2 * param.pad[1] - param.kernel[1]) / + param.stride[1]; oshape[4] = 1 + - (dshape[4] + 2 * param_.pad[2] - param_.kernel[2]) / - param_.stride[2]; + (dshape[4] + 2 * param.pad[2] - param.kernel[2]) / + param.stride[2]; } else { oshape[2] = 1 + static_cast(ceil( - static_cast(dshape[2] + 2 * param_.pad[0] - - param_.kernel[0]) / - param_.stride[0])); + static_cast(dshape[2] + 2 * param.pad[0] - + param.kernel[0]) / + param.stride[0])); oshape[3] = 1 + static_cast(ceil( - static_cast(dshape[3] + 2 * param_.pad[1] - - param_.kernel[1]) / - param_.stride[1])); + static_cast(dshape[3] + 2 * param.pad[1] - + param.kernel[1]) / + param.stride[1])); oshape[4] = 1 + static_cast(ceil( - static_cast(dshape[4] + 2 * param_.pad[2] - - param_.kernel[2]) / - param_.stride[2])); + static_cast(dshape[4] + 2 * param.pad[2] - + param.kernel[2]) / + param.stride[2])); } } out_shape->clear(); out_shape->push_back(oshape); // save output shape #if MXNET_USE_MKLDNN == 1 - if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_)) + if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif } @@ -273,18 +273,16 @@ inline static bool PoolingStorageType(const nnvm::NodeAttrs &attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1); - *dispatch_mode = DispatchMode::kFCompute; #if MXNET_USE_MKLDNN == 1 const PoolingParam ¶m = nnvm::get(attrs.parsed); if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) { - *dispatch_mode = DispatchMode::kFComputeEx; + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); } -#else - CHECK_EQ(out_attrs->size(), 1); #endif - for (size_t i = 0; i < out_attrs->size(); i++) - (*out_attrs)[i] = kDefaultStorage; - return true; + CHECK_EQ(out_attrs->size(), 1); + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); } inline static bool BackwardPoolingStorageType(const nnvm::NodeAttrs &attrs, @@ -296,17 +294,15 @@ inline static bool BackwardPoolingStorageType(const nnvm::NodeAttrs &attrs, CHECK_EQ(in_attrs->size(), GetNumBackInputs(param)); CHECK_EQ(out_attrs->size(), 1); - *dispatch_mode = DispatchMode::kFCompute; #if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) { - *dispatch_mode = DispatchMode::kFComputeEx; + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); } -#else - CHECK_EQ(in_attrs->size(), 3); #endif - for (size_t i = 0; i < out_attrs->size(); i++) - (*out_attrs)[i] = kDefaultStorage; - return true; + CHECK_EQ(in_attrs->size(), 3); + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); } DMLC_REGISTER_PARAMETER(PoolingParam); @@ -360,6 +356,10 @@ height, width)*. .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) #endif +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output"};