Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(treespec): add tree broadcast functions broadcast_common, tree_broadcast_common, tree_broadcast_map, and tree_broadcast_map_with_path #87

Merged
merged 11 commits into from
Oct 12, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add tree broadcast functions `broadcast_common`, `tree_broadcast_common`, `tree_broadcast_map`, and `tree_broadcast_map_with_path` by [@XuehaiPan](https://github.com/XuehaiPan) in [#87](https://github.com/metaopt/optree/pull/87).
- Add function `tree_is_leaf` and add `is_leaf` argument to function `all_leaves` by [@XuehaiPan](https://github.com/XuehaiPan) in [#93](https://github.com/metaopt/optree/pull/93).
- Add methods `PyTreeSpec.entry` and `PyTreeSpec.child` by [@XuehaiPan](https://github.com/XuehaiPan) in [#88](https://github.com/metaopt/optree/pull/88).
- Add Python 3.12 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#90](https://github.com/metaopt/optree/pull/90).
Expand Down
8 changes: 8 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Tree Manipulation Functions
tree_transpose
tree_broadcast_prefix
broadcast_prefix
tree_broadcast_common
broadcast_common
tree_broadcast_map
tree_broadcast_map_with_path
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -57,6 +61,10 @@ Tree Manipulation Functions
.. autofunction:: tree_transpose
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_broadcast_common
.. autofunction:: broadcast_common
.. autofunction:: tree_broadcast_map
.. autofunction:: tree_broadcast_map_with_path
.. autofunction:: prefix_errors

------
Expand Down
14 changes: 14 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class PyTreeSpec {
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
[[nodiscard]] py::list FlattenUpTo(const py::handle &full_tree) const;

// Broadcast to a common suffix of this PyTreeSpec and other PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> BroadcastToCommonSuffix(
const PyTreeSpec &other) const;

// Test whether the given object is a leaf node.
static bool ObjectIsLeaf(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
Expand Down Expand Up @@ -229,6 +233,9 @@ class PyTreeSpec {
// The registry namespace used to resolve the custom pytree node types.
std::string m_namespace{};

// Helper that returns the string representation of a node kind.
static std::string NodeKindToString(const Node &node);

// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node,
const py::object *children,
Expand Down Expand Up @@ -256,6 +263,13 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

static std::tuple<ssize_t, ssize_t, ssize_t, ssize_t> BroadcastToCommonSuffixImpl(
std::vector<Node> &nodes, // NOLINT[runtime/references]
const std::vector<Node> &traversal,
const ssize_t &pos,
const std::vector<Node> &other_traversal,
const ssize_t &other_pos);

template <bool NoneIsLeaf>
static bool ObjectIsLeafImpl(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
Expand Down
1 change: 1 addition & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PyTreeSpec:
type: builtins.type | None
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ...
def compose(self, inner_treespec: PyTreeSpec) -> PyTreeSpec: ...
def walk(
self,
Expand Down
8 changes: 8 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
NONE_IS_LEAF,
NONE_IS_NODE,
all_leaves,
broadcast_common,
broadcast_prefix,
prefix_errors,
tree_all,
tree_any,
tree_broadcast_common,
tree_broadcast_map,
tree_broadcast_map_with_path,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_with_path,
Expand Down Expand Up @@ -102,6 +106,10 @@
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_broadcast_common',
'broadcast_common',
'tree_broadcast_map',
'tree_broadcast_map_with_path',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down
Loading
Loading