Skip to content

Commit

Permalink
Enable CUDNN for conv1D (apache#11194) (apache#11270)
Browse files Browse the repository at this point in the history
* enable cudnn for conv1d

* add checks for backward

* fix build

* fix build

* fix lint

* Update convolution.cc
  • Loading branch information
eric-haibin-lin authored and marcoabreu committed Jun 14, 2018
1 parent 62a47a7 commit 6ae38f2
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/operator/nn/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,20 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
int dtype = inputs[conv::kData].type_flag_;

// If 1D convolution, use MXNet implementation
if (param.kernel.ndim() == 1) {
#if CUDNN_MAJOR < 5
if (param_.layout.value() != kNCW &&
param_.layout.value() != kNCHW &&
param_.layout.value() != kNCDHW) {
// Need CuDNN > 5.0 for layout support. use MXNet implementation
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
op.Forward(ctx, inputs, req, outputs);
})
return;
}
#endif

#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
Expand Down Expand Up @@ -162,15 +167,19 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &in_grad = outputs;
int dtype = out_grad.type_flag_;

// If 1D convolution, use MXNet implementation
if (param.kernel.ndim() == 1) {
#if CUDNN_MAJOR < 5
if (param_.layout.value() != kNCW &&
param_.layout.value() != kNCHW &&
param_.layout.value() != kNCDHW) {
// Need CuDNN > 5.0 for layout support. use MXNet implementation
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
ConvolutionOp<gpu, DType> op;
op.Init(param);
op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
})
return;
}
#endif
#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
Expand Down

0 comments on commit 6ae38f2

Please sign in to comment.