Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeErro in DEQ example: non-boolean (Nothing) used in boolean context #36

Closed
yadmtr opened this issue Jan 23, 2022 · 2 comments
Closed
Assignees

Comments

@yadmtr
Copy link

yadmtr commented Jan 23, 2022

Please help me to understand the cause of the error when running the DEQ example from Julia's blog (Deep Equilibrium Models)

this code

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using OrdinaryDiffEq
#using CUDA
using Plots
using LinearAlgebra
#CUDA.allowscalar(false)

struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end

Flux.@functor DeepEquilibriumNetwork

function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end

Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)

function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}, p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    # The key part of DEQ is similar to that of NeuralODEs
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end

ann = Chain(Dense(1, 5), Dense(5, 1))

deq = DeepEquilibriumNetwork(ann, DynamicSS(Tsit5(), abstol = 1.0f-2, reltol = 1.0f-2))

# Let's run a DEQ model on linear regression for y = 2x
X = reshape(Float32[1; 2; 3; 4; 5; 6; 7; 8; 9; 10], 1, :) 
Y = 2 .* X
opt = ADAM(0.05)

loss(x, y) = sum(abs2, y .- deq(x))

Flux.train!(loss, Flux.params(deq), ((X, Y),), opt)

throws the following error on line (JuliaFlux.train!(loss, Flux.params(deq), ((X, Y),), opt))

ERROR: LoadError: TypeError: non-boolean (Nothing) used in boolean context
Stacktrace:
  [1] _concrete_solve_adjoint(::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:92
  [2] _concrete_solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqSensitivity\Kg0cc\src\concrete_solve.jl:72 [inlined]
  [3] #_solve_adjoint#56
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:347 [inlined]
  [4] _solve_adjoint
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:322 [inlined]
  [5] #rrule#54
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [6] rrule
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:310 [inlined]
  [7] rrule
    @ C:\Users\D\.julia\packages\ChainRulesCore\oBjCg\src\rules.jl:134 [inlined]
  [8] chain_rrule
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\chainrules.jl:216 [inlined]
  [9] macro expansion
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(DiffEqBase.solve_up), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Nothing, ::Matrix{Float32}, ::Vector{Float32}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:9
 [11] _apply
    @ .\boot.jl:804 [inlined]
 [12] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [13] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [14] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:73 [inlined]
 [15] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#38", ::Nothing, ::Matrix{Float32}, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [16] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [17] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [18] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [19] _pullback
    @ C:\Users\D\.julia\packages\DiffEqBase\0PaUK\src\solve.jl:68 [inlined]
 [20] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:u0,), Tuple{Matrix{Float32}}}, ::typeof(solve), ::SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt#10"{DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [21] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [22] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [23] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [24] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:33 [inlined]
 [25] _pullback(::Zygote.Context, ::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, 
::Vector{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [26] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:28 [inlined]
 [27] _pullback(ctx::Zygote.Context, f::DeepEquilibriumNetwork{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, Tuple{DynamicSS{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Float32, Float32, Float64}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, args::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [28] _pullback
    @ c:\Users\D\w7d\test_flux_e[ample.jl:45 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(loss), ::Matrix{Float32}, ::Matrix{Float32})
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [30] _apply
    @ .\boot.jl:804 [inlined]
 [31] adjoint
    @ C:\Users\D\.julia\packages\Zygote\umM0L\src\lib\lib.jl:200 [inlined]
 [32] _pullback
    @ C:\Users\D\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [33] _pullback
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:105 [inlined]
 [34] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{Matrix{Float32}, Matrix{Float32}}})   
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface2.jl:0
 [35] pullback(f::Function, ps::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:352
 [36] gradient(f::Function, args::Zygote.Params)
    @ Zygote C:\Users\D\.julia\packages\Zygote\umM0L\src\compiler\interface.jl:75
 [37] macro expansion
    @ C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:104 [inlined]
 [38] macro expansion
    @ C:\Users\D\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [39] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:102
 [40] train!(loss::Function, ps::Zygote.Params, data::Tuple{Tuple{Matrix{Float32}, Matrix{Float32}}}, opt::ADAM)
    @ Flux.Optimise C:\Users\D\.julia\packages\Flux\BPPNj\src\optimise\train.jl:100
 [41] top-level scope
    @ c:\Users\D\w7d\test_flux_e[ample.jl:47
in expression starting at c:\Users\D\w7d\test_flux_e[ample.jl:47

Operating System: Windows 10
Julia 1.6.5
VScode 1.63.2
Pkg.status

  [052768ef] CUDA v3.6.4
  [31a5f54b] Debugger v0.7.0
  [2b5f629d] DiffEqBase v6.81.0
  [41bf760c] DiffEqSensitivity v6.68.0
  [587475ba] Flux v0.12.8
  [5903a43b] Infiltrator v1.1.2
  [98e50ef6] JuliaFormatter v0.21.2
  [aa1ae85d] JuliaInterpreter v0.9.1
  [eb30cadb] MLDatasets v0.5.14
  [2774e3e8] NLsolve v4.5.1
  [1dea7af3] OrdinaryDiffEq v6.4.2
  [91a5bcdd] Plots v1.25.6
  [ee283ea6] Rebugger v0.2.2
  [9672c7b4] SteadyStateDiffEq v1.6.6
  [c3572dad] Sundials v4.9.1
  [e88e6eb3] Zygote v0.6.33
  [37e2e46d] LinearAlgebra
  [8dfed614] Test
@yadmtr yadmtr changed the title TypeErro in DEQ exampler: non-boolean (Nothing) used in boolean context TypeErro in DEQ example: non-boolean (Nothing) used in boolean context Jan 23, 2022
@ChrisRackauckas
Copy link
Member

@avik-pal

@avik-pal avik-pal self-assigned this Jan 23, 2022
@ChrisRackauckas
Copy link
Member

This works now, and for DEQs we have a whole library https://github.com/SciML/FastDEQ.jl

@avik-pal avik-pal transferred this issue from SciML/DiffEqFlux.jl Feb 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants