From 366ae3e129582475b9449ec87cc07f58fbe17096 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 7 Mar 2023 10:33:33 +0100 Subject: [PATCH] Add xPos embeddings --- finetune_t0_non_causal_decoder.py | 6 +- megatron/arguments.py | 2 +- megatron/enums.py | 1 + megatron/model/positional_embeddings.py | 109 +++++++++++++++++++++++- megatron/model/transformer.py | 25 +++++- 5 files changed, 136 insertions(+), 7 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 14650a6e5..7a15bb735 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -94,7 +94,11 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) - if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: + if args.position_embedding_type not in [ + PositionEmbeddingType.alibi, + PositionEmbeddingType.rotary, + PositionEmbeddingType.xpos, + ]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") return (tokens, position_ids, attention_mask), (labels, loss_mask) diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..bc8913feb 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -398,7 +398,7 @@ def _add_network_size_args(parser): group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute, - help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.' + help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.' ) group.add_argument('--glu-activation', type=str, choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), diff --git a/megatron/enums.py b/megatron/enums.py index 90d00a071..2c7d42035 100644 --- a/megatron/enums.py +++ b/megatron/enums.py @@ -33,3 +33,4 @@ class PositionEmbeddingType(enum.Enum): rotary = 1 absolute = 2 alibi = 3 + xpos = 4 diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 3494f9e4e..8c3609c72 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) \ No newline at end of file + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +# Original implementation adjusted from https://github.com/sunyt32/torchscale + +def fixed_pos_embedding(x, base): + seq_len, dim = x.shape + inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp) + + +class XPosEmbedding(torch.nn.Module): + """ + xPos positional embeddings from https://arxiv.org/abs/2212.10554. + """ + + def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half): + super().__init__() + self.scale_base = scale_base + self.register_buffer( + "scale", + ( + (torch.arange(0, head_dim, 2) + gamma * head_dim) + / ((1.0 + gamma) * head_dim) + ), + ) + self.max_seq_len_cached = None + self.precision = precision + self.freq_base = freq_base + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if ( + self.max_seq_len_cached is None + or (seq_len > self.max_seq_len_cached) + ): + self.max_seq_len_cached = seq_len + scale = ( + self.scale + ** ( + torch.arange(0, seq_len, 1) - seq_len // 2 + ).to(self.scale).div(self.scale_base)[:, None] + ) + cos, sin = fixed_pos_embedding(scale, self.freq_base) + self.cos_cached = cos + self.sin_cached = sin + self.scale_cached = scale + if self.precision == torch.bfloat16: + self.cos_cached = self.cos_cached.bfloat16() + self.sin_cached = self.sin_cached.bfloat16() + return ( + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], + self.scale_cached[:seq_len], + ) + + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m.unsqueeze(1) + + +def _apply_xpos_emb(x, cos, sin, scale): + # x is assumed to be (seq_len, batch_size, dim) here. + cos = duplicate_interleave(cos * scale) + sin = duplicate_interleave(sin * scale) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +@torch.jit.script +def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(q, cos, sin, 1.0 / scale), + ) + + +def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(q, cos, sin, 1.0 / scale), + ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..42d0f0384 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -31,7 +31,14 @@ import deepspeed from .glu_activations import GLU_ACTIVATIONS -from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb +from .positional_embeddings import ( + apply_rotary_pos_emb, + apply_rotary_pos_emb_torch, + apply_xpos_emb, + apply_xpos_emb_torch, + RotaryEmbedding, + XPosEmbedding, +) # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -204,6 +211,8 @@ def __init__(self, init_method, if self.position_embedding_type == PositionEmbeddingType.rotary: self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) + elif self.position_embedding_type == PositionEmbeddingType.xpos: + self.xpos_emb = XPosEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, alibi=None): @@ -291,16 +300,24 @@ def forward(self, hidden_states, attention_mask, layer_past=None, matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]] # Rotary embeddings - if self.position_embedding_type == PositionEmbeddingType.rotary: - apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb - + if self.position_embedding_type in [ + PositionEmbeddingType.rotary, PositionEmbeddingType.xpos]: seq_len = key_layer.shape[0] offset = 0 if layer_past is not None and layer_past.numel() > 0: offset = layer_past[0].shape[0] seq_len += offset + + if self.position_embedding_type == PositionEmbeddingType.rotary: + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) + elif self.position_embedding_type == PositionEmbeddingType.xpos: + apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb + cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_xpos_fn( + query_layer, key_layer, cos, sin, scale, offset=offset) + # Raw attention scores. [b * np, sq, sk] if alibi is None: