You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)
this code
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using OrdinaryDiffEq
#using CUDAusing Plots
using LinearAlgebra
#CUDA.allowscalar(false)struct DeepEquilibriumNetwork{M,P,RE,A,K}
model::M
p::P
re::RE
args::A
kwargs::Kend
Flux.@functor DeepEquilibriumNetwork
functionDeepEquilibriumNetwork(model, args...; kwargs...)
p, re = Flux.destructure(model)
returnDeepEquilibriumNetwork(model, p, re, args, kwargs)
end
Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)
function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}, p = deq.p) where {T}
z = deq.re(p)(x)
# Solving the equation f(u) - u = du = 0# The key part of DEQ is similar to that of NeuralODEsdudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
ssprob =SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
returnsolve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end
ann =Chain(Dense(1, 5), Dense(5, 1))
deq =DeepEquilibriumNetwork(ann, DynamicSS(Tsit5(), abstol =1.0f-2, reltol =1.0f-2))
# Let's run a DEQ model on linear regression for y = 2x
X =reshape(Float32[1; 2; 3; 4; 5; 6; 7; 8; 9; 10], 1, :)
Y =2.* X
opt =ADAM(0.05)
loss(x, y) =sum(abs2, y .-deq(x))
Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)
throws the following error on line (JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt))
The text was updated successfully, but these errors were encountered:
yadmtr
changed the title
TypeErro in DEQ exampler: non-boolean (Nothing) used in boolean context
TypeErro in DEQ example: non-boolean (Nothing) used in boolean context
Jan 23, 2022
Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)
this code
throws the following error on line (
JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt)
)Operating System: Windows 10
Julia 1.6.5
VScode 1.63.2
Pkg.status
The text was updated successfully, but these errors were encountered: