Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add methods PyTreeSpec.is_prefix and PyTreeSpec.is_suffix and function tree_broadcast_prefix #40

Merged
merged 7 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add methods `PyTreeSpec.is_prefix` and `PyTreeSpec.is_suffix` and function `tree_broadcast_prefix` by [@XuehaiPan](https://github.com/XuehaiPan) in [#40](https://github.com/metaopt/optree/pull/40).
- Add tree reduce functions `tree_sum`, `tree_max`, and `tree_min` by [@XuehaiPan](https://github.com/XuehaiPan) in [#39](https://github.com/metaopt/optree/pull/39).
- Test dict key equality with `PyDict_Contains` ($O (n)$) rather than sorting ($O (n \log n)$) by [@XuehaiPan](https://github.com/XuehaiPan) in [#37](https://github.com/metaopt/optree/pull/37).
- Make error message more clear when value mismatch by [@XuehaiPan](https://github.com/XuehaiPan) in [#36](https://github.com/metaopt/optree/pull/36).
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Tree Manipulation Functions
tree_map_with_path
tree_map_with_path_
tree_transpose
tree_broadcast_prefix
broadcast_prefix
tree_replace_nones
prefix_errors
Expand All @@ -51,6 +52,7 @@ Tree Manipulation Functions
.. autofunction:: tree_map_with_path
.. autofunction:: tree_map_with_path_
.. autofunction:: tree_transpose
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_replace_nones
.. autofunction:: prefix_errors
Expand Down Expand Up @@ -84,13 +86,17 @@ PyTreeSpec Functions

.. autosummary::

treespec_is_prefix
treespec_is_suffix
treespec_children
treespec_is_leaf
treespec_is_strict_leaf
treespec_leaf
treespec_none
treespec_tuple

.. autofunction:: treespec_is_prefix
.. autofunction:: treespec_is_suffix
.. autofunction:: treespec_children
.. autofunction:: treespec_is_leaf
.. autofunction:: treespec_is_strict_leaf
Expand Down
46 changes: 29 additions & 17 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ class PyTreeSpec {
// Composes two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Compose(const PyTreeSpec &inner_treespec) const;

// Maps a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node(children, node_data) to each container node.
[[nodiscard]] py::object Walk(const py::function &f_node,
const py::handle &f_leaf,
const py::iterable &leaves) const;

// Returns true if this PyTreeSpec is a prefix of `other`.
[[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const;

// Returns true if this PyTreeSpec is a suffix of `other`.
[[nodiscard]] inline bool IsSuffix(const PyTreeSpec &other, const bool &strict = false) const {
return other.IsPrefix(*this, strict);
}

[[nodiscard]] std::vector<std::unique_ptr<PyTreeSpec>> Children() const;

[[nodiscard]] bool IsLeaf(const bool &strict = true) const;

// Makes a Tuple PyTreeSpec out of a vector of PyTreeSpecs.
static std::unique_ptr<PyTreeSpec> Tuple(const std::vector<PyTreeSpec> &treespecs,
const bool &none_is_leaf);
Expand All @@ -114,30 +132,24 @@ class PyTreeSpec {
// Makes a PyTreeSpec representing a `None` node.
static std::unique_ptr<PyTreeSpec> None(const bool &none_is_leaf);

[[nodiscard]] std::vector<std::unique_ptr<PyTreeSpec>> Children() const;

// Maps a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node(children, node_data) to each container node.
[[nodiscard]] py::object Walk(const py::function &f_node,
const py::handle &f_leaf,
const py::iterable &leaves) const;

[[nodiscard]] ssize_t num_leaves() const;

[[nodiscard]] ssize_t num_nodes() const;
[[nodiscard]] ssize_t GetNumLeaves() const;

[[nodiscard]] ssize_t num_children() const;
[[nodiscard]] ssize_t GetNumNodes() const;

[[nodiscard]] bool get_none_is_leaf() const;
[[nodiscard]] ssize_t GetNumChildren() const;

[[nodiscard]] std::string get_namespace() const;
[[nodiscard]] bool GetNoneIsLeaf() const;

[[nodiscard]] py::object get_type() const;
[[nodiscard]] std::string GetNamespace() const;

[[nodiscard]] bool is_leaf(const bool &strict = true) const;
[[nodiscard]] py::object GetType() const;

bool operator==(const PyTreeSpec &other) const;
bool operator!=(const PyTreeSpec &other) const;
inline bool operator!=(const PyTreeSpec &other) const { return !(*this == other); }
inline bool operator<(const PyTreeSpec &other) const { return IsPrefix(other, true); }
inline bool operator<=(const PyTreeSpec &other) const { return IsPrefix(other, false); }
inline bool operator>(const PyTreeSpec &other) const { return IsSuffix(other, true); }
inline bool operator>=(const PyTreeSpec &other) const { return IsSuffix(other, false); }

template <typename H>
friend H AbslHashValue(H h, const Node &n) {
Expand Down
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,5 +464,5 @@ inline std::pair<py::list, py::list> DictKeysDifference(const py::list& /* uniqu
py::list extra_keys{got_keys - expected_keys};
TotalOrderSort(missing_keys);
TotalOrderSort(extra_keys);
return {missing_keys, extra_keys};
return std::make_pair(std::move(missing_keys), std::move(extra_keys));
}
14 changes: 14 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,27 @@ class PyTreeSpec:
f_leaf: Callable[[T], U] | None,
leaves: Iterable[T],
) -> U: ...
def is_prefix(
self,
other: PyTreeSpec,
strict: bool = ..., # False
) -> bool: ...
def is_suffix(
self,
other: PyTreeSpec,
strict: bool = ..., # False
) -> bool: ...
def children(self) -> list[PyTreeSpec]: ...
def is_leaf(
self,
strict: bool = ..., # True
) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __le__(self, other: object) -> bool: ...
def __lt__(self, other: object) -> bool: ...
def __ge__(self, other: object) -> bool: ...
def __gt__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __len__(self) -> int: ...

Expand Down
6 changes: 6 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
prefix_errors,
tree_all,
tree_any,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_with_path,
tree_leaves,
Expand All @@ -42,7 +43,9 @@
tree_unflatten,
treespec_children,
treespec_is_leaf,
treespec_is_prefix,
treespec_is_strict_leaf,
treespec_is_suffix,
treespec_leaf,
treespec_none,
treespec_tuple,
Expand Down Expand Up @@ -89,6 +92,7 @@
'tree_map_with_path',
'tree_map_with_path_',
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_replace_nones',
'tree_reduce',
Expand All @@ -98,6 +102,8 @@
'tree_all',
'tree_any',
'prefix_errors',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_children',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down
103 changes: 103 additions & 0 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'tree_map_with_path',
'tree_map_with_path_',
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_replace_nones',
'tree_reduce',
Expand All @@ -73,6 +74,8 @@
'tree_min',
'tree_all',
'tree_any',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_children',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down Expand Up @@ -691,6 +694,80 @@ def tree_transpose(
return inner_treespec.unflatten(subtrees) # type: ignore[arg-type]


def tree_broadcast_prefix(
prefix_tree: PyTree[T],
full_tree: PyTree[S],
is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTree[T]: # PyTree[PyTree[T]]
"""Return a pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``.

See also :func:`broadcast_prefix` and :func:`treespec_is_prefix`.

If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.

This function returns a pytree with the same size as ``full_tree``. The leaves are replicated
from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in
``full_tree``.

>>> tree_broadcast_prefix(1, [1, 2, 3])
[1, 1, 1]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3])
[1, 2, 3]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
Traceback (most recent call last):
...
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
[1, 2, (3, 3)]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
[1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True)
[1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]

Args:
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)

Returns:
A pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``.
"""

def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]:
subtreespec = tree_structure(
subtree,
is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return subtreespec.unflatten([x] * subtreespec.num_leaves)

# If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError;
# use prefix_errors to find disagreements and raise more precise error messages.
# prefix_errors = prefix_errors(prefix_tree, full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
return tree_map(
broadcast_leaves, # type: ignore[arg-type]
prefix_tree,
full_tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)


def broadcast_prefix(
prefix_tree: PyTree[T],
full_tree: PyTree[S],
Expand All @@ -701,6 +778,8 @@ def broadcast_prefix(
) -> list[T]:
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.

See also :func:`tree_broadcast_prefix` and :func:`treespec_is_prefix`.

If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.

Expand Down Expand Up @@ -1182,6 +1261,30 @@ def tree_any(
return any(tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)) # type: ignore[arg-type]


def treespec_is_prefix(
treespec: PyTreeSpec,
other_treespec: PyTreeSpec,
strict: bool = False,
) -> bool:
"""Return whether ``treespec`` is a prefix of ``other_treespec``.

See also :func:`treespec_is_prefix` and :meth:`PyTreeSpec.is_prefix`.
"""
return treespec.is_prefix(other_treespec, strict=strict)


def treespec_is_suffix(
treespec: PyTreeSpec,
other_treespec: PyTreeSpec,
strict: bool = False,
) -> bool:
"""Return whether ``treespec`` is a suffix of ``other_treespec``.

See also :func:`treespec_is_suffix` :meth:`PyTreeSpec.is_suffix`.
"""
return treespec.is_suffix(other_treespec, strict=strict)


def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]:
"""Return a list of treespecs for the children of a treespec."""
return treespec.children()
Expand Down
Loading