Skip to content

Commit

Permalink
Fix that the insufficient output HBM buffer init would cause the <unk…
Browse files Browse the repository at this point in the history
…> token generated for quantized int8 model.

PiperOrigin-RevId: 631235764
  • Loading branch information
jax authors committed May 7, 2024
1 parent eee2783 commit 4de3464
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ def paged_flash_attention_kernel_inline_seq_dim(
):
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)

# Initialize the output HBM buffers to avoid accessing garbage memory inside

This comment has been minimized.

Copy link
@jon-chuang

jon-chuang May 12, 2024

Contributor

Don't you mean output VMEM buffers?

# the kernel body below.
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)

def body(i, _):
paged_flash_attention_kernel(
lengths_ref,
Expand Down

0 comments on commit 4de3464

Please sign in to comment.