diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index f6a292136..478f00d1c 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -473,30 +473,47 @@ class MaskFnAttentionBias(BoolAttentionBias): shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False) # The positions in the query sequence that the mask should be computed for. # I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at - # `target_positions[batch, num_heads i]` may attend to. - # If None, set `target_positions[batch, num_heads, i] = i`. - # Shape: [batch]. + # `target_positions[batch, i]` may attend to. + # If None, set `target_positions[batch, i] = i`. + # Shape: [batch] or [batch, target_len]`. # This is typically used during decoding to specify the locations in the sequence being # being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch # entry respectively, we would set `target_positions = jnp.asarray([5, 7])`. + # The motivation for supporting such shapes is for use cases where time_step in transformers + # is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts, + # various papers that need it. target_positions: Optional[Tensor] = None def _bool_value(self) -> Optional[Tensor]: """Return a tensor with the boolean values from `self.mask` before they have been converted to biases. - Shape: - - If `target_positions` is None: [target_len, source_len] - - Else: [batch, target_len, source_len]. + Shape: [batch, target_len, source_len]. + + Raises: + NotImplementedError. If `target_positions.ndim not in [1,2]`. """ target_positions, source_positions = jnp.indices(self.shape, sparse=True) + # Shape: [batch, target_len, source_len]. + target_positions, source_positions = target_positions[None], source_positions[None] if self.target_positions is not None: target_positions = self.target_positions + if target_positions.ndim not in [1, 2]: + raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.") if target_positions.ndim == 1: + # Shape: [batch, target_len]. # pylint: disable-next=unsubscriptable-object target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) - while target_positions.ndim < 3: - target_positions = target_positions[..., None] + elif target_positions.ndim == 2: + shape_with_batch_dim = (1, *self.shape) + # Raise an exception if shapes aren't compatible. We don't use the output. + jnp.broadcast_shapes( + (target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim + ) + else: + raise NotImplementedError(f"Invalid value {target_positions.ndim=}.") + target_positions = target_positions[..., None] # Shape: [batch, target_len, 1]. + return self.mask(target_positions, source_positions) # pylint: disable=not-callable @classmethod diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 0932df100..531d6d229 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -267,6 +267,26 @@ def test_mask_fn_attention_bias(self): expected = attention_bias.bool_to_bias(expected)[:, None, :] self.assertNestedEqual(bias.value(), expected) + def test_mask_fn_attention_bias_target_positions_ndim(self): + """Tests mask_fn_attention_bias` when `target_positions.ndim == 2.""" + bias = attention_bias.MaskFnAttentionBias( + mask=attention_bias.causal_mask, + shape=(5, 5), + target_positions=jnp.asarray([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]), + ) + expected = jnp.asarray( + [ + [ + attention_bias.causal_mask(*jnp.indices([5, 5])), + ], + [ + attention_bias.causal_mask(*jnp.indices([5, 5]))[::-1, :], + ], + ], + dtype=bool, + ) + self.assertNestedEqual(bias.bool_value(), expected) + def test_bool_tensor_attention_bias(self): bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) self.assertNestedEqual(