diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 7fd8bbb55994..0e8a929e1ba5 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -54,7 +54,8 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNConv(inputs[0])) { + const ConvolutionParam& params = nnvm::get(attrs.parsed); + if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, outputs); @@ -68,7 +69,8 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNConv(inputs[0])) { + const ConvolutionParam& params = nnvm::get(attrs.parsed); + if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConvolutionGradCompute, attrs, ctx, inputs, req, outputs); @@ -363,6 +365,18 @@ static void ConvolutionParamParser(nnvm::NodeAttrs* attrs) { if (param_.dilate.ndim() == 0) param_.dilate = Shape3(1, 1, 1); if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); } + CHECK_EQ(param_.kernel.ndim(), param_.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while stride is " + << param_.stride; + CHECK_EQ(param_.kernel.ndim(), param_.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while dilate is " + << param_.dilate; + CHECK_EQ(param_.kernel.ndim(), param_.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while padding is " + << param_.pad; attrs->parsed = std::move(param_); } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 0d1b391104ab..13fc757fb2dc 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -304,7 +304,8 @@ static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNConv(inputs[0])) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNDeconv(param, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(DeconvolutionCompute, attrs, ctx, inputs, req, @@ -320,7 +321,8 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNConv(inputs[0])) { + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNDeconv(param, inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute, attrs, ctx, inputs, req, @@ -356,6 +358,22 @@ static void DeconvolutionParamParser(nnvm::NodeAttrs* attrs) { if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); if (param_.adj.ndim() == 0) param_.adj = Shape3(0, 0, 0); } + CHECK_EQ(param_.kernel.ndim(), param_.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while stride is " + << param_.stride; + CHECK_EQ(param_.kernel.ndim(), param_.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while dilate is " + << param_.dilate; + CHECK_EQ(param_.kernel.ndim(), param_.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while padding is " + << param_.pad; + CHECK_EQ(param_.kernel.ndim(), param_.adj.ndim()) + << "Adjustment must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while adjustment is " + << param_.adj; attrs->parsed = std::move(param_); } diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 16e5605b668e..ccc4acf8b4b1 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -137,10 +137,6 @@ static inline bool SupportMKLDNN(const NDArray &input) { && SupportStorageMKLDNN(input.storage_type()); } -static inline bool SupportMKLDNNConv(const NDArray &input) { - return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; -} - /* * This is to align address to a certain alignment. */ @@ -148,7 +144,11 @@ void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space); namespace op { struct ActivationParam; -bool SupportMKLDNNAct(const op::ActivationParam& param); +struct ConvolutionParam; +struct DeconvolutionParam; +bool SupportMKLDNNAct(const ActivationParam& param); +bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); +bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); } static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 453221f9b377..1e09d208b989 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -31,6 +31,12 @@ namespace mxnet { namespace op { +bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { + if (params.kernel.ndim() != 2) + return false; + return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; +} + static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( const ConvolutionParam& param, bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { @@ -39,16 +45,15 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; if (param.dilate.ndim() == 0 && bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); @@ -61,10 +66,8 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( return mkldnn::convolution_forward::primitive_desc(desc, engine); } else { mkldnn::memory::dims dilates{0, 0}; - if (param.dilate.ndim() == 2) { - dilates[0] = param.dilate[0] - 1; - dilates[1] = param.dilate[1] - 1; - } + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; if (bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, @@ -88,26 +91,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; if (param.dilate.ndim() == 0) { mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); } else { mkldnn::memory::dims dilates{0, 0}; - if (param.dilate.ndim() == 2) { - dilates[0] = param.dilate[0] - 1; - dilates[1] = param.dilate[1] - 1; - } + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); @@ -123,16 +123,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; if (param.dilate.ndim() == 0 && bias == nullptr) { mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); @@ -145,10 +144,8 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); } else { mkldnn::memory::dims dilates{0, 0}; - if (param.dilate.ndim() == 2) { - dilates[0] = param.dilate[0] - 1; - dilates[1] = param.dilate[1] - 1; - } + dilates[0] = param.dilate[0] - 1; + dilates[1] = param.dilate[1] - 1; if (bias == nullptr) { mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index af57b68cfd37..aedecdddb767 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -32,6 +32,12 @@ namespace mxnet { namespace op { +bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) { + if (params.kernel.ndim() != 2) + return false; + return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; +} + static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { mkldnn::memory::dims dims(1); // This is convolution on 4D data. The second dimension is the channel. @@ -67,31 +73,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } else if (param.stride.ndim() == 1) { - strides[0] = param.stride[0]; - strides[1] = param.stride[0]; - } else { - LOG(FATAL) << "Unsupported stride dim"; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } else if (param.pad.ndim() == 1) { - padding[0] = param.pad[0]; - padding[1] = param.pad[0]; - } else { - LOG(FATAL) << "Unsupported pad dim"; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; mkldnn::memory::dims dilate{0, 0}; - if (param.dilate.ndim() == 2) { - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - } + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, padding, dilate); mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, @@ -107,31 +100,18 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } else if (param.stride.ndim() == 1) { - strides[0] = param.stride[0]; - strides[1] = param.stride[0]; - } else { - LOG(FATAL) << "Unsupported stride dim"; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } else if (param.pad.ndim() == 1) { - padding[0] = param.pad[0]; - padding[1] = param.pad[0]; - } else { - LOG(FATAL) << "Unsupported pad dim"; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; mkldnn::memory::dims dilate{0, 0}; - if (param.dilate.ndim() == 2) { - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - } + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, padding, dilate); } @@ -144,31 +124,18 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); + CHECK_GE(param.stride.ndim(), 2U); + CHECK_GE(param.pad.ndim(), 2U); + CHECK_GE(param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - if (param.stride.ndim() == 2) { - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - } else if (param.stride.ndim() == 1) { - strides[0] = param.stride[0]; - strides[1] = param.stride[0]; - } else { - LOG(FATAL) << "Unsupported stride dim"; - } + strides[0] = param.stride[0]; + strides[1] = param.stride[1]; mkldnn::memory::dims padding{0, 0}; - if (param.pad.ndim() == 2) { - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - } else if (param.pad.ndim() == 1) { - padding[0] = param.pad[0]; - padding[1] = param.pad[0]; - } else { - LOG(FATAL) << "Unsupported pad dim"; - } + padding[0] = param.pad[0]; + padding[1] = param.pad[1]; mkldnn::memory::dims dilate{0, 0}; - if (param.dilate.ndim() == 2) { - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - } + dilate[0] = param.dilate[0] - 1; + dilate[1] = param.dilate[1] - 1; if (!has_bias) { mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero);