Skip to content

Commit

Permalink
test: add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 1, 2024
1 parent 178e9fe commit 3f731e7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 55 deletions.
4 changes: 2 additions & 2 deletions optree/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def __new__( # type: ignore[misc]
return path_entry_type(entry, type, kind)

# The __init__() method will be called if the returned instance is a subtype of AutoEntry.
# We should return an uninitialized instance. But we will never reach this point.
raise NotImplementedError('Unreachable code.')
# We should return an uninitialized instance.
return super().__new__(path_entry_type)


class GetItemEntry(PyTreeEntry):
Expand Down
117 changes: 64 additions & 53 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from helpers import TREE_ACCESSORS, SysFloatInfoType, parametrize


def assert_equal_type_and_value(a, b):
assert type(a) == type(b)
assert a == b
def assert_equal_type_and_value(actual, expected, expected_type=None):
if expected_type is None:
expected_type = type(expected)
assert type(actual) == expected_type
assert actual == expected


def test_pytree_accessor_new():
Expand Down Expand Up @@ -309,68 +311,77 @@ class MySequence(UserList):
class MyObject:
pass

entry = optree.AutoEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is SysFloatInfoType
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.StructSequenceEntry
assert_equal_type_and_value(
optree.AutoEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM),
optree.StructSequenceEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM),
expected_type=optree.StructSequenceEntry,
)

entry = optree.AutoEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is CustomTuple
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.NamedTupleEntry
assert_equal_type_and_value(
optree.AutoEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM),
optree.NamedTupleEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM),
expected_type=optree.NamedTupleEntry,
)

entry = optree.AutoEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is CustomDataclass
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.DataclassEntry
assert_equal_type_and_value(
optree.AutoEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM),
optree.DataclassEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM),
expected_type=optree.DataclassEntry,
)

entry = optree.AutoEntry('foo', dict, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is dict
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.MappingEntry
assert_equal_type_and_value(
optree.AutoEntry('foo', dict, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', dict, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

entry = optree.AutoEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is MyMapping
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.MappingEntry
assert_equal_type_and_value(
optree.AutoEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

entry = optree.AutoEntry(0, tuple, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is tuple
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, list, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is list
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry
assert_equal_type_and_value(
optree.AutoEntry(0, list, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, list, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

entry = optree.AutoEntry(0, str, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is str
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry
assert_equal_type_and_value(
optree.AutoEntry(0, str, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, str, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

entry = optree.AutoEntry(0, bytes, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is bytes
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry
assert_equal_type_and_value(
optree.AutoEntry(0, bytes, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, bytes, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

entry = optree.AutoEntry(0, MySequence, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is MySequence
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry
assert_equal_type_and_value(
optree.AutoEntry(0, MySequence, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, MySequence, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

entry = optree.AutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is MyObject
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.FlattenedEntry
assert_equal_type_and_value(
optree.AutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
optree.FlattenedEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
expected_type=optree.FlattenedEntry,
)

class SubclassedAutoEntry(optree.AutoEntry):
pass

assert_equal_type_and_value(
SubclassedAutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
optree.PyTreeEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
expected_type=SubclassedAutoEntry,
)

0 comments on commit 3f731e7

Please sign in to comment.