From 1e76987668f02b1276c67f2c2cfd9d5aeff0d374 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Dec 2023 21:43:15 -0500 Subject: [PATCH] Final few tests --- docs/Project.toml | 6 ++++++ docs/src/index.md | 17 +++++++---------- src/layers.jl | 15 +++++---------- test/layers.jl | 12 +++++++++++- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index be3a01ae..29f987d2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,12 @@ DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] DeepEquilibriumNetworks = "2" diff --git a/docs/src/index.md b/docs/src/index.md index ab3143b1..e9b0753c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -16,8 +16,8 @@ Pkg.add("DeepEquilibriumNetworks") ## Quick-start -```julia -using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote +```@example +using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity # using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support seed = 0 @@ -46,14 +46,11 @@ If you are using this project for research or other academic purposes, consider paper: ```bibtex -@misc{pal2022mixing, - title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural - ODEs (Continuous DEQs)}, - author={Avik Pal and Alan Edelman and Christopher Rackauckas}, - year={2022}, - eprint={2201.12240}, - archivePrefix={arXiv}, - primaryClass={cs.LG} +@article{pal2022continuous, + title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity}, + author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher}, + booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)}, + year={2023} } ``` diff --git a/src/layers.jl b/src/layers.jl index 467e91eb..74467e73 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -27,13 +27,8 @@ function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jaco sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original) ∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7) function ∇DeepEquilibriumSolution(∂sol) - ∂z_star = ∂sol.z_star - ∂u0 = ∂sol.u0 - ∂residual = ∂sol.residual - ∂jacobian_loss = ∂sol.jacobian_loss - ∂nfe = ∂sol.nfe - ∂original = CRC.NoTangent() - return (CRC.NoTangent(), ∂z_star, ∂u0, ∂residual, ∂jacobian_loss, ∂nfe, ∂original) + return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, ∂sol.jacobian_loss, + ∂sol.nfe, CRC.NoTangent()) end return sol, ∇DeepEquilibriumSolution end @@ -149,11 +144,11 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] ## Example -```julia +```@example using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false), - Dense(2, 2; use_bias=false)), VCABM3(); save_everystep=true) + Dense(2, 2; use_bias=false)), VCABM3()) rng = Random.default_rng() ps, st = Lux.setup(rng, model) @@ -218,7 +213,7 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref). ## Example -```julia +```@example using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false), diff --git a/test/layers.jl b/test/layers.jl index f1f4d991..0da64ee6 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -1,5 +1,5 @@ using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq, - SciMLBase, Test, NLsolve + SciMLSensitivity, SciMLBase, Test, NLsolve include("test_utils.jl") @@ -152,6 +152,11 @@ end @test maximum(abs, st.solution.residual) ≤ 1e-3 end + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) + ps, st = Lux.setup(rng, model) st = Lux.update_state(st, :fixed_depth, Val(10)) @test st.solution == DeepEquilibriumSolution() @@ -165,6 +170,11 @@ end @test size(z_) == (sum(prod, scale), size(x, ndims(x))) @test st.solution isa DeepEquilibriumSolution @test st.solution.nfe == 10 + + _, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st) + + @test __is_finite_gradient(gs_x) + @test __is_finite_gradient(gs_ps) end end end