Skip to content

Commit

Permalink
make gather/scatter work with views (#546)
Browse files Browse the repository at this point in the history
* AbstractGPUArray -> AnyGPUArray

* tests

* don't test Enzyme

* add test on discontinuous view

* Update test/runtests.jl
  • Loading branch information
CarloLucibello authored Nov 6, 2023
1 parent 00fccc7 commit 8598c08
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
return dst
end

function gather!(dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray)
function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray)
n_dims = scatter_dims(src, dst, idx)
dims = size(src)[1:n_dims]
max_dims_idx = prod(dims)
Expand Down
6 changes: 3 additions & 3 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA
dst
end

for AT in (AbstractArray, AbstractGPUArray)
for AT in (AbstractArray, AnyGPUArray)
@eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT)
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
Expand All @@ -90,7 +90,7 @@ for AT in (AbstractArray, AbstractGPUArray)
end
end

function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP
function scatter!(op::OP, dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) where OP
n_dims = scatter_dims(dst, src, idx)
args = if n_dims == 0
ndrange = length(idx)
Expand Down Expand Up @@ -228,7 +228,7 @@ end

function ∇scatter_src(
op::Union{typeof(*), typeof(/)}, Δ, dst,
src::AbstractGPUArray{Tsrc, Nsrc}, idx::AbstractGPUArray{Tidx, Nidx},
src::AnyGPUArray{Tsrc, Nsrc}, idx::AnyGPUArray{Tidx, Nidx},
) where {Tsrc, Nsrc, Tidx, Nidx}
n_dims = Nsrc - Nidx
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
Expand Down
2 changes: 1 addition & 1 deletion src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ end
end

# Linear (GPU): parallelization along width dimension.
# TODO replace AbstractArray -> AbstractGPUArray once device arrays subtype it.
# TODO replace AbstractArray -> AnyGPUArray once device arrays subtype it.

@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where {
B <: GPU, T <: AbstractArray{<:Any, 3}, A,
Expand Down
14 changes: 14 additions & 0 deletions test/ext_cuda/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,18 @@
@test y isa CuArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)

@testset "views" begin
x = cu(rand(2, 5))
v = view(x, axes(x)...)
i = cu([1, 2])
outx = NNlib.gather(x, i)
outv = NNlib.gather(v, i)
@test outx == outv

# discontinuous view
v2 = view(x, :, [1,3,5])
outv2 = NNlib.gather(v2, i)
@test collect(outv2) == NNlib.gather(collect(v2), collect(i))
end
end

0 comments on commit 8598c08

Please sign in to comment.