From 527d6982713f8ad7ac9821ebf7a00899dde2cc79 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 18 Sep 2024 13:43:14 -0700 Subject: [PATCH] Clean up and fix primal type to tangent type mapping This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types. Changes: 1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself. 2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion. 3. Add `to_tangent_type` calls in various other places they're missing. 4. Remove non-support for float0 in custom deriviatives? 5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.) PiperOrigin-RevId: 676115753 --- distrax/_src/utils/math.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/distrax/_src/utils/math.py b/distrax/_src/utils/math.py index 9a4d373..b664b63 100644 --- a/distrax/_src/utils/math.py +++ b/distrax/_src/utils/math.py @@ -14,10 +14,13 @@ # ============================================================================== """Utility math functions.""" +import functools from typing import Optional, Tuple import chex import jax +from jax import core as jax_core +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp Array = chex.Array @@ -44,7 +47,24 @@ def multiply_no_nan(x: Array, y: Array) -> Array: return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y) -@multiply_no_nan.defjvp +# TODO(dougalm): move helpers like these into JAX AD utils +def add_maybe_symbolic(x, y): + if isinstance(x, SymbolicZero): + return y + elif isinstance(y, SymbolicZero): + return x + else: + return x + y + + +def scale_maybe_symbolic(result_aval, tangent, scale): + if isinstance(tangent, SymbolicZero): + return SymbolicZero(result_aval) + else: + return tangent * scale + + +@functools.partial(multiply_no_nan.defjvp, symbolic_zeros=True) def multiply_no_nan_jvp( primals: Tuple[Array, Array], tangents: Tuple[Array, Array]) -> Tuple[Array, Array]: @@ -52,8 +72,11 @@ def multiply_no_nan_jvp( x, y = primals x_dot, y_dot = tangents primal_out = multiply_no_nan(x, y) - tangent_out = y * x_dot + x * y_dot - return primal_out, tangent_out + primal_aval = jax_core.raise_to_shaped(jax_core.get_aval(primal_out)) + result_aval = primal_aval.at_least_vspace() + tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y) + tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x) + return primal_out, add_maybe_symbolic(tangent_out_1, tangent_out_2) @jax.custom_jvp