Skip to content

Commit

Permalink
Final few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent d4f6903 commit 1e76987
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 21 deletions.
6 changes: 6 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DeepEquilibriumNetworks = "2"
Expand Down
17 changes: 7 additions & 10 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Pkg.add("DeepEquilibriumNetworks")

## Quick-start

```julia
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote
```@example
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
seed = 0
Expand Down Expand Up @@ -46,14 +46,11 @@ If you are using this project for research or other academic purposes, consider
paper:

```bibtex
@misc{pal2022mixing,
title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural
ODEs (Continuous DEQs)},
author={Avik Pal and Alan Edelman and Christopher Rackauckas},
year={2022},
eprint={2201.12240},
archivePrefix={arXiv},
primaryClass={cs.LG}
@article{pal2022continuous,
title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity},
author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher},
booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)},
year={2023}
}
```

Expand Down
15 changes: 5 additions & 10 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@ function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jaco
sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original)
∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7)
function ∇DeepEquilibriumSolution(∂sol)

Check warning on line 29 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L29

Added line #L29 was not covered by tests
∂z_star = ∂sol.z_star
∂u0 = ∂sol.u0
∂residual = ∂sol.residual
∂jacobian_loss = ∂sol.jacobian_loss
∂nfe = ∂sol.nfe
∂original = CRC.NoTangent()
return (CRC.NoTangent(), ∂z_star, ∂u0, ∂residual, ∂jacobian_loss, ∂nfe, ∂original)
return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, ∂sol.jacobian_loss,
∂sol.nfe, CRC.NoTangent())
end
return sol, ∇DeepEquilibriumSolution
end
Expand Down Expand Up @@ -149,11 +144,11 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
## Example
```julia
```@example
using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
Dense(2, 2; use_bias=false)), VCABM3(); save_everystep=true)
Dense(2, 2; use_bias=false)), VCABM3())
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
Expand Down Expand Up @@ -218,7 +213,7 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
## Example
```julia
```@example
using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),
Expand Down
12 changes: 11 additions & 1 deletion test/layers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq,
SciMLBase, Test, NLsolve
SciMLSensitivity, SciMLBase, Test, NLsolve

include("test_utils.jl")

Expand Down Expand Up @@ -152,6 +152,11 @@ end
@test maximum(abs, st.solution.residual) 1e-3
end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model)
st = Lux.update_state(st, :fixed_depth, Val(10))
@test st.solution == DeepEquilibriumSolution()
Expand All @@ -165,6 +170,11 @@ end
@test size(z_) == (sum(prod, scale), size(x, ndims(x)))
@test st.solution isa DeepEquilibriumSolution
@test st.solution.nfe == 10

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
end
end
end
Expand Down

0 comments on commit 1e76987

Please sign in to comment.