diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index da52d87868e0..b685838683e6 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -14,6 +14,7 @@ List of Functions Partial all_leaves build_tree + register_dataclass register_pytree_node register_pytree_node_class register_pytree_with_keys diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index ef1884347fd9..6a26f4148792 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -21,7 +21,7 @@ from functools import partial import operator as op import textwrap -from typing import Any, Callable, NamedTuple, TypeVar, Union, overload +from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -32,7 +32,7 @@ traceback_util.register_exclusion(__file__) T = TypeVar("T") -U = TypeVar("U", bound=type[Any]) +Typ = TypeVar("Typ", bound=type[Any]) H = TypeVar("H", bound=Hashable) Leaf = Any @@ -254,7 +254,7 @@ def register_pytree_node(nodetype: type[T], _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) -def register_pytree_node_class(cls: U) -> U: +def register_pytree_node_class(cls: Typ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. This function is a thin wrapper around ``register_pytree_node``, and provides @@ -807,7 +807,7 @@ def flatten_func_impl(tree): ) -def register_pytree_with_keys_class(cls: U) -> U: +def register_pytree_with_keys_class(cls: Typ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. This function is similar to ``register_pytree_node_class``, but requires a @@ -838,19 +838,82 @@ def tree_unflatten(cls, aux_data, children): def register_dataclass( - nodetype: type, data_fields: list[str], meta_fields: list[str] -): + nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str] +) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. This differs from ``register_pytree_with_keys_class`` in that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions. + See :ref:`extending-pytrees` for more information about registering pytrees. + Args: - nodetype: a Python type to treat as an internal pytree node. - meta_fields: auxiliary data field names. - data_fields: data field names. + nodetype: a Python type to treat as an internal pytree node. This is assumed + to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class + attributes represent the whole of the object state, and can be passed + as keywords to the class constructor to create a copy of the object. + All defined attributes should be listed among ``meta_fields`` or ``data_fields``. + meta_fields: auxiliary data field names. These fields *must* contain static, + hashable, immutable objects, as these objects are used to generate JIT cache + keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or + :class:`numpy.ndarray` objects. + data_fields: data field names. These fields *must* be JAX-compatible objects + such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or + pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be + ``None``, as this is recognized by JAX as an empty pytree. + + Returns: + The input class ``nodetype`` is returned unchanged after being added to JAX's + pytree registry. This return value allows ``register_dataclass`` to be partially + evaluated and used as a decorator as in the example below. + + Example: + >>> from dataclasses import dataclass + >>> from functools import partial + >>> + >>> @partial(jax.tree_util.register_dataclass, + ... data_fields=['x', 'y'], + ... meta_fields=['op']) + ... @dataclass + ... class MyStruct: + ... x: jax.Array + ... y: jax.Array + ... op: str + ... + >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') + >>> m + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`: + + >>> leaves, treedef = jax.tree.flatten(m) + >>> leaves + [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] + >>> treedef + PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) + >>> jax.tree.unflatten(treedef, leaves) + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + In particular, this registration allows ``m`` to be passed seamlessly through code + wrapped in :func:`jax.jit` and other JAX transformations: + + >>> @jax.jit + ... def compiled_func(m): + ... if m.op == 'add': + ... return m.x + m.y + ... else: + ... raise ValueError(f"{m.op=}") + ... + >>> compiled_func(m) + Array([1., 2., 3.], dtype=float32) """ + # Store inputs as immutable tuples in this scope, because we close over them + # for later evaluation. This prevents potentially confusing behavior if the + # caller were to pass in lists that are later mutated. + meta_fields = tuple(meta_fields) + data_fields = tuple(data_fields) + def flatten_with_keys(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) @@ -867,13 +930,14 @@ def flatten_func(x): data = tuple(getattr(x, name) for name in data_fields) return data, meta - default_registry.register_dataclass_node(nodetype, data_fields, meta_fields) - none_leaf_registry.register_dataclass_node(nodetype, data_fields, meta_fields) - dispatch_registry.register_dataclass_node(nodetype, data_fields, meta_fields) + default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) + none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) + dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( flatten_with_keys, unflatten_func ) + return nodetype def register_static(cls: type[H]) -> type[H]: