diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 04dd9f17933d..037f8bbc2a02 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -266,6 +266,8 @@ def replace_non_float_or_none(arg_tf): ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax) # We must make the float0s that JAX expects def fix_float0(arg_jax, ct_arg_jax): + if arg_jax is None: + return None arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: @@ -273,7 +275,8 @@ def fix_float0(arg_jax, ct_arg_jax): ct_arg_dtype)) return ct_arg_jax - ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax) + ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax, + is_leaf=lambda x: x is None) return ct_args_jax_fixed make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)