Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple ODE solves from a single time series on GPU #279

Closed
mkalia94 opened this issue Jun 3, 2020 · 13 comments · Fixed by #374
Closed

Multiple ODE solves from a single time series on GPU #279

mkalia94 opened this issue Jun 3, 2020 · 13 comments · Fixed by #374

Comments

@mkalia94
Copy link

mkalia94 commented Jun 3, 2020

I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.

Say we have a time series y of dimension 2x100 and I'd like to solve the ODE over small time intervals [0,10] using initial conditions y[:,1:10:end]. It works fine on the cpu using hcat([Array(solve(...))]...), however using the gpu gives me the error:

ERROR: LoadError: CuArray only supports bits types
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] CuArrays.CuArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}},1,P} where P(::UndefInitializer, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:106
 [3] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:139
 [4] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Int64) at ./abstractarray.jl:628
 [5] similar(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/tracked.jl:325
 [6] Zygote.Buffer(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/Zygote/YeCEW/src/tools/buffer.jl:42
 [7] lotka_volterra(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:15
[8] (::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::Vararg{Any,N} where N) at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/diffeqfunction.jl:248
 [9] (::DiffEqSensitivity.var"#67#74"{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:113
 [10] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:207
 [11] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:204
 [12] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}; quad::Bool, noiseterm::Bool) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:111
 [13] adjointdiffcache at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:26 [inlined]
 [14] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::Array{Float32,1}, ::Nothing, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:37
 [15] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; checkpoints::Array{Float32,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:115
 [16] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:17
 [17] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:reltol,),Tuple{Float64}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:6
 [18] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/concrete_solve.jl:107
 [19] #512#back at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [20] #174 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
 [21] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{DiffEqBase.var"#512#back#457"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [22] #solve#443 at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/solve.jl:69 [inlined]
 [23] (::typeof((#solve#443)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#174#175"{typeof((#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
 [25] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof((#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [26] (::typeof((solve##kw)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [27] predict_ODE_solve at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:47 [inlined]
 [28] (::typeof((predict_ODE_solve)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [29] #41 at ./none:0 [inlined]
 [30] (::typeof((λ)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [31] #1187 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172 [inlined]
 [32] #3 at ./generator.jl:36 [inlined]
 [33] iterate at ./generator.jl:47 [inlined]
 [34] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((λ)),1},NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}}},Base.var"#3#4"{Zygote.var"#1187#1191"}}) at ./array.jl:665
 [35] map at ./abstractarray.jl:2154 [inlined]
 [36] (::Zygote.var"#1186#1190"{Array{typeof((λ)),1}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172
 [37] (::Zygote.var"#1194#1195"{Zygote.var"#1186#1190"{Array{typeof((λ)),1}}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:187
 [38] loss_func at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:54 [inlined]
 [39] (::typeof((loss_func)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [40] #16 at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:85 [inlined]
 [41] (::typeof((λ)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#49#50"{Params,Zygote.Context,typeof((λ))})(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [43] gradient(::Function, ::Params) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [44] macro expansion at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:84 [inlined]
 [45] macro expansion at /home/manu/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [46] train!(::typeof(loss_func), ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [47] train!(::Function, ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM) at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
 [48] top-level scope at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:69
 [49] include(::String) at ./client.jl:439
in expression starting at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:66

here is the code:

using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations

# Training hyperparameters
nBatchsize = 10
nEpochs = 10
tspan = 5.0
tsize = 100
nbatches = div(tsize,nBatchsize)

# ODE solve
function lotka_volterra(du,u,p,t)
    dx = Zygote.Buffer(u,size(u)[1])
    dx[1] = u[1]*(p[1]-p[2]*u[2])
    dx[2] = u[2]*(p[3]*u[1]-p[4])
    du .= copy(dx)
    nothing
end

# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4] 
u0 = Float32[0.01, 0.01]

t = range(0.0,tspan,length=tsize)

# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data

data = Float32.(yy) |> gpu

# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu

# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0,(0.0f0,Float32(tspan/nbatches)),p)

# ODE solve to be used for training
function predict_ODE_solve(x)
    return Array(solve(prob2,Tsit5(),u0=x,saveat=t_batch,reltol=1e-4)) 
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    # Solve ODE using initial values from multiple points in enc_.
    # Note: reduce(hcat,[..]) gives a mutating arrays error
    enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) |> gpu
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

opt = ADAM(0.001)

loss_func(data) # This works

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

and here is the current status of packages

[c7e460c6] ArgParse v1.1.0
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [c5f51814] CUDAdrv v6.3.0
  [be33ccc6] CUDAnative v3.1.0
  [3a865a2d] CuArrays v2.2.0
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.12.0
  [41bf760c] DiffEqSensitivity v6.19.0
  [0c46a032] DifferentialEquations v6.14.0
  [31c24e10] Distributions v0.23.3
  [5789e2e9] FileIO v1.3.0
  [587475ba] Flux v0.10.4
  [0c68f7d7] GPUArrays v3.4.1
  [033835bb] JLD2 v0.1.13
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.39.1
  [91a5bcdd] Plots v1.3.6
  [8d666b04] PolyChaos v0.2.1
  [ee283ea6] Rebugger v0.3.3
  [295af30f] Revise v2.7.1
  [9f7883ad] Tracker v0.2.6
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Jun 5, 2020

Hey,
This is fundamentally the wrong approach. What you're looking for is https://github.com/SciML/DiffEqGPU.jl . DiffEqGPU.jl will take small ODE problems and solve them simultaneously for different parameters and initial conditions on a GPU. You'll want to mix this with forward mode AD (i.e. use sensealg=ForwardDiffSensitivity()) which will be more efficient for parameter numbers of this amount per ODE.

@mkalia94
Copy link
Author

mkalia94 commented Jun 8, 2020

Ah I see. I tried DiffEqGPU.jl to no avail, I opened a separate issue there regarding GPU usage SciML/DiffEqGPU.jl#57
I get the following error when I try to use EnsembleThreads() with or without ForwardDiffSensitivity

 Need an adjoint for constructor EnsembleSolution{Float32,3,Array{ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Tuple{Float32,Float32,Float32},ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},1}}. Gradient is of type Array{Float64,3}

while using the following code:

using Pkg
Pkg.activate(".")
using DiffEqGPU, DifferentialEquations, Flux, DiffEqSensitivity
function lorenz(du,u,p,t)
 @inbounds begin
     du[1] = p[1]*(u[2]-u[1])
     du[2] = u[1]*(p[2]-u[3]) - u[2]
     du[3] = u[1]*u[2] - p[3]*u[3]
 end
 nothing
end

nBatchsize = 10
nEpochs = 10
tspan = 100.0
tsize = 1000
nbatches = div(tsize,nBatchsize)
t = range(0.0,tspan,length=tsize)

NN_encode = Chain(Dense(3,10,tanh),Dense(10,10,tanh),Dense(10,3,tanh))
NN_decode = Chain(Dense(3,10),Dense(10,10,tanh),Dense(10,3,tanh))

u0 = Float32[1.0, 0.0, 0.0]
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy)

args_ = Dict()

t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lorenz,u0,(0.0f0,Float32(tspan/nbatches)),p)

function ensemble_solve(x)
    prob_func = (prob,i,repeat) -> remake(prob,u0=x[:,i])
    monteprob = EnsembleProblem(prob2, prob_func = prob_func)
    return  Array(solve(monteprob,Tsit5(),EnsembleThreads(),trajectories=nbatches,saveat=t_batch,sensealg=ForwardDiffSensitivity()))
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    enc_ODE_solve = hcat([ensemble_solve(enc_[:,1:nBatchsize:end])[:,:,i] for i in 1:nbatches]...)
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

ensemble_solve(data)

loss_func(data)

opt = ADAM(0.001)

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Jun 11, 2020

Hmm, Zygote is having issues here.

using ZygoteRules
ZygoteRules.@adjoint function EnsembleSolution(sim,time,converged)
    EnsembleSolution(sim,time,converged),y->(y,nothing,nothing)
end

gets you pretty far. The next issue you hit is that Zygote doesn't seem to work with @threads, so I just changed it to EnsembleSerial() to keep moving. Then, Zygote wasn't able to compile code with @warn (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L125-L152) because that has a try-catch in there, so I commented those out. That made it start calling the adjoint equation, which then errored for a reason I don't understand yet, but that's very close.

So the way to solve this would be to fix:

  • Compatibility of Zygote with @warn
  • Compatibility of Zygote with pmap tmap etc.
  • Add that adjoint rule to the library
  • Debug what's going on with the adjoint that's left, but I think this is a small issue.

The first two should get MWEs on Zygote. That adjoint should get added to DiffEq, and @dhairyagandhi96 I might want help debugging the last part. This is generally something that would be good to have, and the knock-on effects of fixing this case are likely very valuable (pmap/tmap adjoints are probably more broadly useful, so we should look at that first).

@mkalia94
Copy link
Author

mkalia94 commented Jun 17, 2020

I see, I'm afraid I'm presently incapacitated to dig through the Zygote issue. I do however, have a temporary workaround on GPU whilst training with initial conditions as well. The RHS can be better defined but gradients are being computed successfully:

using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations

# Training hyperparameters
nBatchsize = 30
nEpochs = 10
tspan = 20.0
tsize = 300
nbatches = div(tsize,nBatchsize)
statesize = 2

# ODE solve
function lotka_volterra_func!(du,u,p,t)
    du[1] = u[1]*(p[1]-p[2]*u[2])
    du[2] = u[2]*(p[3]*u[1]-p[4])
    return du
end

function lotka_volterra(du,u,p,t)
    dx = Zygote.Buffer(du,size(du)[1])
    for i in 1:nbatches
        index = (i-1)*statesize
        dx[index+1:index+statesize] = lotka_volterra_func!(dx[index+1:index+statesize],u[index+1:index+statesize],p,t)
    end
    du .= copy(dx)
    nothing
end

# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4] 
u0 = Float32[0.01, 0.01]

t = range(0.0,tspan,length=tsize)

# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra_func!,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data

data = Float32.(yy) |> gpu

# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
u0_train = rand(statesize*nbatches)

# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0_train,(0.0f0,Float32(tspan/nbatches)),p)

#ODE solve to be used for training
function predict_ODE_solve()
    return Array(solve(prob2,Tsit5(),saveat=t_batch,reltol=1e-4)) 
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    # Solve ODE using initial values from multiple points in enc_.
    # Note: reduce(hcat,[..]) gives a mutating arrays error
    #enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) #|> gpu
    enc_ODE_solve = hcat([predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches]...) |> gpu
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

opt = ADAM(0.001)

loss_func(data) # This works

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode,u0_train), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

My package list reads:

  [c7e460c6] ArgParse v1.1.0
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [c5f51814] CUDAdrv v6.3.0
  [be33ccc6] CUDAnative v3.1.0
  [3a865a2d] CuArrays v2.2.0
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.12.0
  [41bf760c] DiffEqSensitivity v6.19.0
  [0c46a032] DifferentialEquations v6.14.0
  [31c24e10] Distributions v0.23.3
  [5789e2e9] FileIO v1.3.0
  [587475ba] Flux v0.10.4
  [0c68f7d7] GPUArrays v3.4.1
  [033835bb] JLD2 v0.1.13
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.39.1
  [91a5bcdd] Plots v1.3.6
  [8d666b04] PolyChaos v0.2.1
  [ee283ea6] Rebugger v0.3.3
  [295af30f] Revise v2.7.1
  [9f7883ad] Tracker v0.2.6
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

@ChrisRackauckas
Copy link
Member

hcat(...)

This can be reduce(hcat,[predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches])

@ChrisRackauckas
Copy link
Member

Update: Zygote is now compatible with parallelism (FluxML/Zygote.jl#728) so this should be possible now.

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas
Copy link
Member

using DifferentialEquations, Flux

pa = [1.0]

function model1(input) 
  prob = ODEProblem((u, p, t) -> 1.01u * pa[1], 0.5, (0.0, 1.0))
  
  function prob_func(prob, i, repeat)
    remake(prob, u0 = rand() * prob.u0)
  end
  
  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 100)
end

Input_time_series = zeros(5, 100)
# loss function
loss(x, y) = Flux.mse(model1(x), y)

data = Iterators.repeated((Input_time_series, 0), 1)

cb = function () # callback function to observe training
  println("Tracked Parameters: ", params(pa))
end


opt = ADAM(0.1)
println("Starting to train")
Flux.@epochs 10 Flux.train!(loss, params(pa), data, opt; cb = cb)

is a good MWE

@ChrisRackauckas
Copy link
Member

@DhairyaLGandhi can I get help completing this? It just needs some tweaks to the adjoint definitions now:

using OrdinaryDiffEq, DiffEqSensitivity, Flux,

using ZygoteRules
ZygoteRules.@adjoint EnsembleSolution(sim,time,converged) = EnsembleSolution(sim,time,converged), p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) = sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)

pa = [1.0]
u0 = [1.0]
function model1()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = rand() .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end

# loss function
loss() = sum(abs2,1.0.-model1())

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end


opt = ADAM(0.05)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test l1 < 10l2

Then EnsembleDistributed() work because of FluxML/Zygote.jl#728

@ChrisRackauckas
Copy link
Member

#321 is somewhat related, since we really need adjoints of the solution types and the literal_getproperty calls.

This almost works if you comment out the warns (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L132) in there.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jul 15, 2020

The two adjoints defined in #279 (comment)?

I can take a look in a bit, I'm a bit flocked at the moment.

Edit: you probably want (EnsembleSolution(...), nothing) for the second adjoint at a quick glance?

@ChrisRackauckas
Copy link
Member

alright thanks

@ChrisRackauckas
Copy link
Member

SciML/DiffEqBase.jl#557 fixes this. Final test:

using OrdinaryDiffEq, DiffEqSensitivity, Flux
pa = [1.0]
u0 = [3.0]
function model1()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
end

# loss function
loss() = sum(abs2,1.0.-Array(model1()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()
end

opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1

function model2()
  prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
  end

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()])

pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1

ChrisRackauckas added a commit that referenced this issue Jul 27, 2020
ChrisRackauckas added a commit that referenced this issue Jul 27, 2020
* test the fixed ensemble parallelism

Fixes #279

* update the first two tutorials
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants