From b386b71500d7b3b5260fdbc08e0b603e76fb10b6 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Wed, 24 Apr 2024 16:28:06 -0700 Subject: [PATCH] allow custom dot_general for einsum. PiperOrigin-RevId: 627886032 --- flax/linen/attention.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 4987ccd067..99b79f2de5 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -54,6 +54,7 @@ def dot_product_attention_weights( precision: PrecisionLike = None, module: Optional[Module] = None, force_fp32_for_softmax: bool = False, + einsum_dot_general: Callable[..., Array] = jax.lax.dot_general, ): """Computes dot-product attention weights given query and key. @@ -87,6 +88,7 @@ def dot_product_attention_weights( force_fp32_for_softmax: bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability. + einsum_dot_general: the dot_general to use in einsum. Returns: Output of shape ``[batch..., num_heads, q_length, kv_length]``. @@ -104,7 +106,11 @@ def dot_product_attention_weights( query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) attn_weights = jnp.einsum( - '...qhd,...khd->...hqk', query, key, precision=precision + '...qhd,...khd->...hqk', + query, + key, + precision=precision, + _dot_general=einsum_dot_general, ) # apply attention bias: masking, dropout, proximity bias, etc. @@ -153,6 +159,7 @@ def dot_product_attention( precision: PrecisionLike = None, module: Optional[Module] = None, force_fp32_for_softmax: bool = False, + einsum_dot_general: Callable[..., Array] = jax.lax.dot_general, ): """Computes dot-product attention given query, key, and value. @@ -191,6 +198,7 @@ def dot_product_attention( force_fp32_for_softmax: bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability. + einsum_dot_general: the dot_general to use in einsum. Returns: Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``. @@ -220,11 +228,16 @@ def dot_product_attention( precision, module, force_fp32_for_softmax, + einsum_dot_general=einsum_dot_general, ) # return weighted sum over values for each query position return jnp.einsum( - '...hqk,...khd->...qhd', attn_weights, value, precision=precision + '...hqk,...khd->...qhd', + attn_weights, + value, + precision=precision, + _dot_general=einsum_dot_general, )