From a32993d91a38012eaa55932efe7e04a47bae7fb5 Mon Sep 17 00:00:00 2001 From: Shahul ES Date: Tue, 25 Jul 2023 20:53:29 +0530 Subject: [PATCH] Fix Rope scaling (#3598) ## What Fixed rope scaling for all models. ## Why Earlier the model config was being ignored during patching which might cause issues like initializing with wrong `max_position_embeddings` ## How Gets required args from model_config and passes it to patching functions. --- model/model_training/configs/config.yaml | 4 +- model/model_training/models/patching.py | 29 +++-- model/model_training/models/rope.py | 128 ----------------------- 3 files changed, 22 insertions(+), 139 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index adbb01354a..a7c1f70c5d 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -786,7 +786,7 @@ rope_scaling_test: dtype: bf16 log_dir: "llama_log_7b" learning_rate: 1e-5 - model_name: "huggyllama/llama-7b" + model_name: "meta-llama/Llama-2-13b-hf" deepspeed_config: configs/zero_config.json output_dir: llama weight_decay: 0.0 @@ -811,7 +811,7 @@ rope_scaling_test: superhot: true superhot_config: type: linear - scale: 2 + scaling_factor: 2 datasets: - dolly15k diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index b76baa6e4f..bda61501a3 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -15,13 +15,17 @@ LlamaForCausalLM, LlamaModel, ) +from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, +) from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead from .patching_falcon import falcon_forward_with_flash_attn from .patching_llama import llama_forward_with_flash_attn from .patching_neox import neox_forward_with_flash_attn from .reward_model import GPTNeoXRewardModel -from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope +from .rope import RWNTKScaledRope SUPPORTED_MODELS = [ GPTNeoXModel, @@ -200,25 +204,27 @@ def patch_model( class RopePatch: def __init__(self, model_name, **kwargs): self.args = kwargs - rope_type = self.args.pop("type") + self.rope_type = self.args.pop("type") config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + if hasattr(config, "max_position_embeddings"): + self.args["max_position_embeddings"] = config.max_position_embeddings + if hasattr(config, "base"): + self.args["base"] = config.base architecture = config.architectures if architecture: self.model_name = architecture[0] if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture: self.architecture = "FalconForCausalLM" - if rope_type == "ntk": + if self.rope_type == "ntk": self.patch_fun = RWNTKScaledRope else: raise NotImplementedError() elif "LlamaForCausalLM" in architecture: self.architecture = "LlamaForCausalLM" - if rope_type == "linear": - self.patch_fun = LlamaLinearScaledRope - elif rope_type == "ntk": - self.patch_fun = LlamaNTKScaledRope - elif rope_type == "dynamic-ntk": - self.patch_fun = LlamaDynamicScaledRotaryEmbedding + if self.rope_type == "linear": + self.patch_fun = LlamaLinearScalingRotaryEmbedding + elif self.rope_type == "dynamic": + self.patch_fun = LlamaDynamicNTKScalingRotaryEmbedding else: raise NotImplementedError() else: @@ -230,6 +236,9 @@ def from_config(cls, config): args = config.superhot_config return cls(model_name, **args) + def update_config(self, model, scaling_factor): + model.config["rope_scaling"] = {"type": self.rope_type, "factor": scaling_factor} + def patch(self, model): if self.architecture == "FalconForCausalLM": self.patch_falcon_model(model, **self.args) @@ -238,6 +247,8 @@ def patch(self, model): else: raise NotImplementedError() + self.update_config(model, self.args.get("scaling_factor")) + def patch_falcon_model(self, model, **kwargs): for each in model.transformer.h: each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py index 31ca01f61c..9bc68f6f9b 100644 --- a/model/model_training/models/rope.py +++ b/model/model_training/models/rope.py @@ -62,131 +62,3 @@ def forward(self, q, k, past_key_values_length=0): batch, seq_len, head_dim = q.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, q.device, q.dtype) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - - -class LlamaLinearScaledRope(torch.nn.Module): - """ - reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): - super().__init__() - self.scale = 1 / scale - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - t *= self.scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - t *= self.scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaNTKScaledRope(torch.nn.Module): - - """ - reference: https://github.com/jquesnelle/scaled-rope - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): - super().__init__() - base = base * alpha ** (dim / (dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): - """ - reference: https://github.com/jquesnelle/scaled-rope - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): - super().__init__() - self.ntk = ntk - self.base = base - self.dim = dim - self.max_position_embeddings = max_position_embeddings - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - if self.ntk: - base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** ( - self.dim / (self.dim - 2) - ) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - if not self.ntk: - t *= self.max_position_embeddings / seq_len - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - )