Skip to content

Commit

Permalink
test: add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 1, 2024
1 parent 178e9fe commit 9fa7fbf
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 71 deletions.
4 changes: 2 additions & 2 deletions optree/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def __new__( # type: ignore[misc]
return path_entry_type(entry, type, kind)

# The __init__() method will be called if the returned instance is a subtype of AutoEntry.
# We should return an uninitialized instance. But we will never reach this point.
raise NotImplementedError('Unreachable code.')
# We should return an uninitialized instance.
return super().__new__(path_entry_type)

Check warning on line 190 in optree/accessor.py

View check run for this annotation

Codecov / codecov/patch

optree/accessor.py#L190

Added line #L190 was not covered by tests


class GetItemEntry(PyTreeEntry):
Expand Down
167 changes: 98 additions & 69 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dataclasses
import itertools
import re
from collections import UserDict, UserList
from collections import OrderedDict, UserDict, UserList, defaultdict, deque
from typing import Any, NamedTuple

import pytest
Expand All @@ -27,9 +27,11 @@
from helpers import TREE_ACCESSORS, SysFloatInfoType, parametrize


def assert_equal_type_and_value(a, b):
assert type(a) == type(b)
assert a == b
def assert_equal_type_and_value(actual, expected, expected_type=None):
if expected_type is None:
expected_type = type(expected)
assert type(actual) == expected_type
assert actual == expected


def test_pytree_accessor_new():
Expand Down Expand Up @@ -309,68 +311,95 @@ class MySequence(UserList):
class MyObject:
pass

entry = optree.AutoEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is SysFloatInfoType
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.StructSequenceEntry

entry = optree.AutoEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is CustomTuple
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.NamedTupleEntry

entry = optree.AutoEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is CustomDataclass
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.DataclassEntry

entry = optree.AutoEntry('foo', dict, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is dict
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.MappingEntry

entry = optree.AutoEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM)
assert entry.entry == 'foo'
assert entry.type is MyMapping
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.MappingEntry

entry = optree.AutoEntry(0, tuple, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is tuple
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, list, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is list
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, str, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is str
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, bytes, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is bytes
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, MySequence, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is MySequence
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.SequenceEntry

entry = optree.AutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM)
assert entry.entry == 0
assert entry.type is MyObject
assert entry.kind == optree.PyTreeKind.CUSTOM
assert type(entry) is optree.FlattenedEntry
assert_equal_type_and_value(
optree.AutoEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM),
optree.StructSequenceEntry(0, SysFloatInfoType, optree.PyTreeKind.CUSTOM),
expected_type=optree.StructSequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM),
optree.NamedTupleEntry(0, CustomTuple, optree.PyTreeKind.CUSTOM),
expected_type=optree.NamedTupleEntry,
)

assert_equal_type_and_value(
optree.AutoEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM),
optree.DataclassEntry('foo', CustomDataclass, optree.PyTreeKind.CUSTOM),
expected_type=optree.DataclassEntry,
)

assert_equal_type_and_value(
optree.AutoEntry('foo', dict, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', dict, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

assert_equal_type_and_value(
optree.AutoEntry('foo', OrderedDict, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', OrderedDict, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

assert_equal_type_and_value(
optree.AutoEntry('foo', defaultdict, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', defaultdict, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

assert_equal_type_and_value(
optree.AutoEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM),
optree.MappingEntry('foo', MyMapping, optree.PyTreeKind.CUSTOM),
expected_type=optree.MappingEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, tuple, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, tuple, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, list, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, list, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, deque, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, deque, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, str, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, str, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, bytes, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, bytes, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, MySequence, optree.PyTreeKind.CUSTOM),
optree.SequenceEntry(0, MySequence, optree.PyTreeKind.CUSTOM),
expected_type=optree.SequenceEntry,
)

assert_equal_type_and_value(
optree.AutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
optree.FlattenedEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
expected_type=optree.FlattenedEntry,
)

class SubclassedAutoEntry(optree.AutoEntry):
pass

assert_equal_type_and_value(
SubclassedAutoEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
optree.PyTreeEntry(0, MyObject, optree.PyTreeKind.CUSTOM),
expected_type=SubclassedAutoEntry,
)

0 comments on commit 9fa7fbf

Please sign in to comment.