Skip to content

Commit

Permalink
feat(treespec): add tree broadcast functions broadcast_common, `tre…
Browse files Browse the repository at this point in the history
…e_broadcast_common`, `tree_broadcast_map`, and `tree_broadcast_map_with_path` (#87)
  • Loading branch information
XuehaiPan authored Oct 12, 2023
1 parent e5ed81c commit 1582f22
Show file tree
Hide file tree
Showing 10 changed files with 1,131 additions and 95 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add tree broadcast functions `broadcast_common`, `tree_broadcast_common`, `tree_broadcast_map`, and `tree_broadcast_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#87](https://github.com/metaopt/optree/pull/87).
- Add function `tree_is_leaf` and add `is_leaf` argument to function `all_leaves` by [@XuehaiPan](https://github.com/XuehaiPan) in [#93](https://github.com/metaopt/optree/pull/93).
- Add methods `PyTreeSpec.entry` and `PyTreeSpec.child` by [@XuehaiPan](https://github.com/XuehaiPan) in [#88](https://github.com/metaopt/optree/pull/88).
- Add Python 3.12 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#90](https://github.com/metaopt/optree/pull/90).
Expand Down
8 changes: 8 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Tree Manipulation Functions
tree_transpose
tree_broadcast_prefix
broadcast_prefix
tree_broadcast_common
broadcast_common
tree_broadcast_map
tree_broadcast_map_with_path
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -57,6 +61,10 @@ Tree Manipulation Functions
.. autofunction:: tree_transpose
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_broadcast_common
.. autofunction:: broadcast_common
.. autofunction:: tree_broadcast_map
.. autofunction:: tree_broadcast_map_with_path
.. autofunction:: prefix_errors

------
Expand Down
14 changes: 14 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class PyTreeSpec {
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
[[nodiscard]] py::list FlattenUpTo(const py::handle &full_tree) const;

// Broadcast to a common suffix of this PyTreeSpec and other PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> BroadcastToCommonSuffix(
const PyTreeSpec &other) const;

// Test whether the given object is a leaf node.
static bool ObjectIsLeaf(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
Expand Down Expand Up @@ -229,6 +233,9 @@ class PyTreeSpec {
// The registry namespace used to resolve the custom pytree node types.
std::string m_namespace{};

// Helper that returns the string representation of a node kind.
static std::string NodeKindToString(const Node &node);

// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node,
const py::object *children,
Expand Down Expand Up @@ -256,6 +263,13 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

static std::tuple<ssize_t, ssize_t, ssize_t, ssize_t> BroadcastToCommonSuffixImpl(
std::vector<Node> &nodes, // NOLINT[runtime/references]
const std::vector<Node> &traversal,
const ssize_t &pos,
const std::vector<Node> &other_traversal,
const ssize_t &other_pos);

template <bool NoneIsLeaf>
static bool ObjectIsLeafImpl(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
Expand Down
1 change: 1 addition & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PyTreeSpec:
type: builtins.type | None
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ...
def compose(self, inner_treespec: PyTreeSpec) -> PyTreeSpec: ...
def walk(
self,
Expand Down
8 changes: 8 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
NONE_IS_LEAF,
NONE_IS_NODE,
all_leaves,
broadcast_common,
broadcast_prefix,
prefix_errors,
tree_all,
tree_any,
tree_broadcast_common,
tree_broadcast_map,
tree_broadcast_map_with_path,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_with_path,
Expand Down Expand Up @@ -102,6 +106,10 @@
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_broadcast_common',
'broadcast_common',
'tree_broadcast_map',
'tree_broadcast_map_with_path',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down
Loading

0 comments on commit 1582f22

Please sign in to comment.