Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functional transformer demo #971

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 327 additions & 0 deletions notebooks/functional_transformer_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we'll explore a \"functional\" transformer implementation that can be efficiently executed using Thunder. This approach offers several advantages in terms of optimization and code clarity.\n",
"\n",
"Python's expressivity and flexibility make it an excellent choice for developing for deep learning models. However, these same qualities can sometimes hinder performance optimization and make it challenging to understand or modify existing codebases.\n",
"Recently, there has been a growing trend towards developing models in a simpler, more transparent style that facilitates optimization and comprehension. Projects like [LitGPT](https://github.com/Lightning-AI/litgpt) and [nanoGPT](https://github.com/karpathy/nanoGPT) are examples of this trend.\n",
"One such style is the \"functional\" programming style, which is free of side effects and can be easily understood and optimized by both developers and compilers.\n",
"\n",
"We'll cover the following key points:\n",
"* The structure and implementation of a functional transformer\n",
"* Advantages of this approach compared to traditional implementations\n",
"* How Thunder can be applied to optimize and execute this functional transformer\n",
"\n",
"By the end of this notebook, you'll have a clear understanding of how functional programming principles can be leveraged to create more efficient and compiler-friendly transformer models.\n",
"\n",
"**Credit**: The code used in this notebook is adapted from https://gist.github.com/nreHieW/a4ae05d216c5326c9fb9a70fcdda3274 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To transform a PyTorch module into a \"functional\" Python implementation, we need to restructure how neural network parameters are handled. Instead of relying on class members and `nn.Module`, we'll pass the module's parameters explicitly as function inputs. This approach enhances transparency and makes the code more amenable to optimization.\n",
"To maintain clean and organized code, we'll use named tuples to group related parameters together.\n",
"Let's examine the helper named tuples we'll use to organize our transformer's parameters:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from typing import NamedTuple\n",
"\n",
"# Helper classes that group the parameters together\n",
"class LayerWeights(NamedTuple):\n",
" input_norm: torch.Tensor # (hidden)\n",
" post_attn_norm: torch.Tensor # (hidden)\n",
" q_proj: torch.Tensor # (hidden, q_intermediate)\n",
" k_proj: torch.Tensor # (hidden, kv_intermediate)\n",
" v_proj: torch.Tensor # (hidden, kv_intermediate)\n",
" o_proj: torch.Tensor # (q_intermediate, hidden)\n",
" gate_proj: torch.Tensor # (hidden, intermediate)\n",
" up_proj: torch.Tensor # (hidden, intermediate)\n",
" down_proj: torch.Tensor # (intermediate, hidden)\n",
"\n",
"\n",
"class TransformerWeights(NamedTuple):\n",
" layers: list[LayerWeights]\n",
" token_emb: torch.Tensor # (vocab_size, hidden)\n",
" final_norm: torch.Tensor # (hidden)\n",
" lm_head: torch.Tensor # (hidden, vocab_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we'll write all the layers used in the transformer using functional implementations. This process involves converting each layer from a PyTorch module into a pure function (a function with no side effects) that takes both the input data and the relevant parameters as explicit arguments. \n",
"In the following sections, we'll walk through the functional implementations of key transformer components, including:\n",
"* Layer normalization\n",
"* Feed-forward network\n",
"* Embedding layer\n",
"* Attention mechanism\n",
"\n",
"Each function will clearly define its inputs, including both the data to be processed and the necessary parameters. This approach will provide a comprehensive view of how data flows through the transformer architecture."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from torch.nn.functional import silu, softmax\n",
"\n",
"NUM_Q_HEADS = 32 # Llama numbers\n",
"NUM_KV_HEADS = 8 # Llama numbers\n",
"SLIDING_WINDOW_SIZE = 4096\n",
"\n",
"# Layer normalization\n",
"def norm(x: torch.Tensor, weight: torch.Tensor):\n",
" in_dtype = x.dtype\n",
" compute_dtype = torch.float32\n",
" x = x.to(compute_dtype)\n",
" eps = 1e-5 # eps might change depending on the model\n",
" out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)\n",
" return weight * out.to(in_dtype)\n",
"\n",
"\n",
"# Feed-forward network\n",
"def ffn(x: torch.Tensor, weights: LayerWeights):\n",
" gate = silu(x @ weights.gate_proj)\n",
" fused = gate * (x @ weights.up_proj)\n",
" return fused @ weights.down_proj\n",
"\n",
"\n",
"# Rotary Positional Encoding\n",
"def rope(x: torch.Tensor, freqs_cis: tuple):\n",
" def rotate(x):\n",
" \"\"\"\n",
" rotate(torch.arange(4))\n",
" > tensor([-2, -3, 0, 1])\n",
" \"\"\"\n",
" x1 = x[..., : x.shape[-1] // 2]\n",
" x2 = x[..., x.shape[-1] // 2 :]\n",
" return torch.cat((-x2, x1), dim=-1)\n",
"\n",
" cos, sin = freqs_cis\n",
" cos = cos.to(x.dtype)\n",
" sin = sin.to(x.dtype)\n",
" right = rotate(x.reshape(*x.shape[:-1], -1, 2)).reshape(x.shape)\n",
" out = x * cos + right * sin\n",
" return out.to(x.dtype)\n",
"\n",
"\n",
"# Attention\n",
"def attn(\n",
" x: torch.Tensor,\n",
" weights: LayerWeights,\n",
" freqs_cis: tuple,\n",
" sliding_window_size=None,\n",
"):\n",
" bs, seq_len, d_model = x.shape\n",
" xq, xk, xv = x @ weights.q_proj, x @ weights.k_proj, x @ weights.v_proj\n",
" xq = xq.view(bs, seq_len, NUM_Q_HEADS, -1).transpose(1, 2) # (bs, NUM_Q_HEADS, seq_len, q_intermediate)\n",
" xk = xk.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2) # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)\n",
" xv = xv.view(bs, seq_len, NUM_KV_HEADS, -1).transpose(1, 2) # (bs, NUM_KV_HEADS, seq_len, kv_intermediate)\n",
" head_dim = xq.shape[-1]\n",
"\n",
" # Treat GQA as MHA and just repeat along the head dimension\n",
" xk = torch.repeat_interleave(xk, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)\n",
" xv = torch.repeat_interleave(xv, NUM_Q_HEADS // NUM_KV_HEADS, dim=1)\n",
" xq = rope(xq, freqs_cis)\n",
" xk = rope(xk, freqs_cis)\n",
"\n",
" attn_scores = (xq @ xk.transpose(2, 3)) * (head_dim**-0.5)\n",
" mask = torch.triu(torch.full((bs, seq_len, seq_len), -2.3819763e38), diagonal=1) # This number is taken from Gemma\n",
" if sliding_window_size is not None: # Sliding window attention\n",
" all_ones = torch.ones((seq_len, seq_len))\n",
" sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) * torch.tril(all_ones, sliding_window_size - 1)\n",
" mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)\n",
" mask = mask.to(x.device, x.dtype)\n",
" attn_scores = attn_scores + mask\n",
" attn_probs = softmax(attn_scores, dim=-1)\n",
" attn_out = attn_probs @ xv\n",
" attn_out = attn_out.transpose(1, 2).contiguous().view(bs, seq_len, -1)\n",
" return attn_out @ weights.o_proj\n",
"\n",
"\n",
"# for efficiency, should precompute for 0..max_length * 2 then select [:curr_length]\n",
"def precompute_freqs_cis(head_dim: int, seq_len: int, base_theta: float = 500000.0):\n",
" inv_freqs = 1.0 / (base_theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) # Eq 15: theta_{1} ... theta_{dim/2}. Shape: (dim/2)\n",
" m = torch.arange(seq_len) # all possible position indices\n",
" freqs = torch.outer(m, inv_freqs).float() # [m_i * theta_j] for all i (positions) and j (frequencies). Shape: (seq_len, dim/2) | freqs[i][j] == m[i] * inv_freqs[j]\n",
" cos = torch.cos(freqs) # Shape: (seq_len, dim/2)\n",
" cos = torch.repeat_interleave(cos, 2, dim=-1) # Shape: (seq_len, dim)\n",
" sin = torch.sin(freqs) # Shape: (seq_len, dim/2)\n",
" sin = torch.repeat_interleave(sin, 2, dim=-1) # Shape: (seq_len, dim)\n",
" return (cos, sin)\n",
"\n",
"\n",
"def transformer(in_tokens: torch.Tensor, weights: TransformerWeights):\n",
" x = torch.nn.functional.embedding(in_tokens, weights.token_emb)\n",
" b, t, d = x.shape\n",
" q_intermediate = weights.layers[0].q_proj.shape[1]\n",
" freqs_cis = precompute_freqs_cis(q_intermediate // NUM_Q_HEADS, t) # (cos, sin)\n",
" for i, layer in enumerate(weights.layers):\n",
" residual = x\n",
" hidden = norm(x, layer.input_norm)\n",
" hidden = attn(hidden, layer, freqs_cis, sliding_window_size=SLIDING_WINDOW_SIZE if i % 6 != 0 else None) # Follows https://research.character.ai/optimizing-inference/\n",
" hidden = residual + hidden\n",
"\n",
" residual = hidden\n",
" hidden = norm(hidden, layer.post_attn_norm)\n",
" hidden = ffn(hidden, layer)\n",
" hidden = residual + hidden\n",
" x = hidden\n",
"\n",
" x = norm(x, weights.final_norm)\n",
" x = x @ weights.lm_head\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the functional versions of the layers and transformer components in place, our next crucial step is to properly initialize and organize the model weights. We'll use the `LayerWeights` and `TransformerWeights` classes to store the weights for normalization layers and the entire transformer, respectively. This organization will make it easier to pass the weights as explicit arguments to the functional transformer while working with real-world, pre-trained models.\n",
"\n",
"In the following sections, we'll demonstrate:\n",
"* How to load weights from a pre-trained model\n",
"* The structure of our `LayerWeights` and `TransformerWeights` classes\n",
"* How these weight containers integrate with our functional transformer implementation\n",
"\n",
"**NOTE**: To run the cells below, you'll need access to the Hugging Face Meta-Llama-3-8B model. Be sure to download the model weights and place them in the \"Meta-Llama-3-8B/consolidated.00.pth\". See [here](https://huggingface.co/meta-llama/Meta-Llama-3-8B) to learn more about Hugging Face Meta-Llama-3-8B."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"/tmp/ipykernel_2627731/4043548868.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\", map_location=\"cuda\")\n"
]
}
],
"source": [
"# Download the official repo weights\n",
"state_dict = torch.load(\"Meta-Llama-3-8B/consolidated.00.pth\", map_location=\"cuda\")\n",
"layers = []\n",
"n_layers = 32\n",
"for i in range(n_layers):\n",
" layer = LayerWeights(\n",
" input_norm=state_dict[f\"layers.{i}.attention_norm.weight\"],\n",
" post_attn_norm=state_dict[f\"layers.{i}.ffn_norm.weight\"],\n",
" q_proj=state_dict[f\"layers.{i}.attention.wq.weight\"].t(),\n",
" k_proj=state_dict[f\"layers.{i}.attention.wk.weight\"].t(),\n",
" v_proj=state_dict[f\"layers.{i}.attention.wv.weight\"].t(),\n",
" o_proj=state_dict[f\"layers.{i}.attention.wo.weight\"].t(),\n",
" gate_proj=state_dict[f\"layers.{i}.feed_forward.w1.weight\"].t(),\n",
" up_proj=state_dict[f\"layers.{i}.feed_forward.w3.weight\"].t(),\n",
" down_proj=state_dict[f\"layers.{i}.feed_forward.w2.weight\"].t(),\n",
" )\n",
" layers.append(layer)\n",
"\n",
"weights = TransformerWeights(\n",
" layers=layers,\n",
" token_emb=state_dict[\"tok_embeddings.weight\"],\n",
" final_norm=state_dict[\"norm.weight\"],\n",
" lm_head=state_dict[\"output.weight\"].t(),\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we’ll use Thunder to execute our implementation and see what is the answer for the input text \"the answer to the ultimate question of life\"."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ours: <|begin_of_text|>the answer to the ultimate question of life 42\n",
"the answer to the ultimate question of life\n"
]
}
],
"source": [
"import thunder\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"\n",
"prompt = \"the answer to the ultimate question of life \"\n",
"in_tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"].to(\"cuda\")\n",
"\n",
"# Use thunder on the \"functional\" transformer\n",
"jitted_transformer = thunder.jit(transformer)\n",
"for _ in range(10):\n",
" out = jitted_transformer(in_tokens, weights)\n",
" next_token = torch.argmax(out[:, -1, :])\n",
" in_tokens = torch.cat((in_tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=1)\n",
"\n",
"del weights\n",
"del state_dict\n",
"\n",
"print(\"Ours:\", tokenizer.decode(in_tokens[0].tolist()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Congratulations! You've successfully navigated the process of writing a PyTorch model in a functional Python style. This journey has demonstrated that such a conversion, while different from the traditional PyTorch usage, is more accessible than many might initially assume.\n",
"\n",
"Moving forward, consider how you might apply these functional programming techniques to your own projects:\n",
"* Could your existing models benefit from a functional rewrite?\n",
"* How might this approach impact your model's performance and maintainability?\n",
"* What other deep learning architectures could be reimagined through a functional lens?\n",
"By mastering these techniques, you're well-equipped to develop more efficient, understandable, and optimizable deep learning models. The combination of functional programming and tools like Thunder represents a powerful approach to tackling complex deep learning projects.\n",
"\n",
"Check out the [Thunder step-by-step guide](https://lightning-thunder.readthedocs.io/en/latest/basic/inspecting_traces.html) to learn more about how Thunder maps your PyTorch code to the underlying computational graph and optimizes it for execution. Notice how Thunder's initial trace of the model is very similar to the original functional implementation, there's one less layer of abstraction to worry about!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading