Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Keras mask & causal mask to MultiHeadAttention #16619

Merged
merged 6 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions keras/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class MultiHeadAttention(Layer):
training mode (adding dropout) or in inference mode (no dropout).
Defaults to either using the training mode of the parent layer/model,
or False (inference) if there is no parent layer.
use_causal_mask: A boolean to indicate whether to apply a causal mask to
prevent tokens from attending to future tokens (e.g., used in a decoder
ageron marked this conversation as resolved.
Show resolved Hide resolved
Transformer).

Returns:
attention_output: The result of the computation, of shape `(B, T, E)`,
Expand Down Expand Up @@ -244,6 +247,7 @@ def __init__(
**kwargs
):
super().__init__(**kwargs)
self.supports_masking = True
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
Expand Down Expand Up @@ -449,7 +453,7 @@ def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.

This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
customize attention computation to replace the default dot-product
attention.

Args:
Expand Down Expand Up @@ -502,7 +506,8 @@ def _compute_attention(
key: Projected key `Tensor` of shape `(B, T, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, T, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
attention to certain positions. It is generally not needed if the
`query` and `value` (and/or `key`) are masked.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).

Expand Down Expand Up @@ -543,7 +548,16 @@ def call(
attention_mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
attention_mask = self._compute_attention_mask(
query,
value,
key=key,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
Expand Down Expand Up @@ -592,3 +606,87 @@ def call(
if return_attention_scores:
return attention_output, attention_scores
return attention_output

def _compute_attention_mask(
self, query, value, key=None, attention_mask=None, use_causal_mask=False
):
"""Computes the attention mask, using the Keras masks of the inputs.

* The `query`'s mask is reshaped from [B, T] to [B, T, 1].
* The `value`'s mask is reshaped from [B, S] to [B, 1, S].
* The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
mask is ignored if `key` is `None` or if `key is value`.
* If `use_causal_mask=True`, then the causal mask is computed. Its shape
is [1, T, S].

All defined masks are merged using a logical AND operation (`&`).

In general, if the `query` and `value` are masked, then there is no need
to define the `attention_mask`.

Args:
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
key: Projected key `Tensor` of shape `(B, T, N, key_dim)`.
value: Projected value `Tensor` of shape `(B, T, N, value_dim)`.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
use_causal_mask: A boolean to indicate whether to apply a causal mask
to prevent tokens from attending to future tokens (e.g., used in a
decoder Transformer).
Returns:
ageron marked this conversation as resolved.
Show resolved Hide resolved
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions, based on the Keras masks of the
`query`, `key`, `value`, and `attention_mask` tensors, and the
causal mask if `use_causal_mask=True`.
"""
query_mask = getattr(query, "_keras_mask", None)
value_mask = getattr(value, "_keras_mask", None)
key_mask = getattr(key, "_keras_mask", None)
auto_mask = None
if query_mask is not None:
# B = batch size, T = max query length
auto_mask = query_mask[:, :, tf.newaxis] # shape is [B, T, 1]
if value_mask is not None:
# B = batch size, S == max value length
mask = value_mask[:, tf.newaxis, :] # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if key_mask is not None:
# B == batch size, S == max key length == max value length
mask = key_mask[:, tf.newaxis, :] # shape is [B, 1, S]
auto_mask = mask if auto_mask is None else auto_mask & mask
if use_causal_mask:
# the shape of the causal mask is [1, T, S]
mask = self._compute_causal_mask(query, value)
auto_mask = mask if auto_mask is None else auto_mask & mask
if auto_mask is not None:
# merge attention_mask & automatic mask, to shape [B, T, S]
attention_mask = (
auto_mask
if attention_mask is None
else attention_mask & auto_mask
)
return attention_mask

def _compute_causal_mask(self, query, value=None):
"""Computes a causal mask (e.g., for masked self-attention layers).

For example, if query and value both contain sequences of length 4,
this function returns a boolean `Tensor` equal to:
[[[True, False, False, False],
[True, True, False, False],
[True, True, True, False],
[True, True, True, True]]]

Args:
query: query `Tensor` of shape `(B, T, ...)`.
value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to
query).
Returns:
mask: a boolean `Tensor` of shape [1, T, S] containing a lower
triangular matrix of shape [T, S].
"""
q_seq_length = tf.shape(query)[1]
v_seq_length = q_seq_length if value is None else tf.shape(value)[1]
return tf.linalg.band_part( # creates a lower triangular matrix
tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0
)
43 changes: 43 additions & 0 deletions keras/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,49 @@ def test_ragged_tensor(self, ragged_query, ragged_value, ragged_key):
results = test_layer(query, value, key)
self.assertAllEqual(results.shape.as_list(), query.shape.as_list())

def test_query_mask_progagation(self):
"""Test automatic propagation of the query's mask."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
self.assertTrue(test_layer.supports_masking)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.random((3, 3, 8))
output = test_layer(query=masked_query, value=value)
self.assertTrue(hasattr(output, "_keras_mask"))
self.assertAllEqual(masked_query._keras_mask, output._keras_mask)

@parameterized.named_parameters(("causal", True), ("not_causal", False))
def test_value_mask(self, use_causal_mask):
"""Test that the value and causal masks are taken into account."""
test_layer = keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = keras.layers.Embedding(4, 8, mask_zero=True)(query)
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
masked_value = keras.layers.Embedding(6, 8, mask_zero=True)(value)
output = test_layer(
query=masked_query,
value=masked_value,
use_causal_mask=use_causal_mask,
)
mask = np.array(
[[[True, True, False]] * 3 + [[False, False, False]] * 2]
+ [[[True, False, False]] * 5]
+ [[[True, True, True]] + [[False, False, False]] * 4]
)
if use_causal_mask:
mask = mask & np.array(
[
[[True, False, False], [True, True, False]]
+ [[True, True, True]] * 3
]
)
del masked_query._keras_mask
del masked_value._keras_mask
output_with_manual_mask = test_layer(
query=masked_query, value=masked_value, attention_mask=mask
)
self.assertAllClose(output, output_with_manual_mask)


class SubclassAttention(keras.layers.MultiHeadAttention):
def _build_attention(self, qkv_rank):
Expand Down