Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-359] fix checks on convolution parameters in MKLDNN. #10666

Merged
merged 4 commits into from
May 2, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand All @@ -68,7 +69,8 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
6 changes: 4 additions & 2 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNDeconvolutionForward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(DeconvolutionCompute<cpu>, attrs, ctx, inputs, req,
Expand All @@ -320,7 +321,8 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConv(inputs[0])) {
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ static inline bool SupportMKLDNN(const NDArray &input) {
&& SupportStorageMKLDNN(input.storage_type());
}

static inline bool SupportMKLDNNConv(const NDArray &input) {
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

/*
* This is to align address to a certain alignment.
*/
void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space);

namespace op {
struct ActivationParam;
bool SupportMKLDNNAct(const op::ActivationParam& param);
struct ConvolutionParam;
struct DeconvolutionParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNConv(const DeconvolutionParam& params, const NDArray &input);
}

static int GetTypeSize(int dtype) {
Expand Down
69 changes: 33 additions & 36 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
if (params.kernel.ndim() != 2)
return false;
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
const ConvolutionParam& param, bool is_train, const NDArray &data,
const NDArray &weights, const NDArray *bias, const NDArray &output) {
Expand All @@ -39,16 +45,15 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0 && bias == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
Expand All @@ -61,10 +66,8 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
return mkldnn::convolution_forward::primitive_desc(desc, engine);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
if (bias == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand All @@ -88,26 +91,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be CHECK_EQ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after ONNX is fixed, this should be CHECK_EQ. I didn't know if ONNIX would be fixed when I submitted the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please disregard my comment. I think this change shouldnt depend on whether onnx is fixed or not. CHECK_GE looks good to make it consistent with existing behavior.

CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0) {
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
mkldnn::padding_kind::zero);
Expand All @@ -123,16 +123,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be CHECK_EQ ?

CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
if (param.dilate.ndim() == 0 && bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero);
Expand All @@ -145,10 +144,8 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd);
} else {
mkldnn::memory::dims dilates{0, 0};
if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
}
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
if (bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand Down
99 changes: 33 additions & 66 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNConv(const DeconvolutionParam& params, const NDArray &input) {
if (params.kernel.ndim() != 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add check for strides and dilate too ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a check in the parameter parser of mxnet conv, so we don't need to check it in the MKLDNN code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are checking strides and dilates ndim to be greater than equal to 2, can we fallback to default implementation and return false here when ndim of stride, pad or dilates is less than 2 ?

return false;
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
}

static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) {
mkldnn::memory::dims dims(1);
// This is convolution on 4D data. The second dimension is the channel.
Expand Down Expand Up @@ -67,31 +73,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param.pad.ndim() == 1 will not use mkldnn anymore ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxnet always assume 2 elements in the tuple. in the python, if the input is one element, it'll convert it to 2-element tuple, so in practice, we don't get stride with one element.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python will extend one element to two-element tuple. What about other frontend languages or what about someone calling c APIs to build their model?

strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
strides, padding, dilate);
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
Expand All @@ -107,31 +100,18 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine,
strides, padding, dilate);
}
Expand All @@ -144,31 +124,18 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
mkldnn::memory::dims strides{0, 0};
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
strides[0] = param.stride[0];
strides[1] = param.stride[1];
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
padding[0] = param.pad[0];
padding[1] = param.pad[1];
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
}
dilate[0] = param.dilate[0] - 1;
dilate[1] = param.dilate[1] - 1;
if (!has_bias) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero);
Expand Down