Skip to content

Commit

Permalink
refactor: update definition of structseq sentinel
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 6, 2023
1 parent 05beea6 commit 1dc7b93
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import methodcaller
from threading import Lock
from typing import Any, Callable, Iterable, NamedTuple, Sequence, overload
from typing import Any, Callable, Generic, Iterable, NamedTuple, Sequence, TypeVar, overload
from typing_extensions import Final # Python 3.8+
from typing_extensions import Self # Python 3.11+

from optree import _C
Expand Down Expand Up @@ -380,15 +381,30 @@ def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T]) -> Na
return cls(*children) # type: ignore[call-overload]


class __StructSequenceSentinel(tuple): # noqa: N801
pass
_T_co = TypeVar('_T_co', covariant=True)


def _structseq_flatten(structseq: tuple[T, ...]) -> tuple[tuple[T, ...], type[tuple[T, ...]]]:
return structseq, type(structseq)
# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
class structseq(tuple, Generic[_T_co]): # noqa: N801
n_fields: Final[int] # type: ignore[misc]
n_unnamed_fields: Final[int] # type: ignore[misc]
n_sequence_fields: Final[int] # type: ignore[misc]

def __new__( # type: ignore[empty-body]
cls: type[Self],
sequence: Iterable[_T_co],
dict: dict[str, Any] = ...,
) -> Self:
...


def _structseq_flatten(seq: structseq[T]) -> tuple[tuple[T, ...], type[structseq[T]]]:
return seq, type(seq)


def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tuple[T, ...]:
def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> structseq[T]:
return cls(children)


Expand All @@ -402,7 +418,7 @@ def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tup
OrderedDict: PyTreeNodeRegistryEntry(_ordereddict_flatten, _ordereddict_unflatten), # type: ignore[arg-type]
defaultdict: PyTreeNodeRegistryEntry(_defaultdict_flatten, _defaultdict_unflatten), # type: ignore[arg-type]
deque: PyTreeNodeRegistryEntry(_deque_flatten, _deque_unflatten), # type: ignore[arg-type]
__StructSequenceSentinel: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
structseq: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
}
# pylint: enable=all

Expand All @@ -421,7 +437,7 @@ def _pytree_node_registry_get(
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(__StructSequenceSentinel)
return _NODETYPE_REGISTRY.get(structseq)
return None


Expand Down

0 comments on commit 1dc7b93

Please sign in to comment.