Skip to content

Commit

Permalink
adapt to Flux v0.16 (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 2, 2025
1 parent 62ec930 commit 9df2763
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 44 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Down/Flux/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ LOGGING = get(ENV, "CI", "false") == "false"
test_differentiation(
[
AutoZygote(),
# AutoEnzyme() # TODO: fix
# AutoEnzyme(), # TODO a few scenarios fail
],
DIT.flux_scenarios(Random.MersenneTwister(0));
isapprox=DIT.flux_isapprox,
Expand Down
4 changes: 3 additions & 1 deletion DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
Expand Down Expand Up @@ -46,7 +47,7 @@ DifferentiationInterface = "0.6.0"
DocStringExtensions = "0.8,0.9"
ExplicitImports = "1.10.1"
FiniteDifferences = "0.12"
Flux = "0.13,0.14"
Flux = "0.16"
ForwardDiff = "0.10.36"
Functors = "0.4, 0.5"
JET = "0.4 - 0.8, 0.9"
Expand All @@ -61,6 +62,7 @@ SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
SparseMatrixColorings = "0.4.9"
StaticArrays = "1.9"
Statistics = "1"
Test = "<0.0.1,1"
Zygote = "0.6"
julia = "1.10"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ using Flux:
ConvTranspose,
Dense,
GRU,
GRUCell,
LSTM,
LSTMCell,
Maxout,
MeanPool,
RNN,
RNNCell,
SamePad,
Scale,
SkipConnection,
Expand All @@ -24,6 +27,7 @@ using Flux:
relu
using Functors: @functor, fmapstructure_with_path, fleaves
using LinearAlgebra
using Statistics: mean
using Random: AbstractRNG, default_rng

#=
Expand All @@ -43,31 +47,23 @@ end

function DIT.flux_isapprox(a, b; atol, rtol)
isapprox_results = fmapstructure_with_path(a, b) do kp, x, y
if :state in kp # ignore RNN and LSTM state
if x isa AbstractArray{<:Number}
return isapprox(x, y; atol, rtol)
else # ignore non-arrays
return true
else
if x isa AbstractArray{<:Number}
return isapprox(x, y; atol, rtol)
else # ignore non-arrays
return true
end
end
end
return all(fleaves(isapprox_results))
end

function square_loss(model, x)
Flux.reset!(model)
return sum(abs2, model(x))
end
square_loss(model, x) = mean(abs2, model(x))

function square_loss_iterated(model, x)
Flux.reset!(model)
y = copy(x)
for _ in 1:3
y = model(y)
function square_loss_iterated(cell, x)
y, st = cell(x) # uses default initial state
for _ in 1:2
y, st = cell(x, st)
end
return sum(abs2, y)
return mean(abs2, y)
end

struct SimpleDense{W,B,F}
Expand Down Expand Up @@ -132,37 +128,33 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
Maxout(() -> Dense(5 => 4, tanh; init), 3),
randn(rng, Float32, 5, 1)
),
(
RNN(3 => 2; init),
randn(rng, Float32, 3, 2)
),
(
Chain(RNN(3 => 4; init), RNN(4 => 3; init)),
randn(rng, Float32, 3, 2)
(
SkipConnection(Dense(2 => 2; init), vcat),
randn(rng, Float32, 2, 3)
),
(
LSTM(3 => 5; init),
randn(rng, Float32, 3, 2)
Bilinear((2, 2) => 3; init),
randn(rng, Float32, 2, 1)
),
(
Chain(LSTM(3 => 5; init), LSTM(5 => 3; init)),
randn(rng, Float32, 3, 2)
ConvTranspose((3, 3), 3 => 2; stride=2, init),
rand(rng, Float32, 5, 5, 3, 1)
),
(
SkipConnection(Dense(2 => 2; init), vcat),
randn(rng, Float32, 2, 3)
RNN(3 => 4; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2, 1)
),
(
Bilinear((2, 2) => 3; init),
randn(rng, Float32, 2, 1)
LSTM(3 => 4; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2, 1)
),
(
GRU(3 => 5; init),
randn(rng, Float32, 3, 10)
GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2, 1)
),
(
ConvTranspose((3, 3), 3 => 2; stride=2, init),
rand(rng, Float32, 5, 5, 3, 1)
Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)),
randn(rng, Float32, 3, 2, 1)
),
#! format: on
]
Expand All @@ -176,29 +168,32 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
push!(scens, scen)
end

# Recurrence
# Recurrent Cells

recurrent_models_and_xs = [
#! format: off
(
RNN(3 => 3; init),
RNNCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2)
),
(
LSTMCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2)
),
(
LSTM(3 => 3; init),
GRUCell(3 => 3; init_kernel=init, init_recurrent_kernel=init),
randn(rng, Float32, 3, 2)
),
#! format: on
]

for (model, x) in recurrent_models_and_xs
Flux.trainmode!(model)
g = gradient_finite_differences(square_loss, model, x)
g = gradient_finite_differences(square_loss_iterated, model, x)
scen = DIT.Scenario{:gradient,:out}(
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
)
# TODO: figure out why these tests are broken
# push!(scens, scen)
push!(scens, scen)
end

return scens
Expand Down

0 comments on commit 9df2763

Please sign in to comment.