Skip to content

Commit

Permalink
Fix RMSNorm for very large inputs (fairinternal/xformers#1051)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@603dbe8
  • Loading branch information
danthe3rd authored and xFormers Bot committed Mar 11, 2024
1 parent 1c1c328 commit 8c7d37f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- fMHA: Updated Flash-Attention to v2.5.6: this has a performance improvement for multiquery.
- fMHA: triton_splitk changed and expanded. Now amalgamates using LSE. Can autotune, supports causal with a small number of queries - not just 1. Experimental support for paged attention.
- `rope_padded`: Fixed CUDA error with many queries (more than 65k)
- `rmsnorm`: Fixed CUDA error with large inputs (enables 512k+ sequence length on Llama2 70B)
### Removed
- fMHA: Removed triton operator (`fmha.triton.*`, `xformers.ops.MemoryEfficientAttentionTritonFwdFlashBwOp`, `xformers.ops.TritonFlashAttentionOp`), as it has correctness issues under some conditions, and is slower than other implementations.

Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/_triton/rmsnorm_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _rms_norm_kernel(
BLOCK_SIZE: tl.constexpr,
INCLUDE_WEIGHT: tl.constexpr,
):
row = tl.program_id(0)
row = tl.program_id(0).to(tl.int64)
x_ptr += row * stride
h1_ptr += row * stride

Expand Down

0 comments on commit 8c7d37f

Please sign in to comment.