-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple ODE solves from a single time series on GPU #279
Comments
Hey, |
Ah I see. I tried DiffEqGPU.jl to no avail, I opened a separate issue there regarding GPU usage SciML/DiffEqGPU.jl#57 Need an adjoint for constructor EnsembleSolution{Float32,3,Array{ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Tuple{Float32,Float32,Float32},ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lorenz),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},1}}. Gradient is of type Array{Float64,3} while using the following code: using Pkg
Pkg.activate(".")
using DiffEqGPU, DifferentialEquations, Flux, DiffEqSensitivity
function lorenz(du,u,p,t)
@inbounds begin
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
end
nothing
end
nBatchsize = 10
nEpochs = 10
tspan = 100.0
tsize = 1000
nbatches = div(tsize,nBatchsize)
t = range(0.0,tspan,length=tsize)
NN_encode = Chain(Dense(3,10,tanh),Dense(10,10,tanh),Dense(10,3,tanh))
NN_decode = Chain(Dense(3,10),Dense(10,10,tanh),Dense(10,3,tanh))
u0 = Float32[1.0, 0.0, 0.0]
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy)
args_ = Dict()
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lorenz,u0,(0.0f0,Float32(tspan/nbatches)),p)
function ensemble_solve(x)
prob_func = (prob,i,repeat) -> remake(prob,u0=x[:,i])
monteprob = EnsembleProblem(prob2, prob_func = prob_func)
return Array(solve(monteprob,Tsit5(),EnsembleThreads(),trajectories=nbatches,saveat=t_batch,sensealg=ForwardDiffSensitivity()))
end
function loss_func(data_)
enc_ = NN_encode(data_)
enc_ODE_solve = hcat([ensemble_solve(enc_[:,1:nBatchsize:end])[:,:,i] for i in 1:nbatches]...)
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
end
ensemble_solve(data)
loss_func(data)
opt = ADAM(0.001)
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
end |
Hmm, Zygote is having issues here. using ZygoteRules
ZygoteRules.@adjoint function EnsembleSolution(sim,time,converged)
EnsembleSolution(sim,time,converged),y->(y,nothing,nothing)
end gets you pretty far. The next issue you hit is that Zygote doesn't seem to work with So the way to solve this would be to fix:
The first two should get MWEs on Zygote. That adjoint should get added to DiffEq, and @dhairyagandhi96 I might want help debugging the last part. This is generally something that would be good to have, and the knock-on effects of fixing this case are likely very valuable (pmap/tmap adjoints are probably more broadly useful, so we should look at that first). |
I see, I'm afraid I'm presently incapacitated to dig through the Zygote issue. I do however, have a temporary workaround on GPU whilst training with initial conditions as well. The RHS can be better defined but gradients are being computed successfully: using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations
# Training hyperparameters
nBatchsize = 30
nEpochs = 10
tspan = 20.0
tsize = 300
nbatches = div(tsize,nBatchsize)
statesize = 2
# ODE solve
function lotka_volterra_func!(du,u,p,t)
du[1] = u[1]*(p[1]-p[2]*u[2])
du[2] = u[2]*(p[3]*u[1]-p[4])
return du
end
function lotka_volterra(du,u,p,t)
dx = Zygote.Buffer(du,size(du)[1])
for i in 1:nbatches
index = (i-1)*statesize
dx[index+1:index+statesize] = lotka_volterra_func!(dx[index+1:index+statesize],u[index+1:index+statesize],p,t)
end
du .= copy(dx)
nothing
end
# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4]
u0 = Float32[0.01, 0.01]
t = range(0.0,tspan,length=tsize)
# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra_func!,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy) |> gpu
# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
u0_train = rand(statesize*nbatches)
# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0_train,(0.0f0,Float32(tspan/nbatches)),p)
#ODE solve to be used for training
function predict_ODE_solve()
return Array(solve(prob2,Tsit5(),saveat=t_batch,reltol=1e-4))
end
function loss_func(data_)
enc_ = NN_encode(data_)
# Solve ODE using initial values from multiple points in enc_.
# Note: reduce(hcat,[..]) gives a mutating arrays error
#enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) #|> gpu
enc_ODE_solve = hcat([predict_ODE_solve()[(i-1)*statesize+1:i*statesize,:] for i in 1:nbatches]...) |> gpu
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
end
opt = ADAM(0.001)
loss_func(data) # This works
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode,u0_train), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
end My package list reads: [c7e460c6] ArgParse v1.1.0
[fbb218c0] BSON v0.2.6
[6e4b80f9] BenchmarkTools v0.5.0
[3895d2a7] CUDAapi v4.0.0
[c5f51814] CUDAdrv v6.3.0
[be33ccc6] CUDAnative v3.1.0
[3a865a2d] CuArrays v2.2.0
[31a5f54b] Debugger v0.6.4
[aae7a2af] DiffEqFlux v1.12.0
[41bf760c] DiffEqSensitivity v6.19.0
[0c46a032] DifferentialEquations v6.14.0
[31c24e10] Distributions v0.23.3
[5789e2e9] FileIO v1.3.0
[587475ba] Flux v0.10.4
[0c68f7d7] GPUArrays v3.4.1
[033835bb] JLD2 v0.1.13
[429524aa] Optim v0.21.0
[1dea7af3] OrdinaryDiffEq v5.39.1
[91a5bcdd] Plots v1.3.6
[8d666b04] PolyChaos v0.2.1
[ee283ea6] Rebugger v0.3.3
[295af30f] Revise v2.7.1
[9f7883ad] Tracker v0.2.6
[e88e6eb3] Zygote v0.4.20
[9a3f8284] Random
|
This can be |
Update: Zygote is now compatible with parallelism (FluxML/Zygote.jl#728) so this should be possible now. |
Same as user on Discourse: https://discourse.julialang.org/t/error-loaderror-need-an-adjoint-for-constructor-ensemblesolution/42611/7 |
using DifferentialEquations, Flux
pa = [1.0]
function model1(input)
prob = ODEProblem((u, p, t) -> 1.01u * pa[1], 0.5, (0.0, 1.0))
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 100)
end
Input_time_series = zeros(5, 100)
# loss function
loss(x, y) = Flux.mse(model1(x), y)
data = Iterators.repeated((Input_time_series, 0), 1)
cb = function () # callback function to observe training
println("Tracked Parameters: ", params(pa))
end
opt = ADAM(0.1)
println("Starting to train")
Flux.@epochs 10 Flux.train!(loss, params(pa), data, opt; cb = cb) is a good MWE |
@DhairyaLGandhi can I get help completing this? It just needs some tweaks to the adjoint definitions now: using OrdinaryDiffEq, DiffEqSensitivity, Flux,
using ZygoteRules
ZygoteRules.@adjoint EnsembleSolution(sim,time,converged) = EnsembleSolution(sim,time,converged), p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) = sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true), 0.0, true)
pa = [1.0]
u0 = [1.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
# loss function
loss() = sum(abs2,1.0.-model1())
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
opt = ADAM(0.05)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test l1 < 10l2 Then |
#321 is somewhat related, since we really need adjoints of the solution types and the literal_getproperty calls. This almost works if you comment out the warns (https://github.com/SciML/DiffEqBase.jl/blob/master/src/ensemble/basic_ensemble_solve.jl#L132) in there. |
The two adjoints defined in #279 (comment)? I can take a look in a bit, I'm a bit flocked at the moment. Edit: you probably want |
alright thanks |
SciML/DiffEqBase.jl#557 fixes this. Final test: using OrdinaryDiffEq, DiffEqSensitivity, Flux
pa = [1.0]
u0 = [3.0]
function model1()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100)
end
# loss function
loss() = sum(abs2,1.0.-Array(model1()))
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1
function model2()
prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100).u
end
loss() = sum(abs2,[sum(abs2,1.0.-u) for u in model2()])
pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
l1 = loss()
Flux.@epochs 10 Flux.train!(loss, params([pa,u0]), data, opt; cb = cb)
l2 = loss()
@test 10l2 < l1 |
* test the fixed ensemble parallelism Fixes #279 * update the first two tutorials
I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.
Say we have a time series
y
of dimension2x100
and I'd like to solve the ODE over small time intervals[0,10]
using initial conditionsy[:,1:10:end]
. It works fine on the cpu usinghcat([Array(solve(...))]...)
, however using the gpu gives me the error:here is the code:
and here is the current status of packages
Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)
The text was updated successfully, but these errors were encountered: