diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index be1b7ad1a..804bad70e 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -15,7 +15,7 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( [ AutoZygote(), - # AutoEnzyme() # TODO: fix + # AutoEnzyme(), # TODO a few scenarios fail ], DIT.flux_scenarios(Random.MersenneTwister(0)); isapprox=DIT.flux_isapprox, diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index bfa84d4b4..c5f84e22e 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] @@ -46,7 +47,7 @@ DifferentiationInterface = "0.6.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" -Flux = "0.13,0.14" +Flux = "0.16" ForwardDiff = "0.10.36" Functors = "0.4, 0.5" JET = "0.4 - 0.8, 0.9" @@ -61,6 +62,7 @@ SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" SparseMatrixColorings = "0.4.9" StaticArrays = "1.9" +Statistics = "1" Test = "<0.0.1,1" Zygote = "0.6" julia = "1.10" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index 9b7ef48a1..d0825cee3 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -11,10 +11,13 @@ using Flux: ConvTranspose, Dense, GRU, + GRUCell, LSTM, + LSTMCell, Maxout, MeanPool, RNN, + RNNCell, SamePad, Scale, SkipConnection, @@ -24,6 +27,7 @@ using Flux: relu using Functors: @functor, fmapstructure_with_path, fleaves using LinearAlgebra +using Statistics: mean using Random: AbstractRNG, default_rng #= @@ -43,31 +47,23 @@ end function DIT.flux_isapprox(a, b; atol, rtol) isapprox_results = fmapstructure_with_path(a, b) do kp, x, y - if :state in kp # ignore RNN and LSTM state + if x isa AbstractArray{<:Number} + return isapprox(x, y; atol, rtol) + else # ignore non-arrays return true - else - if x isa AbstractArray{<:Number} - return isapprox(x, y; atol, rtol) - else # ignore non-arrays - return true - end end end return all(fleaves(isapprox_results)) end -function square_loss(model, x) - Flux.reset!(model) - return sum(abs2, model(x)) -end +square_loss(model, x) = mean(abs2, model(x)) -function square_loss_iterated(model, x) - Flux.reset!(model) - y = copy(x) - for _ in 1:3 - y = model(y) +function square_loss_iterated(cell, x) + y, st = cell(x) # uses default initial state + for _ in 1:2 + y, st = cell(x, st) end - return sum(abs2, y) + return mean(abs2, y) end struct SimpleDense{W,B,F} @@ -132,37 +128,33 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Maxout(() -> Dense(5 => 4, tanh; init), 3), randn(rng, Float32, 5, 1) ), - ( - RNN(3 => 2; init), - randn(rng, Float32, 3, 2) - ), - ( - Chain(RNN(3 => 4; init), RNN(4 => 3; init)), - randn(rng, Float32, 3, 2) + ( + SkipConnection(Dense(2 => 2; init), vcat), + randn(rng, Float32, 2, 3) ), ( - LSTM(3 => 5; init), - randn(rng, Float32, 3, 2) + Bilinear((2, 2) => 3; init), + randn(rng, Float32, 2, 1) ), ( - Chain(LSTM(3 => 5; init), LSTM(5 => 3; init)), - randn(rng, Float32, 3, 2) + ConvTranspose((3, 3), 3 => 2; stride=2, init), + rand(rng, Float32, 5, 5, 3, 1) ), ( - SkipConnection(Dense(2 => 2; init), vcat), - randn(rng, Float32, 2, 3) + RNN(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) ), ( - Bilinear((2, 2) => 3; init), - randn(rng, Float32, 2, 1) + LSTM(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) ), ( - GRU(3 => 5; init), - randn(rng, Float32, 3, 10) + GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) ), ( - ConvTranspose((3, 3), 3 => 2; stride=2, init), - rand(rng, Float32, 5, 5, 3, 1) + Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), + randn(rng, Float32, 3, 2, 1) ), #! format: on ] @@ -176,16 +168,20 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) push!(scens, scen) end - # Recurrence + # Recurrent Cells recurrent_models_and_xs = [ #! format: off ( - RNN(3 => 3; init), + RNNCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2) + ), + ( + LSTMCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2) ), ( - LSTM(3 => 3; init), + GRUCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2) ), #! format: on @@ -193,12 +189,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) for (model, x) in recurrent_models_and_xs Flux.trainmode!(model) - g = gradient_finite_differences(square_loss, model, x) + g = gradient_finite_differences(square_loss_iterated, model, x) scen = DIT.Scenario{:gradient,:out}( square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g ) - # TODO: figure out why these tests are broken - # push!(scens, scen) + push!(scens, scen) end return scens