Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model building (and gradients thereof) with jax as the default backend #1912

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

phinate
Copy link
Contributor

@phinate phinate commented Jul 4, 2022

Description

Building on work by @lukasheinrich in #1676 to address #882, this PR attempts to keep going down the debugging rabbit hole to make pyhf model construction properly differentiable for all modifier types.

Right now, there is some jax-only syntax that I plan to refactor for all tensor backends if we get this working, e.g. array = array.at[index].set(value), which is the jax equivalent of array[index] = value since jax arrays are immutable.

A notable other addition is the changing of the jax concatenate implementation as referenced in #1655 to something many factors slower, but it's necessary to match numpy's behaviour. It also doesn't produce a user-facing slowdown that I experience, since the arrays involved are small and the operation is infrequent.

One roadblock right now as discussed offline with @kratsg concerns the following logic in the parameter "finalize" step:

for paramset_requirement in paramset_requirements:
# undefined: the modifier does not support configuring that property
v = paramset_requirement.get(k, 'undefined')
combined_paramset.setdefault(k, set()).add(v)

When building models with arrays in them, the v here can contain jax arrays, which are not hashable. I tried patching in a hashable construct, but this also failed -- when taking gradients, these values will become JVPTracer objects that I don't know a good way to hash (I was using jax.numpy.array_str then).

To add some optimism though: when I bypassed this logic by incorrectly using a list here, I seemed to be able to run the new test that this PR adds, which really would make model construction differentiable for all modifiers. That would mitigate the need from my standpoint to work on #1894, which was my most recent attempt in bypassing the model building logic for the sake of gradient calculations, though it may have other utility.

So: if we can get that logic re-written with collections that don't require hashable elements, then we could be good to go here! :)

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR

@phinate phinate changed the title Feat: model building (and gradients thereof) with jax as the default backend feat: model building (and gradients thereof) with jax as the default backend Jul 4, 2022
@phinate
Copy link
Contributor Author

phinate commented Jul 5, 2022

@kratsg should I wait on your suggested refactor of the above? or do you think we could hack in a short-term solution?

@kratsg
Copy link
Contributor

kratsg commented Aug 31, 2022

@phinate do you mind rebasing and fixing the conflicts in staterror? We had to fix a bug/regression there.

@phinate
Copy link
Contributor Author

phinate commented Sep 1, 2022

@kratsg I looked at the changes, but the only actual code change looked to be the removal of self.__staterror_uncrt -- i've left it in, but can you confirm if that was part of the bug?

@kratsg
Copy link
Contributor

kratsg commented Sep 1, 2022

@kratsg I looked at the changes, but the only actual code change looked to be the removal of self.__staterror_uncrt -- i've left it in, but can you confirm if that was part of the bug?

it is. remove it.

@phinate
Copy link
Contributor Author

phinate commented Sep 1, 2022

@kratsg I looked at the changes, but the only actual code change looked to be the removal of self.__staterror_uncrt -- i've left it in, but can you confirm if that was part of the bug?

it is. remove it.

Done! This bit of code was written by Lukas so I just wanted to double check.

@kratsg
Copy link
Contributor

kratsg commented Sep 1, 2022

Done! This bit of code was written by Lukas so I just wanted to double check.

It was buggy. See #1965.

@matthewfeickert matthewfeickert changed the base branch from master to main September 21, 2022 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants