Skip to content

Commit

Permalink
Dynamic rope scaling (fairinternal/xformers#1171)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@d8f1fb3
  • Loading branch information
bottler authored and xFormers Bot committed Aug 14, 2024
1 parent 596cfcf commit 5e2b537
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 6 deletions.
60 changes: 56 additions & 4 deletions tests/test_rope_padded.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
from functools import partial
from typing import Optional

import pytest
Expand All @@ -21,13 +23,45 @@
)


def apply_scaling(
freqs: torch.Tensor,
old_context_len: float,
low_freq_factor: float,
high_freq_factor: float,
dynamic_scale_factor: float,
):
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen >= high_freq_wavelen

for idx, freq in enumerate(freqs):
wavelen = 2 * math.pi / freq
if wavelen > low_freq_wavelen:
freqs[idx] = freq / dynamic_scale_factor

if high_freq_wavelen <= wavelen and wavelen <= low_freq_wavelen:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
freqs[idx] = (1 - smooth) * freqs[
idx
] / dynamic_scale_factor + smooth * freqs[idx]
return freqs


def _slow_rope(
x: torch.Tensor,
*,
seqpos: Optional[torch.Tensor] = None,
theta=10000,
linear_scale=1,
adjacents: bool = True,
use_dynamic_scaling: bool = False,
dynamic_old_context_len: float = 8192.0,
dynamic_scale_factor: float = 16.0,
dynamic_low_freq_factor: float = 1.0,
dynamic_high_freq_factor: float = 32.0,
):
"""
Simple rope calculation of rope of one tensor
Expand All @@ -45,7 +79,15 @@ def _slow_rope(
if seqpos is None:
seqpos = torch.arange(M, device=x.device)
power = torch.arange(0, dim, 2, device=x.device)[: (dim // 2)].float() / dim
freqs = 1.0 / (theta**power)
freqs: torch.Tensor = 1.0 / (theta**power) # type: ignore
if use_dynamic_scaling:
freqs = apply_scaling(
freqs,
dynamic_old_context_len,
dynamic_low_freq_factor,
dynamic_high_freq_factor,
dynamic_scale_factor,
)
all_freqs = torch.outer(seqpos / linear_scale, freqs)
freqs_cis = torch.polar(torch.ones_like(all_freqs), all_freqs) # complex64
for _ in range(x.ndim - seq_dim - 2):
Expand Down Expand Up @@ -118,7 +160,9 @@ def _slow_rope2(
@pytest.mark.parametrize("dim", [100, 4098])
@pytest.mark.parametrize("padding", [87, 18300])
@pytest.mark.parametrize("groups", [1, 3])
@pytest.mark.parametrize("linear_scale", [1.0, 4.0])
@pytest.mark.parametrize(
"linear_scale, use_dynamic_scaling", [(1.0, False), (4.0, False), (1.0, True)]
)
def test_consistency(
adjacents: bool,
dim: int,
Expand All @@ -127,6 +171,7 @@ def test_consistency(
internal_dtype: str,
dtype_str: str,
linear_scale: float,
use_dynamic_scaling: bool,
):
torch.manual_seed(1)
heads, kvheads = 10, 2
Expand Down Expand Up @@ -181,6 +226,7 @@ def test_consistency(
linear_scale=linear_scale,
adjacents=adjacents,
internal_dtype=internal_dtype,
use_dynamic_scaling=use_dynamic_scaling,
)

seqpos = torch.tensor(
Expand All @@ -189,7 +235,9 @@ def test_consistency(
)
cache_locs = [seqpos[0], seqpos[1], padding + seqpos[2], 2 * padding + seqpos[3]]
baseline = _slow_rope if dtype_str == "f32" else _slow_rope2
expected_out = baseline(
if use_dynamic_scaling:
baseline = partial(_slow_rope, use_dynamic_scaling=True) # type: ignore
expected_out = baseline( # type: ignore
xq, linear_scale=linear_scale, seqpos=seqpos, adjacents=adjacents
)
atol, rtol = ROPE_ATOL_RTOL[dtype_str]
Expand All @@ -200,7 +248,11 @@ def test_consistency(
assert torch.allclose(cache_v, cache_v_orig)

slow_roped_xk = _slow_rope(
xk, linear_scale=linear_scale, seqpos=seqpos, adjacents=adjacents
xk,
linear_scale=linear_scale,
seqpos=seqpos,
adjacents=adjacents,
use_dynamic_scaling=use_dynamic_scaling,
)
assert_allclose(
cache_k[:, cache_locs],
Expand Down
29 changes: 27 additions & 2 deletions xformers/ops/_triton/rope_padded_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def _rope_padded_kernel(
seqlenk,
theta,
linear_scale,
use_dynamic_scaling: tl.constexpr,
dynamic_old_context_len: tl.constexpr,
dynamic_scale_factor: tl.constexpr,
dynamic_low_freq_factor: tl.constexpr,
dynamic_high_freq_factor: tl.constexpr,
first_seqpos,
seqpos,
k_start: tl.constexpr,
Expand Down Expand Up @@ -182,8 +187,28 @@ def _rope_padded_kernel(
re_x = tl.load(x_in + cols_re, mask=mask)
im_x = tl.load(x_in + cols_im, mask=mask)
# freqs = seq_pos / (theta ** (powers / dim))
freqs = seq_pos * pow(theta, powers / (-dim))
freqs = freqs / linear_scale
freqs = pow(theta, powers / (-dim))

if use_dynamic_scaling:
lo_freq_wavelen = dynamic_old_context_len / dynamic_low_freq_factor
hi_freq_wavelen = dynamic_old_context_len / dynamic_high_freq_factor

wavelens = 6.28318530718 / freqs # 2*pi
is_low_freq = wavelens > lo_freq_wavelen
freqs = tl.where(is_low_freq, freqs / dynamic_scale_factor, freqs)

is_mid_freq = hi_freq_wavelen <= wavelens and wavelens <= lo_freq_wavelen

smooth = (dynamic_old_context_len / wavelens - dynamic_low_freq_factor) / (
dynamic_high_freq_factor - dynamic_low_freq_factor
)
freqs = tl.where(
is_mid_freq,
(1 - smooth) * freqs / dynamic_scale_factor + smooth * freqs,
freqs,
)

freqs = seq_pos * freqs / linear_scale
sines = tl.sin(freqs)
cosines = tl.cos(freqs)
re_out = re_x * cosines - im_x * sines
Expand Down
16 changes: 16 additions & 0 deletions xformers/ops/rope_padded.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def rope_padded(
*,
theta: float = 10000.0,
linear_scale: float = 1.0,
use_dynamic_scaling: bool = False,
dynamic_old_context_len: float = 8192.0,
dynamic_scale_factor: float = 16.0,
dynamic_low_freq_factor: float = 1.0,
dynamic_high_freq_factor: float = 32.0,
out_q: Optional[torch.Tensor] = None,
first_seqpos: Optional[torch.Tensor] = None,
seqpos: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -80,6 +85,12 @@ def rope_padded(
linear_scale: A scaling factor to apply to the sequence ids when computing
the RoPE frequencies. When set to K, all sequence indices
are divided by K.
use_dynamic_scaling: If true, dynamic scaling in use, using a scaling like
“YaRN: Efficient Context Window Extension of Large Language Models”
dynamic_old_context_len: used with use_dynamic_scaling
dynamic_scale_factor: used with use_dynamic_scaling
dynamic_low_freq_factor: used with use_dynamic_scaling
dynamic_high_freq_factor: used with use_dynamic_scaling
internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation
"""
if torch.is_grad_enabled() and (
Expand Down Expand Up @@ -245,6 +256,11 @@ def rope_padded(
seqlenk,
theta,
linear_scale,
use_dynamic_scaling,
dynamic_old_context_len if use_dynamic_scaling else 0,
dynamic_scale_factor if use_dynamic_scaling else 0,
dynamic_low_freq_factor if use_dynamic_scaling else 0,
dynamic_high_freq_factor if use_dynamic_scaling else 0,
first_seqpos,
seqpos,
k_start,
Expand Down

0 comments on commit 5e2b537

Please sign in to comment.