From 5f776a93fe5747f4f09d959782d843c8107fe15c Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 15 Aug 2024 10:11:29 +0200 Subject: [PATCH 01/17] Add functional transformer demo --- ...nder_friendly_functional_transformer.ipynb | 292 ++++++++++++++++++ 1 file changed, 292 insertions(+) create mode 100644 notebooks/thunder_friendly_functional_transformer.ipynb diff --git a/notebooks/thunder_friendly_functional_transformer.ipynb b/notebooks/thunder_friendly_functional_transformer.ipynb new file mode 100644 index 0000000000..56c0898c13 --- /dev/null +++ b/notebooks/thunder_friendly_functional_transformer.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we'll explore an example of a \"functional\" transformer that can be easily executed using Thunder.\n", + "\n", + "Thanks to the rich set of utilities provided by Python, we have a very flexible and convenient way to integrate Python with PyTorch to build transformer architectures like the one used by Hugging Face. However, this flexibility also poses challenges for deep learning compilers, which may have difficulty tracing all the Python utilities involved.\n", + "\n", + "On the other hand, using a \"functional\" Python function - one that is free of side effects - can greatly simplify this process.\n", + "\n", + "In this notebook, we will demonstrate an example of a \"functional\" transformer. This will give us some insight into how to convert a PyTorch module into a simple \"functional\" Python function, allowing for seamless integration with Thunder.\n", + "\n", + "**NOTE**: The code used is from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To convert the PyTorch module into a \"functional\" Python function, we need to pass the module's parameters as inputs, rather than relying on data members. Here are some helper classes that group the parameters together for cleaner and more organized notation." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from typing import List, NamedTuple\n", + "\n", + "# Helper classes that group the parameters together\n", + "class LayerWeights(NamedTuple):\n", + " input_norm: torch.Tensor # (hidden)\n", + " post_attn_norm: torch.Tensor # (hidden)\n", + " q_proj: torch.Tensor # (hidden, q_intermediate)\n", + " k_proj: torch.Tensor # (hidden, kv_intermediate)\n", + " v_proj: torch.Tensor # (hidden, kv_intermediate)\n", + " o_proj: torch.Tensor # (q_intermediate, hidden)\n", + " gate_proj: torch.Tensor # (hidden, intermediate)\n", + " up_proj: torch.Tensor # (hidden, intermediate)\n", + " down_proj: torch.Tensor # (intermediate, hidden)\n", + "\n", + "\n", + "class TransformerWeights(NamedTuple):\n", + " layers: List[LayerWeights]\n", + " token_emb: torch.Tensor # (vocab_size, hidden)\n", + " final_norm: torch.Tensor # (hidden)\n", + " lm_head: torch.Tensor # (hidden, vocab_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we will rewrite all the layers used in the transformer as functional functions. This involves converting each layer into a function that takes both the input data and the relevant parameters as arguments." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "NUM_Q_HEADS = 32 # Llama numbers\n", + "NUM_KV_HEADS = 8 # Llama numbers\n", + "SLIDING_WINDOW_SIZE = 4096\n", + "\n", + "def norm(x: torch.Tensor, weight: torch.Tensor):\n", + " in_dtype = x.dtype\n", + " x = x.float()\n", + " out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) # eps might change depending on the model\n", + " return weight * out.to(in_dtype)\n", + "\n", + "\n", + "def ffn(x: torch.Tensor, weights: LayerWeights):\n", + " gate = F.silu(x @ weights.gate_proj)\n", + " fused = gate * (x @ weights.up_proj)\n", + " return fused @ weights.down_proj\n", + "\n", + "\n", + "def rope(x: torch.Tensor, freqs_cis: torch.Tensor):\n", + " def rotate(x):\n", + " \"\"\"\n", + " rotate_half(torch.arange(4))\n", + " > tensor([-2, -3, 0, 1])\n", + " \"\"\"\n", + " x1 = x[..., : x.shape[-1] // 2]\n", + " x2 = x[..., x.shape[-1] // 2 :]\n", + " return torch.cat((-x2, x1), dim=-1)\n", + "\n", + " cos, sin = freqs_cis\n", + " cos, sin = cos.type_as(x), sin.type_as(x)\n", + " right = rotate(x.reshape(*x.shape[:-1], -1, 2)).reshape(x.shape)\n", + " out = x * cos + right * sin\n", + " return out.to(x.dtype)\n", + "\n", + "\n", + "def attn(\n", + " x: torch.Tensor,\n", + " weights: LayerWeights,\n", + " freqs_cis: tuple,\n", + " sliding_window_size=None,\n", + "):\n", + " bs, seq_len, d_model = x.shape\n", + " xq, xk, xv = x @ weights.q_proj, x @ weights.k_proj, x @ weights.v_proj\n", + " xq = xq.view(bs, seq_len, NUM_Q_HEADS, -1).transpose(1, 2) # (bs, NUM_Q_HEADS, seq_len, q_intermediate)\n", + " xk = xk.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2) # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)\n", + " xv = xv.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2) # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)\n", + " head_dim = xq.shape[-1]\n", + "\n", + " # Treat GQA as MHA and just repeat along the head dimension\n", + " xk = torch.repeat_interleave(xk, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)\n", + " xv = torch.repeat_interleave(xv, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)\n", + " xq = rope(xq, freqs_cis)\n", + " xk = rope(xk, freqs_cis)\n", + "\n", + " attn_scores = (xq @ xk.transpose(2, 3)) * (head_dim**-0.5)\n", + " mask = torch.triu(torch.full((bs, seq_len, seq_len), -2.3819763e38), diagonal=1) # This number is taken from Gemma\n", + " if sliding_window_size is not None: # Sliding window attention\n", + " all_ones = torch.ones((seq_len, seq_len))\n", + " sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) * torch.tril(all_ones, sliding_window_size - 1)\n", + " mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)\n", + " mask = mask.to(x.device, x.dtype)\n", + " attn_scores = attn_scores + mask\n", + " attn_probs = F.softmax(attn_scores, dim=-1)\n", + " attn_out = attn_probs @ xv\n", + " attn_out = attn_out.transpose(1, 2).contiguous().view(bs, seq_len, -1)\n", + " return attn_out @ weights.o_proj\n", + "\n", + "\n", + "# for efficiency, should precompute for 0..max_length * 2 then select [:curr_length]\n", + "def precompute_freqs_cis(head_dim: int, seq_len: int, base_theta: float = 500000.0):\n", + " inv_freqs = 1.0 / (base_theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) # Eq 15: theta_{1} ... theta_{dim/2}. Shape: (dim/2)\n", + " m = torch.arange(seq_len) # all possible position indices\n", + " freqs = torch.outer(m, inv_freqs).float() # [m_i * theta_j] for all i (positions) and j (frequencies). Shape: (seq_len, dim/2) | freqs[i][j] == m[i] * inv_freqs[j]\n", + " cos = torch.cos(freqs) # Shape: (seq_len, dim/2)\n", + " cos = torch.repeat_interleave(cos, 2, dim=-1) # Shape: (seq_len, dim)\n", + " sin = torch.sin(freqs) # Shape: (seq_len, dim/2)\n", + " sin = torch.repeat_interleave(sin, 2, dim=-1) # Shape: (seq_len, dim)\n", + " return (cos, sin)\n", + "\n", + "\n", + "def transformer(in_tokens: torch.Tensor, weights: TransformerWeights):\n", + " x = torch.nn.functional.embedding(in_tokens, weights.token_emb)\n", + " b, t, d = x.shape\n", + " q_intermediate = weights.layers[0].q_proj.shape[1]\n", + " freqs_cis = precompute_freqs_cis(q_intermediate // NUM_Q_HEADS, t) # (cos, sin)\n", + " for i, layer in enumerate(weights.layers):\n", + " residual = x\n", + " hidden = norm(x, layer.input_norm)\n", + " hidden = attn(hidden, layer, freqs_cis, sliding_window_size=SLIDING_WINDOW_SIZE if i % 6 != 0 else None) # Follows https://research.character.ai/optimizing-inference/\n", + " hidden = residual + hidden\n", + "\n", + " residual = hidden\n", + " hidden = norm(hidden, layer.post_attn_norm)\n", + " hidden = ffn(hidden, layer)\n", + " hidden = residual + hidden\n", + " x = hidden\n", + "\n", + " x = norm(x, weights.final_norm)\n", + " x = x @ weights.lm_head\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the functional versions of the layers and transformer are ready, we'll load the weights for each layer into our `LayerWeights` and `TransformerWeights` container classes. These classes will store the parameters so that they can be easily passed as inputs to the functional transformer.\n", + "\n", + "**NOTE**: To run the cells below, you'll need access to the Hugging Face Meta-Llama-3-8B model. Be sure to download the model weights and place them in the \"Meta-Llama-3-8B/consolidated.00.pth\". See [here](https://huggingface.co/meta-llama/Meta-Llama-3-8B) to learn more about Hugging Face Meta-Llama-3-8B." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "/tmp/ipykernel_2627731/4043548868.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " state_dict = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\", map_location=\"cuda\")\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + "import thunder\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n", + "# Download the official repo weights\n", + "state_dict = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\", map_location=\"cuda\")\n", + "layers = []\n", + "n_layers = 32\n", + "for i in range(n_layers):\n", + " layer = LayerWeights(\n", + " input_norm=state_dict[f\"layers.{i}.attention_norm.weight\"],\n", + " post_attn_norm=state_dict[f\"layers.{i}.ffn_norm.weight\"],\n", + " q_proj=state_dict[f\"layers.{i}.attention.wq.weight\"].t(),\n", + " k_proj=state_dict[f\"layers.{i}.attention.wk.weight\"].t(),\n", + " v_proj=state_dict[f\"layers.{i}.attention.wv.weight\"].t(),\n", + " o_proj=state_dict[f\"layers.{i}.attention.wo.weight\"].t(),\n", + " gate_proj=state_dict[f\"layers.{i}.feed_forward.w1.weight\"].t(),\n", + " up_proj=state_dict[f\"layers.{i}.feed_forward.w3.weight\"].t(),\n", + " down_proj=state_dict[f\"layers.{i}.feed_forward.w2.weight\"].t(),\n", + " )\n", + " layers.append(layer)\n", + "\n", + "weights = TransformerWeights(\n", + " layers=layers,\n", + " token_emb=state_dict[\"tok_embeddings.weight\"],\n", + " final_norm=state_dict[\"norm.weight\"],\n", + " lm_head=state_dict[\"output.weight\"].t(),\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we’ll use Thunder to execute the \"functional\" transformer and observe the results. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ours: <|begin_of_text|>the answer to the ultimate question of life 42\n", + "the answer to the ultimate question of life\n" + ] + } + ], + "source": [ + "prompt = \"the answer to the ultimate question of life \"\n", + "in_tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(\"cuda\")\n", + "\n", + "# Use thunder on the \"functional\" transformer\n", + "jitted_transformer = thunder.jit(transformer)\n", + "for _ in range(10):\n", + " out = jitted_transformer(in_tokens, weights)\n", + " next_token = torch.argmax(out[:, -1, :])\n", + " in_tokens = torch.cat((in_tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=1)\n", + "\n", + "del weights\n", + "del state_dict\n", + "\n", + "print(\"Ours:\", tokenizer.decode(in_tokens[0].tolist()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And that's it! Converting a PyTorch module into a functional Python function is easier than you might think. This simple modification allows you to take full advantage of Thunder's capabilities." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 5c83b41fb84cdfaeb4fd1e1aae2f02d287faa01f Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 26 Aug 2024 14:43:46 +0200 Subject: [PATCH 02/17] rephrase --- ...transformer.ipynb => functional_transformer_example.ipynb} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename notebooks/{thunder_friendly_functional_transformer.ipynb => functional_transformer_example.ipynb} (97%) diff --git a/notebooks/thunder_friendly_functional_transformer.ipynb b/notebooks/functional_transformer_example.ipynb similarity index 97% rename from notebooks/thunder_friendly_functional_transformer.ipynb rename to notebooks/functional_transformer_example.ipynb index 56c0898c13..6364b9c2b8 100644 --- a/notebooks/thunder_friendly_functional_transformer.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -10,7 +10,7 @@ "\n", "On the other hand, using a \"functional\" Python function - one that is free of side effects - can greatly simplify this process.\n", "\n", - "In this notebook, we will demonstrate an example of a \"functional\" transformer. This will give us some insight into how to convert a PyTorch module into a simple \"functional\" Python function, allowing for seamless integration with Thunder.\n", + "In this notebook, we will demonstrate an example of a \"functional\" transformer. Additionally, we will explore how Thunder can be applied to this version of the transformer.\n", "\n", "**NOTE**: The code used is from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 " ] @@ -264,7 +264,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And that's it! Converting a PyTorch module into a functional Python function is easier than you might think. This simple modification allows you to take full advantage of Thunder's capabilities." + "And that's it! Converting a PyTorch module into a functional Python function is easier than you might think. Moreover, Thunder seamlessly operates with this functional version of the transformer." ] } ], From ade62463946cdf9a8ed45a94f409d6343c29225c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:48:42 +0300 Subject: [PATCH 03/17] Edit the intro text --- notebooks/functional_transformer_example.ipynb | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 6364b9c2b8..4583904c98 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -4,15 +4,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this notebook, we'll explore an example of a \"functional\" transformer that can be easily executed using Thunder.\n", + "In this notebook, we'll explore a \"functional\" transformer implementation that can be efficiently executed using Thunder. This approach offers several advantages in terms of optimization and code clarity.\n", "\n", - "Thanks to the rich set of utilities provided by Python, we have a very flexible and convenient way to integrate Python with PyTorch to build transformer architectures like the one used by Hugging Face. However, this flexibility also poses challenges for deep learning compilers, which may have difficulty tracing all the Python utilities involved.\n", + "Python's expressivity and flexibility make it an excellent choice for developing for deep learning models. However, these same qualities can sometimes hinder performance optimization and make it challenging to understand or modify existing codebases.\n", + "Recently, there has been a growing trend towards developing models in a simpler, more transparent style that facilitates optimization and comprehension. Projects like [LitGPT](https://github.com/Lightning-AI/litgpt) and [nanoGPT](https://github.com/karpathy/nanoGPT) are examples of this trend.\n", + "One such style is the \"functional\" programming style, which is free of side effects and can be easily understood and optimized by both developers and compilers.\n", "\n", - "On the other hand, using a \"functional\" Python function - one that is free of side effects - can greatly simplify this process.\n", + "We'll cover the following key points:\n", + "* The structure and implementation of a functional transformer\n", + "* Advantages of this approach compared to traditional implementations\n", + "* How Thunder can be applied to optimize and execute this functional transformer\n", "\n", - "In this notebook, we will demonstrate an example of a \"functional\" transformer. Additionally, we will explore how Thunder can be applied to this version of the transformer.\n", + "By the end of this notebook, you'll have a clear understanding of how functional programming principles can be leveraged to create more efficient and compiler-friendly transformer models.\n", "\n", - "**NOTE**: The code used is from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 " + "**Credit**: The code used in this notebook is adapted from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 " ] }, { From 499b229bc8c6a80cdcbf323bfe8cf03cbffa0201 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:50:00 +0300 Subject: [PATCH 04/17] Use list instead of List --- notebooks/functional_transformer_example.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 4583904c98..1f0370b62e 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -34,7 +34,7 @@ "outputs": [], "source": [ "import torch\n", - "from typing import List, NamedTuple\n", + "from typing import NamedTuple\n", "\n", "# Helper classes that group the parameters together\n", "class LayerWeights(NamedTuple):\n", @@ -50,7 +50,7 @@ "\n", "\n", "class TransformerWeights(NamedTuple):\n", - " layers: List[LayerWeights]\n", + " layers: list[LayerWeights]\n", " token_emb: torch.Tensor # (vocab_size, hidden)\n", " final_norm: torch.Tensor # (hidden)\n", " lm_head: torch.Tensor # (hidden, vocab_size)" From 808a787f16bb089e33c76f9217fccbc18f0c667b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:53:04 +0300 Subject: [PATCH 05/17] Edit named tuples intro --- notebooks/functional_transformer_example.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 1f0370b62e..825266aad1 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -24,7 +24,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To convert the PyTorch module into a \"functional\" Python function, we need to pass the module's parameters as inputs, rather than relying on data members. Here are some helper classes that group the parameters together for cleaner and more organized notation." + "To transform a PyTorch module into a \"functional\" Python implementation, we need to restructure how neural network parameters are handled. Instead of relying on class members and `nn.Module`, we'll pass the module's parameters explicitly as function inputs. This approach enhances transparency and makes the code more amenable to optimization.\n", + "To maintain clean and organized code, we'll use named tuples to group related parameters together.\n", + "Let's examine the helper named tuples we'll use to organize our transformer's parameters:" ] }, { From 8cf9271e4880528b02d3c4d73ec375ecb2173fc7 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:57:24 +0300 Subject: [PATCH 06/17] Edit intro to the functional implementation --- notebooks/functional_transformer_example.ipynb | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 825266aad1..912b574eef 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -62,7 +62,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we will rewrite all the layers used in the transformer as functional functions. This involves converting each layer into a function that takes both the input data and the relevant parameters as arguments." + "Next, we'll write all the layers used in the transformer using functional implementations. This process involves converting each layer from a PyTorch module into a pure function (a function with no side effects) that takes both the input data and the relevant parameters as explicit arguments. \n", + "In the following sections, we'll walk through the functional implementations of key transformer components, including:\n", + "* Layer normalization\n", + "* Feed-forward network\n", + "* Embedding layer\n", + "* Attention mechanism\n", + "\n", + "Each function will clearly define its inputs, including both the data to be processed and the necessary parameters. This approach will provide a comprehensive view of how data flows through the transformer architecture." ] }, { From dd0b3a75bf52976fe7bdfbb493fa27c132098925 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:58:42 +0300 Subject: [PATCH 07/17] .float() -> .to(float32) --- notebooks/functional_transformer_example.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 912b574eef..58ae8ce6af 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -86,7 +86,8 @@ "\n", "def norm(x: torch.Tensor, weight: torch.Tensor):\n", " in_dtype = x.dtype\n", - " x = x.float()\n", + " compute_dtype = torch.float32\n", + " x = x.to(compute_dtype)\n", " out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) # eps might change depending on the model\n", " return weight * out.to(in_dtype)\n", "\n", From c0a4e6d760a11db7e522f6b0a0bc0936c9dd76e9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 15:59:19 +0300 Subject: [PATCH 08/17] Define eps variable in norm function --- notebooks/functional_transformer_example.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 58ae8ce6af..17ad196324 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -88,7 +88,8 @@ " in_dtype = x.dtype\n", " compute_dtype = torch.float32\n", " x = x.to(compute_dtype)\n", - " out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5) # eps might change depending on the model\n", + " eps = 1e-5 # eps might change depending on the model\n", + " out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)\n", " return weight * out.to(in_dtype)\n", "\n", "\n", From 8b2844924c726a31657a18370a7bdc03c69c23cc Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:00:36 +0300 Subject: [PATCH 09/17] Add header comments for functions in transformer --- notebooks/functional_transformer_example.ipynb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 17ad196324..01dd98ec5d 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -84,6 +84,7 @@ "NUM_KV_HEADS = 8 # Llama numbers\n", "SLIDING_WINDOW_SIZE = 4096\n", "\n", + "# Layer normalization\n", "def norm(x: torch.Tensor, weight: torch.Tensor):\n", " in_dtype = x.dtype\n", " compute_dtype = torch.float32\n", @@ -93,12 +94,14 @@ " return weight * out.to(in_dtype)\n", "\n", "\n", + "# Feed-forward network\n", "def ffn(x: torch.Tensor, weights: LayerWeights):\n", " gate = F.silu(x @ weights.gate_proj)\n", " fused = gate * (x @ weights.up_proj)\n", " return fused @ weights.down_proj\n", "\n", "\n", + "# Rotary Positional Encoding\n", "def rope(x: torch.Tensor, freqs_cis: torch.Tensor):\n", " def rotate(x):\n", " \"\"\"\n", @@ -116,6 +119,7 @@ " return out.to(x.dtype)\n", "\n", "\n", + "# Attention\n", "def attn(\n", " x: torch.Tensor,\n", " weights: LayerWeights,\n", From 47be7a4ceb2f596a1c9b901dfb112cab29c23a86 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:01:38 +0300 Subject: [PATCH 10/17] Import silu and softmax directly --- notebooks/functional_transformer_example.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 01dd98ec5d..559d28ac3f 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "import torch.nn.functional as F\n", + "from torch.nn.functional import silu, softmax\n", "\n", "NUM_Q_HEADS = 32 # Llama numbers\n", "NUM_KV_HEADS = 8 # Llama numbers\n", @@ -96,7 +96,7 @@ "\n", "# Feed-forward network\n", "def ffn(x: torch.Tensor, weights: LayerWeights):\n", - " gate = F.silu(x @ weights.gate_proj)\n", + " gate = silu(x @ weights.gate_proj)\n", " fused = gate * (x @ weights.up_proj)\n", " return fused @ weights.down_proj\n", "\n", @@ -147,7 +147,7 @@ " mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)\n", " mask = mask.to(x.device, x.dtype)\n", " attn_scores = attn_scores + mask\n", - " attn_probs = F.softmax(attn_scores, dim=-1)\n", + " attn_probs = softmax(attn_scores, dim=-1)\n", " attn_out = attn_probs @ xv\n", " attn_out = attn_out.transpose(1, 2).contiguous().view(bs, seq_len, -1)\n", " return attn_out @ weights.o_proj\n", From dcdbdd26d25d40fc9634f58cc4c8ef26cb4c21d0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:04:05 +0300 Subject: [PATCH 11/17] Use .to(x.dtype) instead of .type_as(x) --- notebooks/functional_transformer_example.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 559d28ac3f..3e2555d33c 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -113,7 +113,8 @@ " return torch.cat((-x2, x1), dim=-1)\n", "\n", " cos, sin = freqs_cis\n", - " cos, sin = cos.type_as(x), sin.type_as(x)\n", + " cos = cos.to(x.dtype)\n", + " sin = sin.to(x.dtype)\n", " right = rotate(x.reshape(*x.shape[:-1], -1, 2)).reshape(x.shape)\n", " out = x * cos + right * sin\n", " return out.to(x.dtype)\n", From 4330d3761f8a07202e1bdd4c7932977bc81019bf Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:10:27 +0300 Subject: [PATCH 12/17] Edit intro to weights loading --- notebooks/functional_transformer_example.ipynb | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 3e2555d33c..0a4327dd1c 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -192,7 +192,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Once the functional versions of the layers and transformer are ready, we'll load the weights for each layer into our `LayerWeights` and `TransformerWeights` container classes. These classes will store the parameters so that they can be easily passed as inputs to the functional transformer.\n", + "With the functional versions of the layers and transformer components in place, our next crucial step is to properly initialize and organize the model weights. We'll use the `LayerWeights` and `TransformerWeights` classes to store the weights for normalization layers and the entire transformer, respectively. This organization will make it easier to pass the weights as explicit arguments to the functional transformer while working with real-world, pre-trained models.\n", + "\n", + "In the following sections, we'll demonstrate:\n", + "* How to load weights from a pre-trained model\n", + "* The structure of our `LayerWeights` and `TransformerWeights` classes\n", + "* How these weight containers integrate with our functional transformer implementation\n", "\n", "**NOTE**: To run the cells below, you'll need access to the Hugging Face Meta-Llama-3-8B model. Be sure to download the model weights and place them in the \"Meta-Llama-3-8B/consolidated.00.pth\". See [here](https://huggingface.co/meta-llama/Meta-Llama-3-8B) to learn more about Hugging Face Meta-Llama-3-8B." ] From 79a2b7a0ac2c2911c0105663b4d7fc455098d654 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:11:12 +0300 Subject: [PATCH 13/17] Move import thunder to the cell where it's used --- notebooks/functional_transformer_example.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 0a4327dd1c..d1e4dd3b5f 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -219,7 +219,6 @@ ], "source": [ "from transformers import AutoTokenizer\n", - "import thunder\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n", "# Download the official repo weights\n", @@ -270,6 +269,8 @@ } ], "source": [ + "import thunder\n", + "\n", "prompt = \"the answer to the ultimate question of life \"\n", "in_tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(\"cuda\")\n", "\n", From ed7bdcf6ef3dfa5b66a587ce0f82a58ef4b7db6c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:12:05 +0300 Subject: [PATCH 14/17] Move tokenizer init to the cell where it's used --- notebooks/functional_transformer_example.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index d1e4dd3b5f..d89d3433ec 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -218,9 +218,6 @@ } ], "source": [ - "from transformers import AutoTokenizer\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n", "# Download the official repo weights\n", "state_dict = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\", map_location=\"cuda\")\n", "layers = []\n", @@ -270,6 +267,9 @@ ], "source": [ "import thunder\n", + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n", "\n", "prompt = \"the answer to the ultimate question of life \"\n", "in_tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(\"cuda\")\n", From 869d8b48278efc2aa0ff9ded4836233ebded3d97 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:16:38 +0300 Subject: [PATCH 15/17] Edit text before Thunder execution --- notebooks/functional_transformer_example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index d89d3433ec..b741f966e8 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -248,7 +248,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, we’ll use Thunder to execute the \"functional\" transformer and observe the results. " + "Finally, we’ll use Thunder to execute our implementation and see what is the answer for the input text \"the answer to the ultimate question of life\"." ] }, { From 11e4165e97a6f49670a4a852b3e1f11b361e9058 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Sep 2024 16:24:16 +0300 Subject: [PATCH 16/17] Edit the last paragraph --- notebooks/functional_transformer_example.ipynb | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index b741f966e8..806a10fce7 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -291,7 +291,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And that's it! Converting a PyTorch module into a functional Python function is easier than you might think. Moreover, Thunder seamlessly operates with this functional version of the transformer." + "Congratulations! You've successfully navigated the process of writing a PyTorch model in a functional Python style. This journey has demonstrated that such a conversion, while different from the traditional PyTorch usage, is more accessible than many might initially assume.\n", + "\n", + "Moving forward, consider how you might apply these functional programming techniques to your own projects:\n", + "* Could your existing models benefit from a functional rewrite?\n", + "* How might this approach impact your model's performance and maintainability?\n", + "* What other deep learning architectures could be reimagined through a functional lens?\n", + "By mastering these techniques, you're well-equipped to develop more efficient, understandable, and optimizable deep learning models. The combination of functional programming and tools like Thunder represents a powerful approach to tackling complex deep learning projects.\n", + "\n", + "Check out the [Thunder step-by-step guide](https://lightning-thunder.readthedocs.io/en/latest/basic/inspecting_traces.html) to learn more about how Thunder maps your PyTorch code to the underlying computational graph and optimizes it for execution. Notice how Thunder's initial trace of the model is very similar to the original functional implementation, there's one less layer of abstraction to worry about!" ] } ], From cb44f07fc2aeb17e4fb002f4bf962778a10f9181 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 09:49:07 +0300 Subject: [PATCH 17/17] Apply suggestions from code review * Update the type signature of `rope` * Update the docstring of `rotate` Co-authored-by: beverlylytle <57254617+beverlylytle@users.noreply.github.com> --- notebooks/functional_transformer_example.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/functional_transformer_example.ipynb b/notebooks/functional_transformer_example.ipynb index 806a10fce7..2c919bf488 100644 --- a/notebooks/functional_transformer_example.ipynb +++ b/notebooks/functional_transformer_example.ipynb @@ -102,10 +102,10 @@ "\n", "\n", "# Rotary Positional Encoding\n", - "def rope(x: torch.Tensor, freqs_cis: torch.Tensor):\n", + "def rope(x: torch.Tensor, freqs_cis: tuple):\n", " def rotate(x):\n", " \"\"\"\n", - " rotate_half(torch.arange(4))\n", + " rotate(torch.arange(4))\n", " > tensor([-2, -3, 0, 1])\n", " \"\"\"\n", " x1 = x[..., : x.shape[-1] // 2]\n",