forked from ayaka14732/llama-2-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrotary_embedding.py
59 lines (48 loc) · 2.48 KB
/
rotary_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import NamedTuple
import einops as op
import jax
from jax import Array
import jax.numpy as jnp
from .ModelConfig import ModelConfig
# TODO: Mostly taken from https://github.com/kingoflolz/mesh-transformer-jax/blob/master/mesh_transformer/layers.py
# and https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L92
def _make_weights(seq_len: int, d_k: int) -> tuple[Array, Array]:
inv_freq = 1. / (10000 ** (jnp.arange(0, d_k, 2) / d_k))
sinusoid_inp = op.einsum(jnp.arange(seq_len), inv_freq, 'L, j -> L j')
sin_val = jnp.sin(sinusoid_inp)
cos_val = jnp.cos(sinusoid_inp)
sin_val = op.repeat(sin_val, 'L K -> L (i K)', i=2)
cos_val = op.repeat(cos_val, 'L K -> L (i K)', i=2)
return sin_val, cos_val
def _rotate_half(x: Array) -> Array:
x = op.rearrange(x, '... (i x) -> ... i x', i=2) # split the last dimension: (..., n) -> (..., 2, n // 2)
x = x[..., ::-1, :] # reverse dimension -2
x = x.at[..., 0, :].multiply(-1) # negate the first half of dimension -2
x = op.rearrange(x, '... i x -> ... (i x)') # merge the last two dimensions: (..., 2, n // 2) -> (..., n)
return x
class RotaryValues(NamedTuple):
sin_val: Array
cos_val: Array
def forward_rotary_embedding(m: Array, *, rotary_values: RotaryValues) -> Array:
sin_val, cos_val = rotary_values
assert sin_val.dtype == jnp.float32
assert cos_val.dtype == jnp.float32
n = _rotate_half(m)
a = op.einsum(m, cos_val, 'B ... L K, B L K -> B ... L K').astype(m.dtype)
b = op.einsum(n, sin_val, 'B ... L K, B L K -> B ... L K').astype(m.dtype)
return a + b
def make_rotary_values(leftpad_len: Array | None, batch_size: int, seq_len: int, *, model_config: ModelConfig) -> RotaryValues:
sin_val, cos_val = _make_weights(seq_len, model_config.d_k)
sin_val = jnp.repeat(sin_val[None], batch_size, axis=0)
cos_val = jnp.repeat(cos_val[None], batch_size, axis=0)
if leftpad_len is not None:
roll_func = jax.vmap(lambda a, shift: jnp.roll(a, shift, axis=-2)) # -2: dimension L
sin_val = roll_func(sin_val, leftpad_len)
cos_val = roll_func(cos_val, leftpad_len)
return RotaryValues(sin_val, cos_val)
def get_rotary_values_at_position(rotary_values: RotaryValues, position: Array) -> RotaryValues:
sin_val, cos_val = rotary_values
sin_val = sin_val[:, position][:, None]
cos_val = cos_val[:, position][:, None]
rotary_values = RotaryValues(sin_val, cos_val)
return rotary_values