Skip to content

Commit

Permalink
Update users of jax.tree.map() to be more careful about how they hand…
Browse files Browse the repository at this point in the history
…le Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 673258116
  • Loading branch information
hawkinsp authored and jax authors committed Sep 11, 2024
1 parent e3c4b20 commit 808003b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,17 @@ def replace_non_float_or_none(arg_tf):
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
# We must make the float0s that JAX expects
def fix_float0(arg_jax, ct_arg_jax):
if arg_jax is None:
return None
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
if ct_arg_dtype != ct_arg_jax.dtype:
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
ct_arg_dtype))
return ct_arg_jax

ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax,
is_leaf=lambda x: x is None)
return ct_args_jax_fixed

make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
Expand Down

0 comments on commit 808003b

Please sign in to comment.