Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use jnp.nanmin and jnp.nanmax to compute new stepsize factor #235

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

virajpandya
Copy link

…instead of of jnp.clip in diffrax.step_size_controller.adaptive.adapt_step_size()

Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error.

This solves #223 and is similar to my proposed bugfix / pull request for jax.experimental.ode.optimal_step_size() here: jax-ml/jax#14612 and jax-ml/jax#14624

I confirmed that this works with different explicit/implicit diffrax solvers and I get the expected correct solution to my non-autonomous ODE system vs. scipy.integrate.solve_ivp, manual non-adaptive Euler integration with extremely small timesteps, and manual adaptive RK23 (Bogacki-Shampine) solver in both pure python and JAX.

…of jnp.clip

Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip.
According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than 
jnp.clip but at least it will be robust against NaN's in y_error.
@patrick-kidger
Copy link
Owner

Okay, sorry for taking so long to get back around to this!

One thing I would like to understand is why this approach failed. It looks like both possible fixes involve detecting NaNs.

In particular you mentioned wanting to autodiff through these solvers. If one generates a NaN and then removes it later, then this can sometimes still re-appear on the backward pass. Being robust to such issues usually means catching the NaN as soon as possible.

Are you able to track down how the NaN still sneaks by in the previous approach? I'd definitely like to get some version of this fix in.

@virajpandya
Copy link
Author

virajpandya commented Mar 23, 2023

Sorry for the delay -- I'm finally getting back to this. Thanks for all the help! So the good news is that with my fix, the system can both be successfully solved and forward-mode autodiff'd (jax.jacfwd and jax.jvp). I verified the resulting Jacobian and JVP with finite-difference (using atol and rtol << parameter perturbations). That was with your old v0.2.2 using NoAdjoint(). Switching to your latest v0.3.1 without either my fix or yours, the system and forward-mode autodiff also are both successful (using DirectAdjoint()). I don't yet know exactly what changed in v0.3.1 to have this work out of the box now but it's great!

What doesn't work is reverse-mode autodiff and I would like to know why. RecursiveCheckpointAdjoint() gives NaN gradients and BacksolveAdjoint() leads to a 'max_steps reached' error, even though the system is solved in 290 steps.

Where in diffrax do you recommend putting jax.debug.print statements so I can see if something wonky is happening in solving my ODE system backwards in time? I checked ad.py, adjoint.py and integrate.py and I see a lot of equinox calls, but are there specific places where, e.g., I can print the value of the 8 state variable values and their time derivatives along the backward pass?

@Randl
Copy link
Contributor

Randl commented Aug 6, 2024

Any updates on this one?

@patrick-kidger
Copy link
Owner

Getting NaNs in reverse-mode was expected here I think -- papering over them in the forward pass doesn't remove them from the computation graph, so it's common to expect that they should reappear in the backward pass. I think the goal here should be to find where the NaNs arose in the forward pass in the first place.

@Randl
Copy link
Contributor

Randl commented Aug 6, 2024

So well, my use case has a fastly growing function, and for too large step size the result is Nan. I think what happens is that currently this propagates to the step size and kills the solver. Applying this PR locally fixed the problem for me, that's why I'm asking

@patrick-kidger
Copy link
Owner

Right! So I think the correct fix is not this PR, but instead to locate where the NaNs arise, and prevent them from ever being created.

@Randl
Copy link
Contributor

Randl commented Aug 6, 2024

So lets say I have dy/dt = f(y) where f is not defined for whole R but rather for some subset U. Of course the dynamics are such that y is never outside of U. Yet nothing prevents the solver (and specifically step size controller) from trying values outside U, especially if the solution approaches the boundary. What would be the solution (except praying that the step size is not too large of course)?

@patrick-kidger
Copy link
Owner

This is a good question. See also #200 for another take on this.

Right now the answer is really to:
(a) write your vector field such that it returns non-NaN values for all possible inputs [as above, the moment you get a NaN then it might all be over]
(b) use the InBoundsSolver of #200 to ensure you stay within the valid region.

I think this answer is a little dissatisfying. We should probably add InBoundsSolver into Diffrax itself.
I am also trying to figure out a way to make (a) happen automatically -- or maybe some way that we can be robust to NaNs in all cases. Getting that right is a tricky thing for a user to do.

@Randl
Copy link
Contributor

Randl commented Aug 11, 2024

Ok but what are the arguments against this PR? It will never break working code and may help in cases like mine. Propagating nans is also not the best way of debugging, the user is better off checking for nans in user functions if that is a concern. Minor slowdown? Maybe make this one optional then?

@patrick-kidger
Copy link
Owner

I can see that this PR improves forward integration of systems only defined in some domain.

However, my concern for this PR is that it when backpropagating through a solve which produces NaNs, then it delays the error from the forward pass to the backward pass. NaNs arising only on the backward pass are incredibly tricky to debug!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants