Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow custom dot_general for einsum. #3884

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def dot_product_attention_weights(
precision: PrecisionLike = None,
module: Optional[Module] = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
"""Computes dot-product attention weights given query and key.

Expand Down Expand Up @@ -87,6 +88,7 @@ def dot_product_attention_weights(
force_fp32_for_softmax: bool, whether to force the softmax to be computed in
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.

Returns:
Output of shape ``[batch..., num_heads, q_length, kv_length]``.
Expand All @@ -104,7 +106,11 @@ def dot_product_attention_weights(
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
'...qhd,...khd->...hqk',
query,
key,
precision=precision,
_dot_general=einsum_dot_general,
)

# apply attention bias: masking, dropout, proximity bias, etc.
Expand Down Expand Up @@ -153,6 +159,7 @@ def dot_product_attention(
precision: PrecisionLike = None,
module: Optional[Module] = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
"""Computes dot-product attention given query, key, and value.

Expand Down Expand Up @@ -191,6 +198,7 @@ def dot_product_attention(
force_fp32_for_softmax: bool, whether to force the softmax to be computed in
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.

Returns:
Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
Expand Down Expand Up @@ -220,11 +228,16 @@ def dot_product_attention(
precision,
module,
force_fp32_for_softmax,
einsum_dot_general=einsum_dot_general,
)

# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
'...hqk,...khd->...qhd',
attn_weights,
value,
precision=precision,
_dot_general=einsum_dot_general,
)


Expand Down
Loading