diff --git a/src/cudnn/conv.jl b/src/cudnn/conv.jl index a4f23c1..6dca751 100644 --- a/src/cudnn/conv.jl +++ b/src/cudnn/conv.jl @@ -1,6 +1,6 @@ - -using NNlib: DenseConvDims +using NNlib: DenseConvDims, DepthwiseConvDims import NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act! +import NNlib: depthwise_conv!, ∇depthwise_conv_filter!, ∇depthwise_conv_data! using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims, cudnnConvolutionDescriptor, cudnnConvolutionBwdDataAlgoPerf, @@ -10,8 +10,8 @@ using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims, const CUDNNFloat = Union{Float16,Float32,Float64} -function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}) where T - mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION) +function cudnnConvolutionDescriptor(cdims::ConvDims, x::DenseCuArray{T}) where T + mode = (NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION) cudnnConvolutionDescriptor(convdims(nnlibPadding(cdims),size(x),0), convdims(NNlib.stride(cdims),size(x),1), convdims(NNlib.dilation(cdims),size(x),1), @@ -22,8 +22,8 @@ function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}) wh Cint(NNlib.groupcount(cdims))) end -function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; - alpha=1, beta=0, algo=-1) where T<:CUDNNFloat +function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::ConvDims; + alpha = 1, beta = 0, algo = -1) where T<:CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end @@ -34,9 +34,9 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y) end -function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, - cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity; - z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat +function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, + cdims::ConvDims, bias::DenseCuArray{T}, σ = identity; + z::DenseCuArray{T} = y, alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end @@ -54,7 +54,7 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{ end function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, - cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat + cdims::ConvDims; alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end @@ -72,7 +72,7 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray end function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, - cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat + cdims::ConvDims; alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end @@ -95,3 +95,18 @@ function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0 cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db) return db end + +function depthwise_conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DepthwiseConvDims; + alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat + conv!(y, x, w, cims; alpha, beta, algo) +end + +function ∇depthwise_conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, + cdims::ConvDims; alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat + ∇conv_filter!(dw, x, dy, cdims; alpha, beta, algo) +end + +function ∇depthwise_conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, + cdims::ConvDims; alpha = 1, beta = 0, algo = -1) where T <: CUDNNFloat + ∇conv_data!(dx, dy, w, cdims; alpha, beta, algo) +end