From 10694482b3f25d3bb14b173ee4126f295ddd27de Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 May 2024 10:07:35 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 633992657 --- fedjax/core/regularizers.py | 2 +- fedjax/core/tree_util.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fedjax/core/regularizers.py b/fedjax/core/regularizers.py index fee9ea0..af2c9bb 100644 --- a/fedjax/core/regularizers.py +++ b/fedjax/core/regularizers.py @@ -26,7 +26,7 @@ def _l2_regularize(params: Params, weight: float, params_weights: Optional[Params]) -> float: """Returns L2 regularization weight.""" if center_params is not None: - params = jax.tree_map(lambda a, b: a - b, params, center_params) + params = jax.tree.map(lambda a, b: a - b, params, center_params) leaves = jax.tree_util.tree_leaves(params) if params_weights is not None: pw_leaves = jax.tree_util.tree_leaves(params_weights) diff --git a/fedjax/core/tree_util.py b/fedjax/core/tree_util.py index 4ffc30c..4d909bd 100644 --- a/fedjax/core/tree_util.py +++ b/fedjax/core/tree_util.py @@ -29,7 +29,7 @@ @jax.jit def tree_weight(pytree: PyTree, weight: float) -> PyTree: """Weights tree leaves by weight.""" - return jax.tree_map(lambda l: l * weight, pytree) + return jax.tree.map(lambda l: l * weight, pytree) def tree_inverse_weight(pytree: PyTree, weight: float) -> PyTree: @@ -41,7 +41,7 @@ def tree_inverse_weight(pytree: PyTree, weight: float) -> PyTree: @jax.jit def tree_zeros_like(pytree: PyTree) -> PyTree: """Creates a tree with zeros with same structure as the input.""" - return jax.tree_map(jnp.zeros_like, pytree) + return jax.tree.map(jnp.zeros_like, pytree) @jax.jit