Skip to content

Commit

Permalink
Customizing the einsums allows us to use the quantized einsum layers …
Browse files Browse the repository at this point in the history
…provided by [AQT](https://github.com/google/aqt).

PiperOrigin-RevId: 658120110
  • Loading branch information
Flax Team committed Jul 31, 2024
1 parent 5b12e9b commit cd6218f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 15 deletions.
103 changes: 88 additions & 15 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def dot_product_attention_weights(
precision: PrecisionLike = None,
module: Module | None = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
einsum_dot_general: Callable[..., Array] | None = None,
einsum: Callable[..., Array] | None = None,
):
"""Computes dot-product attention weights given query and key.
Expand Down Expand Up @@ -92,10 +93,30 @@ def dot_product_attention_weights(
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.
einsum: If unspecified, default `jnp.einsum` will be used. This argument is
mutually exclusive with `precision` and `einsum_dot_general`.
Raises:
ValueError: if both `precision`/`einsum_dot_general` and `einsum` are
specified.
Returns:
Output of shape ``[batch..., num_heads, q_length, kv_length]``.
"""
if (precision or einsum_dot_general) and einsum:
raise ValueError(
'precision/einsum_dot_general and einsum are mutually exclusive. Please'
' specify only one of them.'
)
if not einsum:
einsum = functools.partial(
jnp.einsum,
precision=precision,
_dot_general=einsum_dot_general
if einsum_dot_general
else jax.lax.dot_general,
)

query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype

Expand All @@ -108,13 +129,7 @@ def dot_product_attention_weights(
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk',
query,
key,
precision=precision,
_dot_general=einsum_dot_general,
)
attn_weights = einsum('...qhd,...khd->...hqk', query, key)

# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
Expand Down Expand Up @@ -162,7 +177,9 @@ def dot_product_attention(
precision: PrecisionLike = None,
module: Module | None = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
einsum_dot_general: Callable[..., Array] | None = None,
qk_attn_weights_einsum: Callable[..., Array] | None = None,
attn_weights_value_einsum: Callable[..., Array] | None = None,
):
"""Computes dot-product attention given query, key, and value.
Expand Down Expand Up @@ -201,11 +218,39 @@ def dot_product_attention(
force_fp32_for_softmax: bool, whether to force the softmax to be computed in
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.
einsum_dot_general: the dot_general to use in `jnp.einsum`.
qk_attn_weights_einsum: the einsum for computing the attention weights. When
unspecified, the default `jnp.einsum` will be used. This argument is
mutually exclusive with `precision` and `einsum_dot_general`.
attn_weights_value_einsum: the einsum for computing the product of the
attention weights and the values. When unspecified, the default
`jnp.einsum` will be used. This argument is mutually exclusive with
`precision` and `einsum_dot_general`.
Returns:
Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
Raises:
ValueError: if both `precision`/`einsum_dot_general` and
`qk_attn_weights_einsum`/`attn_weights_value_einsum` are
specified.
"""
if (qk_attn_weights_einsum and not attn_weights_value_einsum) or (
not qk_attn_weights_einsum and attn_weights_value_einsum
):
raise ValueError(
'qk_attn_weights_einsum and attn_weights_value_einsum must be specified'
' together.'
)
if (precision or einsum_dot_general) and (
qk_attn_weights_einsum or attn_weights_value_einsum
):
raise ValueError(
'precision/einsum_dot_general and'
' qk_attn_weights_einsum/attn_weights_value_einsum are mutually'
' exclusive. Please specify only one of them.'
)

query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
Expand All @@ -231,16 +276,21 @@ def dot_product_attention(
precision,
module,
force_fp32_for_softmax,
einsum_dot_general=einsum_dot_general,
qk_attn_weights_einsum,
)

if not attn_weights_value_einsum:
attn_weights_value_einsum = functools.partial(
jnp.einsum,
precision=precision,
_dot_general=einsum_dot_general
if einsum_dot_general
else jax.lax.dot_general,
)
# return weighted sum over values for each query position
return jnp.einsum(
return attn_weights_value_einsum(
'...hqk,...khd->...qhd',
attn_weights,
value,
precision=precision,
_dot_general=einsum_dot_general,
)


Expand Down Expand Up @@ -320,6 +370,10 @@ class MultiHeadDotProductAttention(Module):
num_heads, value_channels]``
decode: Whether to prepare and use an autoregressive cache.
normalize_qk: Should QK normalization be applied (arxiv.org/abs/2302.05442).
qk_attn_weights_einsum_cls: factory function to create the einsum for
computing the attention weights.
attn_weights_value_einsum_cls: factory function to create the einsum for
computing the product of the attention weights and the values.
"""

num_heads: int
Expand All @@ -345,6 +399,10 @@ class MultiHeadDotProductAttention(Module):
out_dot_general: DotGeneralT | None = None
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None
qk_attn_weights_einsum_cls: Callable[..., Callable[..., Array]] | None = None
attn_weights_value_einsum_cls: Callable[..., Callable[..., Array]] | None = (
None
)

@overload
def __call__(
Expand Down Expand Up @@ -575,6 +633,19 @@ def __call__(
else:
m_deterministic = True

# `qk_attn_weights_einsum` and `attn_weights_value_einsum` are optional
# arguments that can be used to override the default `jnp.einsum`. They
# exist for quantized einsum support in AQT.
qk_attn_weights_einsum = (
self.qk_attn_weights_einsum_cls()
if self.qk_attn_weights_einsum_cls
else None
)
attn_weights_value_einsum = (
self.attn_weights_value_einsum_cls()
if self.attn_weights_value_einsum_cls
else None
)
# apply attention
attn_args = (query, key, value)
# This kwargs list match the default nn.dot_product_attention.
Expand All @@ -588,6 +659,8 @@ def __call__(
dtype=self.dtype,
precision=self.precision,
force_fp32_for_softmax=self.force_fp32_for_softmax,
qk_attn_weights_einsum=qk_attn_weights_einsum,
attn_weights_value_einsum=attn_weights_value_einsum,
)
attn_kwargs = {
k: v
Expand Down
41 changes: 41 additions & 0 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,47 @@ def test_mixed_precision_multihead_attention(
attn_weights_dtype,
)

@parameterized.parameters(
(lax.Precision.DEFAULT, None),
(None, jax.lax.dot_general),
)
def test_dot_product_attention_precision_and_einsum_override(
self, precision, einsum_dot_general
):
# Test that we raise a ValueError if the user specifies both
# `precision` and/or `einsum_dot_general` and `qk_attn_weights_einsum`.
einsum_cls = lambda: jnp.einsum
self.assertRaises(
ValueError,
nn.dot_product_attention,
query=jnp.ones((1, 4, 2)),
key=jnp.ones((1, 4, 2)),
value=jnp.ones((1, 4, 2)),
precision=precision,
einsum_dot_general=einsum_dot_general,
qk_attn_weights_einsum=einsum_cls,
attn_weights_value_einsum=einsum_cls,
)

@parameterized.parameters(
(lambda: jax.lax.dot_general, None),
(None, lambda: jax.lax.dot_general),
)
def test_dot_product_attention_specify_einsums_together(
self, qk_attn_weights_einsum, attn_weights_value_einsum
):
# Test that we raise a ValueError if the user specifies only one of
# `qk_attn_weights_einsum` and `attn_weights_value_einsum`.
self.assertRaises(
ValueError,
nn.dot_product_attention,
query=jnp.ones((1, 4, 2)),
key=jnp.ones((1, 4, 2)),
value=jnp.ones((1, 4, 2)),
qk_attn_weights_einsum=qk_attn_weights_einsum,
attn_weights_value_einsum=attn_weights_value_einsum,
)


if __name__ == '__main__':
absltest.main()

0 comments on commit cd6218f

Please sign in to comment.