From ab5403a0b0975e01bc11f00c9adadcca5ee14690 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 22 Mar 2023 14:15:00 +0000 Subject: [PATCH 1/2] Change how we trace metadata. --- benchmarks/benchmarks.ipynb | 33 ++++-- mytree/mytree.py | 211 ++++++++++++++++++++---------------- tests/test_mytree.py | 24 ---- 3 files changed, 139 insertions(+), 129 deletions(-) diff --git a/benchmarks/benchmarks.ipynb b/benchmarks/benchmarks.ipynb index 999a708..af7473f 100644 --- a/benchmarks/benchmarks.ipynb +++ b/benchmarks/benchmarks.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "3a7f5936", "metadata": {}, "outputs": [], @@ -24,10 +24,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "536f66e1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danieldodd/miniconda3/lib/python3.10/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n", + " jax.tree_util.register_keypaths(\n" + ] + } + ], "source": [ "from mytree import Mytree, param_field, Softplus\n", "\n", @@ -59,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "1ce3a220", "metadata": {}, "outputs": [], @@ -135,7 +144,7 @@ "source": [ "Run on a M1 Pro CPU.\n", "\n", - "- **Initialisation**: is faster for mytree, despite it unpacking metadata, and working out what attributes are leaves of the nested pytree structure.\n", + "- **Initialisation**: is faster for mytree.\n", "- **Transformations**: is faster for mytree.\n", "- **Replacing attributes**: is faster for mytree implimentation.\n", "\n", @@ -144,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "8db39ca8", "metadata": {}, "outputs": [ @@ -154,14 +163,14 @@ "text": [ "\n", " mytree:\n", - "50.7 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "862 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", - "1.34 µs ± 31.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", + "52.1 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "1.02 s ± 35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "1.52 µs ± 15.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", "\n", " pytree:\n", - "51.8 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "893 ms ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", - "1.7 µs ± 44.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" + "58 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "1.08 s ± 20.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "2.05 µs ± 76.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n" ] } ], diff --git a/mytree/mytree.py b/mytree/mytree.py index 34c7e80..f167fa6 100644 --- a/mytree/mytree.py +++ b/mytree/mytree.py @@ -2,22 +2,32 @@ __all__ = ["Mytree", "meta_leaves", "meta"] +import dataclasses from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List +from typing import Any, Callable, Dict, Iterable, Tuple import jax import jax.tree_util as jtu +from jax._src.tree_util import _registry from simple_pytree import Pytree, static_field from .bijectors import Bijector, Identity class Mytree(Pytree): - _pytree__leaf_meta: Dict[str, Any] = static_field() + _pytree__meta: Dict[str, Any] = static_field() def __init_subclass__(cls, mutable: bool = False): - cls._pytree__leaf_meta = dict() + cls._pytree__meta = dict() super().__init_subclass__(mutable=mutable) + class_vars = vars(cls) + for field, value in class_vars.items(): + if ( + field not in cls._pytree__static_fields + and isinstance(value, dataclasses.Field) + and value.metadata is not None + ): + cls._pytree__meta[field] = {**value.metadata} def replace(self, **kwargs: Any) -> Mytree: """ @@ -34,47 +44,6 @@ def replace(self, **kwargs: Any) -> Mytree: pytree.__dict__.update(kwargs) return pytree - if not TYPE_CHECKING: - - def __setattr__(self, field: str, value: Any): - super().__setattr__(field, value) - - # TODO: Clean up this mess. - if field not in self._pytree__static_fields: - _not_pytree = ( - jtu.tree_map( - lambda x: isinstance(x, Pytree), - value, - is_leaf=lambda x: isinstance(x, Pytree), - ) - == False - ) - - if _not_pytree: - try: - field_metadata = { - **type(self) - .__dict__["__dataclass_fields__"][field] - .metadata - } - except KeyError: - try: - dataclass_field_ = type(self).__dict__[field] - try: - field_metadata = {**dataclass_field_.metadata} - except AttributeError: - field_metadata = {} - - except KeyError: - field_metadata = {} - - if field_metadata.get("pytree_node", True): - object.__setattr__( - self, - "_pytree__leaf_meta", - self._pytree__leaf_meta | {field: field_metadata}, - ) - def replace_meta(self, **kwargs: Any) -> Mytree: """ Replace the values of the fields of the object with the values of the @@ -83,13 +52,11 @@ def replace_meta(self, **kwargs: Any) -> Mytree: type as the original object. """ for key in kwargs: - if key not in self._pytree__leaf_meta.keys(): + if key not in self._pytree__meta.keys(): raise ValueError(f"'{key}' is not a leaf of {type(self).__name__}") pytree = copy(self) - pytree.__dict__.update( - _pytree__leaf_meta={**pytree._pytree__leaf_meta, **kwargs} - ) + pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs}) return pytree def update_meta(self, **kwargs: Any) -> Mytree: @@ -100,19 +67,19 @@ def update_meta(self, **kwargs: Any) -> Mytree: type as the original object. """ for key in kwargs: - if key not in self._pytree__leaf_meta.keys(): + if key not in self._pytree__meta.keys(): raise ValueError( f"'{key}' is not an attribute of {type(self).__name__}" ) pytree = copy(self) - new = deepcopy(pytree._pytree__leaf_meta) + new = deepcopy(pytree._pytree__meta) for key, value in kwargs.items(): if key in new: new[key].update(value) else: new[key] = value - pytree.__dict__.update(_pytree__leaf_meta=new) + pytree.__dict__.update(_pytree__meta=new) return pytree def replace_trainable(Mytree: Mytree, **kwargs: Dict[str, bool]) -> Mytree: @@ -129,9 +96,12 @@ def constrain(self) -> Mytree: Returns: Mytree: tranformed to the constrained space. """ - return _meta_map( - lambda leaf, meta: meta.get("bijector", Identity).forward(leaf), self - ) + + def _apply_constrain(meta_leaf): + meta, leaf = meta_leaf + return meta.get("bijector", Identity).forward(leaf) + + return meta_map(_apply_constrain, self) def unconstrain(self) -> Mytree: """Transform model parameters to the unconstrained space according to their defined bijectors. @@ -139,9 +109,12 @@ def unconstrain(self) -> Mytree: Returns: Mytree: tranformed to the unconstrained space. """ - return _meta_map( - lambda leaf, meta: meta.get("bijector", Identity).inverse(leaf), self - ) + + def _apply_unconstrain(meta_leaf): + meta, leaf = meta_leaf + return meta.get("bijector", Identity).inverse(leaf) + + return meta_map(_apply_unconstrain, self) def stop_gradient(self) -> Mytree: """Stop gradients flowing through the Mytree. @@ -154,70 +127,122 @@ def stop_gradient(self) -> Mytree: def _stop_grad(leaf: jax.Array, trainable: bool) -> jax.Array: return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, leaf) - return _meta_map( - lambda leaf, meta: _stop_grad(leaf, meta.get("trainable", True)), self - ) + def _apply_stop_grad(meta_leaf): + meta, leaf = meta_leaf + return _stop_grad(leaf, meta.get("trainable", True)) + return meta_map(_apply_stop_grad, self) -def _meta_map(f: Callable[[Any, Dict[str, Any]], Any], pytree: Mytree) -> Mytree: - """Apply a function to a pytree where the first argument are the pytree leaves, and the second argument are the pytree metadata leaves. + +def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]: + """Unpacks a list of meta corresponding to the top-level nodes of the pytree. Args: - f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. - pytree (Mytree): The pytree to apply the function to. + pytree (Any): pytree to unpack the meta from. Returns: - Mytree: The transformed pytree. + List[Dict[str, Any]]: meta of the top-level nodes of the pytree. """ - leaves, treedef = jtu.tree_flatten(pytree) - all_leaves = [leaves] + [meta_leaves(pytree)] - return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) - - -def _toplevel_meta(pytree: Mytree) -> List[Dict[str, Any]]: + if isinstance(pytree, Iterable): + return [None] * len(pytree) return [ - pytree._pytree__leaf_meta[k] for k in sorted(pytree._pytree__leaf_meta.keys()) + pytree._pytree__meta.get(field, {}) + for field, _ in sorted(vars(pytree).items()) + if field not in pytree._pytree__static_fields ] -def meta_leaves(pytree: Mytree) -> List[Dict[str, Any]]: +def meta_leaves( + pytree: Mytree, + *, + is_leaf: Callable[[Any], bool] | None = None, +) -> List[Tuple[Dict[str, Any], Any]]: """ - Returns a list of the Mytree Mytree leaves' metadata. + Returns the meta of the leaves of the pytree. Args: - pytree (Mytree): Mytree to get the metadata of the leaves. + pytree (Mytree): pytree to get the meta of. + is_leaf (Callable[[Any], bool]): predicate to determine if a node is a leaf. Defaults to None. Returns: - List[Dict[str, Any]]: List of the Mytree leaves' metadata. + List[Tuple[Dict[str, Any], Any]]: meta of the leaves of the pytree. """ - _leaf_metadata = _toplevel_meta(pytree) - def _nested_unpack_metadata(pytree: Mytree, *rest: Mytree) -> None: - if isinstance(pytree, Mytree): - _leaf_metadata.extend(_toplevel_meta(pytree)) - _unpack_metadata(pytree, *rest) + def _unpack_metadata( + meta: Any, + pytree: Mytree, + is_leaf: Callable[[Any], bool] | None, + ): + """Recursively unpack leaf metadata.""" + if is_leaf and is_leaf(pytree): + yield meta + return + + if type(pytree) in _registry: # Registry tree trick, thanks to PyTreeClass! + leaves_values, _ = _registry[type(pytree)].to_iter(pytree) + leaves_meta = _toplevel_meta(pytree) - def _unpack_metadata(pytree: Mytree, *rest: Mytree) -> None: - pytrees = (pytree,) + rest - _ = jax.tree_map( - _nested_unpack_metadata, - *pytrees, - is_leaf=lambda x: isinstance(x, Mytree) and not x in pytrees, - ) + elif pytree is not None: + yield meta + return - _unpack_metadata(pytree) + for metadata, leaf in zip(leaves_meta, leaves_values): + yield from _unpack_metadata((metadata, leaf), leaf, is_leaf) - return _leaf_metadata + return list(_unpack_metadata(pytree, pytree, is_leaf)) -def meta(pytree: Mytree) -> Mytree: +def meta_flatten( + pytree: Mytree, *, is_leaf: Callable[[Any], bool] | None = None +) -> Mytree: """ - Returns the meta of the Mytree Mytree. + Returns the meta of the Mytree. Args: pytree (Mytree): Mytree to get the meta of. + is_leaf (Callable[[Any], bool]): predicate to determine if a node is a leaf. Defaults to None. Returns: Mytree: meta of the Mytree. """ - return jtu.tree_structure(pytree).unflatten(meta_leaves(pytree)) + return meta_leaves(pytree, is_leaf=is_leaf), jtu.tree_structure( + pytree, is_leaf=is_leaf + ) + + +def meta_map( + f: Callable[[Any, Dict[str, Any]], Any], + pytree: Mytree, + *rest: Any, + is_leaf: Callable[[Any], bool] | None = None, +) -> Mytree: + """Apply a function to a mytree where the first argument are the pytree leaves, and the second argument are the mytree metadata leaves. + Args: + f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. + pytree (Mytree): The pytree to apply the function to. + rest (Any, optional): Additional pytrees to apply the function to. Defaults to None. + is_leaf (Callable[[Any], bool], optional): predicate to determine if a node is a leaf. Defaults to None. + + Returns: + Mytree: The transformed pytree. + """ + leaves, treedef = meta_flatten(pytree, is_leaf=is_leaf) + all_leaves = [leaves] + [treedef.treedef.flatten_up_to(r) for r in rest] + return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) + + +def meta(pytree: Mytree, *, is_leaf: Callable[[Any], bool] | None = None) -> Mytree: + """Returns the metadata of the mytree as a pytree. + + Args: + pytree (Mytree): pytree to get the metadata of. + + Returns: + Mytree: metadata of the pytree. + """ + + def _filter_meta(meta_leaf): + meta, _ = meta_leaf + return meta + + return meta_map(_filter_meta, pytree, is_leaf=is_leaf) diff --git a/tests/test_mytree.py b/tests/test_mytree.py index 626f524..7c7867c 100644 --- a/tests/test_mytree.py +++ b/tests/test_mytree.py @@ -447,30 +447,6 @@ def __init__(self, trees): assert unconstrained_tree.trees[2].c == Identity.inverse(3.0) -@pytest.mark.parametrize("is_dataclass", [True, False]) -def test_pytree_leaf_meta_inheritence_and_unmarked_fields(is_dataclass): - class A(Mytree): - a: int = static_field() - - def __init__(self, a=1, b=2) -> None: - self.a = a - self.b = b - self.h = 3 - - class B(A): - c: int - - def __init__(self, c=3) -> None: - super().__init__() - self.c = c - - if is_dataclass: - A = dataclasses.dataclass(A) - B = dataclasses.dataclass(B) - - assert B()._pytree__leaf_meta == {"b": {}, "h": {}, "c": {}} - - # The following tests are adapted from equinox 🏴‍☠️ From 5ac21b6ecea115d031dc60a73ddd3098ce1c976b Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 22 Mar 2023 14:50:56 +0000 Subject: [PATCH 2/2] Add bijectors tests. --- mytree/bijectors.py | 21 ++++++++++----------- tests/test_bijectors.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 tests/test_bijectors.py diff --git a/mytree/bijectors.py b/mytree/bijectors.py index 0c5ea5b..5a25762 100644 --- a/mytree/bijectors.py +++ b/mytree/bijectors.py @@ -2,19 +2,17 @@ __all__ = ["Bijector", "Identity", "Softplus"] -import importlib -import jax.numpy as jnp -from typing import Callable from dataclasses import dataclass +from typing import Callable + +import jax.numpy as jnp from simple_pytree import Pytree, static_field + @dataclass class Bijector(Pytree): - forward: Callable = static_field() - inverse: Callable = static_field() - -def __init__(self, forward: Callable, inverse: Callable) -> None: - """Initialise the bijector. + """ + Create a bijector. Args: forward(Callable): The forward transformation. @@ -23,13 +21,14 @@ def __init__(self, forward: Callable, inverse: Callable) -> None: Returns: Bijector: A bijector. """ - self.forward = forward - self.inverse = inverse + + forward: Callable = static_field() + inverse: Callable = static_field() Identity = Bijector(forward=lambda x: x, inverse=lambda x: x) Softplus = Bijector( - forward=lambda x: jnp.log(1 + jnp.exp(x)), + forward=lambda x: jnp.log(1.0 + jnp.exp(x)), inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), ) diff --git a/tests/test_bijectors.py b/tests/test_bijectors.py new file mode 100644 index 0000000..9e6eb59 --- /dev/null +++ b/tests/test_bijectors.py @@ -0,0 +1,22 @@ +import jax.numpy as jnp +import pytest + +from mytree.bijectors import Bijector, Identity, Softplus + + +def test_bijector(): + bij = Bijector(forward=lambda x: jnp.exp(x), inverse=lambda x: jnp.log(x)) + assert bij.forward(1.0) == pytest.approx(jnp.exp(1.0)) + assert bij.inverse(jnp.exp(1.0)) == pytest.approx(1.0) + + +def test_identity(): + bij = Identity + assert bij.forward(1.0) == pytest.approx(1.0) + assert bij.inverse(1.0) == pytest.approx(1.0) + + +def test_softplus(): + bij = Softplus + assert bij.forward(1.0) == pytest.approx(jnp.log(1.0 + jnp.exp(1.0))) + assert bij.inverse(jnp.log(1.0 + jnp.exp(1.0))) == pytest.approx(1.0)