diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 48638de20ccb..3f47d58bb8c3 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -224,16 +224,25 @@ void BatchNormForward(const OpContext &ctx, const BatchNormParam& param, */ template void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, + const std::vector &inputs, const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 8U); + CHECK_EQ(outputs.size(), 3U); + std::vector out_grad(1); + std::vector out_data(3); + std::vector in_data(3); + std::vector aux_states(2); + + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + const std::vector &in_grad = outputs; mshadow::Stream *s = ctx.get_stream(); BatchNormBackwardImpl(s, ctx, param, out_grad, in_data, out_data, req, in_grad, aux_states); @@ -261,23 +270,11 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 8U); const BatchNormParam& param = nnvm::get(attrs.parsed); - int num_out_grads = param.output_mean_var ? 3U : 1U; - int in_data_start = 3; - int aux_states_start = in_data_start + batchnorm::kInMovingMean; - int out_data_start = in_data_start + batchnorm::kInMovingVar + 1; - std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); - std::vector in_data(inputs.begin() + in_data_start, - inputs.begin() + aux_states_start); - std::vector aux_states(inputs.begin() + aux_states_start, - inputs.begin() + out_data_start); - std::vector out_data(inputs.begin() + out_data_start, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); - - MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, in_data, out_data, req, - in_grad, aux_states); + + MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { + BatchNormBackward(ctx, param, inputs, req, outputs); }); } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index c8b5d58156e5..457f536d7fa0 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -413,24 +413,26 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 8U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - int num_out_grads = param.output_mean_var ? 3U : 1U; - int in_data_start = 3; - int aux_states_start = in_data_start + batchnorm::kInMovingMean; - int out_data_start = in_data_start + batchnorm::kInMovingVar + 1; TShape shape = inputs[0].shape(); // MKLDNN batchnorm only works well on the special MKLDNN layout. if (SupportMKLDNNBN(inputs[0], param) - && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { - std::vector out_grad(inputs.begin(), inputs.begin() + num_out_grads); - std::vector in_data(inputs.begin() + in_data_start, - inputs.begin() + aux_states_start); - std::vector aux_states(inputs.begin() + aux_states_start, - inputs.begin() + out_data_start); - std::vector out_data(inputs.begin() + out_data_start, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); + && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { + std::vector out_grad(1); + std::vector out_data(3); + std::vector in_data(3); + std::vector aux_states(2); + out_grad[0] = inputs[0]; + out_data[batchnorm::kMean] = inputs[1]; + out_data[batchnorm::kVar] = inputs[2]; + in_data[batchnorm::kData] = inputs[3]; + in_data[batchnorm::kGamma] = inputs[4]; + in_data[batchnorm::kBeta] = inputs[5]; + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; + const std::vector &in_grad = outputs; if (inputs[0].dtype() == mshadow::kFloat32) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); @@ -470,8 +472,6 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 11); - CHECK_EQ(out_attrs->size(), 5); DispatchMode wanted_mode; #if MXNET_USE_MKLDNN == 1 if (dev_mask == mshadow::cpu::kDevMask) @@ -486,6 +486,46 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, dispatch_mode, wanted_mode); } +std::vector BatchNormGrad(const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector out_data(n->num_outputs()); + for (uint32_t i = 0; i < out_data.size(); ++i) { + out_data[i] = nnvm::NodeEntry{n, i, 0}; + } + std::vector heads; + heads.reserve(8); + heads.push_back(ograds[0]); + heads.push_back(out_data[batchnorm::kMean]); + heads.push_back(out_data[batchnorm::kVar]); + heads.push_back(n->inputs[batchnorm::kData]); + heads.push_back(n->inputs[batchnorm::kGamma]); + heads.push_back(n->inputs[batchnorm::kBeta]); + heads.push_back(n->inputs[batchnorm::kInMovingMean]); + heads.push_back(n->inputs[batchnorm::kInMovingVar]); + + nnvm::NodePtr gnode = nnvm::Node::Create(); + gnode->inputs = std::move(heads); + gnode->control_deps.emplace_back(n); + gnode->attrs = n->attrs; + gnode->attrs.op = nnvm::Op::Get("_backward_BatchNorm"); + gnode->attrs.name = n->attrs.name + "_backward"; + // The input of batchnorm + std::vector in_grad(5); + for (uint32_t i = 0; i < 3; ++i) { + in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; + } + + // attach no gradient node to forbid gradient on aux_state + nnvm::NodePtr ng = nnvm::Node::Create(); + ng->attrs.op = Op::Get("_NoGradient"); + ng->attrs.name = "NoGradient"; + // the aux state of batchnorm + for (uint32_t i = 0; i < 2; ++i) { + in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0}; + } + return in_grad; +} + NNVM_REGISTER_OP(BatchNorm) .describe(R"code(Batch normalization. @@ -559,7 +599,7 @@ then set ``gamma`` to 1 and its gradient to 0. #if MXNET_USE_MKLDNN == 1 .set_attr("FComputeEx", BatchNormComputeExCPU) #endif -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_BatchNorm"}) +.set_attr("FGradient", BatchNormGrad) #if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -583,7 +623,7 @@ then set ``gamma`` to 1 and its gradient to 0. }); NNVM_REGISTER_OP(_backward_BatchNorm) -.set_num_outputs(5) +.set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", backward_BatchNormStorageType) #if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index b8657fc4d367..c310a93d700f 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -690,13 +690,8 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CHECK_EQ(inputs.size(), 11U); + CHECK_EQ(inputs.size(), 8U); BatchNormParam param = nnvm::get(attrs.parsed); - std::vector out_grad(1, inputs[0]); - std::vector in_data(inputs.begin() + 3, inputs.begin() + 6); - std::vector aux_states(inputs.begin() + 6, inputs.begin() + 8); - std::vector out_data(inputs.begin() + 8, inputs.end()); - std::vector in_grad(outputs.begin(), outputs.begin() + 3); int dtype = inputs[0].type_flag_; TShape shape = inputs[0].shape_; @@ -705,19 +700,18 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - GetCuDNNOp(param).Backward(ctx, out_grad, in_data, out_data, - req, in_grad, aux_states); + GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, - in_data, out_data, req, in_grad, aux_states); + BatchNormBackward(ctx, param, inputs, req, outputs); }) } #else + aux_states[batchnorm::kMovingMean] = inputs[6]; + aux_states[batchnorm::kMovingVar] = inputs[7]; MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, { - BatchNormBackward(ctx, param, out_grad, - in_data, out_data, req, in_grad, aux_states); + BatchNormBackward(ctx, param, inputs, req, outputs); }); #endif } diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index d0dd7dd27a60..c98a010774d7 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -124,6 +124,8 @@ struct ConvolutionParam : public dmlc::Parameter { } }; +typedef ParamOpSign ConvSignature; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index d7f9e564a603..f6d14e3558b8 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -41,13 +41,40 @@ static CuDNNConvolutionOp &GetCuDNNConvOp(const ConvolutionParam& param, const std::vector& in_shape, const std::vector& out_shape, const Context& ctx) { #if DMLC_CXX11_THREAD_LOCAL - static thread_local CuDNNConvolutionOp op; + static thread_local std::unordered_map >, + OpHash> ops; #else - static MX_THREAD_LOCAL CuDNNConvolutionOp op; + static MX_THREAD_LOCAL std::unordered_map >, + OpHash> ops; #endif - op.Init(param, forward_compute_type, backward_compute_type, - in_shape, out_shape, ctx); - return op; + ConvSignature key(param); + size_t ndim = 0; + for (auto &s : in_shape) + ndim += s.ndim(); + for (auto &s : out_shape) + ndim += s.ndim(); + key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + + ndim + 1 /* for dev_id */); + + key.AddSign(forward_compute_type); + key.AddSign(backward_compute_type); + key.AddSign(in_shape); + key.AddSign(out_shape); + key.AddSign(ctx.dev_id); + + auto it = ops.find(key); + if (it == ops.end()) { + std::shared_ptr> op(new CuDNNConvolutionOp()); + auto ins_ret = ops.insert(std::pair>>( + key, op)); + CHECK(ins_ret.second); + it = ins_ret.first; + it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, + out_shape, ctx); + } + return *it->second; } #endif diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index e2337049060e..e3d5dd9204b9 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -67,10 +67,10 @@ class CuDNNBatchNormOp { } void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 3U); @@ -158,29 +158,30 @@ class CuDNNBatchNormOp { } void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); - CHECK_EQ(in_data.size(), 3U); - CHECK_EQ(out_data.size(), 3U); - CHECK_EQ(in_grad.size(), 3U); + CHECK_EQ(inputs.size(), 8U); + CHECK_EQ(outputs.size(), 3U); CHECK(ctx.is_train && !param_.use_global_stats) << "use global statistics is not yet supported in CuDNNBatchNorm"; - Init(in_data[cudnnbatchnorm::kData]); + // Rename the inputs and outputs. + const TBlob &out_grad = inputs[0]; + const TBlob &out_mean = inputs[1]; + const TBlob &out_var = inputs[2]; + const TBlob &in_data = inputs[3]; + const TBlob &in_gamma = inputs[4]; + const std::vector &in_grad = outputs; + + Init(in_data); Stream *s = ctx.get_stream(); - Tensor x = - in_data[cudnnbatchnorm::kData].get_with_shape(shape_, s); + Tensor x = in_data.get_with_shape(shape_, s); Tensor dx = in_grad[cudnnbatchnorm::kData].get_with_shape(shape_, s); - Tensor dy = - out_grad[cudnnbatchnorm::kOut].get_with_shape(shape_, s); + Tensor dy = out_grad.get_with_shape(shape_, s); #if CUDNN_VERSION >= 4007 #if CUDNN_VERSION >= 7002 @@ -190,15 +191,15 @@ class CuDNNBatchNormOp { #endif MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { Tensor gamma = - in_data[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); + in_gamma.get_with_shape(Shape1(shape_[1]), s); Tensor dbeta = in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); Tensor dgamma = in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); Tensor save_mean = - out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), s); + out_mean.get_with_shape(Shape1(shape_[1]), s); Tensor save_inv_var = - out_data[cudnnbatchnorm::kInvVar].get_with_shape(Shape1(shape_[1]), s); + out_var.get_with_shape(Shape1(shape_[1]), s); typename DataType::ScaleType a = 1.0f; typename DataType::ScaleType b = 0.0f; @@ -232,15 +233,15 @@ class CuDNNBatchNormOp { #else // CUDNN_VERSION < 4007 MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { Tensor gamma = - in_data[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); + in_gamma.get_with_shape(Shape1(shape_[1]), s); Tensor dbeta = in_grad[cudnnbatchnorm::kBeta].get_with_shape(Shape1(shape_[1]), s); Tensor dgamma = in_grad[cudnnbatchnorm::kGamma].get_with_shape(Shape1(shape_[1]), s); Tensor save_mean = - out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), s); + out_mean.get_with_shape(Shape1(shape_[1]), s); Tensor save_inv_var = - out_data[cudnnbatchnorm::kInvVar].get_with_shape(Shape1(shape_[1]), s); + out_var.get_with_shape(Shape1(shape_[1]), s); typename DataType::ScaleType a = 1.0f; typename DataType::ScaleType b = 0.0f; diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index badbb8b9d672..b41ecf4aa41e 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -169,6 +169,8 @@ struct DeconvolutionParam : public dmlc::Parameter { } }; +typedef ParamOpSign DeconvSignature; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index c7395428c2a0..086b47000b2c 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -40,9 +40,35 @@ static CuDNNDeconvolutionOp &GetCuDNNDeconvOp(const DeconvolutionParam& p const std::vector& in_shape, const std::vector& out_shape, const Context& ctx) { - static thread_local CuDNNDeconvolutionOp op; - op.Init(param, forward_compute_type, backward_compute_type, in_shape, out_shape, ctx); - return op; + static thread_local std::unordered_map >, + OpHash> ops; + DeconvSignature key(param); + size_t ndim = 0; + for (auto &s : in_shape) + ndim += s.ndim(); + for (auto &s : out_shape) + ndim += s.ndim(); + key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + + ndim + 1 /* for dev_id */); + + key.AddSign(forward_compute_type); + key.AddSign(backward_compute_type); + key.AddSign(in_shape); + key.AddSign(out_shape); + key.AddSign(ctx.dev_id); + + auto it = ops.find(key); + if (it == ops.end()) { + std::shared_ptr> op(new CuDNNDeconvolutionOp()); + auto ins_ret = ops.insert( + std::pair>>(key, op)); + CHECK(ins_ret.second); + it = ins_ret.first; + it->second->Init(param, forward_compute_type, backward_compute_type, in_shape, + out_shape, ctx); + } + return *it->second; } #endif diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 71fdf4ca585b..8c19850ced38 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -93,7 +93,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } -typedef MKLDNNParamOpSign MKLDNNActSignature; +typedef ParamOpSign MKLDNNActSignature; class MKLDNNActForward { std::shared_ptr fwd; @@ -137,7 +137,7 @@ class MKLDNNActForward { static MKLDNNActForward &GetActForward(const ActivationParam& param, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; MKLDNNActSignature key(param); key.AddSign(ctx.is_train); key.AddSign(param.act_type); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 1c583e1f671e..362f5fbde727 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -296,111 +296,6 @@ class MKLDNNStream { } }; -class MKLDNNOpSignature { - std::vector eles; - uint64_t hash; - - public: - MKLDNNOpSignature() { - hash = 0; - } - - explicit MKLDNNOpSignature(uint64_t hash) { - this->hash = hash; - } - - /* - * We provide different methods to add signature to an op. - * For operations, such as convolutin and fully connected, which determines - * the optimal data layout for the op, we only need to use the shape and data - * type to sign the op. For other operations, such as activation, which uses - * whatever layout in the input array, we have to use the shape, the data type - * and the layout to sign the op. - */ - - void AddSign(const mkldnn::memory &mem) { - auto desc = mem.get_primitive_desc().desc(); - hash = hash * 2 + desc.data.format; - eles.push_back(desc.data.format); - hash = hash * 2 + desc.data.data_type; - eles.push_back(desc.data.data_type); - for (int i = 0; i < desc.data.ndims; i++) { - hash = hash * 2 + desc.data.dims[i]; - eles.push_back(desc.data.dims[i]); - } - } - - void AddSign(const std::vector &arrs) { - for (auto &arr : arrs) { - AddSign(arr); - } - } - - void AddSign(const NDArray &arr) { - if (arr.IsMKLDNNData()) { - AddSign(*(arr.GetMKLDNNData())); - } else { - hash = hash * 2 + arr.dtype(); - eles.push_back(arr.dtype()); - AddSign(arr.shape()); - } - } - - void AddSign(const TShape &shape) { - for (size_t i = 0; i < shape.ndim(); i++) { - hash = hash * 2 + shape[i]; - eles.push_back(shape[i]); - } - } - - void AddSign(int val) { - hash = hash * 2 + val; - eles.push_back(val); - } - - bool operator==(const MKLDNNOpSignature &sign) const { - if (hash != sign.hash) - return false; - if (eles.size() != sign.eles.size()) - return false; - for (size_t i = 0; i < eles.size(); i++) - if (eles[i] != sign.eles[i]) - return false; - return true; - } - - uint64_t GetHash() const { - return hash; - } -}; - -struct MKLDNNOpHash { - size_t operator()(const MKLDNNOpSignature &sign) const { - return sign.GetHash(); - } -}; - -template -class MKLDNNParamOpSign: public MKLDNNOpSignature { - const ParamType param; - - static size_t hash(const ParamType ¶m) { - std::hash fn; - return fn(param); - } - - public: - explicit MKLDNNParamOpSign(const ParamType &_param): MKLDNNOpSignature( - hash(_param)), param(_param) { - } - - bool operator==(const MKLDNNParamOpSign &sign) const { - const MKLDNNOpSignature &this_upper = *this; - const MKLDNNOpSignature &other_upper = sign; - return this_upper == other_upper && param == sign.param; - } -}; - enum OutDataOp { Noop, CopyBack, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index a685ebfb4abe..16f9874bd5c8 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -98,7 +98,7 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags)); } -typedef MKLDNNParamOpSign MKLDNNBNSignature; +typedef ParamOpSign MKLDNNBNSignature; class MKLDNNBNForward { std::shared_ptr data_m; @@ -184,7 +184,7 @@ template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const NDArray &in_data, unsigned flags) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); key.AddSign(in_data); @@ -302,7 +302,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &in_grad, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); + CHECK_EQ(out_grad.size(), 1U); CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 76efc244fc42..453221f9b377 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -226,13 +226,13 @@ class MKLDNNConvForward { } }; -typedef MKLDNNParamOpSign MKLDNNConvSignature; +typedef ParamOpSign MKLDNNConvSignature; static inline MKLDNNConvForward &GetConvFwd( const nnvm::NodeAttrs& attrs, bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { - static thread_local std::unordered_map fwds; + static thread_local std::unordered_map fwds; const ConvolutionParam& param = nnvm::get(attrs.parsed); MKLDNNConvSignature key(param); key.AddSign(is_train); diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index a0d3df7bb477..af57b68cfd37 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -289,16 +289,14 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, } } -typedef MKLDNNParamOpSign MKLDNNDeconvSignature; - static inline MKLDNNDeconvForward &GetDeconvFwd( const nnvm::NodeAttrs& attrs, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { static thread_local - std::unordered_map fwds; + std::unordered_map fwds; const DeconvolutionParam& param = nnvm::get(attrs.parsed); - MKLDNNDeconvSignature key(param); + DeconvSignature key(param); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the // data type of the arrays. @@ -313,7 +311,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd( bool has_bias = (bias != nullptr); MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); auto ins_ret = fwds.insert( - std::pair(key, fwd)); + std::pair(key, fwd)); CHECK(ins_ret.second); it = ins_ret.first; } diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 61895b4d4423..2097d57ba92f 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -104,7 +104,7 @@ inline bool MKLDNNRequireWorkspace(const PoolingParam ¶m) { return param.pool_type != pool_enum::kAvgPooling; } -typedef MKLDNNParamOpSign MKLDNNPoolingSignature; +typedef ParamOpSign MKLDNNPoolingSignature; void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace); diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 86f13145eaa5..1aeb7d48dc35 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -188,7 +188,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, const NDArray &output) { static thread_local std::unordered_map pooling_fwds; + OpHash> pooling_fwds; bool with_workspace = is_train && MKLDNNRequireWorkspace(param); MKLDNNPoolingSignature key(param); diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 10581d14ba72..a629ba5eed8b 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -489,6 +489,130 @@ inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } +class OpSignature { + std::vector eles; + uint64_t hash; + + public: + OpSignature() { + hash = 0; + } + + explicit OpSignature(uint64_t hash) { + this->hash = hash; + } + + /* + * This is to reserve space for the vector. + */ + void Reserve(size_t num) { + eles.reserve(num); + } + + /* + * We provide different methods to add signature to an op. + * For operations, such as convolutin and fully connected, which determines + * the optimal data layout for the op, we only need to use the shape and data + * type to sign the op. For other operations, such as activation, which uses + * whatever layout in the input array, we have to use the shape, the data type + * and the layout to sign the op. + */ + +#if MXNET_USE_MKLDNN == 1 + void AddSign(const mkldnn::memory &mem) { + auto desc = mem.get_primitive_desc().desc(); + hash = hash * 2 + desc.data.format; + eles.push_back(desc.data.format); + hash = hash * 2 + desc.data.data_type; + eles.push_back(desc.data.data_type); + for (int i = 0; i < desc.data.ndims; i++) { + hash = hash * 2 + desc.data.dims[i]; + eles.push_back(desc.data.dims[i]); + } + } +#endif + + void AddSign(const std::vector &arrs) { + for (auto &arr : arrs) { + AddSign(arr); + } + } + + void AddSign(const NDArray &arr) { +#if MXNET_USE_MKLDNN == 1 + if (arr.IsMKLDNNData()) { + AddSign(*(arr.GetMKLDNNData())); + } else { +#endif + hash = hash * 2 + arr.dtype(); + eles.push_back(arr.dtype()); + AddSign(arr.shape()); +#if MXNET_USE_MKLDNN == 1 + } +#endif + } + + void AddSign(const std::vector &shapes) { + for (auto &shape : shapes) { + AddSign(shape); + } + } + + void AddSign(const TShape &shape) { + for (size_t i = 0; i < shape.ndim(); i++) { + hash = hash * 2 + shape[i]; + eles.push_back(shape[i]); + } + } + + void AddSign(int val) { + hash = hash * 2 + val; + eles.push_back(val); + } + + bool operator==(const OpSignature &sign) const { + if (hash != sign.hash) + return false; + if (eles.size() != sign.eles.size()) + return false; + for (size_t i = 0; i < eles.size(); i++) + if (eles[i] != sign.eles[i]) + return false; + return true; + } + + uint64_t GetHash() const { + return hash; + } +}; + +struct OpHash { + size_t operator()(const OpSignature &sign) const { + return sign.GetHash(); + } +}; + +template +class ParamOpSign: public OpSignature { + const ParamType param; + + static size_t hash(const ParamType ¶m) { + std::hash fn; + return fn(param); + } + + public: + explicit ParamOpSign(const ParamType &_param): OpSignature( + hash(_param)), param(_param) { + } + + bool operator==(const ParamOpSign &sign) const { + const OpSignature &this_upper = *this; + const OpSignature &other_upper = sign; + return this_upper == other_upper && param == sign.param; + } +}; + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 63f5c91911ed..7dc05fda2cc6 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -141,8 +141,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer static auto gradient = nnvm::Op::GetAttr("FGradient"); nnvm::FGradient grad_fun = gradient.get(op_, nullptr); if (grad_fun) { - std::vector out_grads; - std::vector entries = grad_fun(MakeNode(), out_grads); + auto n = MakeNode(); + std::vector out_grads(n->num_outputs()); + std::vector entries = grad_fun(n, out_grads); CHECK_GE(entries.size(), 1U); res.reserve(entries.size()); for (const nnvm::NodeEntry& node_entry : entries) { @@ -467,7 +468,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer input_shapes_ = input_shapes; // BWD Output shapes output_shapes = backward_for_op->input_shapes_; - CHECK_EQ(output_shapes.size(), inferred_num_outputs); + output_shapes.resize(inferred_num_outputs); } else { output_shapes = input_shapes; output_shapes.resize(inferred_num_outputs); diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 4b08d985de3e..2f9de742a35a 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -77,10 +77,10 @@ enum ForwardOutputs { * \brief Backward */ enum BackwardInputs { - /* out_grad */ bwd_out_grad_Grad, bwd_out_grad_Mean, bwd_out_grad_Var, + /* out_grad */ bwd_out_grad_Grad, + /* out_data */ bwd_out_data_Mean, bwd_out_data_Var, /* in_data */ bwd_in_data_Data, bwd_in_data_Gamma, bwd_in_data_Beta, - /* aux_states */ bwd_aux_states_MovingMean, bwd_aux_states_MovingVar, - /* in_grad */ bwd_out_data_Data, bwd_out_data_Mean, bwd_out_data_Var + /* aux_states */ bwd_aux_states_MovingMean, bwd_aux_states_MovingVar }; enum BackwardOutputs { /* in_grad */ bwd_in_grad_Data /* Original input data */, @@ -250,17 +250,12 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor { test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingMean), 0); test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingVar), 1); - val = -.101; - test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_data_Data), [&val]() -> double { - return val += 1; }); test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Mean), 0.0); test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Var), 1.0); val = -.001; test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_grad_Grad), [&val]() -> double { return val += 0.01; }); - test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Mean), 0.0); - test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Var), 1.0); } const bool hasWeightAndBias_; // This will cause forward pass validation to fail