Skip to content

Commit

Permalink
Replace @require CUDA with using GPUArraysCore (FluxML#1272)
Browse files Browse the repository at this point in the history
* require GPUArrays instead of CUDA

* more

* change to unconditionally load GPUArraysCore

* add GPUArrays dep

* trivial trigger commit
  • Loading branch information
mcabbott authored Jul 26, 2022
1 parent 995778d commit 5ffbd43
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -31,6 +33,8 @@ ChainRulesTestUtils = "1"
DiffRules = "1.4"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
GPUArrays = "8.4.2" # not loaded, just a version bound
GPUArraysCore = "0.1.1"
IRTools = "0.4.4"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
Expand Down
29 changes: 9 additions & 20 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,43 +253,32 @@ end
return y, bc_fwd_back
end

@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame

const CuArrayStyle = CUDA.AbstractGPUArrayStyle

if isdefined(CUDA, :cufunc) # CUDA < 3.0

@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)

else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
@adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) =
broadcast_forward(f, args...)

@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(f, args...)

end

@adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
T(xs), Δ -> (convert(Array, Δ), )

@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
@adjoint function sum(xs::AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end

@eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
end
pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]

0 comments on commit 5ffbd43

Please sign in to comment.