Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-359] fix checks on convolution parameters in MKLDNN. (#10666)
Browse files Browse the repository at this point in the history
* fix check on tuples of conv.

* check params in (de)conv.

* rename.

* add messages.
  • Loading branch information
zheng-da authored and piiswrong committed May 2, 2018
1 parent ebd8a6b commit 1420697
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 111 deletions.
18 changes: 16 additions & 2 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand All @@ -68,7 +69,8 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down Expand Up @@ -363,6 +365,18 @@ static void ConvolutionParamParser(nnvm::NodeAttrs* attrs) {
if (param_.dilate.ndim() == 0) param_.dilate = Shape3(1, 1, 1);
if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0);
}
CHECK_EQ(param_.kernel.ndim(), param_.stride.ndim())
<< "Stride must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while stride is "
<< param_.stride;
CHECK_EQ(param_.kernel.ndim(), param_.dilate.ndim())
<< "Dilate must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while dilate is "
<< param_.dilate;
CHECK_EQ(param_.kernel.ndim(), param_.pad.ndim())
<< "Padding must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while padding is "
<< param_.pad;
attrs->parsed = std::move(param_);
}

Expand Down
22 changes: 20 additions & 2 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
if (SupportMKLDNNDeconv(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req,
Expand All @@ -320,7 +321,8 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
if (SupportMKLDNNDeconv(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down Expand Up @@ -356,6 +358,22 @@ static void DeconvolutionParamParser(nnvm::NodeAttrs* attrs) {
if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0);
if (param_.adj.ndim() == 0) param_.adj = Shape3(0, 0, 0);
}
CHECK_EQ(param_.kernel.ndim(), param_.stride.ndim())
<< "Stride must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while stride is "
<< param_.stride;
CHECK_EQ(param_.kernel.ndim(), param_.dilate.ndim())
<< "Dilate must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while dilate is "
<< param_.dilate;
CHECK_EQ(param_.kernel.ndim(), param_.pad.ndim())
<< "Padding must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while padding is "
<< param_.pad;
CHECK_EQ(param_.kernel.ndim(), param_.adj.ndim())
<< "Adjustment must have the same number of dimensions with kernel_size,"
<< "but kernel_size is set to " << param_.kernel << " while adjustment is "
<< param_.adj;
attrs->parsed = std::move(param_);
}

Expand Down
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,18 @@ static inline bool SupportMKLDNN(const NDArray &input) {
&& SupportStorageMKLDNN(input.storage_type());
}

static inline bool SupportMKLDNNConv(const NDArray &input) {
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

/*
* This is to align address to a certain alignment.
*/
void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space);

namespace op {
struct ActivationParam;
bool SupportMKLDNNAct(const op::ActivationParam& param);
struct ConvolutionParam;
struct DeconvolutionParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
}

static int GetTypeSize(int dtype) {
Expand Down
69 changes: 33 additions & 36 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
if (params.kernel.ndim() != 2)
return false;
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
const ConvolutionParam& param, bool is_train, const NDArray &data,
const NDArray &weights, const NDArray *bias, const NDArray &output) {
Expand All @@ -39,16 +45,15 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0 && bias == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
Expand All @@ -61,10 +66,8 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
return mkldnn::convolution_forward::primitive_desc(desc, engine);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
if (bias == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand All @@ -88,26 +91,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0) {
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
mkldnn::padding_kind::zero);
Expand All @@ -123,16 +123,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0 && bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
Expand All @@ -145,10 +144,8 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
if (bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand Down
99 changes: 33 additions & 66 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) {
if (params.kernel.ndim() != 2)
return false;
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) {
mkldnn::memory::dims dims(1);
// This is convolution on 4D data. The second dimension is the channel.
Expand Down Expand Up @@ -67,31 +73,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
strides, padding, dilate);
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
Expand All @@ -107,31 +100,18 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
strides, padding, dilate);
}
Expand All @@ -144,31 +124,18 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
if (!has_bias) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero);
Expand Down

0 comments on commit 1420697

Please sign in to comment.