diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index b49e16d95408..a5b5aaf31799 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -786,10 +786,14 @@ def _get_causal_mask(T, S): return mask[None, None, :, :] def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): - q_indices = jnp.arange(0, T)[None, :, None] - kv_indices = jnp.arange(0, S)[None, None, :] - q_mask = q_indices < q_seqlen[:, None, None] - kv_mask = kv_indices < kv_seqlen[:, None, None] + q_mask = True + kv_mask = True + if q_seqlen is not None: + q_indices = jnp.arange(0, T)[None, :, None] + q_mask = q_indices < q_seqlen[:, None, None] + if kv_seqlen is not None: + kv_indices = jnp.arange(0, S)[None, None, :] + kv_mask = kv_indices < kv_seqlen[:, None, None] mask = jnp.logical_and(q_mask, kv_mask) return mask[:, None, :, :] @@ -813,7 +817,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): mask = _get_causal_mask(T, S) combined_mask = jnp.logical_and(combined_mask, mask) - if q_seqlen is not None and kv_seqlen is not None: + if q_seqlen is not None or kv_seqlen is not None: mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) combined_mask = jnp.logical_and(combined_mask, mask) @@ -1001,12 +1005,22 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], kv_seqlen=key_value_seq_lengths, ) case 'cudnn': + use_padding = ( + query_seq_lengths is not None or key_value_seq_lengths is not None + ) + if use_padding: + if query_seq_lengths is None: + T = query_arr.shape[1] + query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32) + if key_value_seq_lengths is None: + key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32) + mask_type = MaskType.NO_MASK - if query_seq_lengths is not None and is_causal: + if use_padding and is_causal: mask_type = MaskType.PADDING_CAUSAL elif is_causal: mask_type = MaskType.CAUSAL - elif query_seq_lengths is not None: + elif use_padding: mask_type = MaskType.PADDING out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,