You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using DeepEquilibriumNetworks, Lux, Random, Zygote
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
seed =0
rng = Random.default_rng()
Random.seed!(rng, seed)
model =Chain(Dense(2=>2),
DeepEquilibriumNetwork(Parallel(+,
Dense(2=>2; use_bias=false),
Dense(2=>2; use_bias=false)),
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
reltol_termination=0.1f0);
save_everystep=true))
gdev =gpu_device()
cdev =cpu_device()
ps, st = Lux.setup(rng, model) |> gdev
x =rand(rng, Float32, 2, 100) |> gdev
y =rand(rng, Float32, 2, 100) |> gdev
gs =only(Zygote.gradient(p ->sum(abs2, first(first(model(x, p, st))) .- y), ps))
This gives the following error + warning in 1.9 (used a try,catch log because the original error flooded the repl and can't be accessed 😅 ):
┌ Warning: Automatic AD choice of autojacvec failed in ODE adjoint, failing back to ODE adjoint + numerical vjp
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/U8Axh/src/sensitivity_interface.jl:381
┌ Warning: AD choice of autojacvec failed in nonlinear solve adjoint
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/U8Axh/src/steadystate_adjoint.jl:112
1278
To reproduce:
This gives the following error + warning in 1.9 (used a try,catch log because the original error flooded the repl and can't be accessed 😅 ):
and
The text was updated successfully, but these errors were encountered: