Skip to content

Commit

Permalink
making nested test compatibility with jax==0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewghgriffiths committed May 17, 2023
1 parent 9960057 commit 32a09d4
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions test_autofit/graphical/functionality/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,31 +158,33 @@ def test_nested_items():
obj2 = [1, (2, 3), [3, {'a': 1, 'b': 2, }]]
obj3 = [1, NTuple(2, 3), [3, {'a': 1, 'b': 2, }]]

jax_flat = tree_util.tree_flatten_with_path(obj1)[0]
af_flat = utils.nested_items(obj2)

for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat):
jkey = jax_path_to_key(jpath)
assert jkey == akey
assert jval == aval
assert (
utils.nested_get(obj2, jkey)
== utils.nested_get(obj1, jkey)
== utils.nested_get(obj2, akey)
== utils.nested_get(obj1, akey)
)

jax_flat = tree_util.tree_flatten_with_path(obj2)[0]
af_flat = utils.nested_items(obj3)
for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat):
jkey = jax_path_to_key(jpath)
assert jkey == akey
assert jval == aval
assert (
utils.nested_get(obj2, jkey)
== utils.nested_get(obj1, jkey)
== utils.nested_get(obj2, akey)
== utils.nested_get(obj1, akey)
== utils.nested_get(obj3, jkey)
== utils.nested_get(obj3, akey)
)
# Need jax version > 0.4
if hasattr(tree_util, "tree_flatten_with_path"):
jax_flat = tree_util.tree_flatten_with_path(obj1)[0]
af_flat = utils.nested_items(obj2)

for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat):
jkey = jax_path_to_key(jpath)
assert jkey == akey
assert jval == aval
assert (
utils.nested_get(obj2, jkey)
== utils.nested_get(obj1, jkey)
== utils.nested_get(obj2, akey)
== utils.nested_get(obj1, akey)
)

jax_flat = tree_util.tree_flatten_with_path(obj2)[0]
af_flat = utils.nested_items(obj3)
for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat):
jkey = jax_path_to_key(jpath)
assert jkey == akey
assert jval == aval
assert (
utils.nested_get(obj2, jkey)
== utils.nested_get(obj1, jkey)
== utils.nested_get(obj2, akey)
== utils.nested_get(obj1, akey)
== utils.nested_get(obj3, jkey)
== utils.nested_get(obj3, akey)
)

0 comments on commit 32a09d4

Please sign in to comment.