From 266b73368fb2503e670f08ec17747bfb01a96999 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Dec 2024 18:29:00 +0530 Subject: [PATCH] test: temporarily mark the test as broken --- test/enzyme_tests.jl | 8 ++------ test/layers/basic_tests.jl | 18 +++++++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 8fbf085fc5..7141e7d098 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -83,12 +83,10 @@ export generic_loss_function, compute_enzyme_gradient, compute_zygote_gradient, end @testitem "Enzyme Integration" setup=[EnzymeTestSetup, SharedTestSetup] tags=[ - :autodiff, :enzyme] begin + :autodiff, :enzyme] timeout=3600 begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - # TODO: Currently all the tests are run on CPU. We should eventually add tests for - # CUDA and AMDGPU. ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) @@ -106,15 +104,13 @@ end end end -@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] tags=[ +@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] timeout=3600 tags=[ :autodiff, :enzyme] begin using ComponentArrays rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - # TODO: Currently all the tests are run on CPU. We should eventually add tests for - # CUDA and AMDGPU. ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 0d8926764f..1e36b9899b 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -165,15 +165,19 @@ end ) x = randn(SVector{N, Float64}) - grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps - sumabs2first(d, x, ps, (;)) - end + broken = pkgversion(Enzyme) ≥ v"0.13.18" + + @test begin + grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps + sumabs2first(d, x, ps, (;)) + end - grad2 = Enzyme.gradient( - Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) - )[3] + grad2 = Enzyme.gradient( + Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) + )[3] - @test maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 + maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 + end broken=broken end end