You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My last-minute PR #3634 refactored the ODE API towards more explicit shapes. Unfortunately (and not detected by my benchmark) it also slowed down NUTS by about 50x compared to the original implementation by Demetri.
When printing the parameters of every forward simulation, I made three interesting observations:
in the refactored implementation, NUTS has a larger tree size & depth --> more forward simulations
it looks like no parameter set is simulated twice --> sensitivity caching works fine(?)
the parameters are printed with 8 digits before & 16 digits after the refactor
So my hypothesis would be that by consistently using floatX, NUTS does more accurate leapfrogs, leading to more gradient evaluations and higher runtime.
Description of your problem
My last-minute PR #3634 refactored the ODE API towards more explicit shapes. Unfortunately (and not detected by my benchmark) it also slowed down NUTS by about 50x compared to the original implementation by Demetri.
For reference, the approximate seconds/it/chain:
Also see the discussion over at #3634.
Versions and main components
The text was updated successfully, but these errors were encountered: