Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document jax.tree_util.register_dataclass
Browse files Browse the repository at this point in the history
jakevdp committed May 16, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 9dd98dc commit 7573dce
Showing 2 changed files with 67 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/jax.tree_util.rst
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@ List of Functions
Partial
all_leaves
build_tree
register_dataclass
register_pytree_node
register_pytree_node_class
register_pytree_with_keys
74 changes: 66 additions & 8 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
traceback_util.register_exclusion(__file__)

T = TypeVar("T")
U = TypeVar("U", bound=type[Any])
Typ = TypeVar("Typ", bound=type[Any])
H = TypeVar("H", bound=Hashable)

Leaf = Any
@@ -254,7 +254,7 @@ def register_pytree_node(nodetype: type[T],
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)


def register_pytree_node_class(cls: U) -> U:
def register_pytree_node_class(cls: Typ) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around ``register_pytree_node``, and provides
@@ -807,7 +807,7 @@ def flatten_func_impl(tree):
)


def register_pytree_with_keys_class(cls: U) -> U:
def register_pytree_with_keys_class(cls: Typ) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is similar to ``register_pytree_node_class``, but requires a
@@ -838,18 +838,75 @@ def tree_unflatten(cls, aux_data, children):


def register_dataclass(
nodetype: type, data_fields: list[str], meta_fields: list[str]
):
nodetype: Typ, data_fields: list[str], meta_fields: list[str]
) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This differs from ``register_pytree_with_keys_class`` in that the C++
registries use the optimized C++ dataclass builtin instead of the argument
functions.
See :ref:`extending-pytrees` for more information about registering pytrees.
Args:
nodetype: a Python type to treat as an internal pytree node.
meta_fields: auxiliary data field names.
data_fields: data field names.
nodetype: a Python type to treat as an internal pytree node. This is assumed
to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class
attributes represent the whole of the object state, and can be passed
as keywords to the class constructor to create a copy of the object.
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
meta_fields: auxiliary data field names. These fields *must* contain static,
hashable, immutable objects, as these objects are used to generate JIT cache
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
:class:`numpy.ndarray` objects.
data_fields: data field names. These fields *must* be JAX-compatible objects
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
``None``, as this is recognized by JAX as an empty pytree.
Returns:
The input class ``nodetype`` is returned unchanged after being added to JAX's
pytree registry. This return value allows ``register_dataclass`` to be partially
evaluated and used as a decorator as in the example below.
Example:
>>> from dataclasses import dataclass
>>> from functools import partial
>>>
>>> @partial(jax.tree_util.register_dataclass,
... data_fields=['x', 'y'],
... meta_fields=['op'])
... @dataclass
... class MyStruct:
... x: jax.Array
... y: jax.Array
... op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`:
>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows ``m`` to be passed seamlessly through code
wrapped in :func:`jax.jit` and other JAX transformations:
>>> @jax.jit
... def compiled_func(m):
... if m.op == 'add':
... return m.x + m.y
... else:
... raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)
"""
def flatten_with_keys(x):
meta = tuple(getattr(x, name) for name in meta_fields)
@@ -874,6 +931,7 @@ def flatten_func(x):
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
return nodetype


def register_static(cls: type[H]) -> type[H]:

0 comments on commit 7573dce

Please sign in to comment.