Skip to content

Commit

Permalink
Incorrect trace
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent e773e63 commit ebf1645
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ version = "6.144.0"

[[deps.DiffEqCallbacks]]
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"]
git-tree-sha1 = "d0b94b3694d55e7eedeee918e7daee9e3b873399"
git-tree-sha1 = "e48b985459d1cbe8c809de192529f1e25c3382a6"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
version = "2.35.0"
version = "2.36.0"

[deps.DiffEqCallbacks.weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)

# Jacobian Stabilization
function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng)
__f = u -> first(model((u, x), ps))
__f = u -> model((u, x), ps)
res = zero(eltype(x))
ϵ = cbrt(eps(typeof(res)))
ϵ⁻¹ = inv(ϵ)
Expand Down
9 changes: 5 additions & 4 deletions test/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end

model_type = (:deq, :skipdeq, :skipregdeq)
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
jacobian_regularizations = Any[nothing, AutoZygote()]
!ongpu && push!(jacobian_regularizations, AutoFiniteDiff())

@testset "Solver: $(__nameof(solver))" for solver in solvers,
mtype in model_type, jacobian_regularization in jacobian_regularizations
Expand Down Expand Up @@ -133,10 +134,10 @@ end
jacobian_regularization)
end

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
@test st.solution == DeepEquilibriumSolution()

x = randn(rng, Float32, x_size...)
x = randn(rng, Float32, x_size...) |> dev
z, st = model(x, ps, st)
z_ = DEQs.__flatten_vcat(z)

Expand All @@ -157,7 +158,7 @@ end
@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model)
ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
@test st.solution == DeepEquilibriumSolution()

Expand Down
2 changes: 2 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote
import LuxTestUtils: @jet
using LuxCUDA

CUDA.allowscalar(false)

__nameof(::X) where {X} = nameof(X)

__get_prng(seed::Int) = StableRNG(seed)
Expand Down

0 comments on commit ebf1645

Please sign in to comment.