From f8ea1a68e253b5abb4169df3e0f3c9a700159656 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Jun 2024 20:00:35 -0700 Subject: [PATCH] Update to new Lux API --- Project.toml | 4 ++-- docs/src/tutorials/reduced_dim_deq.md | 2 +- src/layers.jl | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 450b5c66..87502fb8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.1.1" +version = "2.1.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -41,7 +41,7 @@ FastClosures = "0.3" ForwardDiff = "0.10.36" Functors = "0.4.10" LinearSolve = "2.21.2" -Lux = "0.5.50" +Lux = "0.5.56" LuxCUDA = "0.3.2" LuxCore = "0.1.14" LuxTestUtils = "0.1.15" diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 9f72ac69..1e01b3a5 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -57,7 +57,7 @@ function construct_model(solver; model_type::Symbol=:regdeq) else # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here # we are only using Zygote so this is fine. - init = WrappedFunction(x -> Zygote.@ignore(fill!( + init = WrappedFunction{:direct_call}(x -> Zygote.@ignore(fill!( similar(x, 128, size(x, 2)), false))) end diff --git a/src/layers.jl b/src/layers.jl index 4665f492..35d490f9 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -142,7 +142,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] - `init`: Initial Condition for the rootfinding problem. If `nothing`, the initial condition is set to `zero(x)`. If `missing`, the initial condition is set to - `WrappedFunction(zero)`. In other cases the initial condition is set to + `WrappedFunction{:direct_call}(zero)`. In other cases the initial condition is set to `init(x, ps, st)`. - `jacobian_regularization`: Must be one of `nothing`, `AutoForwardDiff`, `AutoFiniteDiff` or `AutoZygote`. @@ -160,8 +160,8 @@ julia> model = DeepEquilibriumNetwork( DeepEquilibriumNetwork( model = Parallel( + - Dense(2 => 2, bias=false), # 4 parameters - Dense(2 => 2, bias=false), # 4 parameters + layer_1 = Dense(2 => 2, bias=false), # 4 parameters + layer_2 = Dense(2 => 2, bias=false), # 4 parameters ), init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Nothing}(DeepEquilibriumNetworks.__zeros_init, nothing)), ) # Total: 8 parameters, @@ -184,7 +184,8 @@ function DeepEquilibriumNetwork( model isa AbstractExplicitLayer || (model = Lux.transform(model)) if init === missing # Regular DEQ - init = WrappedFunction(Base.Fix1(__zeros_init, __getproperty(model, Val(:scales)))) + init = WrappedFunction{:direct_call}(Base.Fix1( + __zeros_init, __getproperty(model, Val(:scales)))) elseif init === nothing # SkipRegDEQ init = NoOpLayer() elseif !(init isa AbstractExplicitLayer)