Skip to content

Commit

Permalink
test: update tests for braodcast
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 6, 2023
1 parent e17276f commit 438ef1b
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_prefix_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,20 @@ def test_different_types():
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = {'a': 1, 'b': 2}, defaultdict(int, {'a': 1, 'b': [2, 3]})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)


Expand Down Expand Up @@ -266,30 +274,50 @@ def test_different_metadata():
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore key ordering
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = defaultdict(list, {'a': 1, 'b': 2}), defaultdict(set, {'b': [4, 5], 'a': 3})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore default factory
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = {'a': 1, 'b': 2}, defaultdict(list, {'b': [4, 5], 'a': 3})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = OrderedDict({'b': 5, 'a': 4}), {'a': 1, 'b': [2, 3]}
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = deque([1, 2], maxlen=None), deque([3, [4, 5]], maxlen=3)
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore maxlen
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = FlatCache([None, 1]), FlatCache(1)
Expand Down Expand Up @@ -432,6 +460,10 @@ def test_no_errors():
optree.tree_map_(lambda x, y: None, lhs, rhs)
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)


Expand Down

0 comments on commit 438ef1b

Please sign in to comment.