diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index 65a320ded169..9f61212d5c78 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -89,8 +89,11 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param = nnvm::get(attrs.parsed); int dtype = inputs[conv::kData].type_flag_; - // If 1D convolution, use MXNet implementation - if (param.kernel.ndim() == 1) { +#if CUDNN_MAJOR < 5 + if (param_.layout.value() != kNCW && + param_.layout.value() != kNCHW && + param_.layout.value() != kNCDHW) { + // Need CuDNN > 5.0 for layout support. use MXNet implementation MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); @@ -98,6 +101,8 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, }) return; } +#endif + #if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW && @@ -162,8 +167,11 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, const std::vector &in_grad = outputs; int dtype = out_grad.type_flag_; - // If 1D convolution, use MXNet implementation - if (param.kernel.ndim() == 1) { +#if CUDNN_MAJOR < 5 + if (param_.layout.value() != kNCW && + param_.layout.value() != kNCHW && + param_.layout.value() != kNCDHW) { + // Need CuDNN > 5.0 for layout support. use MXNet implementation MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); @@ -171,6 +179,7 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, }) return; } +#endif #if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW &&