Skip to content

Commit

Permalink
Fix Rope scaling (#3598)
Browse files Browse the repository at this point in the history
## What
Fixed rope scaling for all models.

## Why
Earlier the model config was being ignored during patching which might
cause issues like initializing with wrong `max_position_embeddings`

## How
Gets required args from model_config and passes it to patching
functions.
  • Loading branch information
shahules786 authored Jul 25, 2023
1 parent aed6f17 commit a32993d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 139 deletions.
4 changes: 2 additions & 2 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ rope_scaling_test:
dtype: bf16
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
model_name: "meta-llama/Llama-2-13b-hf"
deepspeed_config: configs/zero_config.json
output_dir: llama
weight_decay: 0.0
Expand All @@ -811,7 +811,7 @@ rope_scaling_test:
superhot: true
superhot_config:
type: linear
scale: 2
scaling_factor: 2
datasets:
- dolly15k

Expand Down
29 changes: 20 additions & 9 deletions model/model_training/models/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
LlamaForCausalLM,
LlamaModel,
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_falcon import falcon_forward_with_flash_attn
from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope
from .rope import RWNTKScaledRope

SUPPORTED_MODELS = [
GPTNeoXModel,
Expand Down Expand Up @@ -200,25 +204,27 @@ def patch_model(
class RopePatch:
def __init__(self, model_name, **kwargs):
self.args = kwargs
rope_type = self.args.pop("type")
self.rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if hasattr(config, "max_position_embeddings"):
self.args["max_position_embeddings"] = config.max_position_embeddings
if hasattr(config, "base"):
self.args["base"] = config.base
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture:
self.architecture = "FalconForCausalLM"
if rope_type == "ntk":
if self.rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
raise NotImplementedError()
elif "LlamaForCausalLM" in architecture:
self.architecture = "LlamaForCausalLM"
if rope_type == "linear":
self.patch_fun = LlamaLinearScaledRope
elif rope_type == "ntk":
self.patch_fun = LlamaNTKScaledRope
elif rope_type == "dynamic-ntk":
self.patch_fun = LlamaDynamicScaledRotaryEmbedding
if self.rope_type == "linear":
self.patch_fun = LlamaLinearScalingRotaryEmbedding
elif self.rope_type == "dynamic":
self.patch_fun = LlamaDynamicNTKScalingRotaryEmbedding
else:
raise NotImplementedError()
else:
Expand All @@ -230,6 +236,9 @@ def from_config(cls, config):
args = config.superhot_config
return cls(model_name, **args)

def update_config(self, model, scaling_factor):
model.config["rope_scaling"] = {"type": self.rope_type, "factor": scaling_factor}

def patch(self, model):
if self.architecture == "FalconForCausalLM":
self.patch_falcon_model(model, **self.args)
Expand All @@ -238,6 +247,8 @@ def patch(self, model):
else:
raise NotImplementedError()

self.update_config(model, self.args.get("scaling_factor"))

def patch_falcon_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)
Expand Down
128 changes: 0 additions & 128 deletions model/model_training/models/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,131 +62,3 @@ def forward(self, q, k, past_key_values_length=0):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


class LlamaLinearScaledRope(torch.nn.Module):
"""
reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = 1 / scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaNTKScaledRope(torch.nn.Module):

"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
super().__init__()
base = base * alpha ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (
self.dim / (self.dim - 2)
)
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

0 comments on commit a32993d

Please sign in to comment.