Skip to content

Commit

Permalink
test: try fixing enzyme test (#1119)
Browse files Browse the repository at this point in the history
* test: try fixing enzyme test

* test: temporarily mark the test as broken
  • Loading branch information
avik-pal authored Dec 5, 2024
1 parent 4c90a39 commit 5ad4fa9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
8 changes: 2 additions & 6 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ export generic_loss_function, compute_enzyme_gradient, compute_zygote_gradient,
end

@testitem "Enzyme Integration" setup=[EnzymeTestSetup, SharedTestSetup] tags=[
:autodiff, :enzyme] begin
:autodiff, :enzyme] timeout=3600 begin
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
# TODO: Currently all the tests are run on CPU. We should eventually add tests for
# CUDA and AMDGPU.
ongpu && continue

@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST)
Expand All @@ -106,15 +104,13 @@ end
end
end

@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] tags=[
@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] timeout=3600 tags=[
:autodiff, :enzyme] begin
using ComponentArrays

rng = StableRNG(12345)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
# TODO: Currently all the tests are run on CPU. We should eventually add tests for
# CUDA and AMDGPU.
ongpu && continue

@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST)
Expand Down
19 changes: 13 additions & 6 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,19 @@ end
)
x = randn(SVector{N, Float64})

fun = let d = d, x = x
ps -> sum(d(x, ps, (;))[1])
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1]
@test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6
broken = pkgversion(Enzyme) v"0.13.18"

@test begin
grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps
sumabs2first(d, x, ps, (;))
end

grad2 = Enzyme.gradient(
Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;))
)[3]

maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6
end broken=broken
end
end

Expand Down

0 comments on commit 5ad4fa9

Please sign in to comment.