From 808003b4e29e878349192e0f63fa1a2454ace56b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Sep 2024 23:53:24 -0700 Subject: [PATCH] Update users of jax.tree.map() to be more careful about how they handle Nones. Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself. Fix code that was relying on this bug. Most commonly, the fix is to write `jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`. PiperOrigin-RevId: 673258116 --- jax/experimental/jax2tf/call_tf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 04dd9f17933d..037f8bbc2a02 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -266,6 +266,8 @@ def replace_non_float_or_none(arg_tf): ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax) # We must make the float0s that JAX expects def fix_float0(arg_jax, ct_arg_jax): + if arg_jax is None: + return None arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: @@ -273,7 +275,8 @@ def fix_float0(arg_jax, ct_arg_jax): ct_arg_dtype)) return ct_arg_jax - ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax) + ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax, + is_leaf=lambda x: x is None) return ct_args_jax_fixed make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)