Skip to content

Commit

Permalink
test: Reactant with recurrent layers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 12, 2024
1 parent c0e4564 commit f319646
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testsetup module SharedReactantLayersTestSetup

using Lux, Reactant, Enzyme, Zygote

sumabs2(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

function ∇sumabs2_zygote(model, x, ps, st)
return Zygote.gradient((x, ps) -> sumabs2(model, x, ps, st), x, ps)
end

function ∇sumabs2_enzyme(model, x, ps, st)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse, sumabs2, Active,
Const(model), Duplicated(x, dx),
Duplicated(ps, dps), Const(st)
)
return dx, dps
end

export ∇sumabs2_zygote, ∇sumabs2_enzyme

end

@testitem "Recurrent Layers" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux
using LuxTestUtils: check_approx

rng = StableRNG(123)

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for cell in (RNNCell, LSTMCell, GRUCell)
model = Recurrence(cell(4 => 4))
ps, st = Lux.setup(rng, model)
ps_ra, st_ra = (ps, st) |> Reactant.to_rarray
x = rand(Float32, 4, 16, 12)
x_ra = x |> Reactant.to_rarray

y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ray atol=1e-4

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-4
@test check_approx(∂ps_ra, ∂ps; atol=1e-4)
end
end
end
end

0 comments on commit f319646

Please sign in to comment.