diff --git a/docs/Project.toml b/docs/Project.toml index be3a01ae..0556b9c0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,11 @@ DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] DeepEquilibriumNetworks = "2" diff --git a/docs/src/index.md b/docs/src/index.md index ab3143b1..e9b0753c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -16,8 +16,8 @@ Pkg.add("DeepEquilibriumNetworks") ## Quick-start -```julia -using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote +```@example +using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity # using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support seed = 0 @@ -46,14 +46,11 @@ If you are using this project for research or other academic purposes, consider paper: ```bibtex -@misc{pal2022mixing, - title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural - ODEs (Continuous DEQs)}, - author={Avik Pal and Alan Edelman and Christopher Rackauckas}, - year={2022}, - eprint={2201.12240}, - archivePrefix={arXiv}, - primaryClass={cs.LG} +@article{pal2022continuous, + title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity}, + author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher}, + booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)}, + year={2023} } ``` diff --git a/test/layers.jl b/test/layers.jl index f1f4d991..0da64ee6 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -1,5 +1,5 @@ using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, - SciMLBase, Test, NLsolve + SciMLSensitivity, SciMLBase, Test, NLsolve include("test_utils.jl") @@ -152,6 +152,11 @@ end @test maximum(abs, st.solution.residual) ≤ 1e-3 end + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + ps, st = Lux.setup(rng, model) st = Lux.update_state(st, :fixed_depth, Val(10)) @test st.solution == DeepEquilibriumSolution() @@ -165,6 +170,11 @@ end @test size(z_) == (sum(prod, scale), size(x, ndims(x))) @test st.solution isa DeepEquilibriumSolution @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) end end end