Skip to content

Commit

Permalink
refactor: update definition of structseq sentinel
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 6, 2023
1 parent 05beea6 commit ff23de4
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,18 @@
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import methodcaller
from threading import Lock
from typing import Any, Callable, Iterable, NamedTuple, Sequence, overload
from typing import (
Any,
Callable,
Generic,
Iterable,
NamedTuple,
NoReturn,
Sequence,
TypeVar,
overload,
)
from typing_extensions import Final # Python 3.8+
from typing_extensions import Self # Python 3.11+

from optree import _C
Expand Down Expand Up @@ -380,15 +391,33 @@ def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T]) -> Na
return cls(*children) # type: ignore[call-overload]


class __StructSequenceSentinel(tuple): # noqa: N801
pass
_T_co = TypeVar('_T_co', covariant=True)


# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
class structseq(tuple, Generic[_T_co]): # noqa: N801
n_fields: Final[int] # type: ignore[misc]
n_unnamed_fields: Final[int] # type: ignore[misc]
n_sequence_fields: Final[int] # type: ignore[misc]

def __init_subclass__(cls) -> NoReturn:
raise TypeError("type 'structseq' is not an acceptable base type")

Check warning on line 406 in optree/registry.py

View check run for this annotation

Codecov / codecov/patch

optree/registry.py#L406

Added line #L406 was not covered by tests

def __new__( # type: ignore[empty-body]
cls: type[Self],
sequence: Iterable[_T_co],
dict: dict[str, Any] = ...,
) -> Self:
...

Check warning on line 413 in optree/registry.py

View check run for this annotation

Codecov / codecov/patch

optree/registry.py#L413

Added line #L413 was not covered by tests


def _structseq_flatten(structseq: tuple[T, ...]) -> tuple[tuple[T, ...], type[tuple[T, ...]]]:
return structseq, type(structseq)
def _structseq_flatten(seq: structseq[T]) -> tuple[tuple[T, ...], type[structseq[T]]]:
return seq, type(seq)


def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tuple[T, ...]:
def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> structseq[T]:
return cls(children)


Expand All @@ -402,7 +431,7 @@ def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tup
OrderedDict: PyTreeNodeRegistryEntry(_ordereddict_flatten, _ordereddict_unflatten), # type: ignore[arg-type]
defaultdict: PyTreeNodeRegistryEntry(_defaultdict_flatten, _defaultdict_unflatten), # type: ignore[arg-type]
deque: PyTreeNodeRegistryEntry(_deque_flatten, _deque_unflatten), # type: ignore[arg-type]
__StructSequenceSentinel: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
structseq: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
}
# pylint: enable=all

Expand All @@ -421,7 +450,7 @@ def _pytree_node_registry_get(
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(__StructSequenceSentinel)
return _NODETYPE_REGISTRY.get(structseq)
return None


Expand Down

0 comments on commit ff23de4

Please sign in to comment.