-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use jnp.nanmin and jnp.nanmax to compute new stepsize factor #235
base: main
Are you sure you want to change the base?
Conversation
…of jnp.clip Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error.
Okay, sorry for taking so long to get back around to this! One thing I would like to understand is why this approach failed. It looks like both possible fixes involve detecting NaNs. In particular you mentioned wanting to autodiff through these solvers. If one generates a NaN and then removes it later, then this can sometimes still re-appear on the backward pass. Being robust to such issues usually means catching the NaN as soon as possible. Are you able to track down how the NaN still sneaks by in the previous approach? I'd definitely like to get some version of this fix in. |
Sorry for the delay -- I'm finally getting back to this. Thanks for all the help! So the good news is that with my fix, the system can both be successfully solved and forward-mode autodiff'd (jax.jacfwd and jax.jvp). I verified the resulting Jacobian and JVP with finite-difference (using atol and rtol << parameter perturbations). That was with your old v0.2.2 using NoAdjoint(). Switching to your latest v0.3.1 without either my fix or yours, the system and forward-mode autodiff also are both successful (using DirectAdjoint()). I don't yet know exactly what changed in v0.3.1 to have this work out of the box now but it's great! What doesn't work is reverse-mode autodiff and I would like to know why. RecursiveCheckpointAdjoint() gives NaN gradients and BacksolveAdjoint() leads to a 'max_steps reached' error, even though the system is solved in 290 steps. Where in diffrax do you recommend putting jax.debug.print statements so I can see if something wonky is happening in solving my ODE system backwards in time? I checked ad.py, adjoint.py and integrate.py and I see a lot of equinox calls, but are there specific places where, e.g., I can print the value of the 8 state variable values and their time derivatives along the backward pass? |
Any updates on this one? |
Getting NaNs in reverse-mode was expected here I think -- papering over them in the forward pass doesn't remove them from the computation graph, so it's common to expect that they should reappear in the backward pass. I think the goal here should be to find where the NaNs arose in the forward pass in the first place. |
So well, my use case has a fastly growing function, and for too large step size the result is Nan. I think what happens is that currently this propagates to the step size and kills the solver. Applying this PR locally fixed the problem for me, that's why I'm asking |
Right! So I think the correct fix is not this PR, but instead to locate where the NaNs arise, and prevent them from ever being created. |
So lets say I have |
This is a good question. See also #200 for another take on this. Right now the answer is really to: I think this answer is a little dissatisfying. We should probably add |
Ok but what are the arguments against this PR? It will never break working code and may help in cases like mine. Propagating nans is also not the best way of debugging, the user is better off checking for nans in user functions if that is a concern. Minor slowdown? Maybe make this one optional then? |
I can see that this PR improves forward integration of systems only defined in some domain. However, my concern for this PR is that it when backpropagating through a solve which produces NaNs, then it delays the error from the forward pass to the backward pass. NaNs arising only on the backward pass are incredibly tricky to debug! |
…instead of of jnp.clip in diffrax.step_size_controller.adaptive.adapt_step_size()
Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error.
This solves #223 and is similar to my proposed bugfix / pull request for jax.experimental.ode.optimal_step_size() here: jax-ml/jax#14612 and jax-ml/jax#14624
I confirmed that this works with different explicit/implicit diffrax solvers and I get the expected correct solution to my non-autonomous ODE system vs. scipy.integrate.solve_ivp, manual non-adaptive Euler integration with extremely small timesteps, and manual adaptive RK23 (Bogacki-Shampine) solver in both pure python and JAX.