Skip to content

Commit

Permalink
Final few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent d4f6903 commit 461aa79
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
5 changes: 5 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DeepEquilibriumNetworks = "2"
Expand Down
17 changes: 7 additions & 10 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Pkg.add("DeepEquilibriumNetworks")

## Quick-start

```julia
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote
```@example
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
seed = 0
Expand Down Expand Up @@ -46,14 +46,11 @@ If you are using this project for research or other academic purposes, consider
paper:

```bibtex
@misc{pal2022mixing,
title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural
ODEs (Continuous DEQs)},
author={Avik Pal and Alan Edelman and Christopher Rackauckas},
year={2022},
eprint={2201.12240},
archivePrefix={arXiv},
primaryClass={cs.LG}
@article{pal2022continuous,
title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity},
author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher},
booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)},
year={2023}
}
```

Expand Down
12 changes: 11 additions & 1 deletion test/layers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq,
SciMLBase, Test, NLsolve
SciMLSensitivity, SciMLBase, Test, NLsolve

include("test_utils.jl")

Expand Down Expand Up @@ -152,6 +152,11 @@ end
@test maximum(abs, st.solution.residual) 1e-3
end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model)
st = Lux.update_state(st, :fixed_depth, Val(10))
@test st.solution == DeepEquilibriumSolution()
Expand All @@ -165,6 +170,11 @@ end
@test size(z_) == (sum(prod, scale), size(x, ndims(x)))
@test st.solution isa DeepEquilibriumSolution
@test st.solution.nfe == 10

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
end
end
end
Expand Down

0 comments on commit 461aa79

Please sign in to comment.