Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup MaskFnAttentionBias.target_positions. #895

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,30 +473,47 @@ class MaskFnAttentionBias(BoolAttentionBias):
shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False)
# The positions in the query sequence that the mask should be computed for.
# I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
# `target_positions[batch, num_heads i]` may attend to.
# If None, set `target_positions[batch, num_heads, i] = i`.
# Shape: [batch].
# `target_positions[batch, i]` may attend to.
# If None, set `target_positions[batch, i] = i`.
# Shape: [batch] or [batch, target_len]`.
# This is typically used during decoding to specify the locations in the sequence being
# being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
# entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
# The motivation for supporting such shapes is for use cases where time_step in transformers
# is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
# various papers that need it.
target_positions: Optional[Tensor] = None

def _bool_value(self) -> Optional[Tensor]:
"""Return a tensor with the boolean values from `self.mask` before they have been converted
to biases.

Shape:
- If `target_positions` is None: [target_len, source_len]
- Else: [batch, target_len, source_len].
Shape: [batch, target_len, source_len].

Raises:
NotImplementedError. If `target_positions.ndim not in [1,2]`.
"""
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
# Shape: [batch, target_len, source_len].
target_positions, source_positions = target_positions[None], source_positions[None]
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim not in [1, 2]:
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
if target_positions.ndim == 1:
# Shape: [batch, target_len].
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
while target_positions.ndim < 3:
target_positions = target_positions[..., None]
elif target_positions.ndim == 2:
shape_with_batch_dim = (1, *self.shape)
# Raise an exception if shapes aren't compatible. We don't use the output.
jnp.broadcast_shapes(
(target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim
)
else:
raise NotImplementedError(f"Invalid value {target_positions.ndim=}.")
target_positions = target_positions[..., None] # Shape: [batch, target_len, 1].

return self.mask(target_positions, source_positions) # pylint: disable=not-callable

@classmethod
Expand Down
20 changes: 20 additions & 0 deletions axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ def test_mask_fn_attention_bias(self):
expected = attention_bias.bool_to_bias(expected)[:, None, :]
self.assertNestedEqual(bias.value(), expected)

def test_mask_fn_attention_bias_target_positions_ndim(self):
"""Tests mask_fn_attention_bias` when `target_positions.ndim == 2."""
bias = attention_bias.MaskFnAttentionBias(
mask=attention_bias.causal_mask,
shape=(5, 5),
target_positions=jnp.asarray([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]),
)
expected = jnp.asarray(
[
[
attention_bias.causal_mask(*jnp.indices([5, 5])),
],
[
attention_bias.causal_mask(*jnp.indices([5, 5]))[::-1, :],
],
],
dtype=bool,
)
self.assertNestedEqual(bias.bool_value(), expected)

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand Down
Loading