-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Profiling #34
Comments
Reproducing John Chodera's comment from the PR:
I think those are the 3 parts that are going to need the most attention. Right now, the pair computation (currently limited just to LJ as that is really the only function implemented) and neighbor/pair list functions are already coded up with the appropriate functions jitted for performance. A lot of the routines will not benefit from jitting (or in many cases can't really be jitted). I had done some "offline" (i.e., outside of this PR) benchmarking of the langevin integrator, looking at different approaches to how we jit the routine (e.g., there are a few consistent patterns that can be coded up to minimize overhead of jit). Once this is all merged, I was going to create a few benchmarks of different approaches in an issue so we can figure out, discuss, brainstorm, etc. the most efficient implementation. Preliminarily, it seems that jax.numpy itself is already quite efficient even without jitting the integration itself, which I suspect is because the amount of "work" done to integrate is much less compared to, say computing the potential. That is any speed gains are offset by overhead and moving data from cpu to gpu. |
This issue is meant to serve as a placeholder with regards to optimization and profiling of the code and how this may be impacted by our design choices (e.g., use of deepcopy of SamplerState in places.
The text was updated successfully, but these errors were encountered: