From 5c40ee9d9d2e07f21ae94af492836dbac362e0ba Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 17 Sep 2024 09:54:15 -0700 Subject: [PATCH] Clean up and fix primal type to tangent type mapping This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types. Changes: 1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself. 2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion. 3. Add `to_tangent_type` calls in various other places they're missing. 4. Remove non-support for float0 in custom deriviatives? 5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.) PiperOrigin-RevId: 675606346 --- distrax/_src/utils/math.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/distrax/_src/utils/math.py b/distrax/_src/utils/math.py index 9a4d373..a71a708 100644 --- a/distrax/_src/utils/math.py +++ b/distrax/_src/utils/math.py @@ -14,10 +14,13 @@ # ============================================================================== """Utility math functions.""" +import functools from typing import Optional, Tuple import chex import jax +from jax import core as jax_core +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp Array = chex.Array @@ -44,7 +47,24 @@ def multiply_no_nan(x: Array, y: Array) -> Array: return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y) -@multiply_no_nan.defjvp +# TODO(dougalm): move helpers like these into JAX AD utils +def add_maybe_symbolic(x, y): + if isinstance(x, SymbolicZero): + return y + elif isinstance(y, SymbolicZero): + return x + else: + return x + y + + +def scale_maybe_symbolic(result_aval, tangent, scale): + if isinstance(tangent, SymbolicZero): + return SymbolicZero(result_aval) + else: + return tangent * scale + + +@functools.partial(multiply_no_nan.defjvp, symbolic_zeros=True) def multiply_no_nan_jvp( primals: Tuple[Array, Array], tangents: Tuple[Array, Array]) -> Tuple[Array, Array]: @@ -52,8 +72,10 @@ def multiply_no_nan_jvp( x, y = primals x_dot, y_dot = tangents primal_out = multiply_no_nan(x, y) - tangent_out = y * x_dot + x * y_dot - return primal_out, tangent_out + result_aval = jax_core.get_aval(primal_out).to_tangent_aval() + tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y) + tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x) + return primal_out, add_maybe_symbolic(tangent_out_1, tangent_out_2) @jax.custom_jvp