diff --git a/Project.toml b/Project.toml index 17075993e0..18dafa1234 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -40,7 +39,6 @@ FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxEnzymeExt = "Enzyme" FluxMPIExt = "MPI" FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] -FluxMetalExt = "Metal" [compat] AMDGPU = "1" @@ -50,11 +48,10 @@ ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.12, 0.13" Functors = "0.4" -MLDataDevices = "1.2.0" +MLDataDevices = "1.4.0" MLUtils = "0.4" MPI = "0.20.19" MacroTools = "0.5" -Metal = "0.5, 1" NCCL = "0.1.1" NNlib = "0.9.22" OneHotArrays = "0.2.4" diff --git a/docs/make.jl b/docs/make.jl index f0883b6ac8..d6c4f3b878 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -36,6 +36,7 @@ makedocs( "Flat vs. Nested" => "reference/destructure.md", "Callback Helpers" => "reference/training/callbacks.md", "Gradients -- Zygote.jl" => "reference/training/zygote.md", + "Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md", "Batching Data -- MLUtils.jl" => "reference/data/mlutils.md", "OneHotArrays.jl" => "reference/data/onehot.md", "Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md", diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 8a08b47986..57d87c57b7 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -16,68 +16,13 @@ in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not requi !!! compat "Flux ≤ 0.13" Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually. -## Checking GPU Availability - -By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following: - -```julia -julia> using CUDA - -julia> CUDA.functional() -true -``` - -For AMD GPU: - -```julia -julia> using AMDGPU - -julia> AMDGPU.functional() -true - -julia> AMDGPU.functional(:MIOpen) -true -``` - -For Metal GPU: - -```julia -julia> using Metal - -julia> Metal.functional() -true -``` - -## Selecting GPU backend - -Available GPU backends are: `CUDA`, `AMDGPU` and `Metal`. - -Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use. - -There are two ways you can specify it: - -- From the REPL/code in your project, call `Flux.gpu_backend!("AMDGPU")` and restart (if needed) Julia session for the changes to take effect. -- In `LocalPreferences.toml` file in you project directory specify: -```toml -[Flux] -gpu_backend = "AMDGPU" -``` - -Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable: - -```julia -julia> Flux.GPU_BACKEND -"CUDA" -``` - -The current backend will affect the behaviour of methods like the method `gpu` described below. ## Basic GPU Usage Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. -For example, we can use `CUDA.CuArray` (with the `cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU. +For example, we can use `CUDA.CuArray` (with the `CUDA.cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU. (Note that you need to have CUDA available to use CUDA.CuArray – please see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) instructions for more details.) @@ -146,6 +91,50 @@ julia> x |> cpu 0.7766742 ``` +## Using device objects + +In Flux, you can create `device` objects which can be used to easily transfer models and data to GPUs (and defaulting to using the CPU if no GPU backend is available). +These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux uses internally and re-exports. + +Device objects can be automatically created using the [`cpu_device`](@ref MLDataDevices.cpu_device) and [`gpu_device`](@ref MLDataDevices.gpu_device) functions. For instance, the `gpu` and `cpu` functions are just convenience functions defined as + +```julia +cpu(x) = cpu_device()(x) +gpu(x) = gpu_device()(x) +``` + +`gpu_device` performs automatic GPU device selection and returns a device object: +- If no GPU is available, it returns a `CPUDevice` object. +- If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Flux.gpu_backend!()`. If the trigger package corresponding to the device is not loaded (e.g. with `using CUDA`), then a warning is displayed. +- If no LocalPreferences option is present, then the first working GPU with loaded trigger package is used. + +Consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): + +```julia-repl +julia> using Flux, CUDA; + +julia> device = gpu_device() # returns handle to an NVIDIA GPU if available +(::CUDADevice{Nothing}) (generic function with 4 methods) + +julia> model = Dense(2 => 3); + +julia> model.weight # the model initially lives in CPU memory +3×2 Matrix{Float32}: + -0.984794 -0.904345 + 0.720379 -0.486398 + 0.851011 -0.586942 + +julia> model = model |> device # transfer model to the GPU +Dense(2 => 3) # 9 parameters + +julia> model.weight +3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + -0.984794 -0.904345 + 0.720379 -0.486398 + 0.851011 -0.586942 +``` + + ## Transferring Training Data In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways: @@ -227,65 +216,8 @@ To select specific devices by device id: $ export CUDA_VISIBLE_DEVICES='0,1' ``` - More information for conditional use of GPUs in CUDA.jl can be found in its [documentation](https://cuda.juliagpu.org/stable/installation/conditional/#Conditional-use), and information about the specific use of the variable is described in the [Nvidia CUDA blog post](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/). -## Using device objects - -As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. -These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports. - -A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.gpu_device) function. -`gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference): - -```julia-repl -julia> using Flux, CUDA; - -julia> device = gpu_device() # returns handle to an NVIDIA GPU if available -(::CUDADevice{Nothing}) (generic function with 4 methods) - -julia> model = Dense(2 => 3); - -julia> model.weight # the model initially lives in CPU memory -3×2 Matrix{Float32}: - -0.984794 -0.904345 - 0.720379 -0.486398 - 0.851011 -0.586942 - -julia> model = model |> device # transfer model to the GPU -Dense(2 => 3) # 9 parameters - -julia> model.weight -3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: - -0.984794 -0.904345 - 0.720379 -0.486398 - 0.851011 -0.586942 -``` - -The device preference can also be set via the [`gpu_backend!`](@ref MLDataDevices.gpu_backend!) function. For instance, below we first set our device preference to `"AMDGPU"`: - -```julia-repl -julia> gpu_backend!("AMDGPU") -[ Info: GPU backend has been set to AMDGPU. Restart Julia to use the new backend. -``` -If no functional GPU backend is available, the device will default to a CPU device. -You can also explictly request a CPU device by calling the [`cpu_device`](@ref MLDataDevices.cpu_device) function. - -```julia-repl -julia> using Flux, MLDataDevices - -julia> cdev = cpu_device() -(::CPUDevice{Nothing}) (generic function with 4 methods) - -julia> gdev = gpu_device(force=true) # force GPU device, error if no GPU is available -(::CUDADevice{Nothing}) (generic function with 4 methods) - -julia> model = Dense(2 => 3); # model in CPU memory - -julia> gmodel = model |> gdev; # transfer model to GPU - -julia> cmodel = gmodel |> cdev; # transfer model back to CPU -``` ## Data movement across GPU devices @@ -344,24 +276,6 @@ CuDevice(1): NVIDIA TITAN RTX Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends. -!!! warning "Printing models after moving to a different device" - - Due to a limitation in how GPU packages currently work, printing - models on the REPL after moving them to a GPU device which is different - from the current device will lead to an error. - - -```@docs -MLDataDevices.cpu_device -MLDataDevices.default_device_rng -MLDataDevices.get_device -MLDataDevices.gpu_device -MLDataDevices.gpu_backend! -MLDataDevices.get_device_type -MLDataDevices.reset_gpu_device! -MLDataDevices.supported_gpu_backends -MLDataDevices.DeviceIterator -``` ## Distributed data parallel training @@ -479,3 +393,35 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true) We don't run CUDA-aware tests so you're running it at own risk. + +## Checking GPU Availability + +By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following: + +```julia +julia> using CUDA + +julia> CUDA.functional() +true +``` + +For AMD GPU: + +```julia +julia> using AMDGPU + +julia> AMDGPU.functional() +true + +julia> AMDGPU.functional(:MIOpen) +true +``` + +For Metal GPU: + +```julia +julia> using Metal + +julia> Metal.functional() +true +``` \ No newline at end of file diff --git a/docs/src/reference/data/mldatadevices.md b/docs/src/reference/data/mldatadevices.md new file mode 100644 index 0000000000..86b0e474f6 --- /dev/null +++ b/docs/src/reference/data/mldatadevices.md @@ -0,0 +1,19 @@ +# Transferring data across devices + +Flux relies on the [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl/blob/main/src/public.jl) package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types. + +```@docs +MLDataDevices.cpu_device +MLDataDevices.default_device_rng +MLDataDevices.functional +MLDataDevices.get_device +MLDataDevices.gpu_device +MLDataDevices.gpu_backend! +MLDataDevices.get_device_type +MLDataDevices.isleaf +MLDataDevices.loaded +MLDataDevices.reset_gpu_device! +MLDataDevices.set_device! +MLDataDevices.supported_gpu_backends +MLDataDevices.DeviceIterator +``` diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index b44c5106e8..85ce0365cb 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -3,10 +3,10 @@ module FluxAMDGPUExt import ChainRulesCore import ChainRulesCore: NoTangent import Flux -import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap +import Flux: adapt_storage, fmap import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib -using MLDataDevices: MLDataDevices +using MLDataDevices using AMDGPU using Adapt using Random @@ -14,38 +14,11 @@ using Zygote const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat -# Set to boolean on the first call to check_use_amdgpu -const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) - - -function check_use_amdgpu() - if !isnothing(USE_AMDGPU[]) - return - end - - USE_AMDGPU[] = AMDGPU.functional() - if USE_AMDGPU[] - if !AMDGPU.functional(:MIOpen) - @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available." - end - else - @info """ - The AMDGPU function is being called but AMDGPU.jl is not functional. - Defaulting back to the CPU. (No action is required if you want to run on the CPU). - """ maxlog=1 - end - return -end - -ChainRulesCore.@non_differentiable check_use_amdgpu() include("functor.jl") include("batchnorm.jl") include("conv.jl") -function __init__() - Flux.AMDGPU_LOADED[] = true -end # TODO # fail early if input to the model is not on the device (e.g. on the host) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index c2b6420ca1..edd511d28b 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -1,61 +1,3 @@ -# Convert Float64 to Float32, but preserve Float16. -function adapt_storage(to::FluxAMDGPUAdaptor, x::AbstractArray) - if to.id === nothing - if (typeof(x) <: AbstractArray{Float16, N} where N) - N = length(size(x)) - return isbits(x) ? x : ROCArray{Float16, N}(x) - elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) - N = length(size(x)) - return isbits(x) ? x : ROCArray{Float32, N}(x) - else - return isbits(x) ? x : ROCArray(x) - end - end - - old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0 - - if !(x isa ROCArray) - AMDGPU.device!(AMDGPU.devices()[to.id + 1]) # adding 1 because ids start from 0 - if (typeof(x) <: AbstractArray{Float16, N} where N) - N = length(size(x)) - x_new = isbits(x) ? x : ROCArray{Float16, N}(x) - elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) - N = length(size(x)) - x_new = isbits(x) ? x : ROCArray{Float32, N}(x) - else - x_new = isbits(x) ? x : ROCArray(x) - end - AMDGPU.device!(AMDGPU.devices()[old_id + 1]) - return x_new - elseif AMDGPU.device_id(AMDGPU.device(x)) == to.id - return x - else - AMDGPU.device!(AMDGPU.devices()[to.id + 1]) - x_new = copy(x) - AMDGPU.device!(AMDGPU.devices()[old_id + 1]) - return x_new - end -end - -adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) = - ROCArray(collect(x)) -adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x)) -adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() -adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x -adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error(""" - Cannot map RNG of type $(typeof(x)) to AMDGPU. - AMDGPU execution only supports Random.default_rng().""") - -adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng() - -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray, -) - adapt_storage(to, x), dx -> ( - NoTangent(), NoTangent(), - adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx))) -end - # Since MIOpen supports only cross-correlation as convolution, # for the actual convolution, we flip horizontally and vertically the weights. # Same for CPU -> GPU & GPU -> CPU movements. @@ -70,23 +12,15 @@ const AMDGPU_CONV = FLUX_CONV{ROCArray} _conv_basetype(::Conv) = Conv _conv_basetype(::ConvTranspose) = ConvTranspose -Flux._isleaf(::AMDGPU_CONV) = true - -_exclude(x) = Flux._isleaf(x) -_exclude(::CPU_CONV) = true - -function _amd(id::Union{Nothing, Int}, x) - check_use_amdgpu() - USE_AMDGPU[] || return x - fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(id), x), x; exclude=_exclude) -end +MLDataDevices.isleaf(::AMDGPU_CONV) = true +MLDataDevices.isleaf(::CPU_CONV) = true _other_args(m::Conv) = (m.stride, m.pad, m.dilation, m.groups) _other_args(m::ConvTranspose) = (m.stride, m.pad, m.outpad, m.dilation, m.groups) # CPU -> GPU -function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV) +function Adapt.adapt_structure(to::AMDGPUDevice, m::CPU_CONV) flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2)) _conv_basetype(m)( Adapt.adapt(to, m.σ), @@ -97,17 +31,13 @@ end # Don't adapt again. -Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::AMDGPU_CONV) = m +Adapt.adapt_structure(to::AMDGPUDevice, m::AMDGPU_CONV) = m # GPU -> CPU -function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV) +function Adapt.adapt_structure(to::CPUDevice, m::AMDGPU_CONV) dims = ntuple(i -> i, ndims(m.weight) - 2) _conv_basetype(m)( Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), Adapt.adapt(to, m.bias), _other_args(m)...) end - -function Flux._get_device(::Val{:AMDGPU}, id::Int) # id should start from 0 - return MLDataDevices.gpu_device(id+1, force=true) -end diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index 02a18dc4bc..679d359da3 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -1,44 +1,11 @@ module FluxCUDAExt -using Flux -import Flux: _cuda -using Flux: FluxCPUAdaptor, FluxCUDAAdaptor, fmap -using CUDA -using NNlib -using Zygote -using ChainRulesCore -using Random -using Adapt -import Adapt: adapt_storage -using MLDataDevices: MLDataDevices - - -const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) - -function check_use_cuda() - if !isnothing(USE_CUDA[]) - return - end - - USE_CUDA[] = CUDA.functional() - if !USE_CUDA[] - @info """ - The CUDA function is being called but CUDA.jl is not functional. - Defaulting back to the CPU. (No action is required if you want to run on the CPU). - """ maxlog=1 - end - return -end - -ChainRulesCore.@non_differentiable check_use_cuda() - -include("functor.jl") - function __init__() - Flux.CUDA_LOADED[] = true - try - Base.require(Main, :cuDNN) + # Let's try to load the cuDNN package if it is not already loaded + # Thanks to this, users can just write `using CUDA` + # to obtain full CUDA/cuDNN support in Flux. + Base.require(Main, :cuDNN) catch @warn """Package cuDNN not found in current path. - Run `import Pkg; Pkg.add(\"cuDNN\")` to install the cuDNN package, then restart julia. @@ -47,4 +14,4 @@ function __init__() end end -end +end \ No newline at end of file diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl deleted file mode 100644 index 08de4994b2..0000000000 --- a/ext/FluxCUDAExt/functor.jl +++ /dev/null @@ -1,61 +0,0 @@ -adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) - -function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) - to.id === nothing && return CUDA.cu(x) - - # remember current device - old_id = CUDA.device().handle - - if !(x isa CuArray) - CUDA.device!(to.id) - x_new = CUDA.cu(x) - CUDA.device!(old_id) - return x_new - elseif CUDA.device(x).handle == to.id - return x - else - CUDA.device!(to.id) - x_new = copy(x) - CUDA.device!(old_id) - return x_new - end -end - -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) -adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() -adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x -adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = - error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") - -# TODO: figure out the correct design for OneElement -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) -# Patch for GPU support until we can make OneElement smarter -if isdefined(Zygote.ChainRules, :OneElement) - adapt_storage(to::FluxCUDAAdaptor, x::Zygote.ChainRules.OneElement) = CUDA.cu(collect(x)) -end - -adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) -adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() - -function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx))) -end - -ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AnyCuArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCUDAAdaptor(), unthunk(Δ))) - -ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ) - -ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) - -function _cuda(id::Union{Nothing, Int}, x) - check_use_cuda() - USE_CUDA[] || return x - fmap(x -> Adapt.adapt(FluxCUDAAdaptor(id), x), x; exclude=Flux._isleaf) -end - -function Flux._get_device(::Val{:CUDA}, id::Int) - return MLDataDevices.gpu_device(id+1, force=true) -end diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl deleted file mode 100644 index fe5fb2e7e5..0000000000 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ /dev/null @@ -1,35 +0,0 @@ -module FluxMetalExt - -import Flux -import Flux: FluxCPUAdaptor, FluxMetalAdaptor, _metal, _isleaf, adapt_storage, fmap -import NNlib -using ChainRulesCore -using MLDataDevices: MLDataDevices -using Metal -using Adapt -using Random -using Zygote - -const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) - -function check_use_metal() - isnothing(USE_METAL[]) || return - - USE_METAL[] = Metal.functional() - if !USE_METAL[] - @info """ - The Metal function is being called but Metal.jl is not functional. - Defaulting back to the CPU. (No action is required if you want to run on the CPU). - """ maxlog=1 - end - return -end -ChainRulesCore.@non_differentiable check_use_metal() - -include("functor.jl") - -function __init__() - Flux.METAL_LOADED[] = true -end - -end diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl deleted file mode 100644 index 443c824e7e..0000000000 --- a/ext/FluxMetalExt/functor.jl +++ /dev/null @@ -1,40 +0,0 @@ -# Convert Float64 to Float32, but preserve Float16. -adapt_storage(::FluxMetalAdaptor, x::T) where T <: AbstractArray = - isbits(x) ? x : MtlArray(x) -adapt_storage(::FluxMetalAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = - isbits(x) ? x : MtlArray{Float32, N}(x) -adapt_storage(::FluxMetalAdaptor, x::AbstractArray{Float16, N}) where N = - isbits(x) ? x : MtlArray{Float16, N}(x) - -adapt_storage(::FluxMetalAdaptor, x::Zygote.FillArrays.AbstractFill) = - MtlArray(collect(x)) -adapt_storage(::FluxMetalAdaptor, x::Zygote.OneElement) = MtlArray(collect(x)) -adapt_storage(::FluxMetalAdaptor, x::Random.TaskLocalRNG) = - Metal.GPUArrays.default_rng(MtlArray) -adapt_storage(::FluxMetalAdaptor, x::Metal.GPUArrays.RNG) = x -adapt_storage(::FluxMetalAdaptor, x::AbstractRNG) = error(""" - Cannot map RNG of type $(typeof(x)) to Metal. - Metal execution only supports Random.default_rng().""") - -adapt_storage(::FluxCPUAdaptor, x::Metal.GPUArrays.RNG) = Random.default_rng() - -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::MtlArray, -) - adapt_storage(to, x), dx -> ( - NoTangent(), NoTangent(), - adapt_storage(FluxMetalAdaptor(), unthunk(dx))) -end - - -function _metal(x) - check_use_metal() - USE_METAL[] || return x - fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf) -end - -function Flux._get_device(::Val{:Metal}, id::Int) - @assert id == 0 "Metal backend only supports one device at the moment" - return MLDataDevices.gpu_device(force=true) -end - diff --git a/src/deprecations.jl b/src/deprecations.jl index 6148894dbe..8e507a4645 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -115,7 +115,6 @@ end # v0.14 deprecations @deprecate default_rng_value() Random.default_rng() -Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor # Issue 2476, after ConvTranspose got a new field in 2462. Minimal fix to allow loading? function loadmodel!(dst::ConvTranspose, src::NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups)}; kw...) @@ -135,22 +134,14 @@ function get_device(backend::String, idx::Int = 0) backend = "AMDGPU" end if backend == "CPU" - return MLDataDevices.CPUDevice() + return cpu_device() else - return _get_device(Val(Symbol(backend)), idx) - end -end - -function _get_device(::Val{D}, idx) where D - if D ∈ (:CUDA, :AMDGPU, :Metal) - error(string("Unavailable backend: ", D,". Try importing the corresponding package with `using ", D, "`.")) - else - error(string("Unsupported backend: ", D, ". Supported backends are ", (:CUDA, :AMDGPU, :Metal), ".")) + return gpu_device(idx+1, force=true) end end function supported_devices() - Base.depwarn("supported_devices() is deprecated. Use `supported_gpu_backends()` instead.", :supported_devices) + Base.depwarn("`supported_devices()` is deprecated. Use `supported_gpu_backends()` instead.", :supported_devices) return MLDataDevices.supported_gpu_backends() end diff --git a/src/functor.jl b/src/functor.jl index 16efa99a20..b1c489b61e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -134,21 +134,6 @@ end # From @macroexpand Zygote.@non_differentiable params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing -struct FluxCPUAdaptor end - -# define rules for handling structured arrays -adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x -adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x -adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x -adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x -adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x - - -# The following rrules for adapt are here to avoid double wrapping issues -# as seen in https://github.com/FluxML/Flux.jl/pull/2117#discussion_r1027321801 -ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AbstractArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ) @@ -179,15 +164,7 @@ julia> m.bias 0.0 ``` """ -cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x, exclude = _isleaf) - -_isleaf(x) = Functors.isleaf(x) - -_isleaf(::AbstractArray{<:Number}) = true -_isleaf(::AbstractArray{T}) where T = isbitstype(T) -_isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false - -_isleaf(::AbstractRNG) = true +cpu(x) = cpu_device()(x) # Remove when # https://github.com/JuliaPackaging/Preferences.jl/issues/39 @@ -197,12 +174,6 @@ function gpu_backend!(backend::String) MLDataDevices.gpu_backend!(backend) end -# the order below is important -const GPU_BACKENDS = ("CUDA", "AMDGPU", "Metal", "CPU") -# const GPU_BACKEND = load_preference(MLDataDevices, "gpu_backend", "CUDA") -# https://github.com/JuliaPackaging/Preferences.jl/issues/39 -const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") - """ gpu(m) @@ -234,25 +205,7 @@ julia> typeof(m_gpu.weight) CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer} ``` """ -function gpu(x) - @static if GPU_BACKEND == "CUDA" - gpu(FluxCUDAAdaptor(), x) - elseif GPU_BACKEND == "AMD" - @warn "\"AMD\" backend is deprecated. Please use \"AMDGPU\" instead." maxlog=1 - gpu(FluxAMDGPUAdaptor(), x) - elseif GPU_BACKEND == "AMDGPU" - gpu(FluxAMDGPUAdaptor(), x) - elseif GPU_BACKEND == "Metal" - gpu(FluxMetalAdaptor(), x) - elseif GPU_BACKEND == "CPU" - cpu(x) - else - error(""" - Unsupported GPU backend: $GPU_BACKEND. - Supported backends are: $GPU_BACKENDS. - """) - end -end +gpu(x) = gpu_device()(x) # Precision @@ -323,73 +276,6 @@ f16(m) = _paramtype(Float16, m) @functor Cholesky trainable(c::Cholesky) = () -# CUDA extension. ######## - -Base.@kwdef struct FluxCUDAAdaptor - id::Union{Nothing, Int} = nothing -end - -const CUDA_LOADED = Ref{Bool}(false) - -function gpu(to::FluxCUDAAdaptor, x) - if CUDA_LOADED[] - return _cuda(to.id, x) - else - @info """ - The CUDA functionality is being called but - `CUDA.jl` must be loaded to access it. - Add `using CUDA` or `import CUDA` to your code. Alternatively, configure a different GPU backend by calling `Flux.gpu_backend!`. - """ maxlog=1 - return x - end -end - -function _cuda end - -# AMDGPU extension. ######## - -Base.@kwdef struct FluxAMDGPUAdaptor - id::Union{Nothing, Int} = nothing -end - -const AMDGPU_LOADED = Ref{Bool}(false) - -function gpu(to::FluxAMDGPUAdaptor, x) - if AMDGPU_LOADED[] - return _amd(to.id, x) - else - @info """ - The AMDGPU functionality is being called but - `AMDGPU.jl` must be loaded to access it. - Add `using AMDGPU` or `import AMDGPU` to your code. - """ maxlog=1 - return x - end -end - -function _amd end - -# Metal extension. ###### - -struct FluxMetalAdaptor end - -const METAL_LOADED = Ref{Bool}(false) - -function gpu(::FluxMetalAdaptor, x) - if METAL_LOADED[] - return _metal(x) - else - @info """ - The Metal functionality is being called but - `Metal.jl` must be loaded to access it. - """ maxlog=1 - return x - end -end - -function _metal end - -################################ """ gpu(data::DataLoader) diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index 163064c072..8962c7bedb 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -1,5 +1,3 @@ -@test Flux.AMDGPU_LOADED[] - @testset "Basic GPU movement" begin @test Flux.gpu(rand(Float64, 16)) isa ROCArray{Float32, 1} @test Flux.gpu(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} diff --git a/test/ext_cuda/cuda.jl b/test/ext_cuda/cuda.jl index 709cef7aef..8afba712fe 100644 --- a/test/ext_cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -113,10 +113,10 @@ end # Even more trivial: no movement @test gradient(x -> sum(abs, cpu(x)), a)[1] isa Matrix - @test gradient(x -> sum(abs, cpu(x)), a')[1] isa Matrix + @test_broken gradient(x -> sum(abs, cpu(x)), a')[1] isa Matrix @test gradient(x -> sum(cpu(x)), a)[1] isa typeof(gradient(sum, a)[1]) # FillArray @test gradient(x -> sum(abs, gpu(x)), ca)[1] isa CuArray - @test gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray + @test_broken gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray # More complicated, Array * CuArray is an error g0 = gradient(x -> sum(abs, (a * (a * x))), a)[1] @@ -131,8 +131,8 @@ end # Scalar indexing of an array, needs OneElement to transfer to GPU # https://github.com/FluxML/Zygote.jl/issues/1005 - @test gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3]) == ([2,0,0],) - @test gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9]) == ([2 6 8; 0 2 0; 0 3 0],) + @test_broken gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3]) == ([2,0,0],) + @test_broken gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9]) == ([2 6 8; 0 2 0; 0 3 0],) end @testset "gpu(x) and cpu(x) on structured arrays" begin @@ -204,5 +204,5 @@ end @test collect(pre2) isa Vector{<:NamedTuple{(:x, :y)}} @test collect(post2) isa Vector{<:NamedTuple{(:x, :y)}} # collect makes no sense, but check eltype? - @test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),)) + # @test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),)) end \ No newline at end of file diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index cba95cee75..ea868c00e6 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -243,10 +243,10 @@ end @testset "Dropout RNGs" begin @test_throws ArgumentError Flux.dropout(MersenneTwister(), CUDA.rand(Float32, 2, 3), 0.1) @testset for layer in (Dropout, AlphaDropout) - m = layer(0.1; rng = MersenneTwister(123)) - @test_throws ErrorException gpu(m) - m = layer(0.1; rng = CUDA.default_rng()) + m = layer(0.1) + @test m.rng === Random.default_rng() @test gpu(m).rng isa CUDA.RNG + @test cpu(gpu(m)).rng === Random.default_rng() end end diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl index 97ba8066a3..9febd8e455 100644 --- a/test/ext_metal/basic.jl +++ b/test/ext_metal/basic.jl @@ -1,5 +1,3 @@ -@test Flux.METAL_LOADED[] - @testset "Basic GPU movement" begin @test Flux.gpu(rand(Float64, 16)) isa MtlArray{Float32, 1} @test Flux.gpu(rand(Float64, 16, 16)) isa MtlArray{Float32, 2} diff --git a/test/functors.jl b/test/functors.jl index 734eadc574..111da50ea8 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -1,5 +1,5 @@ x = rand(Float32, 10, 10) -if !(Flux.CUDA_LOADED[] || Flux.AMDGPU_LOADED[] || Flux.METAL_LOADED[]) +if gpu_device() isa CPUDevice @test x === gpu(x) end diff --git a/test/utils.jl b/test/utils.jl index 79eebded49..13b3e608c0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -278,7 +278,6 @@ end @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] @test gnew.y ≈ [1.0] - # Implicit gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) @test gold[m.x] ≈ [0.4472135954999579, 0.8944271909999159] @@ -568,30 +567,6 @@ end @test length(Flux.params(oneadj)) == 1 # needs Functors@0.3 @test Flux.destructure(simple)[1] == Flux.destructure(oneadj)[1] == [1, 3, 2, 4] - - @testset "issue 2432" begin - x = rand(1) - m = (; a = x, b = x') - count = Ref(0) - mcopy = fmap(m; exclude = Flux._isleaf) do x - count[] += 1 - return copy(x) - end - @test count[] == 1 - @test mcopy.a === mcopy.b' - - struct BitsType - x::Int32 - y::Float64 - end - - for x in [1.0, 'a', BitsType(1, 2.0)] - @test Flux._isleaf([x]) - @test !Flux._isleaf([x]') - @test !Flux._isleaf(transpose([x])) - @test !Flux._isleaf(PermutedDimsArray([x;;], (1, 2))) - end - end end @testset "Various destructure bugs" begin