Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiling #34

Open
chrisiacovella opened this issue Feb 22, 2024 · 2 comments
Open

Profiling #34

chrisiacovella opened this issue Feb 22, 2024 · 2 comments

Comments

@chrisiacovella
Copy link
Member

This issue is meant to serve as a placeholder with regards to optimization and profiling of the code and how this may be impacted by our design choices (e.g., use of deepcopy of SamplerState in places.

@chrisiacovella
Copy link
Member Author

Reproducing John Chodera's comment from the PR:

While "premature optimization is the root of all evil",
we should think ahead to the very small number of key methods we will need to ensure can be efficiently jitted to perform well on GPU hardware and pay attention to how we structure these methods.

My current understanding is that we will really want the following methods to be very fast:

  1. The inner loop of MD integration
  2. The computation of the potential energy gradient (from grad applied to the potential energy function) within this MD step loop
  3. The pairlist/displacement vector (re)generation that is also called within the MD step loop

It's worth making sure we clear any irregular computation out of the way within these loops, and potentially to build some profiling harnesses so we can monitor what kernel execution looks like in these inner loops to make sure we are keeping the GPU busy.

We can always systematically improve this later, but they key idea here is that we want to make sure our API and code structure will permit us to work hard on those three methods, since we expect the overwhelmingly large fraction of our compute time to be spent there.

I think those are the 3 parts that are going to need the most attention.

Right now, the pair computation (currently limited just to LJ as that is really the only function implemented) and neighbor/pair list functions are already coded up with the appropriate functions jitted for performance. A lot of the routines will not benefit from jitting (or in many cases can't really be jitted).

I had done some "offline" (i.e., outside of this PR) benchmarking of the langevin integrator, looking at different approaches to how we jit the routine (e.g., there are a few consistent patterns that can be coded up to minimize overhead of jit). Once this is all merged, I was going to create a few benchmarks of different approaches in an issue so we can figure out, discuss, brainstorm, etc. the most efficient implementation. Preliminarily, it seems that jax.numpy itself is already quite efficient even without jitting the integration itself, which I suspect is because the amount of "work" done to integrate is much less compared to, say computing the potential. That is any speed gains are offset by overhead and moving data from cpu to gpu.

@chrisiacovella
Copy link
Member Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant