Skip to content

Commit

Permalink
Merge branch 'main' into integration
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 6, 2023
2 parents 6835b51 + 8e33b68 commit e9c1e0c
Show file tree
Hide file tree
Showing 20 changed files with 398 additions and 144 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add `tree_ravel` function for JAX/NumPy/PyTorch array/tensor tree manipulation by [@XuehaiPan](https://github.com/XuehaiPan) in [#100](https://github.com/metaopt/optree/pull/100).
- Expose node kind enum for `PyTreeSpec` by [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/optree/pull/98).
- Expose function `tree_flatten_one_level` by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/optree/pull/101).
- Add tree broadcast functions `broadcast_common`, `tree_broadcast_common`, `tree_broadcast_map`, and `tree_broadcast_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#87](https://github.com/metaopt/optree/pull/87).
- Add function `tree_is_leaf` and add `is_leaf` argument to function `all_leaves` by [@XuehaiPan](https://github.com/XuehaiPan) in [#93](https://github.com/metaopt/optree/pull/93).
- Add methods `PyTreeSpec.entry` and `PyTreeSpec.child` by [@XuehaiPan](https://github.com/XuehaiPan) in [#88](https://github.com/metaopt/optree/pull/88).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Tree Manipulation Functions
broadcast_common
tree_broadcast_map
tree_broadcast_map_with_path
tree_flatten_one_level
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -65,6 +66,7 @@ Tree Manipulation Functions
.. autofunction:: broadcast_common
.. autofunction:: tree_broadcast_map
.. autofunction:: tree_broadcast_map_with_path
.. autofunction:: tree_flatten_one_level
.. autofunction:: prefix_errors

------
Expand Down
6 changes: 6 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ initializer
CPython
CPython's
sortable
OrderedDict
defaultdict
unsubclassable
sys
structseq
MyTuple
jax
numpy
torch
Expand Down
8 changes: 4 additions & 4 deletions docs/source/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Typing Support

PyTreeSpec
PyTreeDef
PyTreeKind
PyTree
PyTreeTypeVar
CustomTreeNode
Expand All @@ -24,17 +25,16 @@ Typing Support

.. autoclass:: PyTreeDef

.. autoclass:: PyTreeKind
:members:

.. autoclass:: PyTree
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: PyTreeTypeVar

.. autoclass:: CustomTreeNode
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: is_namedtuple

Expand Down
12 changes: 6 additions & 6 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ class PyTreeTypeRegistry {
// The Python type object, used to identify the type.
py::object type;
// A function with signature: object -> (iterable, metadata, entries)
py::function to_iterable;
py::function flatten_func;
// A function with signature: (metadata, iterable) -> object
py::function from_iterable;
py::function unflatten_func;
};

// Registers a new custom type. Objects of `cls` will be treated as container node types in
// PyTrees.
static void Register(const py::object &cls,
const py::function &to_iterable,
const py::function &from_iterable,
const py::function &flatten_func,
const py::function &unflatten_func,
const std::string &registry_namespace = "");

// Finds the custom type registration for `type`. Returns nullptr if none exists.
Expand All @@ -93,8 +93,8 @@ class PyTreeTypeRegistry {

template <bool NoneIsLeaf>
static void RegisterImpl(const py::object &cls,
const py::function &to_iterable,
const py::function &from_iterable,
const py::function &flatten_func,
const py::function &unflatten_func,
const std::string &registry_namespace);

class TypeHash {
Expand Down
6 changes: 4 additions & 2 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class PyTreeSpec {

[[nodiscard]] py::object GetType() const;

[[nodiscard]] PyTreeKind GetPyTreeKind() const;

bool operator==(const PyTreeSpec &other) const;
inline bool operator!=(const PyTreeSpec &other) const { return !(*this == other); }
inline bool operator<(const PyTreeSpec &other) const { return IsPrefix(other, true); }
Expand Down Expand Up @@ -196,14 +198,14 @@ class PyTreeSpec {
// For a OrderedDict, contains a list of keys.
// For a DefaultDict, contains a tuple of (default_factory, sorted list of keys).
// For a Deque, contains the `maxlen` attribute.
// For a Custom type, contains the auxiliary data returned by the `to_iterable` function.
// For a Custom type, contains the auxiliary data returned by the `flatten_func` function.
py::object node_data;

// The tuple of path entries.
// This is optional, if not specified, `range(arity)` is used.
// For a sequence, contains the index of the element.
// For a mapping, contains the key of the element.
// For a Custom type, contains the path entries returned by the `to_iterable` function.
// For a Custom type, contains the path entries returned by the `flatten_func` function.
py::object node_entries;

// Custom type registration. Must be null for non-custom types.
Expand Down
19 changes: 17 additions & 2 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# pylint: disable=all

import builtins
import enum
from collections.abc import Callable, Iterable, Sequence
from typing import Any

Expand Down Expand Up @@ -64,13 +65,27 @@ def is_structseq(obj: object | type) -> bool: ...
def is_structseq_class(cls: type) -> bool: ...
def structseq_fields(obj: builtins.tuple | type[builtins.tuple]) -> builtins.tuple[str, ...]: ...

class PyTreeKind(enum.IntEnum):
CUSTOM = 0 # a custom type
LEAF = enum.auto() # an opaque leaf node
NONE = enum.auto() # None
TUPLE = enum.auto() # a tuple
LIST = enum.auto() # a list
DICT = enum.auto() # a dict
NAMEDTUPLE = enum.auto() # a collections.namedtuple
ORDEREDDICT = enum.auto() # a collections.OrderedDict
DEFAULTDICT = enum.auto() # a collections.defaultdict
DEQUE = enum.auto() # a collections.deque
STRUCTSEQUENCE = enum.auto() # a PyStructSequence

class PyTreeSpec:
num_nodes: int
num_leaves: int
num_children: int
none_is_leaf: bool
namespace: str
type: builtins.type | None
kind: PyTreeKind
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ...
Expand Down Expand Up @@ -100,7 +115,7 @@ class PyTreeSpec:

def register_node(
cls: type[CustomTreeNode[T]],
to_iterable: FlattenFunc,
from_iterable: UnflattenFunc,
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
namespace: str,
) -> None: ...
4 changes: 4 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
tree_broadcast_map_with_path,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_one_level,
tree_flatten_with_path,
tree_is_leaf,
tree_leaves,
Expand Down Expand Up @@ -72,6 +73,7 @@
FlattenFunc,
PyTree,
PyTreeDef,
PyTreeKind,
PyTreeSpec,
PyTreeTypeVar,
UnflattenFunc,
Expand Down Expand Up @@ -116,6 +118,7 @@
'tree_min',
'tree_all',
'tree_any',
'tree_flatten_one_level',
'prefix_errors',
'treespec_is_prefix',
'treespec_is_suffix',
Expand All @@ -139,6 +142,7 @@
# Typing
'PyTreeSpec',
'PyTreeDef',
'PyTreeKind',
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
Expand Down
130 changes: 81 additions & 49 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,21 @@
import itertools
import textwrap
from collections import OrderedDict, defaultdict, deque
from typing import Any, Callable, cast, overload
from typing import Any, Callable, overload

from optree import _C
from optree.registry import (
AttributeKeyPathEntry,
FlattenedKeyPathEntry,
KeyPath,
KeyPathEntry,
PyTreeNodeRegistryEntry,
register_keypaths,
register_pytree_node,
)
from optree.typing import (
Children,
Iterable,
MetaData,
NamedTuple,
PyTree,
PyTreeSpec,
S,
Expand Down Expand Up @@ -81,6 +80,7 @@
'tree_min',
'tree_all',
'tree_any',
'tree_flatten_one_level',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_paths',
Expand Down Expand Up @@ -1686,6 +1686,80 @@ def tree_any(
)


def tree_flatten_one_level(
tree: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[
list[PyTree[T]],
MetaData,
tuple[Any, ...],
Callable[[MetaData, list[PyTree[T]]], PyTree[T]],
]:
"""Flatten the pytree one level, returning a 4-tuple of children, auxiliary data, path entries, and an unflatten function.
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`.
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level({'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5})
>>> children, metadata, entries
([1, (2, [3, 4]), None, 5], ['a', 'b', 'c', 'd'], ('a', 'b', 'c', 'd'))
>>> unflatten_func(metadata, children)
{'a': 1, 'b': (2, [3, 4]), 'c': None, 'd': 5}
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level([{'a': 1, 'b': (2, 3)}, (4, 5)])
>>> children, metadata, entries
([{'a': 1, 'b': (2, 3)}, (4, 5)], None, (0, 1))
>>> unflatten_func(metadata, children)
[{'a': 1, 'b': (2, 3)}, (4, 5)]
Args:
tree (pytree): A pytree to be traversed.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A 4-tuple ``(children, metadata, entries, unflatten_func)``. The first element is a list of
one-level children of the pytree node. The second element is the auxiliary data used to
reconstruct the pytree node. The third element is a tuple of path entries to the children.
The fourth element is a function that can be used to unflatten the auxiliary data and
children back to the pytree node.
""" # pylint: disable=line-too-long
node_type = type(tree)
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): # type: ignore[unreachable,arg-type]
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

handler: PyTreeNodeRegistryEntry | None = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
if handler:
flattened = tuple(handler.flatten_func(tree)) # type: ignore[arg-type]
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
children, metadata, entries = flattened
children = list(children) # type: ignore[arg-type]
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]

raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')


def treespec_is_prefix(
treespec: PyTreeSpec,
other_treespec: PyTreeSpec,
Expand Down Expand Up @@ -1938,48 +2012,6 @@ def treespec_tuple(
return _C.tuple(list(treespecs), none_is_leaf)


def flatten_one_level(
tree: PyTree[T],
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> tuple[Children[T], MetaData, tuple[Any, ...]]:
"""Flatten the pytree one level, returning a tuple of children, auxiliary data, and path entries."""
if tree is None:
if none_is_leaf: # type: ignore[unreachable]
raise ValueError(f'Cannot flatten leaf-type: {type(None)}.')
return [], None, ()

node_type = type(tree)
handler = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
if handler:
flattened = tuple(handler.to_iterable(tree))
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
children, metadata, entries = flattened
children = list(children)
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries

if is_namedtuple(tree):
return list(cast(NamedTuple, tree)), node_type, tuple(range(len(cast(NamedTuple, tree))))

if is_structseq(tree):
return list(cast(tuple, tree)), node_type, tuple(range(len(cast(tuple, tree))))

raise ValueError(f'Cannot flatten leaf-type: {node_type}.')


def prefix_errors(
prefix_tree: PyTree[T],
full_tree: PyTree[S],
Expand Down Expand Up @@ -2043,12 +2075,12 @@ def _prefix_error(
# Or they may disagree if their roots have different numbers of children (note that because both
# prefix_tree and full_tree have the same type at this point, and because prefix_tree is not a
# leaf, each can be flattened once):
prefix_tree_children, prefix_tree_metadata, _ = flatten_one_level(
prefix_tree_children, prefix_tree_metadata, _, __ = tree_flatten_one_level(
prefix_tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
full_tree_children, full_tree_metadata, _ = flatten_one_level(
full_tree_children, full_tree_metadata, _, __ = tree_flatten_one_level(
full_tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
Expand Down Expand Up @@ -2161,8 +2193,8 @@ def _prefix_error(
for k, t1, t2 in zip(keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(
key_path + k,
cast(PyTree[T], t1),
cast(PyTree[S], t2),
t1,
t2,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
Expand Down
Loading

0 comments on commit e9c1e0c

Please sign in to comment.