You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Feb 28, 2022. It is now read-only.
sol_lazy =solve(solver,op)
# We need a forward solution that can be iterated in reverse
sol =collect(sol_lazy) # Option 1
sol =CeckpointedSol(sol_lazy) # Option 2
sol_reversed =reverse(sol)
for (uh_n,t_n) in sol_reversed
end# Adjoint (from the sol that can be iterated in reverse)
adj_op =Adjoint(op,sol,j_u(sol))
# Lazy adjoint solution
adj_sol =solve(solve,adj_op)
# Iteration of the adjoint backwards in timefor (adj_uh_n, t_n) in adj_sol
end# If needed, reverse iteration of forward and adjoint solutionsfor ((adj_uh_n, t_n), (uh_n, t_n) )in zip(adj_sol,sol_reversed)
end
To create a CheckpointedSolution of a lazy ODESolution. The obvious case is store all steps, but we could consider other more advanced constructors that accept a function that says whether a step must be stored or not in the future.
The CheckpointedSolution should iterate as a standard ODESolution BUT with the only difference that it does not compute steps that have already been checkpointed.
checkp_sol =CheckpointedSolution(sol,strategy) # Optional strategy, default all
Create an Adjoint of a ODEOperator. I will work on the details of this implementation. One of the inputs of its constructor is the CheckpointedSolution.
adj_op =Adjoint(op,checkp_sol)
To be able to reverse the iteration of a CheckpointedSolution, which requires to compute and store all time steps between CheckpointedSolution in one shot, and eliminate all this data when we move to the next
checkpoint interval (in reverse mode). We could think about a trait Forward and Reverse mode for CheckpointedSolution that would modify the iteration process.
sol_reversed =reverse(checkp_sol)
To create a lazy adjoint solution, that is iterated backward in time, but analogous to the forward ODESolution. Internally, it makes use of the reversed checkpointed solution.
# Iteration of the adjoint backwards in timefor (adj_uh_n, uh_n, t_n) in adj_sol
# here your adjoint sol# but also the sol'on at t_nend
This way, we have both the forward solution in the right direction, the adjoint solution in the right direction, and the forward solution at the same time step as the adjoint solution (since it comes for free).
With the DiffEq wrapper, the only thing left I think you need for the adjoints is to make use of p. Some of the functions won't vjp with Zygote though so it would need to use Array of struct ReverseDiff, but these functions should be compatible with ReverseDiffVJP(true).
cc @santiagobadia @oriolcg
The text was updated successfully, but these errors were encountered: