-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-359] fix checks on convolution parameters in MKLDNN. #10666
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,12 @@ | |
namespace mxnet { | ||
namespace op { | ||
|
||
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { | ||
if (params.kernel.ndim() != 2) | ||
return false; | ||
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; | ||
} | ||
|
||
static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | ||
const ConvolutionParam& param, bool is_train, const NDArray &data, | ||
const NDArray &weights, const NDArray *bias, const NDArray &output) { | ||
|
@@ -39,16 +45,15 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0 && bias == nullptr) { | ||
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
|
@@ -61,10 +66,8 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | |
return mkldnn::convolution_forward::primitive_desc(desc, engine); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
if (bias == nullptr) { | ||
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
|
@@ -88,26 +91,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0) { | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
mkldnn::padding_kind::zero); | ||
|
@@ -123,16 +123,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be CHECK_EQ ? |
||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0 && bias == nullptr) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
|
@@ -145,10 +144,8 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( | |
return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
if (bias == nullptr) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,12 @@ | |
namespace mxnet { | ||
namespace op { | ||
|
||
bool SupportMKLDNNConv(const DeconvolutionParam& params, const NDArray &input) { | ||
if (params.kernel.ndim() != 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to add check for strides and dilate too ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should have a check in the parameter parser of mxnet conv, so we don't need to check it in the MKLDNN code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are checking strides and dilates ndim to be greater than equal to 2, can we fallback to default implementation and return false here when ndim of stride, pad or dilates is less than 2 ? |
||
return false; | ||
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; | ||
} | ||
|
||
static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { | ||
mkldnn::memory::dims dims(1); | ||
// This is convolution on 4D data. The second dimension is the channel. | ||
|
@@ -67,31 +73,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. param.pad.ndim() == 1 will not use mkldnn anymore ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mxnet always assume 2 elements in the tuple. in the python, if the input is one element, it'll convert it to 2-element tuple, so in practice, we don't get stride with one element. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python will extend one element to two-element tuple. What about other frontend languages or what about someone calling c APIs to build their model? |
||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, | ||
strides, padding, dilate); | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
|
@@ -107,31 +100,18 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, | ||
strides, padding, dilate); | ||
} | ||
|
@@ -144,31 +124,18 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
if (!has_bias) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be CHECK_EQ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after ONNX is fixed, this should be CHECK_EQ. I didn't know if ONNIX would be fixed when I submitted the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please disregard my comment. I think this change shouldnt depend on whether onnx is fixed or not.
CHECK_GE
looks good to make it consistent with existing behavior.