Skip to content

Commit

Permalink
test: update tests for braodcast
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 7, 2023
1 parent dd5d9b0 commit 23a82f1
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 34 deletions.
288 changes: 256 additions & 32 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,24 @@ def test_tree_map_with_path():
optree.tree_map_with_path(lambda *xs: tuple(xs), x, y)


def test_tree_broadcast_map():
x = ((1, 2, None), (3, (4, [5]), 6))
y = (([7], None, None), ({'foo': 'bar'}, 9, [10, 11]))
out = optree.tree_broadcast_map(lambda *xs: tuple(xs), x, y)
assert out == (
([(1, 7)], None, None),
({'foo': (3, 'bar')}, ((4, 9), [(5, 9)]), [(6, 10), (6, 11)]),
)

x = ((1, 2, None), (3, (4, [5]), 6))
y = (([7], None, 8), ({'foo': 'bar'}, 9, [10, 11]))
out = optree.tree_broadcast_map(lambda *xs: tuple(xs), x, y)
assert out == (
([(1, 7)], None, None),
({'foo': (3, 'bar')}, ((4, 9), [(5, 9)]), [(6, 10), (6, 11)]),
)


def test_tree_map_none_is_leaf():
x = ((1, 2, None), (3, 4, 5))
y = (([6], None, None), ({'foo': 'bar'}, 7, [8, 9]))
Expand Down Expand Up @@ -566,6 +584,24 @@ def test_tree_map_with_path_none_is_leaf():
)


def test_tree_broadcast_map_none_is_leaf():
x = ((1, 2, None), (3, (4, [5]), 6))
y = (([7], None, None), ({'foo': 'bar'}, 9, [10, 11]))
out = optree.tree_broadcast_map(lambda *xs: tuple(xs), x, y, none_is_leaf=True)
assert out == (
([(1, 7)], (2, None), (None, None)),
({'foo': (3, 'bar')}, ((4, 9), [(5, 9)]), [(6, 10), (6, 11)]),
)

x = ((1, 2, None), (3, (4, [5]), 6))
y = (([7], None, 8), ({'foo': 'bar'}, 9, [10, 11]))
out = optree.tree_broadcast_map(lambda *xs: tuple(xs), x, y, none_is_leaf=True)
assert out == (
([(1, 7)], (2, None), (None, 8)),
({'foo': (3, 'bar')}, ((4, 9), [(5, 9)]), [(6, 10), (6, 11)]),
)


def test_tree_map_key_order():
tree = {'b': 2, 'a': 1, 'c': 3, 'd': None, 'e': 4}
leaves = []
Expand Down Expand Up @@ -900,6 +936,16 @@ def add_leaves(p, x):
assert leaves == [1, 2, 3, None, 4]


def test_tree_replace_nones():
sentinel = object()
assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == {
'a': 1,
'b': sentinel,
'c': (2, sentinel),
}
assert optree.tree_replace_nones(sentinel, None) == sentinel


@parametrize(
tree=TREES,
none_is_leaf=[False, True],
Expand Down Expand Up @@ -1025,57 +1071,235 @@ class MyExtraDict(MyAnotherDict):


def test_tree_broadcast_prefix():
assert optree.tree_broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
assert optree.tree_broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1]
assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3]
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'),
):
optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, (3, 3)]
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [
1,
2,
{'a': 3, 'b': 3, 'c': (None, 3)},
]
optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, (3, 3)]
assert optree.tree_broadcast_prefix(
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
) == [1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}]
assert optree.tree_broadcast_prefix(
[1, 2, 3],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
none_is_leaf=True,
) == [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]
assert optree.tree_broadcast_prefix(
[1, OrderedDict(b=3, c=4, a=2)],
[(5, 6), {'c': (None, 9), 'a': 7, 'b': 8}],
) == [(1, 1), OrderedDict(b=3, c=(None, 4), a=2)]
assert optree.tree_broadcast_prefix(
[1, OrderedDict(b=3, c=4, a=2)],
[(5, 6), {'c': (None, 9), 'a': 7, 'b': 8}],
none_is_leaf=True,
) == [(1, 1), OrderedDict(b=3, c=(4, 4), a=2)]
assert optree.tree_broadcast_prefix(
[1, {'c': 4, 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=8, c=(None, 9), a=7)],
) == [(1, 1), {'c': (None, 4), 'b': 3, 'a': 2}]
assert optree.tree_broadcast_prefix(
[1, {'c': 4, 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=8, c=(None, 9), a=7)],
none_is_leaf=True,
) == [(1, 1), {'c': (4, 4), 'b': 3, 'a': 2}]


def test_broadcast_prefix():
assert optree.broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
assert optree.broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1]
assert optree.broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3]
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'),
):
optree.broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, 3, 3]
assert optree.broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [
1,
2,
3,
3,
3,
]
optree.broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
assert optree.broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, 3, 3]
assert optree.broadcast_prefix(
[1, 2, 3],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
) == [1, 2, 3, 3, 3]
assert optree.broadcast_prefix(
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
none_is_leaf=True,
) == [1, 2, 3, 3, 3, 3]
assert optree.broadcast_prefix(
[1, OrderedDict(b=3, c=4, a=2)],
[(5, 6), {'c': (None, 9), 'a': 7, 'b': 8}],
) == [1, 1, 3, 4, 2]
assert optree.broadcast_prefix(
[1, OrderedDict(b=3, c=4, a=2)],
[(5, 6), {'c': (None, 9), 'a': 7, 'b': 8}],
none_is_leaf=True,
) == [1, 1, 3, 4, 4, 2]
assert optree.broadcast_prefix(
[1, {'c': 4, 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=8, c=(None, 9), a=7)],
) == [1, 1, 2, 3, 4]
assert optree.broadcast_prefix(
[1, {'c': 4, 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=8, c=(None, 9), a=7)],
none_is_leaf=True,
) == [1, 1, 2, 3, 4, 4]


def test_tree_replace_nones():
sentinel = object()
assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == {
'a': 1,
'b': sentinel,
'c': (2, sentinel),
}
assert optree.tree_replace_nones(sentinel, None) == sentinel
def test_tree_broadcast_common():
assert optree.tree_broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4])
assert optree.tree_broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6])
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4.'),
):
optree.tree_broadcast_common([1, 2, 3], [1, 2, 3, 4])
assert optree.tree_broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == (
[1, 2, (3, 3)],
[4, 5, (6, 7)],
)
assert optree.tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == (
[1, (2, 3), (4, 4)],
[5, (6, 6), (7, 8)],
)
assert optree.tree_broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, (2, 3), {'a': 4, 'b': 4, 'c': (None, 4)}],
[5, (6, 6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (None, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
) == (
[(1, 1), OrderedDict(b=4, c=(None, 5), a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
none_is_leaf=True,
) == (
[(1, 1), OrderedDict(b=4, c=(5, 5), a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}],
)
assert optree.tree_broadcast_common(
[1, {'c': (None, 4), 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
) == (
[(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}],
[(5, 6), OrderedDict(b=9, c=(None, 0), a=(7, 8))],
)
assert optree.tree_broadcast_common(
[1, {'b': 3, 'a': 2, 'c': (None, 4)}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
none_is_leaf=True,
) == (
[(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}],
[(5, 6), OrderedDict(b=9, c=(0, 0), a=(7, 8))],
)


def test_broadcast_common():
assert optree.broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4])
assert optree.broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6])
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4.'),
):
optree.broadcast_common([1, 2, 3], [1, 2, 3, 4])
assert optree.broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == (
[1, 2, 3, 3],
[4, 5, 6, 7],
)
assert optree.broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == (
[1, 2, 3, 4, 4],
[5, 6, 6, 7, 8],
)
assert optree.broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, 2, 3, 4, 4, 4],
[5, 6, 6, 7, 8, 9],
)
assert optree.broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, 2, 3, 4, 4, 4, 4],
[5, 6, 6, 7, 8, None, 9],
)
assert optree.broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, 2, 3, 4, 4, 4],
[5, 6, 6, 7, 8, 9],
)
assert optree.broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, 2, 3, 4, 4, 4, 4],
[5, 6, 6, 7, 8, None, 9],
)
assert optree.broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
) == (
[1, 1, 4, 5, 2, 3],
[6, 7, 9, 0, 8, 8],
)
assert optree.broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
none_is_leaf=True,
) == (
[1, 1, 4, 5, 5, 2, 3],
[6, 7, 9, None, 0, 8, 8],
)
assert optree.broadcast_common(
[1, {'c': (None, 4), 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
) == (
[1, 1, 2, 2, 3, 4],
[5, 6, 7, 8, 9, 0],
)
assert optree.broadcast_common(
[1, {'b': 3, 'a': 2, 'c': (None, 4)}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
none_is_leaf=True,
) == (
[1, 1, 2, 2, 3, None, 4],
[5, 6, 7, 8, 9, 0, 0],
)


def test_tree_reduce():
Expand Down
6 changes: 4 additions & 2 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_treespec_compose_children(tree, inner_tree, none_is_leaf, namespace):
inner_treespec.num_nodes * treespec.num_leaves
)
assert composed_treespec.num_nodes == expected_nodes
leaves = [1] * expected_leaves
leaves = list(range(expected_leaves))
composed = optree.tree_unflatten(composed_treespec, leaves)
assert leaves == optree.tree_leaves(
composed,
Expand Down Expand Up @@ -569,7 +569,9 @@ def test_treespec_num_nodes(tree, none_is_leaf, namespace):
while stack:
spec = stack.pop()
nodes.append(spec)
stack.extend(reversed(spec.children()))
children = spec.children()
stack.extend(reversed(children))
assert spec.num_nodes == sum(child.num_nodes for child in children) + 1
assert treespec.num_nodes == len(nodes)


Expand Down

0 comments on commit 23a82f1

Please sign in to comment.