From 73ebc762c3dbaa122897d377f128634ec291d3a6 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 14 Jun 2018 13:11:25 -0700 Subject: [PATCH] Enable CUDNN for conv1D (#11194) (#11270) * enable cudnn for conv1d * add checks for backward * fix build * fix build * fix lint * Update convolution.cc --- src/operator/nn/convolution.cu | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index 65a320ded169..9f61212d5c78 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -89,8 +89,11 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param = nnvm::get(attrs.parsed); int dtype = inputs[conv::kData].type_flag_; - // If 1D convolution, use MXNet implementation - if (param.kernel.ndim() == 1) { +#if CUDNN_MAJOR < 5 + if (param_.layout.value() != kNCW && + param_.layout.value() != kNCHW && + param_.layout.value() != kNCDHW) { + // Need CuDNN > 5.0 for layout support. use MXNet implementation MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); @@ -98,6 +101,8 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, }) return; } +#endif + #if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW && @@ -162,8 +167,11 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, const std::vector &in_grad = outputs; int dtype = out_grad.type_flag_; - // If 1D convolution, use MXNet implementation - if (param.kernel.ndim() == 1) { +#if CUDNN_MAJOR < 5 + if (param_.layout.value() != kNCW && + param_.layout.value() != kNCHW && + param_.layout.value() != kNCDHW) { + // Need CuDNN > 5.0 for layout support. use MXNet implementation MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { ConvolutionOp op; op.Init(param); @@ -171,6 +179,7 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, }) return; } +#endif #if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7 if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW &&