-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/nested #709
Features/nested #709
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably worth writing some unit tests for the nested_* functions checking both standard behaviour and some edge cases.
jac[v0][i][indexes] = jac_v0v_i | ||
key_path = (v0, *ks, indexes) | ||
nested_set(jac, key_path, jac_v0v_i) | ||
# jac[v0][i][indexes] = jac_v0v_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented code
def test_nested_factor(): | ||
def func(a, b): | ||
a0 = a[0] | ||
c = a[1]['c'] | ||
return a0 * c * b | ||
|
||
a, b, c = graph.variables("a, b, c") | ||
|
||
f = func((1, {'c': 2}), 3) | ||
values = {a: 1., b: 3., c: 2.} | ||
|
||
factor = graph.Factor(func, [a, {'c': c}], b) | ||
|
||
assert factor(values) == pytest.approx(f) | ||
|
||
fval, grad = factor.func_gradient(values) | ||
|
||
assert fval == pytest.approx(f) | ||
assert grad[a] == pytest.approx(6) | ||
assert grad[b] == pytest.approx(2) | ||
assert grad[c] == pytest.approx(3) | ||
|
||
|
||
factor = graph.Factor(func, (a, {'c': c}), b, vjp=True) | ||
|
||
assert factor(values) == pytest.approx(f) | ||
|
||
fval, grad = factor.func_gradient(values) | ||
|
||
assert fval == pytest.approx(f) | ||
assert grad[a] == pytest.approx(6) | ||
assert grad[b] == pytest.approx(2) | ||
assert grad[c] == pytest.approx(3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd break this down into two tests
I've actually added PyTree support for autofit objects on the feature/jax branch. There's some other fairly major changes so we won't be merging for a while. |
If we're merging functionality with jax, we might want to make the |
I guess at the moment JAX isn't a requirement whereas it would be if we used that functionality |
Yeah I imagine that we'd want a backstop implementation that doesn't rely on jax but matches it's behaviour (might be a good way of testing the functionality as well) |
autofit/graphical/utils.py
Outdated
""" | ||
out, *_ = args | ||
if isinstance(out, dict): | ||
for k in out: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to match jax
functionality (and to provide stable ordering) may want from k in sorted(out)
@rhayes777
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense
Codecov Report
@@ Coverage Diff @@
## main #709 +/- ##
==========================================
+ Coverage 81.91% 82.00% +0.08%
==========================================
Files 180 180
Lines 13219 13264 +45
==========================================
+ Hits 10829 10877 +48
+ Misses 2390 2387 -3 see 6 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Have updated library and added more tests. @rhayes777 |
Updating
Factor
api to allow nested arguments to factorSee example,
If we end up merging functionality with
jax.PyTree
this will allow us to pass objects to functions as well. I think the jaxtyping library could also allow closer integration with the declarative PyAutoFit style with jax arrays.