Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/nested #709

Merged
merged 9 commits into from
May 24, 2023
Merged

Features/nested #709

merged 9 commits into from
May 24, 2023

Conversation

matthewghgriffiths
Copy link
Collaborator

Updating Factor api to allow nested arguments to factor

See example,

>>> import autofit.graphical as graph 
>>> def func(a, b):
...     a0 = a[0]
...     c = a[1]['c']
...     return a0 * c * b
... 
>>> a, b, c = graph.variables("a, b, c")
>>> func((1, {'c': 2}), 3)
6
>>> factor = graph.Factor(func, [a, {'c': c}], b)
>>> factor({a: 1, b: 3, c:2})
FactorValue(6, {})
>>> factor.numerical_func_jacobian({a: 1, b: 3, c:2})
(FactorValue(6., {}), JacobianVectorProduct((a, c, b) → ∂(FactorValue)ᵀ (a, c, b)))
>>> factor.func_gradient({a: 1, b: 3, c:2})
(FactorValue(6., {}), VariableData({<class 'autofit.mapper.variable.FactorValue'>: 1.0, Variable(a): 5.999999963535174, Variable(c): 2.999999892949745, Variable(b): 1.999999899027216}))

If we end up merging functionality with jax.PyTree this will allow us to pass objects to functions as well. I think the jaxtyping library could also allow closer integration with the declarative PyAutoFit style with jax arrays.

Copy link
Owner

@rhayes777 rhayes777 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth writing some unit tests for the nested_* functions checking both standard behaviour and some edge cases.

jac[v0][i][indexes] = jac_v0v_i
key_path = (v0, *ks, indexes)
nested_set(jac, key_path, jac_v0v_i)
# jac[v0][i][indexes] = jac_v0v_i
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code

Comment on lines 80 to 113
def test_nested_factor():
def func(a, b):
a0 = a[0]
c = a[1]['c']
return a0 * c * b

a, b, c = graph.variables("a, b, c")

f = func((1, {'c': 2}), 3)
values = {a: 1., b: 3., c: 2.}

factor = graph.Factor(func, [a, {'c': c}], b)

assert factor(values) == pytest.approx(f)

fval, grad = factor.func_gradient(values)

assert fval == pytest.approx(f)
assert grad[a] == pytest.approx(6)
assert grad[b] == pytest.approx(2)
assert grad[c] == pytest.approx(3)


factor = graph.Factor(func, (a, {'c': c}), b, vjp=True)

assert factor(values) == pytest.approx(f)

fval, grad = factor.func_gradient(values)

assert fval == pytest.approx(f)
assert grad[a] == pytest.approx(6)
assert grad[b] == pytest.approx(2)
assert grad[c] == pytest.approx(3)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd break this down into two tests

@rhayes777
Copy link
Owner

I've actually added PyTree support for autofit objects on the feature/jax branch. There's some other fairly major changes so we won't be merging for a while.

@matthewghgriffiths
Copy link
Collaborator Author

If we're merging functionality with jax, we might want to make the nested_* functionality align with the functionality in jax.tree_utils ?

@rhayes777
Copy link
Owner

I guess at the moment JAX isn't a requirement whereas it would be if we used that functionality

@matthewghgriffiths
Copy link
Collaborator Author

Yeah I imagine that we'd want a backstop implementation that doesn't rely on jax but matches it's behaviour (might be a good way of testing the functionality as well)

"""
out, *_ = args
if isinstance(out, dict):
for k in out:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to match jax functionality (and to provide stable ordering) may want from k in sorted(out) @rhayes777

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense

@codecov
Copy link

codecov bot commented May 18, 2023

Codecov Report

Merging #709 (32a09d4) into main (dd0a2ec) will increase coverage by 0.08%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##             main     #709      +/-   ##
==========================================
+ Coverage   81.91%   82.00%   +0.08%     
==========================================
  Files         180      180              
  Lines       13219    13264      +45     
==========================================
+ Hits        10829    10877      +48     
+ Misses       2390     2387       -3     

see 6 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@matthewghgriffiths
Copy link
Collaborator Author

Have updated library and added more tests. @rhayes777

@matthewghgriffiths matthewghgriffiths merged commit 560ecca into main May 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants