Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama/GPTNeoX: add RoPE scaling #24653

Merged
merged 9 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig):
use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B).
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.

Example:

```python
Expand Down Expand Up @@ -115,6 +124,7 @@ def __init__(
eos_token_id=2,
tie_word_embeddings=False,
use_parallel_residual=True,
rope_scaling=None,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
Expand All @@ -135,7 +145,32 @@ def __init__(
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
126 changes: 102 additions & 24 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _set_gradient_checkpointing(self, module, value=False):
class GPTNeoXAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
Expand All @@ -94,28 +95,56 @@ def __init__(self, config):
)
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self._init_bias(config.max_position_embeddings)

self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self._init_rope()

self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)

self.attention_dropout = nn.Dropout(config.attention_dropout)

def _init_bias(self, max_positions, device=None):
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
if device is not None:
self.bias = self.bias.to(device)

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = GPTNeoXRotaryEmbedding(
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -210,6 +239,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)

# dynamically increase the causal mask with the key length, if needed.
if key_length > self.bias.shape[-1]:
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
Expand Down Expand Up @@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores


class RotaryEmbedding(torch.nn.Module):
class GPTNeoXRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
Expand All @@ -275,18 +315,56 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None):

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)


class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]


class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len

if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,24 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights


# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
Expand All @@ -256,15 +264,8 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None):

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)


Expand Down
34 changes: 34 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.

Example:

```python
Expand Down Expand Up @@ -97,6 +106,7 @@ def __init__(
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -109,10 +119,34 @@ def __init__(
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
Loading