From 87ae1f2db9d099c0cf1bcbe67936a75d1d115f96 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 24 Jul 2023 18:24:18 +0000 Subject: [PATCH 1/5] pass config args while patching --- model/model_training/models/patching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index b76baa6e4f..430346c855 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -202,6 +202,10 @@ def __init__(self, model_name, **kwargs): self.args = kwargs rope_type = self.args.pop("type") config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + if hasattr(config, "max_position_embeddings"): + self.args["max_position_embeddings"] = config.max_position_embeddings + if hasattr(config, "base"): + self.args["base"] = config.base architecture = config.architectures if architecture: self.model_name = architecture[0] From 1cdf4fc0aad1af077d77ceb7f0f94e719aca0722 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 24 Jul 2023 18:56:02 +0000 Subject: [PATCH 2/5] rmv custom scaling --- model/model_training/models/rope.py | 40 ----------------------------- 1 file changed, 40 deletions(-) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py index 31ca01f61c..2e92e54a7b 100644 --- a/model/model_training/models/rope.py +++ b/model/model_training/models/rope.py @@ -64,46 +64,6 @@ def forward(self, q, k, past_key_values_length=0): return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) -class LlamaLinearScaledRope(torch.nn.Module): - """ - reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): - super().__init__() - self.scale = 1 / scale - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - t *= self.scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - t *= self.scale - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - class LlamaNTKScaledRope(torch.nn.Module): """ From dee7ef3cc967bb9bfceb23bb33e36db7e1a64c7e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 24 Jul 2023 18:56:14 +0000 Subject: [PATCH 3/5] migrate to HF impl --- model/model_training/models/patching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index 430346c855..582a24d3fe 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -15,13 +15,14 @@ LlamaForCausalLM, LlamaModel, ) +from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead from .patching_falcon import falcon_forward_with_flash_attn from .patching_llama import llama_forward_with_flash_attn from .patching_neox import neox_forward_with_flash_attn from .reward_model import GPTNeoXRewardModel -from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope +from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaNTKScaledRope, RWNTKScaledRope SUPPORTED_MODELS = [ GPTNeoXModel, @@ -218,7 +219,7 @@ def __init__(self, model_name, **kwargs): elif "LlamaForCausalLM" in architecture: self.architecture = "LlamaForCausalLM" if rope_type == "linear": - self.patch_fun = LlamaLinearScaledRope + self.patch_fun = LlamaLinearScalingRotaryEmbedding elif rope_type == "ntk": self.patch_fun = LlamaNTKScaledRope elif rope_type == "dynamic-ntk": From d5fec710f761db97a204f70eabce89bcc5f77b2c Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 24 Jul 2023 19:00:02 +0000 Subject: [PATCH 4/5] update param --- model/model_training/configs/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index adbb01354a..a7c1f70c5d 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -786,7 +786,7 @@ rope_scaling_test: dtype: bf16 log_dir: "llama_log_7b" learning_rate: 1e-5 - model_name: "huggyllama/llama-7b" + model_name: "meta-llama/Llama-2-13b-hf" deepspeed_config: configs/zero_config.json output_dir: llama weight_decay: 0.0 @@ -811,7 +811,7 @@ rope_scaling_test: superhot: true superhot_config: type: linear - scale: 2 + scaling_factor: 2 datasets: - dolly15k From e67cfd87ece8406796da727a60edd0bfdb74478e Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 25 Jul 2023 06:24:26 +0000 Subject: [PATCH 5/5] replace llama rope with HF --- model/model_training/models/patching.py | 24 ++++--- model/model_training/models/rope.py | 88 ------------------------- 2 files changed, 15 insertions(+), 97 deletions(-) diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index 582a24d3fe..bda61501a3 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -15,14 +15,17 @@ LlamaForCausalLM, LlamaModel, ) -from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, +) from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead from .patching_falcon import falcon_forward_with_flash_attn from .patching_llama import llama_forward_with_flash_attn from .patching_neox import neox_forward_with_flash_attn from .reward_model import GPTNeoXRewardModel -from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaNTKScaledRope, RWNTKScaledRope +from .rope import RWNTKScaledRope SUPPORTED_MODELS = [ GPTNeoXModel, @@ -201,7 +204,7 @@ def patch_model( class RopePatch: def __init__(self, model_name, **kwargs): self.args = kwargs - rope_type = self.args.pop("type") + self.rope_type = self.args.pop("type") config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) if hasattr(config, "max_position_embeddings"): self.args["max_position_embeddings"] = config.max_position_embeddings @@ -212,18 +215,16 @@ def __init__(self, model_name, **kwargs): self.model_name = architecture[0] if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture: self.architecture = "FalconForCausalLM" - if rope_type == "ntk": + if self.rope_type == "ntk": self.patch_fun = RWNTKScaledRope else: raise NotImplementedError() elif "LlamaForCausalLM" in architecture: self.architecture = "LlamaForCausalLM" - if rope_type == "linear": + if self.rope_type == "linear": self.patch_fun = LlamaLinearScalingRotaryEmbedding - elif rope_type == "ntk": - self.patch_fun = LlamaNTKScaledRope - elif rope_type == "dynamic-ntk": - self.patch_fun = LlamaDynamicScaledRotaryEmbedding + elif self.rope_type == "dynamic": + self.patch_fun = LlamaDynamicNTKScalingRotaryEmbedding else: raise NotImplementedError() else: @@ -235,6 +236,9 @@ def from_config(cls, config): args = config.superhot_config return cls(model_name, **args) + def update_config(self, model, scaling_factor): + model.config["rope_scaling"] = {"type": self.rope_type, "factor": scaling_factor} + def patch(self, model): if self.architecture == "FalconForCausalLM": self.patch_falcon_model(model, **self.args) @@ -243,6 +247,8 @@ def patch(self, model): else: raise NotImplementedError() + self.update_config(model, self.args.get("scaling_factor")) + def patch_falcon_model(self, model, **kwargs): for each in model.transformer.h: each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py index 2e92e54a7b..9bc68f6f9b 100644 --- a/model/model_training/models/rope.py +++ b/model/model_training/models/rope.py @@ -62,91 +62,3 @@ def forward(self, q, k, past_key_values_length=0): batch, seq_len, head_dim = q.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, q.device, q.dtype) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - - -class LlamaNTKScaledRope(torch.nn.Module): - - """ - reference: https://github.com/jquesnelle/scaled-rope - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): - super().__init__() - base = base * alpha ** (dim / (dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): - """ - reference: https://github.com/jquesnelle/scaled-rope - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): - super().__init__() - self.ntk = ntk - self.base = base - self.dim = dim - self.max_position_embeddings = max_position_embeddings - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - if self.ntk: - base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** ( - self.dim / (self.dim - 2) - ) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - if not self.ntk: - t *= self.max_position_embeddings / seq_len - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - )