From e35d643357a11d1eb951f58a8fa8da699f2323a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Nov 2024 12:16:52 -0500 Subject: [PATCH] chore: run the formatter --- .../ext/MLDataDevicesChainRulesExt.jl | 5 ++-- .../ext/MLDataDevicesZygoteExt.jl | 6 ++--- lib/MLDataDevices/test/misc_tests.jl | 24 +++++++++---------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 039058cffa..eef457df10 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,14 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using ChainRules: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 9bec6a82fc..53544a520e 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,16 +1,16 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end - diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 05e98b6a2a..2a22df3702 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -222,6 +222,18 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end +@testset "Zygote and ChainRules OneElement #1016" begin + using Zygote + + cpu = cpu_device() + gpu = gpu_device() + + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1, 2, 3])[1] + @test g isa Vector{Float32} + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + @test g isa Matrix{Float32} +end + @testset "OneHotArrays" begin using OneHotArrays @@ -241,15 +253,3 @@ end @test x_rd isa Reactant.ConcreteRArray{Bool, 2} end end - -@testset "Zygote and ChainRules OneElement" begin - # Issue #1016 - using Zygote - cpu = cpu_device() - gpu = gpu_device() - - g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] - @test g isa Vector{Float32} - g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] - @test g isa Matrix{Float32} -end