-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-359] fix checks on convolution parameters in MKLDNN. #10666
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,12 @@ | |
namespace mxnet { | ||
namespace op { | ||
|
||
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { | ||
if (params.kernel.ndim() != 2) | ||
return false; | ||
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; | ||
} | ||
|
||
static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | ||
const ConvolutionParam& param, bool is_train, const NDArray &data, | ||
const NDArray &weights, const NDArray *bias, const NDArray &output) { | ||
|
@@ -39,16 +45,15 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0 && bias == nullptr) { | ||
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
|
@@ -61,10 +66,8 @@ static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | |
return mkldnn::convolution_forward::primitive_desc(desc, engine); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
if (bias == nullptr) { | ||
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
|
@@ -88,26 +91,23 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0) { | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
return mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
mkldnn::padding_kind::zero); | ||
|
@@ -123,16 +123,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be CHECK_EQ ? |
||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
if (param.dilate.ndim() == 0 && bias == nullptr) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); | ||
|
@@ -145,10 +144,8 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( | |
return mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); | ||
} else { | ||
mkldnn::memory::dims dilates{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
} | ||
dilates[0] = param.dilate[0] - 1; | ||
dilates[1] = param.dilate[1] - 1; | ||
if (bias == nullptr) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
data_md, weight_md, out_md, strides, dilates, padding, padding, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,12 @@ | |
namespace mxnet { | ||
namespace op { | ||
|
||
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) { | ||
if (params.kernel.ndim() != 2) | ||
return false; | ||
return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; | ||
} | ||
|
||
static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { | ||
mkldnn::memory::dims dims(1); | ||
// This is convolution on 4D data. The second dimension is the channel. | ||
|
@@ -67,31 +73,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. param.pad.ndim() == 1 will not use mkldnn anymore ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mxnet always assume 2 elements in the tuple. in the python, if the input is one element, it'll convert it to 2-element tuple, so in practice, we don't get stride with one element. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python will extend one element to two-element tuple. What about other frontend languages or what about someone calling c APIs to build their model? |
||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
auto bwd_pd = GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, | ||
strides, padding, dilate); | ||
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, | ||
|
@@ -107,31 +100,18 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, | ||
strides, padding, dilate); | ||
} | ||
|
@@ -144,31 +124,18 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( | |
auto weight_md = GetWeightDesc(weights, param.num_group); | ||
auto out_md = GetMemDesc(output); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
CHECK_GE(param.stride.ndim(), 2U); | ||
CHECK_GE(param.pad.ndim(), 2U); | ||
CHECK_GE(param.dilate.ndim(), 2U); | ||
mkldnn::memory::dims strides{0, 0}; | ||
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
} | ||
dilate[0] = param.dilate[0] - 1; | ||
dilate[1] = param.dilate[1] - 1; | ||
if (!has_bias) { | ||
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, | ||
out_md, weight_md, data_md, strides, dilate, padding, padding, mkldnn::padding_kind::zero); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be CHECK_EQ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after ONNX is fixed, this should be CHECK_EQ. I didn't know if ONNIX would be fixed when I submitted the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please disregard my comment. I think this change shouldnt depend on whether onnx is fixed or not.
CHECK_GE
looks good to make it consistent with existing behavior.