Skip to content

Commit

Permalink
Add xPos embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
janEbert committed Mar 9, 2023
1 parent e52bdab commit 366ae3e
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 7 deletions.
6 changes: 5 additions & 1 deletion finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
if args.position_embedding_type not in [
PositionEmbeddingType.alibi,
PositionEmbeddingType.rotary,
PositionEmbeddingType.xpos,
]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

return (tokens, position_ids, attention_mask), (labels, loss_mask)
Expand Down
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _add_network_size_args(parser):
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
choices=list(PositionEmbeddingType),
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.'
help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.'
)
group.add_argument('--glu-activation', type=str,
choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(),
Expand Down
1 change: 1 addition & 0 deletions megatron/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
alibi = 3
xpos = 4
109 changes: 108 additions & 1 deletion megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):

def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


# Original implementation adjusted from https://github.com/sunyt32/torchscale

def fixed_pos_embedding(x, base):
seq_len, dim = x.shape
inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
)
return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp)


class XPosEmbedding(torch.nn.Module):
"""
xPos positional embeddings from https://arxiv.org/abs/2212.10554.
"""

def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half):
super().__init__()
self.scale_base = scale_base
self.register_buffer(
"scale",
(
(torch.arange(0, head_dim, 2) + gamma * head_dim)
/ ((1.0 + gamma) * head_dim)
),
)
self.max_seq_len_cached = None
self.precision = precision
self.freq_base = freq_base

def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if (
self.max_seq_len_cached is None
or (seq_len > self.max_seq_len_cached)
):
self.max_seq_len_cached = seq_len
scale = (
self.scale
** (
torch.arange(0, seq_len, 1) - seq_len // 2
).to(self.scale).div(self.scale_base)[:, None]
)
cos, sin = fixed_pos_embedding(scale, self.freq_base)
self.cos_cached = cos
self.sin_cached = sin
self.scale_cached = scale
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return (
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
self.scale_cached[:seq_len],
)


def rotate_every_two(x):
x1 = x[:, :, ::2]
x2 = x[:, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\


def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m.unsqueeze(1)


def _apply_xpos_emb(x, cos, sin, scale):
# x is assumed to be (seq_len, batch_size, dim) here.
cos = duplicate_interleave(cos * scale)
sin = duplicate_interleave(sin * scale)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)


@torch.jit.script
def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0):
# q/k are assumed to be (seq_len, batch_size, dim) here.
cos = cos[offset:q.shape[0] + offset]
sin = sin[offset:q.shape[0] + offset]
scale = scale[offset:q.shape[0] + offset]
return (
_apply_xpos_emb(q, cos, sin, scale),
_apply_xpos_emb(q, cos, sin, 1.0 / scale),
)


def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0):
# q/k are assumed to be (seq_len, batch_size, dim) here.
cos = cos[offset:q.shape[0] + offset]
sin = sin[offset:q.shape[0] + offset]
scale = scale[offset:q.shape[0] + offset]
return (
_apply_xpos_emb(q, cos, sin, scale),
_apply_xpos_emb(q, cos, sin, 1.0 / scale),
)
25 changes: 21 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
import deepspeed

from .glu_activations import GLU_ACTIVATIONS
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb
from .positional_embeddings import (
apply_rotary_pos_emb,
apply_rotary_pos_emb_torch,
apply_xpos_emb,
apply_xpos_emb_torch,
RotaryEmbedding,
XPosEmbedding,
)

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
Expand Down Expand Up @@ -204,6 +211,8 @@ def __init__(self, init_method,

if self.position_embedding_type == PositionEmbeddingType.rotary:
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)
elif self.position_embedding_type == PositionEmbeddingType.xpos:
self.xpos_emb = XPosEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)

def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None, alibi=None):
Expand Down Expand Up @@ -291,16 +300,24 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

# Rotary embeddings
if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb

if self.position_embedding_type in [
PositionEmbeddingType.rotary, PositionEmbeddingType.xpos]:
seq_len = key_layer.shape[0]
offset = 0
if layer_past is not None and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset

if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
elif self.position_embedding_type == PositionEmbeddingType.xpos:
apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb
cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_xpos_fn(
query_layer, key_layer, cos, sin, scale, offset=offset)


# Raw attention scores. [b * np, sq, sk]
if alibi is None:
Expand Down

0 comments on commit 366ae3e

Please sign in to comment.