Skip to content

Commit

Permalink
refactor: update definition of structseq sentinel
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 6, 2023
1 parent 05beea6 commit a049d11
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ CPython's
sortable
OrderedDict
defaultdict
unsubclassable
15 changes: 6 additions & 9 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
UnflattenFunc,
is_namedtuple_class,
is_structseq_class,
structseq,
)
from optree.utils import safe_zip, total_order_sorted, unzip2

Expand Down Expand Up @@ -380,15 +381,11 @@ def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T]) -> Na
return cls(*children) # type: ignore[call-overload]


class __StructSequenceSentinel(tuple): # noqa: N801
pass
def _structseq_flatten(seq: structseq[T]) -> tuple[tuple[T, ...], type[structseq[T]]]:
return seq, type(seq)


def _structseq_flatten(structseq: tuple[T, ...]) -> tuple[tuple[T, ...], type[tuple[T, ...]]]:
return structseq, type(structseq)


def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tuple[T, ...]:
def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> structseq[T]:
return cls(children)


Expand All @@ -402,7 +399,7 @@ def _structseq_unflatten(cls: type[tuple[T, ...]], children: Iterable[T]) -> tup
OrderedDict: PyTreeNodeRegistryEntry(_ordereddict_flatten, _ordereddict_unflatten), # type: ignore[arg-type]
defaultdict: PyTreeNodeRegistryEntry(_defaultdict_flatten, _defaultdict_unflatten), # type: ignore[arg-type]
deque: PyTreeNodeRegistryEntry(_deque_flatten, _deque_unflatten), # type: ignore[arg-type]
__StructSequenceSentinel: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
structseq: PyTreeNodeRegistryEntry(_structseq_flatten, _structseq_unflatten), # type: ignore[arg-type]
}
# pylint: enable=all

Expand All @@ -421,7 +418,7 @@ def _pytree_node_registry_get(
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(__StructSequenceSentinel)
return _NODETYPE_REGISTRY.get(structseq)
return None


Expand Down
26 changes: 25 additions & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
Union,
)
from typing_extensions import OrderedDict # Generic OrderedDict: Python 3.7.2+
from typing_extensions import Self # Python 3.11+
from typing_extensions import TypeAlias # Python 3.10+
from typing_extensions import Protocol, runtime_checkable # Python 3.8+
from typing_extensions import Final, Protocol, runtime_checkable # Python 3.8+

from optree import _C
from optree._C import PyTreeKind, PyTreeSpec
Expand Down Expand Up @@ -299,6 +300,29 @@ def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
return cls._fields # type: ignore[attr-defined]


_T_co = TypeVar('_T_co', covariant=True)


# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
class structseq(tuple, Generic[_T_co]): # noqa: N801
n_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name

def __init_subclass__(cls) -> NoReturn:
raise TypeError("type 'structseq' is not an acceptable base type")

def __new__( # type: ignore[empty-body] # pylint: disable=unused-argument
cls: type[Self],
sequence: Iterable[_T_co],
dict: dict[str, Any] = ..., # pylint: disable=redefined-builtin
) -> Self:
...


def is_structseq(obj: object | type) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
Expand Down

0 comments on commit a049d11

Please sign in to comment.