From 47049e31bda3d61aefecfdd7d30981f5505b9b66 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Fri, 13 Sep 2024 09:54:55 -0700 Subject: [PATCH] Updated LLM samples with support for Phi 3.5 mini and Llama 3.1 --- PyTorch/llm/README.md | 72 ++++++++-------- PyTorch/llm/app.py | 23 +++-- PyTorch/llm/models/base.py | 11 ++- PyTorch/llm/models/configs.py | 50 ++++++++++- PyTorch/llm/models/layers.py | 95 +++++++++++++++++++-- PyTorch/llm/models/phi2.py | 6 +- PyTorch/llm/scripts/download_and_convert.py | 5 +- PyTorch/llm/utils.py | 4 +- 8 files changed, 208 insertions(+), 58 deletions(-) diff --git a/PyTorch/llm/README.md b/PyTorch/llm/README.md index b762af29..9f04abdb 100644 --- a/PyTorch/llm/README.md +++ b/PyTorch/llm/README.md @@ -8,8 +8,7 @@ This sample is extracted from [pytorch-labs/gpt-fast](https://github.com/pytorch - [Setup](#setup) - [Run the App](#run-the-app) - [App Settings](#app-settings) -- [External Links](#external-links) -- [Model Licenses](#model-licenses) +- [External Links & Model Licenses](#external-links-and-model-licenses) ## Supported Models @@ -17,21 +16,25 @@ The following models are currently supported by this sample: - [Phi-2](https://huggingface.co/microsoft/phi-2): Small Language Model with 2.7 billion parameters. Best suited for prompts using QA format, chat format, and code format. - [Phi-3 Mini 4K](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct): Small Language Model with 3.8 billion parameters using a 4k context window. The Instruct version has been fine-tuned to follow instructions and adhere to safety measures. +- [Phi-3.5 Mini](https://huggingface.co/microsoft/Phi-3.5-mini-instruct): A lightweight, state-of-the-art open model with 3.8 billion parameters using a 128k context window. The Instruct version has been fine-tuned to ensure precise instruction adherence and robust safety measures. - [LLaMA 2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf): Large Language Model with 7 billion parameters optimized specifically for dialogue use cases. - [LLaMA 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct): Large Language Model with 8 billion parameters. The Llama 3 instruction tuned models are optimized for dialogue use cases. +- [LLaMA 3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct): Large Language Model with 8 billion parameters. The Llama 3.1 instruction tuned models are an inprovement over Llama-3 and are optimized for dialogue use cases. - [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1): Large Language Model with 7 billion parameters. The Mistral-7B-Instruct-v0.1 Large Language Model is a instruct fine-tuned version of the Mistral-7B-v0.1 generative text model using a variety of publicly available conversation datasets. >⚠️ **NOTE**: Other variants of these models may work but they were not tested. The various models have different VRAM requirements, the following table lists the memory requirements for the tested models. -| Model | fp16 | fp32 | -| --------------- | ------| ----- | -| Phi-2 | 6GB | 12GB | -| Phi-3-mini-4k | 8GB | >16GB | -| Llama-2-7b | 14GB | 28GB | -| Meta-Llama-3-8B | >16GB | 32GB | -| Mistral-7B | 15GB | 30GB | +| Model | fp16 | fp32 | +| ----------------- | ------| ----- | +| Phi-2 | 6GB | 12GB | +| Phi-3-mini-4k | 8GB | >16GB | +| Phi-3.5-mini | 8GB | >16GB | +| Llama-2-7b | 14GB | 28GB | +| Meta-Llama-3-8B | >16GB | 32GB | +| Meta-Llama-3.1-8B | >16GB | 32GB | +| Mistral-7B | 15GB | 30GB | ## Setup Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/windows/ai/directml/pytorch-windows) or [WSL 2](https://learn.microsoft.com/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running app: @@ -44,6 +47,7 @@ To use the Llama and Mistral models, you will need to go through an extra step t 1. Visit - LLaMA 2: [https://huggingface.co/meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) - LLaMA 3: [https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) + - LLaMA 3.1: [https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) - Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 2. Follow the steps on the Hugging Face page to obtain access 3. Run `huggingface-cli login` @@ -104,7 +108,7 @@ To run the model using `float32` precision, pass `--precision float32` to `app.p ### Change the model -You can also select another model to run (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/phi-2`, `meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Llama-2-7b-chat-hf`, `mistralai/Mistral-7B-Instruct-v0.1`). +You can also select another model to run (`microsoft/Phi-3.5-mini-instruct`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/phi-2`, `meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Llama-2-7b-chat-hf`, `mistralai/Mistral-7B-Instruct-v0.1`). For example to run `Mistral-7B-Instruct-v0.1` use the following command: @@ -151,16 +155,20 @@ Following is a list of the basic settings supported by `app.py`: | `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{hf_model}/model.pth` | | `--max_context_length` | Max prompt length including the history. If exceeded, history is clipped starting from the first (user, assistant) pair. | `1500` | | `--disable_history` | Disable the chat history during generation. | Enabled | +| `--max_pos_emb` | Maximum Position to scale Phi-3.5 position encodings. | 8192 | >⚠️ **NOTE**: The app uses the checkpoint path to determine the correct transformer model to load. The model path must specify the Hugging Face model ID included in the path name. For example: - `checkpoints/microsoft/phi-2/model.pth` - `checkpoints/microsoft/Phi-3-mini-4k-instruct/model.pth` +- `checkpoints/microsoft/Phi-3.5-mini-instruct/model.pth` - `checkpoints/mistralai/Mistral-7B-v0.1/model.pth` - `checkpoints/mistralai/Mistral-7B-Instruct-v0.1/model.pth` - `checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth` - `checkpoints/meta-llama/Meta-Llama-3-8B/model.pth` - `checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth` +- `checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth` +- `checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct/model.pth` ## _[Optional]_ Prepare the Supported Models This step is optional as `app.py` script in [Run the App](#run-the-app) section handles both downloading and optimizing a PyTorch model with DirectML. @@ -179,36 +187,26 @@ After the model is downloaded and converted, you can pass the following paramete > python app.py --hf_model "microsoft/Phi-3-mini-4k-instruct" ``` -### Download a DirectML optimized PyTorch model from the [Microsoft Hugging Face repo](https://huggingface.co/microsoft): - 1. cd checkpoints - 2. git clone https://huggingface.co/{hf_model} {hf_model} - 3. cd ../ -After the model is downloaded, you can pass the following parameter to `app.py` to run the language model: - -``` -> python app.py --checkpoint_path "checkpoints/{hf_model}/model.pth" -``` - -## External Links -- [Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2) -- [Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) -- [LLaMA 2 Hugging Face Repository](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) -- [LLaMA 3 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) -- [Mistral 7B Hugging Face Repository](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) +## External Links and Model Licenses +- [Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2) +This sample uses the Phi-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/phi-2/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. +- [Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) +This sample uses the Phi-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. +- [Phi-3.5 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) +This sample uses the Phi-3.5 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3.5-mini-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. +- [LLaMA 2 Hugging Face Repository](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +This sample uses the Llama-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the[ LLAMA 2 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/LICENSE.txt). For terms of use, please visit: Llama 2 - [Acceptable Use Policy - Meta AI](https://ai.meta.com/llama/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms. +- [LLaMA 3 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +This sample uses the Llama-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the +[LLAMA 3 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE). For terms of use, please visit: [Meta Llama 3 Acceptable Use Policy](https://llama.meta.com/llama3/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms. +- [LLaMA 3.1 Hugging Face Repository](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) +This sample uses the Llama-3.1 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the +[LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE). For terms of use, please visit: [Meta Llama 3.1 Acceptable Use Policy](https://llama.meta.com/llama3_1/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms. +- [Mistral 7B Hugging Face Repository](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) +This sample uses the Mistral model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [Apache-2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. - [PyTorch gpt-fast Source Code](https://github.com/pytorch-labs/gpt-fast/) -## Model Licenses - -- [DirectML-Optimized Phi-2 Hugging Face Repository](https://huggingface.co/microsoft/phi-2-pytdml) -This sample uses the phi-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/phi-2/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. -- [DirectML-Optimized Phi-3 Hugging Face Repository](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-pytdml) -This sample uses the phi-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/LICENSE). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. - -- [DirectML-Optimized LLaMA 2 Hugging Face Repository](https://huggingface.co/microsoft/Llama-2-7b-chat-hf-pytdml) -This sample uses the Llama-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the[ LLAMA 2 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/LICENSE.txt). For terms of use, please visit: Llama 2 - [Acceptable Use Policy - Meta AI](https://ai.meta.com/llama/use-policy/). If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms. -- [DirectML-Optimized Mistral 7B Hugging Face Repository](https://huggingface.co/microsoft/Mistral-7B-Instruct-v0.1-pytdml) -This sample uses the Mistral model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the [Apache-2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms. \ No newline at end of file diff --git a/PyTorch/llm/app.py b/PyTorch/llm/app.py index 420245c3..0356c7c8 100644 --- a/PyTorch/llm/app.py +++ b/PyTorch/llm/app.py @@ -125,7 +125,8 @@ def __init__( precision: str = 'float32', stream_every_n: int = 7, max_context_length: int = 3500, - use_history: bool = False + use_history: bool = False, + max_pos_emb: int = 8192 ): self.prompt = prompt self.interactive = interactive @@ -139,6 +140,7 @@ def __init__( self.stream_every_n = stream_every_n self.max_context_length = max_context_length self.use_history = use_history + self.max_pos_emb = max_pos_emb self.tokenizer = None self.model = None @@ -177,8 +179,7 @@ def format_prompt_and_encode( messages.append(assistant) messages.append({"role": "user", "content": prompt}) tokens = self.tokenizer.apply_chat_template( - messages, return_tensors="pt", add_generation_prompt=self.is_llama_3)[0].to(dtype=torch.int, device=device) - + messages, return_tensors="pt", add_generation_prompt=True)[0].to(dtype=torch.int, device=device) if self.use_history: while tokens.size(0) > max_context_length: print("Clipping history of conversation as it exceeds the max context length.") @@ -188,7 +189,7 @@ def format_prompt_and_encode( else: break tokens = self.tokenizer.apply_chat_template( - messages, return_tensors="pt", add_generation_prompt=self.is_llama_3)[0].to(dtype=torch.int, device=device) + messages, return_tensors="pt", add_generation_prompt=True)[0].to(dtype=torch.int, device=device) return tokens @@ -274,7 +275,7 @@ def load_model(self) -> None: if self.is_phi_2: self.precision = torch.float32 - self.model = _load_model(self.checkpoint_path, device, self.precision) + self.model = _load_model(self.checkpoint_path, device, self.precision, max_pos_emb=self.max_pos_emb) self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint_path.parent) if self.max_context_length > self.model.config.block_size - (self.max_new_tokens+1): raise ValueError( @@ -288,6 +289,7 @@ def chat( **sampling_kwargs ) -> Iterator[str]: torch.manual_seed(1235) + encoded = self.encode_tokens( prompt, history, @@ -348,7 +350,15 @@ def chat(message: str, history: List[List[str]]) -> Iterator[str]: choices=['float16', 'float32'], help='Precision to run the generation with.' ) + parser.add_argument( + '--max_pos_emb', + type=int, + default=8192, + help='Maximum Position to scale Phi-3.5 position encodings.' + ) args = parser.parse_args() + if args.max_pos_emb > 131072: + args.max_pos_emb = 131072 llm_model = LLM_Model(prompt = "Hello", interactive = False, @@ -360,7 +370,8 @@ def chat(message: str, history: List[List[str]]) -> Iterator[str]: checkpoint_path = args.checkpoint_path, precision = args.precision, max_context_length = args.max_context_length, - use_history = not args.disable_history) + use_history = not args.disable_history, + max_pos_emb=args.max_pos_emb) llm_model.load_model() demo = gr.ChatInterface(chat).queue() diff --git a/PyTorch/llm/models/base.py b/PyTorch/llm/models/base.py index 1e74a783..34f0b340 100644 --- a/PyTorch/llm/models/base.py +++ b/PyTorch/llm/models/base.py @@ -24,6 +24,9 @@ def __init__(self, config: ModelArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 + def set_max_position_embeddings(self, max_pos_emb): + self.max_position_embeddings = max_pos_emb + def setup_caches(self, max_batch_size, max_seq_length): head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) @@ -49,5 +52,9 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return logits @classmethod - def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) + def from_name(cls, name: str, max_pos_emb: int = 8192): + model = cls(ModelArgs.from_name(name)) + if "phi-3.5" in name.lower(): + model.set_max_position_embeddings(max_pos_emb) + + return model \ No newline at end of file diff --git a/PyTorch/llm/models/configs.py b/PyTorch/llm/models/configs.py index 8f11756a..382a7837 100644 --- a/PyTorch/llm/models/configs.py +++ b/PyTorch/llm/models/configs.py @@ -4,7 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional, Dict, Any def find_multiple(n: int, k: int) -> int: @@ -25,6 +26,8 @@ class ModelArgs: rope_base: float = 10000 norm_eps: float = 1e-5 partial_rotary_factor: float = 1.0 + rope_scaling: Optional[Dict[str, Any]] = field(default=None) + original_max_position_embeddings: int = None def __post_init__(self): if self.n_local_heads == -1: @@ -59,6 +62,48 @@ def from_name(cls, name: str): "Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064), "Mistral-7B": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0), + "Llama-3.1-8B": dict( + block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, + intermediate_size=14336, vocab_size=128256, rope_base=500000.0, + rope_scaling={ + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + } + ), + "Phi-3.5-mini-instruct": dict( + block_size=8192, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, + rope_base=10000, vocab_size=32064, original_max_position_embeddings=4096, + rope_scaling={ + "long_factor": [ + 1.0800000429153442, 1.1100000143051147, 1.1399999856948853, 1.340000033378601, 1.5899999141693115, + 1.600000023841858, 1.6200000047683716, 2.620000123977661, 3.2300000190734863, 3.2300000190734863, + 4.789999961853027, 7.400000095367432, 7.700000286102295, 9.09000015258789, 12.199999809265137, + 17.670000076293945, 24.46000099182129, 28.57000160217285, 30.420001983642578, 30.840002059936523, + 32.590003967285156, 32.93000411987305, 42.320003509521484, 44.96000289916992, 50.340003967285156, + 50.45000457763672, 57.55000305175781, 57.93000411987305, 58.21000289916992, 60.1400032043457, + 62.61000442504883, 62.62000274658203, 62.71000289916992, 63.1400032043457, 63.1400032043457, + 63.77000427246094, 63.93000411987305, 63.96000289916992, 63.970001220703125, 64.02999877929688, + 64.06999969482422, 64.08000183105469, 64.12000274658203, 64.41000366210938, 64.4800033569336, + 64.51000213623047, 64.52999877929688, 64.83999633789062 + ], + "short_factor": [ + 1.0, 1.0199999809265137, 1.0299999713897705, 1.0299999713897705, 1.0499999523162842, 1.0499999523162842, + 1.0499999523162842, 1.0499999523162842, 1.0499999523162842, 1.0699999332427979, 1.0999999046325684, + 1.1099998950958252, 1.1599998474121094, 1.1599998474121094, 1.1699998378753662, 1.2899998426437378, + 1.339999794960022, 1.679999828338623, 1.7899998426437378, 1.8199998140335083, 1.8499997854232788, + 1.8799997568130493, 1.9099997282028198, 1.9399996995925903, 1.9899996519088745, 2.0199997425079346, + 2.0199997425079346, 2.0199997425079346, 2.0199997425079346, 2.0199997425079346, 2.0199997425079346, + 2.0299997329711914, 2.0299997329711914, 2.0299997329711914, 2.0299997329711914, 2.0299997329711914, + 2.0299997329711914, 2.0299997329711914, 2.0299997329711914, 2.0299997329711914, 2.0799996852874756, + 2.0899996757507324, 2.189999580383301, 2.2199995517730713, 2.5899994373321533, 2.729999542236328, + 2.749999523162842, 2.8399994373321533 + ], + "rope_type": "longrope" + }, + ), } default_models = { @@ -66,5 +111,6 @@ def from_name(cls, name: str): "llama-2": "meta-llama/Llama-2-7b-chat-hf", "phi-3": "microsoft/Phi-3-mini-4k-instruct", "phi-2": "microsoft/phi-2", - "mistral": "mistralai/Mistral-7B-Instruct-v0.1" + "mistral": "mistralai/Mistral-7B-Instruct-v0.1", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct" } diff --git a/PyTorch/llm/models/layers.py b/PyTorch/llm/models/layers.py index 67114b9b..9c207648 100644 --- a/PyTorch/llm/models/layers.py +++ b/PyTorch/llm/models/layers.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Optional, Union, Tuple, Dict import torch_directml @@ -32,8 +33,10 @@ def __init__( base: Union[int, float] = 10000, dtype: torch.dtype = torch.float16, device: Optional[torch.device] = None, + config: ModelArgs = None, ): super().__init__() + self.config = config self.dtype = dtype device = device if device is not None else torch_directml.device(torch_directml.default_device()) @@ -42,6 +45,25 @@ def __init__( self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + if self.config.rope_scaling and self.config.rope_scaling["rope_type"] == "llama3": + self.attention_scaling = 1.0 + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache( @@ -57,16 +79,73 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device) -> None: self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False) - def forward(self) -> Tuple[torch.Tensor, torch.Tensor] : + def forward(self, seq_len) -> Tuple[torch.Tensor, torch.Tensor] : return ( self.cos_cached, self.sin_cached ) +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=131072, base=10000, dtype=torch.float16, device=None, config=None): + super().__init__() + self.device = device if device is not None else torch_directml.device(torch_directml.default_device()) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.short_factor = torch.tensor(config.rope_scaling["short_factor"], dtype=torch.float32, device=self.device) + self.long_factor = torch.tensor(config.rope_scaling["long_factor"], dtype=torch.float32, device=self.device) + self.original_max_position_embeddings = config.original_max_position_embeddings + self.dtype = dtype + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + self.scaling_factor = 1.0 + else: + self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + # Precompute cos and sin for short and long factors + self._set_precomputed_caches() + + def _set_precomputed_caches(self): + # Compute inv_freq for short and long factors + inv_freq_short = self._compute_inv_freq(self.short_factor) + inv_freq_long = self._compute_inv_freq(self.long_factor) + + # Precompute the cos and sin caches for short and long factors + self.cos_cache_short, self.sin_cache_short = self._compute_cos_sin_cache(inv_freq_short) + self.cos_cache_long, self.sin_cache_long = self._compute_cos_sin_cache(inv_freq_long) + + def _compute_inv_freq(self, factor): + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.device) / self.dim + inv_freq = 1.0 / (factor * self.base**inv_freq_shape) + return inv_freq + + def _compute_cos_sin_cache(self, inv_freq): + t = torch.arange(self.max_position_embeddings, device=self.device, dtype=self.dtype) + + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_cache = emb.cos().to(self.dtype) + sin_cache = emb.sin().to(self.dtype) + + return cos_cache.unsqueeze(0).unsqueeze(0), sin_cache.unsqueeze(0).unsqueeze(0) + + @torch.no_grad() + def forward(self, seq_len): + if seq_len > self.original_max_position_embeddings: + cos_cached = self.cos_cache_long + sin_cached = self.sin_cache_long + else: + cos_cached = self.cos_cache_short + sin_cached = self.sin_cache_short + + cos = cos_cached * self.scaling_factor + sin = sin_cached * self.scaling_factor + return (cos, sin) + class LlamaAttention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() assert config.dim % config.n_head == 0 + self.config = config total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) @@ -77,7 +156,6 @@ def __init__(self, config: ModelArgs): self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim - self._register_load_state_dict_pre_hook(self.load_hook) def load_hook(self, state_dict: Dict[str, torch.Tensor], prefix: str, *argspy): @@ -91,9 +169,14 @@ def _init_rope( self, max_position_embeddings: int = 4096, rope_base: Union[int, self.min_position = 0 self.past_key_tensor = None self.past_value_tensor = None - self.rotary_emb = RotaryEmbedding( - self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype - ) + if self.config.rope_scaling and self.config.rope_scaling["rope_type"] == "longrope": + self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype, config=self.config + ) + else: + self.rotary_emb = RotaryEmbedding( + self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype, config=self.config + ) def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: bsz, seqlen, _ = x.shape @@ -105,7 +188,7 @@ def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) - k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2) v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb() + cos, sin = self.rotary_emb(self.min_position + seqlen) q, k = torch_directml.apply_rotary_position_emb( q, k, cos, sin, self.min_position, seqlen, self.head_dim) diff --git a/PyTorch/llm/models/phi2.py b/PyTorch/llm/models/phi2.py index 2a87784b..6901a4e1 100644 --- a/PyTorch/llm/models/phi2.py +++ b/PyTorch/llm/models/phi2.py @@ -40,6 +40,7 @@ class Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() assert config.dim % config.n_head == 0 + self.config = config self.wqkv = nn.Linear(config.dim, 3 * config.dim) self.wo = nn.Linear(config.dim, config.dim) @@ -73,7 +74,8 @@ def _init_rope( int(self.head_dim * self.partial_rotary_factor), max_position_embeddings=max_position_embeddings, base=rope_base, - dtype=dtype + dtype=dtype, + config=self.config, ) def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: @@ -84,7 +86,7 @@ def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) - q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1,2) k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1,2) - cos, sin = self.rotary_emb() + cos, sin = self.rotary_emb(self.min_position + seqlen) q, k = torch_directml.apply_rotary_position_emb( q, k, cos, sin, self.min_position, seqlen, self.rotary_emb.dim) diff --git a/PyTorch/llm/scripts/download_and_convert.py b/PyTorch/llm/scripts/download_and_convert.py index 0dd67e76..51ab07b5 100644 --- a/PyTorch/llm/scripts/download_and_convert.py +++ b/PyTorch/llm/scripts/download_and_convert.py @@ -79,7 +79,10 @@ def convert_hf_checkpoint( weight_maps = json.load(file) model_name = checkpoint_dir.name - model_name = "llama" if "phi" not in model_name.lower() else model_name + if "phi-3" in model_name.lower(): + model_name = "Phi-3-mini-4k-instruct" + elif "phi" not in model_name: + model_name = "llama" weight_map = weight_maps[model_name] # Load the json file containing weight mapping diff --git a/PyTorch/llm/utils.py b/PyTorch/llm/utils.py index c875b944..6dad028a 100644 --- a/PyTorch/llm/utils.py +++ b/PyTorch/llm/utils.py @@ -74,13 +74,13 @@ def decode_with_overlap(tokenizer: PreTrainedTokenizerFast, tokens: List[Tensor] text_output = current_decoded return text_output -def _load_model(checkpoint_path: str, device: torch.device, precision: torch.dtype) -> torch.nn.Module: +def _load_model(checkpoint_path: str, device: torch.device, precision: torch.dtype, max_pos_emb=8192) -> torch.nn.Module: model_name = checkpoint_path.parent.name with torch.device('meta'): if 'phi-2' in model_name.lower(): model = Phi2Transformer.from_name(model_name) elif 'phi-3' in model_name.lower(): - model = Phi3Transformer.from_name(model_name) + model = Phi3Transformer.from_name(model_name, max_pos_emb=max_pos_emb) else: model = LlamaTransformer.from_name(model_name)