Skip to content
This repository has been archived by the owner on Feb 28, 2022. It is now read-only.

Discussion on time-dependent adjoints #28

Open
fverdugo opened this issue Jul 27, 2020 · 2 comments
Open

Discussion on time-dependent adjoints #28

fverdugo opened this issue Jul 27, 2020 · 2 comments

Comments

@fverdugo
Copy link
Member

sol_lazy =solve(solver,op)

# We need a forward solution that can be iterated in reverse
sol = collect(sol_lazy) # Option 1
sol = CeckpointedSol(sol_lazy) # Option 2
sol_reversed = reverse(sol)

for (uh_n,t_n) in sol_reversed

end

#  Adjoint (from the sol that can be iterated in reverse)
adj_op = Adjoint(op,sol,j_u(sol))

# Lazy adjoint solution
adj_sol = solve(solve,adj_op)

# Iteration of the adjoint backwards in time
for (adj_uh_n, t_n) in adj_sol

end

# If needed, reverse iteration of forward and adjoint solutions
for ((adj_uh_n, t_n), (uh_n, t_n) )in zip(adj_sol,sol_reversed)

end

cc @santiagobadia @oriolcg

@santiagobadia
Copy link
Member

@oriolcg and @fverdugo

I would consider the following steps:

  1. To create a CheckpointedSolution of a lazy ODESolution. The obvious case is store all steps, but we could consider other more advanced constructors that accept a function that says whether a step must be stored or not in the future.

  2. The CheckpointedSolution should iterate as a standard ODESolution BUT with the only difference that it does not compute steps that have already been checkpointed.

checkp_sol = CheckpointedSolution(sol,strategy) # Optional strategy, default all
  1. Create an Adjoint of a ODEOperator. I will work on the details of this implementation. One of the inputs of its constructor is the CheckpointedSolution.
adj_op = Adjoint(op,checkp_sol)
  1. To be able to reverse the iteration of a CheckpointedSolution, which requires to compute and store all time steps between CheckpointedSolution in one shot, and eliminate all this data when we move to the next
    checkpoint interval (in reverse mode). We could think about a trait Forward and Reverse mode for CheckpointedSolution that would modify the iteration process.
sol_reversed = reverse(checkp_sol)
  1. To create a lazy adjoint solution, that is iterated backward in time, but analogous to the forward ODESolution. Internally, it makes use of the reversed checkpointed solution.
# Iteration of the adjoint backwards in time
for (adj_uh_n, uh_n, t_n) in adj_sol
  # here your adjoint sol
  # but also the sol'on at t_n
end

This way, we have both the forward solution in the right direction, the adjoint solution in the right direction, and the forward solution at the same time step as the adjoint solution (since it comes for free).

@ChrisRackauckas
Copy link

With the DiffEq wrapper, the only thing left I think you need for the adjoints is to make use of p. Some of the functions won't vjp with Zygote though so it would need to use Array of struct ReverseDiff, but these functions should be compatible with ReverseDiffVJP(true).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants