diff --git a/flax/linen/attention.py b/flax/linen/attention.py index b7e9de6c77..6e91be996d 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -16,6 +16,7 @@ from __future__ import annotations import functools +import inspect import warnings from typing import Any, Callable, Optional, Union, overload @@ -574,33 +575,28 @@ def __call__( m_deterministic = True # apply attention + attn_args = (query, key, value) + # This kwargs list match the default nn.dot_product_attention. + # For custom `attention_fn`s, invalid kwargs will be filtered. + attn_kwargs = dict( + mask=mask, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + broadcast_dropout=self.broadcast_dropout, + deterministic=m_deterministic, + dtype=self.dtype, + precision=self.precision, + force_fp32_for_softmax=self.force_fp32_for_softmax, + ) + attn_kwargs = { + k: v + for k, v in attn_kwargs.items() + if k in inspect.signature(self.attention_fn).parameters + } if sow_weights: - x = self.attention_fn( - query, - key, - value, - mask=mask, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - broadcast_dropout=self.broadcast_dropout, - deterministic=m_deterministic, - dtype=self.dtype, - precision=self.precision, - module=self, - ) # pytype: disable=wrong-keyword-args + x = self.attention_fn(*attn_args, **attn_kwargs, module=self) else: - x = self.attention_fn( - query, - key, - value, - mask=mask, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - broadcast_dropout=self.broadcast_dropout, - deterministic=m_deterministic, - dtype=self.dtype, - precision=self.precision, - ) + x = self.attention_fn(*attn_args, **attn_kwargs) # back to the original inputs dimensions out = DenseGeneral( features=features, diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 7dd078783c..e7d69a43fa 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -14,7 +14,6 @@ """Tests for flax.linen.attention.""" -import functools from absl.testing import absltest, parameterized from flax import errors, jax_utils from flax import linen as nn @@ -565,9 +564,7 @@ def test_mixed_precision_multihead_attention( qkv_features=4, kernel_init=initializers.lecun_normal(), bias_init=initializers.uniform(), - attention_fn=functools.partial( - nn.dot_product_attention, force_fp32_for_softmax=force_fp32 - ), + force_fp32_for_softmax=force_fp32, deterministic=False, dtype=jnp.bfloat16, )