From ff355284fadf7f1d3ccef41711b87e9305b30d0e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 31 Dec 2024 18:32:44 +0100 Subject: [PATCH] update flux v0.16 --- DifferentiationInterfaceTest/Project.toml | 2 +- .../DifferentiationInterfaceTestFluxExt.jl | 25 ++++++++----------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 77f5a5116..c5f84e22e 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -47,7 +47,7 @@ DifferentiationInterface = "0.6.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" -Flux = "0.15" +Flux = "0.16" ForwardDiff = "0.10.36" Functors = "0.4, 0.5" JET = "0.4 - 0.8, 0.9" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index c6ba08f27..ee0d5b168 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -44,31 +44,22 @@ end function DIT.flux_isapprox(a, b; atol, rtol) isapprox_results = fmapstructure_with_path(a, b) do kp, x, y - if :state in kp # ignore RNN and LSTM state + if x isa AbstractArray{<:Number} + return isapprox(x, y; atol, rtol) + else # ignore non-arrays return true - else - if x isa AbstractArray{<:Number} - return isapprox(x, y; atol, rtol) - else # ignore non-arrays - return true - end end end return all(fleaves(isapprox_results)) end -function square_loss(model, x) - y = model(x) - y = y isa Tuple ? y[1] : y # handle LSTM - return mean(abs2, y) -end +square_loss(model, x) = mean(abs2, model(x)) function square_loss_iterated(cell, x) - st = cell(x) # uses default initial state + y, st = cell(x) # uses default initial state for _ in 1:2 - st = cell(x, st) + y, st = cell(x, st) end - y = st isa Tuple ? st[1] : st # handle LSTM return mean(abs2, y) end @@ -158,6 +149,10 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2, 1) ), + ( + Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), + randn(rng, Float32, 3, 2, 1) + ), #! format: on ]