diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index badb292aa23c..e9c9fbbc3745 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,25 +114,35 @@ def paged_flash_attention_kernel( b = b * b_step + b_start length = lengths_ref[b] - def advance_to_next_non_zero_length(b): - return lax.fori_loop( - lax.div(b, b_step), - lax.div(batch_size, b_step), - lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b), - b, - ) - def compute_block_indices(b, h, i): - length = lengths_ref[b] - not_done = i * bk < length - i_next = jnp.where(not_done, i, 0) - h_next = jnp.where(not_done, h, h + h_step) - is_last_head = h_next >= num_kv_heads - h_next = jnp.where(is_last_head, h_start, h_next) - b_next = jnp.where( - is_last_head, advance_to_next_non_zero_length(b + b_step), b - ) - return b_next, h_next, i_next + + def advance_b(): + next_b = b + b_step + + def advance_to_next_non_zero_length(): + next_next_b = next_b + b_step + return lax.fori_loop( + lax.div(next_next_b, b_step), + lax.div(batch_size, b_step), + lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b), + next_next_b, + ) + + return ( + lax.cond( + jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + advance_to_next_non_zero_length, + lambda: next_b, + ), + h_start, + 0, + ) + + def advance_h(): + next_h = h + h_step + return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b) + + return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h) def create_kv_async_copy_descriptors(b, h, i, buffer_index): page_offset = b * pages_per_sequence + i * pages_per_compute_block