Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make force_fp32_for_softmax arg in MultiHeadDotProductAttention useful. #4029

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import functools
import inspect
import warnings
from typing import Any, Callable, Optional, Union, overload

Expand Down Expand Up @@ -574,33 +575,28 @@ def __call__(
m_deterministic = True

# apply attention
attn_args = (query, key, value)
# This kwargs list match the default nn.dot_product_attention.
# For custom `attention_fn`s, invalid kwargs will be filtered.
attn_kwargs = dict(
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
force_fp32_for_softmax=self.force_fp32_for_softmax,
)
attn_kwargs = {
k: v
for k, v in attn_kwargs.items()
if k in inspect.signature(self.attention_fn).parameters
}
if sow_weights:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
module=self,
) # pytype: disable=wrong-keyword-args
x = self.attention_fn(*attn_args, **attn_kwargs, module=self)
else:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
)
x = self.attention_fn(*attn_args, **attn_kwargs)
# back to the original inputs dimensions
out = DenseGeneral(
features=features,
Expand Down
5 changes: 1 addition & 4 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Tests for flax.linen.attention."""

import functools
from absl.testing import absltest, parameterized
from flax import errors, jax_utils
from flax import linen as nn
Expand Down Expand Up @@ -565,9 +564,7 @@ def test_mixed_precision_multihead_attention(
qkv_features=4,
kernel_init=initializers.lecun_normal(),
bias_init=initializers.uniform(),
attention_fn=functools.partial(
nn.dot_product_attention, force_fp32_for_softmax=force_fp32
),
force_fp32_for_softmax=force_fp32,
deterministic=False,
dtype=jnp.bfloat16,
)
Expand Down
Loading