diff --git a/src/gather.jl b/src/gather.jl index 1ad69df24..d75f89a2c 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) return dst end -function gather!(dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) +function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] max_dims_idx = prod(dims) diff --git a/src/scatter.jl b/src/scatter.jl index 6edf6e379..3507b906d 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -81,7 +81,7 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA dst end -for AT in (AbstractArray, AbstractGPUArray) +for AT in (AbstractArray, AnyGPUArray) @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT) Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) @@ -90,7 +90,7 @@ for AT in (AbstractArray, AbstractGPUArray) end end -function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP +function scatter!(op::OP, dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) where OP n_dims = scatter_dims(dst, src, idx) args = if n_dims == 0 ndrange = length(idx) @@ -228,7 +228,7 @@ end function ∇scatter_src( op::Union{typeof(*), typeof(/)}, Δ, dst, - src::AbstractGPUArray{Tsrc, Nsrc}, idx::AbstractGPUArray{Tidx, Nidx}, + src::AnyGPUArray{Tsrc, Nsrc}, idx::AnyGPUArray{Tidx, Nidx}, ) where {Tsrc, Nsrc, Tidx, Nidx} n_dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) diff --git a/src/upsample.jl b/src/upsample.jl index a320ca9e6..babd613fb 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -411,7 +411,7 @@ end end # Linear (GPU): parallelization along width dimension. -# TODO replace AbstractArray -> AbstractGPUArray once device arrays subtype it. +# TODO replace AbstractArray -> AnyGPUArray once device arrays subtype it. @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 3}, A, diff --git a/test/ext_cuda/gather.jl b/test/ext_cuda/gather.jl index 36d42dbcc..9fa30efa8 100644 --- a/test/ext_cuda/gather.jl +++ b/test/ext_cuda/gather.jl @@ -89,4 +89,18 @@ @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) + + @testset "views" begin + x = cu(rand(2, 5)) + v = view(x, axes(x)...) + i = cu([1, 2]) + outx = NNlib.gather(x, i) + outv = NNlib.gather(v, i) + @test outx == outv + + # discontinuous view + v2 = view(x, :, [1,3,5]) + outv2 = NNlib.gather(v2, i) + @test collect(outv2) == NNlib.gather(collect(v2), collect(i)) + end end