Skip to content

Commit

Permalink
changed to for attention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Dec 12, 2023
1 parent 2e51a4e commit 7f10ad8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __call__(
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
return_weights: bool = False,
sow_weights: bool = False,
):
...

Expand All @@ -333,7 +333,7 @@ def __call__(
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
return_weights: bool = False,
sow_weights: bool = False,
):
...

Expand All @@ -348,7 +348,7 @@ def __call__(
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
return_weights: bool = False,
sow_weights: bool = False,
):
"""Applies multi-head dot product attention on the input data.
Expand All @@ -375,7 +375,7 @@ def __call__(
dropout, whereas if true, the attention weights are deterministic.
dropout_rng: optional rng key to pass to the attention layer's dropout
mask. Otherwise, self.make_rng('dropout') is used instead.
return_weights: if ``True``, the attention weights are sowed into the
sow_weights: if ``True``, the attention weights are sowed into the
'intermediates' collection. Remember to mark 'intermediates' as
mutable via ``mutable=['intermediates']`` in order to have that
collection returned.
Expand Down Expand Up @@ -527,7 +527,7 @@ def __call__(
m_deterministic = True

# apply attention
if return_weights:
if sow_weights:
x = self.attention_fn(
query,
key,
Expand Down
10 changes: 5 additions & 5 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,13 @@ class Model(nn.Module):
attention_kwargs: dict

@nn.compact
def __call__(self, x, return_weights=False):
def __call__(self, x, sow_weights=False):
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
x, return_weights=return_weights
x, sow_weights=sow_weights
)
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x)
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
x, return_weights=return_weights
x, sow_weights=sow_weights
)
return x

Expand All @@ -372,7 +372,7 @@ def __call__(self, x, return_weights=False):
)
v = module.init(rng, x)
_, intermediates = module.apply(
v, x, mutable=['intermediates'], return_weights=True
v, x, mutable=['intermediates'], sow_weights=True
)
self.assertEqual(
intermediates['intermediates']['MultiHeadDotProductAttention_0'][
Expand All @@ -390,7 +390,7 @@ def __call__(self, x, return_weights=False):
(4, 8, 6, 6),
)
_, intermediates = module.apply(
v, x, mutable=['intermediates'], return_weights=False
v, x, mutable=['intermediates'], sow_weights=False
)
self.assertNotIn('intermediates', intermediates)

Expand Down

0 comments on commit 7f10ad8

Please sign in to comment.