From 5ffbd43f70d85ed53ab5ca2cb4f281158414706f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 26 Jul 2022 15:38:50 -0400 Subject: [PATCH] Replace `@require CUDA` with `using GPUArraysCore` (#1272) * require GPUArrays instead of CUDA * more * change to unconditionally load GPUArraysCore * add GPUArrays dep * trivial trigger commit --- Project.toml | 4 ++++ src/lib/broadcast.jl | 29 +++++++++-------------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 9920da241..15f08ad2b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,8 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -31,6 +33,8 @@ ChainRulesTestUtils = "1" DiffRules = "1.4" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13" ForwardDiff = "0.10" +GPUArrays = "8.4.2" # not loaded, just a version bound +GPUArraysCore = "0.1.1" IRTools = "0.4.4" LogExpFunctions = "0.3.1" MacroTools = "0.5" diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 6dbfdb829..b3c16e823 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -253,43 +253,32 @@ end return y, bc_fwd_back end -@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin +using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame - const CuArrayStyle = CUDA.AbstractGPUArrayStyle - - if isdefined(CUDA, :cufunc) # CUDA < 3.0 - - @eval @adjoint broadcasted(::CuArrayStyle, f, args...) = - broadcast_forward(CUDA.cufunc(f), args...) - - else # CUDA >= 3.0 -- don't need cufunc(f). # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, # so perhaps this can be deleted? Possible edge case here: # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415 + @adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) = + broadcast_forward(f, args...) - @eval @adjoint broadcasted(::CuArrayStyle, f, args...) = - broadcast_forward(f, args...) - - end - - @adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} = + @adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} = T(xs), Δ -> (convert(Array, Δ), ) - @adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :) + @adjoint function sum(xs::AbstractGPUArray; dims = :) placeholder = similar(xs) sum(xs, dims = dims), Δ -> (placeholder .= Δ,) end # Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible - @adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...) + @adjoint function sum(f, xs::AbstractGPUArray; kws...) @assert !haskey(kws, :init) # TODO add init support (julia 1.6) return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) end - @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray} + @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray} Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),) end - @eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz] -end + pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz] +