-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: model building (and gradients thereof) with jax as the default backend #1912
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…e/pyhf into make_difffable_model_ctor
@kratsg should I wait on your suggested refactor of the above? or do you think we could hack in a short-term solution? |
@phinate do you mind rebasing and fixing the conflicts in staterror? We had to fix a bug/regression there. |
@kratsg I looked at the changes, but the only actual code change looked to be the removal of |
for more information, see https://pre-commit.ci
it is. remove it. |
for more information, see https://pre-commit.ci
Done! This bit of code was written by Lukas so I just wanted to double check. |
It was buggy. See #1965. |
Description
Building on work by @lukasheinrich in #1676 to address #882, this PR attempts to keep going down the debugging rabbit hole to make pyhf model construction properly differentiable for all modifier types.
Right now, there is some jax-only syntax that I plan to refactor for all tensor backends if we get this working, e.g.
array = array.at[index].set(value)
, which is the jax equivalent ofarray[index] = value
since jax arrays are immutable.A notable other addition is the changing of the jax concatenate implementation as referenced in #1655 to something many factors slower, but it's necessary to match numpy's behaviour. It also doesn't produce a user-facing slowdown that I experience, since the arrays involved are small and the operation is infrequent.
One roadblock right now as discussed offline with @kratsg concerns the following logic in the parameter "finalize" step:
pyhf/src/pyhf/parameters/utils.py
Lines 37 to 40 in acde7f4
When building models with arrays in them, the
v
here can contain jax arrays, which are not hashable. I tried patching in a hashable construct, but this also failed -- when taking gradients, these values will becomeJVPTracer
objects that I don't know a good way to hash (I was usingjax.numpy.array_str
then).To add some optimism though: when I bypassed this logic by incorrectly using a list here, I seemed to be able to run the new test that this PR adds, which really would make model construction differentiable for all modifiers. That would mitigate the need from my standpoint to work on #1894, which was my most recent attempt in bypassing the model building logic for the sake of gradient calculations, though it may have other utility.
So: if we can get that logic re-written with collections that don't require hashable elements, then we could be good to go here! :)
Checklist Before Requesting Reviewer
Before Merging
For the PR Assignees: