diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 77f5a5116..c5f84e22e 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -47,7 +47,7 @@ DifferentiationInterface = "0.6.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" -Flux = "0.15" +Flux = "0.16" ForwardDiff = "0.10.36" Functors = "0.4, 0.5" JET = "0.4 - 0.8, 0.9" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index c6ba08f27..ee0d5b168 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -44,31 +44,22 @@ end function DIT.flux_isapprox(a, b; atol, rtol) isapprox_results = fmapstructure_with_path(a, b) do kp, x, y - if :state in kp # ignore RNN and LSTM state + if x isa AbstractArray{<:Number} + return isapprox(x, y; atol, rtol) + else # ignore non-arrays return true - else - if x isa AbstractArray{<:Number} - return isapprox(x, y; atol, rtol) - else # ignore non-arrays - return true - end end end return all(fleaves(isapprox_results)) end -function square_loss(model, x) - y = model(x) - y = y isa Tuple ? y[1] : y # handle LSTM - return mean(abs2, y) -end +square_loss(model, x) = mean(abs2, model(x)) function square_loss_iterated(cell, x) - st = cell(x) # uses default initial state + y, st = cell(x) # uses default initial state for _ in 1:2 - st = cell(x, st) + y, st = cell(x, st) end - y = st isa Tuple ? st[1] : st # handle LSTM return mean(abs2, y) end @@ -158,6 +149,10 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2, 1) ), + ( + Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), + randn(rng, Float32, 3, 2, 1) + ), #! format: on ]