-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added JumpStepWrapper #484
base: main
Are you sure you want to change the base?
Conversation
78b122a
to
0eac356
Compare
I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have diffrax/diffrax/_step_size_controller/adaptive.py Lines 569 to 574 in 501bed5
I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.
I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere. P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper. |
d022ac1
to
4702380
Compare
Hi @patrick-kidger, diffrax/benchmarks/jump_step_timing.py Lines 126 to 128 in 345e23a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, quick first pass at a review!
"Maximum number of rejected steps reached. " | ||
"Consider increasing JumpStepWrapper.rejected_step_buffer_len.", | ||
) | ||
rjct_buff = jnp.where(keep_step, rjct_buff, rjct_buff.at[i_rjct].set(t1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is a very expensive way to describe this operation! You're copying the whole buffer. XLA will sometimes optimize this out -- because I added that optimization to it! -- but not always.
Better is to do rjct_buff.at[i_rjct].set(jnp.where(keep_step, rjct_buff[i_rjct], t1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than that, I think we may need to extend the API here slightly -- we should be able to mark state like this as being a buffer for the purposes of:
Line 621 in 0679807
final_state = outer_while_loop( |
which is needed to avoid spurious copies during backpropagation.
(You can see that both of these comments are basically us having to work around limitations of the XLA compiler.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first one makes sense, I should have seen this.
I don't really know what you want me to do in your second comment. And frankly diffeqsolve
is something I haven't even started digesting yet. Are you telling me rejected_buffer
should be one of the outer_buffers
, meaning that I should make it an instance of SaveState
or sth like that? I would apprecaite a bit more guidance.
Also damn how you managed to write all this code is beyond me. Even trying to begin understanding it seems a lot! Very impressive!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, you're too kind!
As for my second comment -- I think I've realised that I was wrong. Let me explain. Our backpropagation involves saving copies of our state in checkpoints. Let's suppose we set RecursiveCheckpointAdjoint(checkpoints=max_steps)
, so that's O(max_steps)
memory right? Well, not quite: our updating buffer here is potentially of length max_steps
(as per the debate above), and we're saving a copy of it in every checkpont, so we'd actually be using O(max_steps^2)
memory! That's not acceptable.
The simple solution to this will just be to set the size of this buffer to e.g. 100 by default, and just allow those copies to be made. And given the behaviour you have here -- in which you potentially overwrite values -- then that is actually what's necessary as well.
As for the complicated solution that I was wrong about: let's consider the case of SaveAt(steps=True)
. This also involves a buffer of length max_steps
, that we save into as we go along. Fortunately, this one has a useful extra property, which is that we never overwrite a value. That means we don't actually need to copy our buffer for every checkpoint! We can use a single buffer that is shared across all checkpoints, getting gradually filled in. To support this case then we actually have a special argument eqxi.while_loop(..., buffers=...)
, to declare which elements of our loop state have this behaviour. Unfortunately that's not the case here because we do overwrite the values. (And side-note the presence of this buffers
parameter is the reason I've not made this public API in Equinox, because the buffer-ness is completely unchecked and it's very easy to shoot yourself in the foot.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I see. Thanks for the in-depth explanation! So let's see if I understand this correctly. If this was not getting rewritten, then I should make it register as a buffer in the outer while loop of diffeqsolve. But, because it does get rewritten, I should not do that(??). Still, I am curious, if I did want to register it as a buffer, how would I accomplish that? Is it indeed by making it an instance of SaveState
, or is it something else entirely?
Other than that, should I keep it an Optional[Int]
and just add something like this to the docstring:
For most SDEs, setting this to `100` should be sufficient, but if more consecutive steps are rejected, then an error will be raised.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup that's correct. If you have an array that is getting filled-in then registering it as a buffer will mean that a single copy is used across all checkpoints. (The checkpoints in eqxi.while_loop
, for later backprop.)
But as we're overwriting values here then we actually must keep separate copies of it in each checkpoint, just like every other value that changes from step-to-step of the while loop.
To register something as a buffer it must be specified in eqxi.while_loop(..., buffers=....)
, see here:
Line 627 in 467d95f
cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers |
As for the message you've suggested -- this looks good to me.
Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of |
0050fa2
to
c3c4dcf
Compare
If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.
I've just done another revivew, LMK what you think!
|
||
def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike: | ||
i_min_len = jnp.minimum(i, len(ts) - 1) | ||
return jnp.where(i == len(ts), jnp.inf, ts[i_min_len]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the inf
here -- can you add a test for using this with a backward solve with t0 > t1
? Just to make sure that we're correctly handling that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For each of the tests (except backprop) in test_adaptive_stepsize_controller.py
I added @pytest.mark.parametrize("backwards", [False, True])
which just swaps the order of t0
and t1
. This also revealed two other problems:
step_ts
andjump_ts
need to be sorted and in particular need to be re-sorted after multiplying them withdirection
inwrap
. I added sorting both inJSW.__init__
andJSW.wrap
.- VBT complains if
t0 >= t1
. I don't think this is necessarily a problem but it could be confusing to some users, so let me know if you want to revisit that design decision and I can try to do something about it.
But other than that it all seems to work perfectly when solving backwards. To be fair the way you dealt with the whole backwards solve business is very clean and makes this work perfectly without any alterations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, great! Then I'm very glad to have these tests :D
For VBT I think I'm happy with either approach -- autoswitching t0
, t1
or just raising an error. No strong feelings on what is the better UX.
keep_step, | ||
next_t0, | ||
next_t1, | ||
_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think discarding here is correct. We should do the right thing even if we have a doubly-nested JumpStepWrapper(JumpStepWraper(PIDController(...), ...), ...)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've given this a hard think and I think it is nearly impossible to do this perfectly. We can maintain correctness of the DE solution by allowing made_jump
to sometimes be a false positive (e.g. the inner JSW recorded jump_next_step=True
in the previous step, but then the outer JSW further clipped the proposal, so the jump didn't actually happen). That just makes us do one unnecessary VF evaluation at the fanthom jump point, but if I understand correctly the final solution should still be correct. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following on from my main comment below -- perhaps if the this jump-next-step business is handled in integrate.py
, then we can make each stepsize controller just report whether or not they clipped the step, with these wrappers just |
ing things together on whether something was clipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I handled this all inside the JSW in the new commit and I wrote a long comment about why I think the implementation is correct. But please do double check it.
if step_ts is not None: | ||
# If we stepped to `t1 == step_ts[i_step]` and kept the step, then we | ||
# increment i_step and move on to the next t in step_ts. | ||
step_inc_cond = keep_step & (t1 == _get_t(i_step, step_ts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I'm comfortable with this ==
between floating point numbers.
More generally speaking I think there could be cases in which the times being passed here do not perfectly align with the times that the adaptive step size controller suggested on the previous step (e.g. because of further wrapping of the step size controller), so I think this kind of logic is wrong anyway. I think you need something more like the jump_ts
branch below, where you just want to snap i_step
to the correct value. (Nothing that the correct value should be determinable statically, we only have state here for efficiency purposes.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented a linear search as you suggest and now use it to determine i_step
and i_jump
.
However, I think that for the rejected step business using linear search is incorrect, because t1
should never be greater than rejected_t = _get_t(i_reject, rejected_buffer)
. So depending on what you think is more appropriate we can either use ==
or jnp.isclose(t1, rejected_t, atol=1e-12)
. I also added an eqxi.error_if
(see below), but I can remove it if you don't think it is necessary. I could also add a parameter that only activates this callback when we are doing tests. Let me know.
i_reject = eqx.error_if(
i_reject,
t1 > rejected_t,
"Jumped over a rejected time. Please report this as a bug.",
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW if you want error_if
here then I'd make it a testing-only branch. During regular user runtime I try to avoid using it for these kinds of asserts as it's actually quite slow.
next_jump_t = _get_t(i_jump, jump_ts) | ||
jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t)) | ||
i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, if someone else is fiddling with the times then this seems to me like it might be fragile.
Recalling the previous implementation with its inefficient use of searchsorted. I think the robust approach here might be to write something with the same API as that, but whose implementation is just a simple linear search forwards or backwards from the current position (which is a 'hint' about where to start searching).
Most of the time that will just iterate once and be done, as here. But in the edge cases it should now do the right thing.
Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship. |
I am very confused about what the correct value of Suppose there is a jump at t=2. I will present 2 possible scenarios, in both of which I think something goes wrong (although maybe diffeqsolve might correct for the issue in scenario B). I wrote them as if JSW and the controller are separate, but the same holds for just the old PID controller. ====== scenario A =======
====== scenario B =======
Another way of seeing this all is through this:
Hence setting |
So I think the Line 390 in daec89c
I made the decision to handle some of the step-rejection logic in the main So I think this fine? Do double-check my logic though! :p Other than that, one thing I am noticing is that this |
Great, that's exactly the line I was looking for (I must admit I looked in Thinking about it now, the Edit: I already implemented what I mentioned above and wrote the proof in a comment. If you're curious and have extra time (yes I know that's a very far tail event :)) you can find it on my |
c3c4dcf
to
e203b53
Compare
e203b53
to
7325e74
Compare
Hi Patrick! I just pushed a new version of this PR, rebased on top of the most current main. I think I addressed everything you asked me to fix. As it stands this contains 3 commits, contatining:
I left some conversations unresloved. I did try to fix the things mentioned in those, but I am not sure whether what I did was the best way to tackle that so I wanted to hear your opinion. Also the test are failing because pyright doesn't know how to import PS: The linear search I added slows it down compared to the way I wrote it before, but it is still faster than the old implementation with binary search. In particular the times (as obtained by
Additionally, changing the length of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! I think I really like this.
First of all, I think I'm basically happy with pretty much everything outside of jump_step_wrapper.py
. The changes here are pleasingly simple ^^
For jump_step_wrapper.py
, I think my main question is around whether the rejected-step-buffer should actually be part of this wrapper at all -- since that handles SDEs with any kind of step rejection, which I think is completely orthogonal to clipping steps? (Not sure how I didn't notice this before!) I've also commented on a few other more minor points.
By the way, what did you think of the idea of moving next_made_jump
into _integrate.py
? It doesn't have to be now -- happy for that to be a separate PR -- just checking your thoughts on whether it is a generalisable thing.
Finally: merry Christmas, and a happy new year! :D
```bibtex | ||
@misc{foster2024convergenceadaptiveapproximationsstochastic, | ||
title={On the convergence of adaptive approximations for stochastic differential equations}, | ||
author={James Foster and Andraž Jelinčič}, | ||
year={2024}, | ||
eprint={2311.14201}, | ||
archivePrefix={arXiv}, | ||
primaryClass={math.NA}, | ||
url={https://arxiv.org/abs/2311.14201}, | ||
} | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I'm really glad to be advertising this paper, it needs to be better-known.
(Nit: whilst the ```bibtex
is fine, the contents are currently all indented one step too far.)
module_meta = type(eqx.Module) | ||
|
||
|
||
class PIDMeta(module_meta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this should be private _PIDMeta
. (Also similar for module_meta
, although you could also just inline class _PIDMeta(type(eqx.Module))
)
|
||
|
||
_ControllerState = TypeVar("_ControllerState") | ||
_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want TypeVar("_Dt0", bound=Optional[RealScalarLike])
here.
-
When you have
TypeVar("T", Foo, Bar)
, then it indicates thatT
must be filled by preciselyFoo
orBar
, and in particular not a subclass. -
Meanwhile
TypeVar("T", bound=Union[Foo, Bar])
indicates that any subclass ofUnion[Foo, Bar]
is acceptable -- in particular this includes bothFoo
(which is a subclass of the union type) andBar
(which is also a subclass of the union type), but also includes any subclass ofFoo
andBar
themselves.
The reason this matters is that RealScalarLike
is itself a union (which if you haven't thought about it before is essentially an anonymous ABC), so by definition no instance can ever have that type! It means that version (1) above can essentially never be useful.
(Okay, I'm lying a little bit: ever-so-technically the static type checker could take a value with a known concrete type T
and pretend that it has type Union[T, S]
, and it actually will do so sometimes... but this is pretty fragile so I never rely on this working in practice.)
at_dtmin = at_dtmin | (prev_dt <= self.dtmin) | ||
keep_step = keep_step | at_dtmin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, does at_dtmin
need to be state? (I'm not sure it ever did.) I think we might just be able to have keep_step = keep_step | (prev_dt <= self.dtmin)
?
The `step_ts` and `jump_ts` are used to force the solver to step to certain times. | ||
They mostly act in the same way, except that when we hit an element of `jump_ts`, | ||
the controller must return `made_jump = True`, so that the diffeqsolve function | ||
knows that the vector field has a discontinuity at that point, in which case it | ||
re-evaluates it right after the jump point. In addition, the | ||
exact time of the jump will be skipped using eqxi.prevbefore and eqxi.nextafter. | ||
So now to the explanation of the two (we will use `step_ts` as an example, but the | ||
same applies to `jump_ts`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I often rewrite parts of docs after merging anyway, so feel free to ignore this for now -- but just a heads-up that this part is discussing a lot of implementation details: made_jump = True
and eqxi.{prevbefore,nextafter}
are not details familiar to most users.
i = jax.lax.while_loop(cond_up, lambda _i: _i + 1, i) | ||
i = jax.lax.while_loop(cond_down, lambda _i: _i - 1, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have both of these loops? I think we only need a linear search in one direction: to find the next element of ts
to clip to?
(And if we do need a bidirectional search, then given a hint n
it's probably more efficient to search e.g. n / n+1 / n-1 / n+2 / n-2 / ...
etc back and forth?)
|
||
# This is just a logging utility for testing purposes | ||
if self.callback_on_reject is not None: | ||
jax.debug.callback(self.callback_on_reject, keep_step, t1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might suggest making this a pure_callback
or io_callback
, so that it will definitely be called in the right order across steps. JAX doesn't actually offer guarantees about the order in which multiple debug callbacks are called.
See for example how eqx.error_if
works, which does the same thing by requiring a token.
(There is actually jax.debug.callback(..., ordered=True)
, but this works by having JAX sneakily rewriting the jaxpr to thread a dummy argument through as a token so as to order things... and I think that edge cases, so I try to avoid it.)
next_t0, RealScalarLike | ||
), f"type(next_t0) = {type(next_t0)}" | ||
else: | ||
isinstance(next_t0, get_args(RealScalarLike)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line doesn't actually do anything? It just creates a False
or True
and then does nothing. Are you missing an assert
?
if TYPE_CHECKING: # if i don't seperate this out pyright complains | ||
assert isinstance( | ||
next_t0, RealScalarLike | ||
), f"type(next_t0) = {type(next_t0)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the string here will never appear as it's inside a static-type-checking-only block.
# Let's prove that the line below is correct. Say the inner controller is | ||
# itself a JumpStepWrapper (JSW) with some inner_jump_ts. Then, given that | ||
# it propsed (next_t0, original_next_t1), there cannot be any jumps in | ||
# inner_jump_ts between next_t0 and original_next_t1. So if the next_t1 | ||
# proposed by the outer JSW is different from the original_next_t1 then | ||
# next_t1 \in (next_t0, original_next_t1) and hence there cannot be a jump | ||
# in inner_jump_ts at next_t1. So the jump_at_next_t1 only depends on | ||
# jump_at_next_t1. | ||
# On the other hand if original_next_t1 == next_t1, then we just take an | ||
# OR of the two. | ||
jump_at_next_t1 = jnp.where( | ||
next_t1 == original_next_t1, | ||
jump_at_original_next_t1, | ||
jump_at_next_t1 | jump_at_original_next_t1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think I completely believe this. Can we have the following:
- the PID controller proposes
t1
. - the inner JSW wants to clip to a jump
b < t1
. - the outer JSW wants to clip to a step (not a jump!)
a < b
?
In this case then we will have next_t1 != original_next_t1
, an jump_at_original_next_t1 == True
... but overall we want made_jump == False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+can we have a test for two tested JSW, including the above scenario? It doesn't need to be a full diffeqsolve, just directly calling adapt_step_size
and checking that we get the right output.
Hi Patrick,
I factored the
jump_ts
andstep_ts
out of thePIDController
intoJumpStepWrapper
(I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:t1-t0 <= prev_dt
(this is checked viaeqx.error_if
), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).next_dt
must be>=prev_dt
.next_dt
must be< t1-t0
.We achieve this in a very simple way here:
diffrax/diffrax/_step_size_controller/jump_step_wrapper.py
Lines 119 to 123 in 78b122a
The next step is to add a parameter
JumpStepWrapper.revisit_rejected_steps
which does what you expect. That will appear in a future commit in this same PR.