Skip to content

Commit

Permalink
feat(typing): make structseq more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 6, 2023
1 parent a049d11 commit e47e7fc
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
35 changes: 24 additions & 11 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# ==============================================================================
"""Typing utilities for OpTree."""

# mypy: no-warn-unused-ignores

from __future__ import annotations

import types
from typing import (
Any,
Callable,
ClassVar,
DefaultDict,
Deque,
Dict,
Expand Down Expand Up @@ -303,24 +303,37 @@ def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
_T_co = TypeVar('_T_co', covariant=True)


class _StructSequenceMeta(type):
def __subclasscheck__(cls, subclass: type) -> bool:
"""Return whether the class is a PyStructSequence type."""
return is_structseq_class(subclass)

def __instancecheck__(cls, instance: Any) -> bool:
"""Return whether the object is a PyStructSequence instance."""
return is_structseq_class(type(instance))


# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
class structseq(tuple, Generic[_T_co]): # noqa: N801
n_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[int] # type: ignore[misc] # pylint: disable=invalid-name
class structseq(tuple, Generic[_T_co], metaclass=_StructSequenceMeta): # type: ignore[misc] # noqa: N801
"""A generic type stub for CPython's ``PyStructSequence`` type."""

n_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name

def __init_subclass__(cls) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError("type 'structseq' is not an acceptable base type")

def __new__( # type: ignore[empty-body] # pylint: disable=unused-argument
cls: type[Self],
sequence: Iterable[_T_co],
dict: dict[str, Any] = ..., # pylint: disable=redefined-builtin
) -> Self:
...
# pylint: disable-next=unused-argument,redefined-builtin
def __new__(cls: type[Self], sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
raise NotImplementedError


del _StructSequenceMeta


def is_structseq(obj: object | type) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,8 @@ def test_tree_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901
assert one_level_treespec.kind == optree.PyTreeKind.NAMEDTUPLE
elif optree.is_structseq(node):
assert optree.is_structseq_class(node_type)
assert isinstance(node, optree.typing.structseq)
assert issubclass(node_type, optree.typing.structseq)
assert metadata is node_type
assert one_level_treespec.kind == optree.PyTreeKind.STRUCTSEQUENCE
else:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def test_is_namedtuple():


def test_is_structseq():
with pytest.raises(TypeError, match="type 'structseq' is not an acceptable base type"):

class MyTuple(optree.typing.structseq):
pass

with pytest.raises(NotImplementedError):
optree.typing.structseq(range(1))

assert not optree.is_structseq((1, 2))
assert not optree.is_structseq([1, 2])
assert optree.is_structseq(sys.float_info)
Expand Down

0 comments on commit e47e7fc

Please sign in to comment.