diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml new file mode 100644 index 00000000..b8e44b74 --- /dev/null +++ b/.github/workflows/ci-tk.yaml @@ -0,0 +1,74 @@ +name: "TK CI" + +on: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test: + name: "Unit Tests and Type Checking" + strategy: + fail-fast: false + matrix: + version: [3.11] + os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + runs-on: ${{matrix.os}} + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@v3 + + - name: Cache Pip Packages + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-cache-dir -r iree-requirements-ci.txt + pip install -r requirements.txt -e . + + - name: Run unit tests + if: ${{ !cancelled() }} + run: | + pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + + - name: Run e2e tests on MI300 + if: "contains(matrix.os, 'mi300') && !cancelled()" + run: | + export WAVE_RUN_E2E_TESTS=1 + pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + + - name: Run LIT tests + if: ${{ !cancelled() }} + run: | + lit lit_tests/ -v + + - name: MyPy Type Checking + if: ${{ !cancelled() }} + run: | + mypy diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8f938a0e..bfb1175e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -18,9 +18,10 @@ jobs: test: name: "Unit Tests and Type Checking" strategy: + fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [ubuntu-latest] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -54,14 +55,7 @@ jobs: - name: Run unit tests if: ${{ !cancelled() }} run: | - pytest -n 4 . - - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" - run: | - export WAVE_RUN_E2E_TESTS=1 - export TEST_PARAMS_PATH=./tests/kernel/wave/test_param.json - pytest -n 4 ./tests/kernel/wave/ + pytest -n 4 --capture=tee-sys -vv . - name: Run LIT tests if: ${{ !cancelled() }} diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf.yaml index 87adb13d..1b8d8271 100644 --- a/.github/workflows/perf.yaml +++ b/.github/workflows/perf.yaml @@ -21,9 +21,10 @@ jobs: test: name: "Unit Tests and Type Checking" strategy: + fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [nodai-amdgpu-mi300-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -53,21 +54,10 @@ jobs: pip install --no-compile -r pytorch-cpu-requirements.txt pip install --no-cache-dir -r iree-requirements-ci.txt pip install -r requirements.txt -e . - - name: Run unit tests - if: ${{ !cancelled() }} - run: | - pytest -n 4 . + - name: Run e2e tests on MI300 if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" - pytest -n 1 ./tests/kernel/wave/ - - name: Run LIT tests - if: ${{ !cancelled() }} - run: | - lit lit_tests/ -v - - name: MyPy Type Checking - if: ${{ !cancelled() }} - run: | - mypy + pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ diff --git a/.github/workflows/test_build_release.yml b/.github/workflows/test_build_release.yml index c0546365..aea5a16a 100644 --- a/.github/workflows/test_build_release.yml +++ b/.github/workflows/test_build_release.yml @@ -19,6 +19,7 @@ jobs: test: name: "Test Build Release Process" strategy: + fail-fast: false matrix: version: [3.11] os: [ubuntu-latest] diff --git a/MANIFEST.in b/MANIFEST.in index 97971bba..65338637 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include README.md include requirements.txt include pytorch-cpu-requirements.txt include version_info.json -include shark_turbine/ops/templates/*.mlir +include iree/turbine/ops/templates/*.mlir diff --git a/README.md b/README.md index 4d0d0c22..aa01b826 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Turbine provides a collection of tools: * *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py)) - for simple models and an underlying [advanced API](shark_turbine/aot/compiled_module.py) for complicated models + for simple models and an underlying [advanced API](iree/turbine/aot/compiled_module.py) for complicated models and accessing the full features of the runtime. * *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device is available for more native, interactive use within a PyTorch session. diff --git a/build_tools/build_release.py b/build_tools/build_release.py index 5a6ef98d..5a90a7cf 100755 --- a/build_tools/build_release.py +++ b/build_tools/build_release.py @@ -159,10 +159,8 @@ def main(): print("Downloading remaining requirements") download_requirements(REPO_ROOT / "requirements.txt") - print("Building shark-turbine") - build_wheel(REPO_ROOT) print("Building iree-turbine") - build_wheel(REPO_ROOT, env={"TURBINE_PACKAGE_NAME": "iree-turbine"}) + build_wheel(REPO_ROOT) if __name__ == "__main__": diff --git a/examples/aot_mlp/mlp_export_dynamic.py b/examples/aot_mlp/mlp_export_dynamic.py index cd863655..3bedd7c1 100644 --- a/examples/aot_mlp/mlp_export_dynamic.py +++ b/examples/aot_mlp/mlp_export_dynamic.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index fed4795d..30d7ae95 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/llama2_inference/README.md b/examples/llama2_inference/README.md deleted file mode 100644 index 50bc6537..00000000 --- a/examples/llama2_inference/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# LLAMA 2 Inference - -This example require some extra dependencies. Here's an easy way to get it running on a fresh server. - -Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens - -```bash -#!/bin/bash - - -# if you don't insert it, you will be prompted to log in later; -# you may need to rerun this script after logging in -YOUR_HF_TOKEN="insert token for headless" - -# clone and install dependencies -sudo apt install -y git -git clone https://github.com/nod-ai/SHARK-Turbine.git -cd SHARK-Turbine -pip install -r requirements.txt -pip install --update "huggingface_hub[cli]" transformers sentencepiece protobuf - -# do an editable install from the cloned SHARK-Turbine -pip install --editable . - -# Log in with Hugging Face CLI if token setup is required -if [[ $YOUR_HF_TOKEN == hf_* ]]; then - huggingface login --token $YOUR_HF_TOKEN - echo "Logged in with YOUR_HF_TOKEN." -elif [ -f ~/.cache/huggingface/token ]; then - # Read token from the file - TOKEN_CONTENT=$(cat ~/.cache/huggingface/token) - - # Check if the token starts with "hf_" - if [[ $TOKEN_CONTENT == hf_* ]]; then - echo "Already logged in with a Hugging Face token." - else - echo "Token in file does not start with 'hf_'. Please log into huggingface to download models." - huggingface-cli login - fi -else - echo "Please log into huggingface to download models." - huggingface-cli login -fi - -# Step 7: Run the Python script -python examples/llama2_inference/stateless_llama.py -``` diff --git a/examples/llama2_inference/llama2.ipynb b/examples/llama2_inference/llama2.ipynb deleted file mode 100644 index b008bbd2..00000000 --- a/examples/llama2_inference/llama2.ipynb +++ /dev/null @@ -1,503 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "c0c9f034-7af1-4dc2-bbfb-5bb9e27c07ca", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "import torch\n", - "from torch.utils import _pytree as pytree\n", - "from shark_turbine.aot import *\n", - "from iree.compiler.ir import Context\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4d92bb47-2b93-4f32-a445-c0ad2adc37ad", - "metadata": {}, - "outputs": [], - "source": [ - "#set some config values\n", - "\n", - "hf_auth_token = \"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\"\n", - "hf_model_name = \"meta-llama/Llama-2-7b-chat-hf\"\n", - "state_schema_path = \"llama2_state_schema.json\"\n", - "with open(state_schema_path, \"r+\") as f:\n", - " state_schema = pytree.treespec_loads(f.read())\n", - "prompt = \"\"\"\n", - "[INST] <>\n", - "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST]\n", - "\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d4664585-5e15-45c7-8c5c-c8eaf6381435", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:640: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n", - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:479: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5e411acda19c4228b008ff622bdf110e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00.5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:26 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,234] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:72 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,409] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s2, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:118 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:33,707] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s3 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:189 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,845] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s3, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:228 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:33,878] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s4, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:235 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,188] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s5 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:306 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,326] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s5, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:345 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,359] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s6, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:352 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,661] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s7 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:423 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,800] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s7, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:462 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,832] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s8, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:469 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,130] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s9 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:540 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,271] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s9, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:579 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,305] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s10, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:586 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,611] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s11 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:657 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,762] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s11, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:696 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,795] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s12, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:703 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,107] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s13 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:774 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s13, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:813 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,282] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s14, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:820 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,589] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s15 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:891 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s15, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:930 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s16, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:937 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,105] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s17 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1008 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s17, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1047 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,286] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s18, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1054 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,595] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s19 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1125 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,744] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s19, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1164 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,778] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s20, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1171 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,090] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s21 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1242 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,238] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s21, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1281 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,272] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s22, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1288 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,584] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s23 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1359 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s23, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1398 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s24, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1405 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,086] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s25 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1476 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,239] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s25, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1515 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,274] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s26, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1522 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,597] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s27 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1593 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,759] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s27, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1632 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,812] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s28, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1639 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:40,330] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s29 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1710 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:40,534] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s29, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1749 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:40,582] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s30, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1756 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,068] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s31 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1827 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,242] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s31, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1866 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:41,280] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s32, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1873 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,686] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s33 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1944 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,968] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s33, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1983 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,004] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s34, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1990 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:42,419] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s35 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2061 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:42,580] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s35, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2100 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,618] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s36, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2107 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,002] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s37 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2178 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,174] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s37, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2217 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,215] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s38, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2224 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,566] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s39 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2295 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,738] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s39, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2334 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,776] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s40, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2341 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,116] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s41 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2412 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,281] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s41, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2451 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,320] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s42, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2458 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,656] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s43 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2529 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,822] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s43, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2568 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,860] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s44, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2575 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,218] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s45 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2646 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,387] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s45, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2685 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,426] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s46, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2692 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,772] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s47 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2763 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,943] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s47, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2802 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,983] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s48, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2809 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,376] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s49 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2880 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:46,563] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s49, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2919 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:46,605] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s50, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2926 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,962] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s51 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2997 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,136] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s51, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3036 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,176] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s52, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3043 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:47,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s53 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3114 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,718] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s53, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3153 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,758] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s54, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3160 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,125] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s55 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3231 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,308] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s55, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3270 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,349] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s56, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3277 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,715] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s57 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3348 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,897] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s57, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3387 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,937] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s58, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3394 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,317] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s59 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3465 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:49,499] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s59, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3504 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:49,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s60, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3511 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,915] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s61 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3582 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,113] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s61, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3621 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,155] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s62, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3628 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:50,515] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s63 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3699 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,697] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s63, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3738 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,737] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s64, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3745 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:53,791] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards\n", - "[2023-10-09 18:49:54,155] torch.fx.experimental.symbolic_shapes: [WARNING] Ignored guard s0 + s1 > 4096 == False, this could result in accuracy problems\n", - "[2023-10-09 18:49:54,157] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] (_decomp/decompositions.py:725 in slice_forward)\n" - ] - } - ], - "source": [ - "#Run the export pipeline\n", - "inst = StateUpdateModule(context=Context(), import_to=\"IMPORT\")\n", - "module_str = str(CompiledModule.get_mlir_module(inst))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "bc04e1db-a8cc-4182-884d-ba3d8ae5adeb", - "metadata": {}, - "outputs": [], - "source": [ - "#Output a torch-ir mlir file\n", - "with open(\"llama2_torch.mlir\", \"w+\") as f:\n", - " f.write(module_str)\n", - "#TODO: run the rest of the compile pipeline and do inference" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/llama2_inference/llama2_state_schema.json b/examples/llama2_inference/llama2_state_schema.json deleted file mode 100644 index b5506055..00000000 --- a/examples/llama2_inference/llama2_state_schema.json +++ /dev/null @@ -1 +0,0 @@ -[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}] diff --git a/examples/llama2_inference/requirements.txt b/examples/llama2_inference/requirements.txt deleted file mode 100644 index acbc93ca..00000000 --- a/examples/llama2_inference/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -protobuf -sentencepiece -shark_turbine -transformers @ git+https://github.com/huggingface/transformers.git@7d8ff3629b2725ec43ace99c1a6e87ac1978d433 diff --git a/examples/resnet-18/requirements.txt b/examples/resnet-18/requirements.txt index a5123e97..b7428649 100644 --- a/examples/resnet-18/requirements.txt +++ b/examples/resnet-18/requirements.txt @@ -1,2 +1,2 @@ transformers -shark_turbine==0.9.2 +iree_turbine==0.9.2 diff --git a/examples/resnet-18/resnet-18.py b/examples/resnet-18/resnet-18.py index 20340013..2b3fce56 100644 --- a/examples/resnet-18/resnet-18.py +++ b/examples/resnet-18/resnet-18.py @@ -1,6 +1,6 @@ from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch -from shark_turbine.aot import * +from iree.turbine.aot import * import iree.runtime as rt # Loading feature extractor and pretrained model from huggingface diff --git a/examples/runtime_torture/launchable_torture.py b/examples/runtime_torture/launchable_torture.py index 56f92a99..d58c6a80 100644 --- a/examples/runtime_torture/launchable_torture.py +++ b/examples/runtime_torture/launchable_torture.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/iree/turbine/__init__.py b/iree/turbine/__init__.py index c59e85c2..d95aa54f 100644 --- a/iree/turbine/__init__.py +++ b/iree/turbine/__init__.py @@ -8,15 +8,3 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# TODO: This redirection layer exists while we are migrating from the -# shark_turbine top-level package name to iree.turbine. It exports the -# public API but not the internal details. In a future switch, all code -# will be directly located here and the redirect will be done in the -# shark_turbine namespace. - -from shark_turbine import aot -from shark_turbine import dynamo -from shark_turbine import kernel -from shark_turbine import ops -from shark_turbine import runtime diff --git a/shark_turbine/aot/__init__.py b/iree/turbine/aot/__init__.py similarity index 100% rename from shark_turbine/aot/__init__.py rename to iree/turbine/aot/__init__.py diff --git a/shark_turbine/aot/builtins/__init__.py b/iree/turbine/aot/builtins/__init__.py similarity index 100% rename from shark_turbine/aot/builtins/__init__.py rename to iree/turbine/aot/builtins/__init__.py diff --git a/shark_turbine/aot/builtins/globals.py b/iree/turbine/aot/builtins/globals.py similarity index 100% rename from shark_turbine/aot/builtins/globals.py rename to iree/turbine/aot/builtins/globals.py diff --git a/shark_turbine/aot/builtins/jittable.py b/iree/turbine/aot/builtins/jittable.py similarity index 100% rename from shark_turbine/aot/builtins/jittable.py rename to iree/turbine/aot/builtins/jittable.py diff --git a/shark_turbine/aot/compiled_module.py b/iree/turbine/aot/compiled_module.py similarity index 96% rename from shark_turbine/aot/compiled_module.py rename to iree/turbine/aot/compiled_module.py index 3f44c8b9..5fffd6a0 100644 --- a/shark_turbine/aot/compiled_module.py +++ b/iree/turbine/aot/compiled_module.py @@ -41,6 +41,7 @@ from .support.ir_utils import ( ModuleBuilder, + ModuleBuilderOptions, ) @@ -162,11 +163,13 @@ class CompiledModuleClassInfo: __slots__ = [ "all_exports", "ir_module_name", + "options", ] - def __init__(self, *, ir_module_name: str): + def __init__(self, *, ir_module_name: str, options: ModuleBuilderOptions): self.ir_module_name = ir_module_name self.all_exports: Dict[str, Exportable] = dict() + self.options = options def add_export(self, key: str, value: Exportable): if key in self.all_exports: @@ -370,13 +373,23 @@ class CompiledModuleMeta(type): # It is passed the dictionary of declared attributes and any keyword # arguments from the class declaration: # class Foo(Bar, kwarg="you probably just learned this is possible"): - def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): + def __new__( + mcls, + name: str, + bases, + dct, + *, + export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, + ): if not _metaclass_setup_complete: return type.__new__(mcls, name, bases, dct) ir_module_name = _derive_ir_module_name(name, export_name) logger.debug("Create new CompiledModule: %s", ir_module_name) - info = CompiledModuleClassInfo(ir_module_name=ir_module_name) + info = CompiledModuleClassInfo( + ir_module_name=ir_module_name, options=options or ModuleBuilderOptions() + ) # Process that attributes that were set as part of class definition. # Any attributes that we decide are part of the compiled module @@ -436,6 +449,7 @@ def create_from_dict( dct: dict, *, export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, ) -> CompiledModuleMeta: """Creates a CompiledModule subclass with an explicit dictionary of members. @@ -446,7 +460,9 @@ class Foo(CompiledModule, export_name="bar"): def member(): ... ``` """ - return CompiledModuleMeta(name, (cls,), dct, export_name=export_name) + return CompiledModuleMeta( + name, (cls,), dct, export_name=export_name, options=options + ) @staticmethod def get_class_info(cls: CompiledModuleMeta) -> CompiledModuleClassInfo: @@ -596,7 +612,7 @@ def __new__( module_op.attributes["sym_name"] = StringAttr.get( class_info.ir_module_name, context=context ) - module_builder = ModuleBuilder(module_op) + module_builder = ModuleBuilder(module_op, options=class_info.options) info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder) _all_compiled_module_instance_infos[self] = info diff --git a/shark_turbine/aot/decompositions.py b/iree/turbine/aot/decompositions.py similarity index 100% rename from shark_turbine/aot/decompositions.py rename to iree/turbine/aot/decompositions.py diff --git a/shark_turbine/aot/exporter.py b/iree/turbine/aot/exporter.py similarity index 94% rename from shark_turbine/aot/exporter.py rename to iree/turbine/aot/exporter.py index 4c0e0160..c1adb527 100644 --- a/shark_turbine/aot/exporter.py +++ b/iree/turbine/aot/exporter.py @@ -26,6 +26,7 @@ from .builtins import * from .compiled_module import ( CompiledModule, + ModuleBuilderOptions, ImportPhase, ) from .fx_programs import FxPrograms @@ -175,6 +176,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Exports a torch.nn.Module. @@ -223,6 +225,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Generic export of supported entities. @@ -270,11 +273,19 @@ def export( "LambdaCompiledModule", {(function_name or "main"): mdl}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, FxPrograms): TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", mdl.programs, export_name=module_name or "module" + "LambdaCompiledModule", + mdl.programs, + export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, torch.nn.Module): # Normalize arguments for torch.export. @@ -302,6 +313,9 @@ def export( "LambdaCompiledModule", {(function_name or "main"): exported_program}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif issubclass(mdl, CompiledModule): TransformedModule = mdl diff --git a/shark_turbine/aot/fx_programs.py b/iree/turbine/aot/fx_programs.py similarity index 100% rename from shark_turbine/aot/fx_programs.py rename to iree/turbine/aot/fx_programs.py diff --git a/shark_turbine/aot/params.py b/iree/turbine/aot/params.py similarity index 100% rename from shark_turbine/aot/params.py rename to iree/turbine/aot/params.py diff --git a/shark_turbine/aot/passes/__init__.py b/iree/turbine/aot/passes/__init__.py similarity index 100% rename from shark_turbine/aot/passes/__init__.py rename to iree/turbine/aot/passes/__init__.py diff --git a/shark_turbine/aot/passes/functorch.py b/iree/turbine/aot/passes/functorch.py similarity index 100% rename from shark_turbine/aot/passes/functorch.py rename to iree/turbine/aot/passes/functorch.py diff --git a/shark_turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py similarity index 97% rename from shark_turbine/aot/support/ir_utils.py rename to iree/turbine/aot/support/ir_utils.py index a662c15c..e1eb9d56 100644 --- a/shark_turbine/aot/support/ir_utils.py +++ b/iree/turbine/aot/support/ir_utils.py @@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Sequence, Tuple +from dataclasses import dataclass from pathlib import Path import tempfile @@ -148,6 +149,12 @@ def infer_external_from_tensor( ############################################################################### +@dataclass +class ModuleBuilderOptions: + # Whether to import torch symbolic shape expressions for ExportedPrograms. + import_symbolic_shape_expressions: bool = False + + class ModuleBuilder: """Wrapper around module and IR accounting for a module being built.""" @@ -159,14 +166,18 @@ class ModuleBuilder: "last_global_op", "ip", "module_op", + "options", "symbol_table", "global_ref_tracker", "native_type_converter", "_auto_symbol_counts", ] - def __init__(self, module_op: Operation): + def __init__( + self, module_op: Operation, *, options: Optional[ModuleBuilderOptions] = None + ): self.module_op = module_op + self.options = options or ModuleBuilderOptions() self.context = module_op.context self.body = module_op.regions[0].blocks[0] self.symbol_table = SymbolTable(module_op) diff --git a/shark_turbine/aot/support/procedural/__init__.py b/iree/turbine/aot/support/procedural/__init__.py similarity index 100% rename from shark_turbine/aot/support/procedural/__init__.py rename to iree/turbine/aot/support/procedural/__init__.py diff --git a/shark_turbine/aot/support/procedural/base.py b/iree/turbine/aot/support/procedural/base.py similarity index 100% rename from shark_turbine/aot/support/procedural/base.py rename to iree/turbine/aot/support/procedural/base.py diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/iree/turbine/aot/support/procedural/exported_program.py similarity index 95% rename from shark_turbine/aot/support/procedural/exported_program.py rename to iree/turbine/aot/support/procedural/exported_program.py index 331a7345..f6540bab 100644 --- a/shark_turbine/aot/support/procedural/exported_program.py +++ b/iree/turbine/aot/support/procedural/exported_program.py @@ -181,7 +181,10 @@ def import_exported_program( ) -> ExportedProgramIntrinsic: fx_importer = _create_fx_importer(module_builder) entry_func_op = fx_importer.import_program( - exported_program, func_name=symbol_name, func_visibility=symbol_visibility + exported_program, + func_name=symbol_name, + func_visibility=symbol_visibility, + import_symbolic_shape_expressions=module_builder.options.import_symbolic_shape_expressions, ) module_call_graph = exported_program.module_call_graph @@ -234,6 +237,8 @@ def store_produced_value( raise ValueError(f"Cannot store value to unmapped global for: {info}") logger.debug("Resolved global for store %r", mapping) materialized_global: MaterializedGlobal = mapping.value # type: ignore + assert isinstance(materialized_global.global_op, util_d.GlobalOp) + materialized_global.global_op.is_mutable = True converted_value = Operation.create( "torch_c.to_builtin_tensor", results=[materialized_global.ir_type], @@ -251,7 +256,7 @@ def resolve_literal( return None # See if we know about it. - materialized_global = self._lift_tensor_to_global(literal) + materialized_global = self._lift_tensor_to_global(literal, info) if not materialized_global: # If it is unknown, just let the default importer take it on. return None @@ -269,7 +274,7 @@ def resolve_literal( return converted_value def _lift_tensor_to_global( - self, literal: torch.Tensor + self, literal: torch.Tensor, info: InputInfo | None ) -> Optional[MaterializedGlobal]: module_builder = self.module_builder mapping = module_builder.global_ref_tracker.track(literal) @@ -282,7 +287,7 @@ def _lift_tensor_to_global( # Policy check: Should we auto-import? Generally, we keep "small" # tensors as inline as they can be optimized. external_trait = ExternalTensorTrait.get(literal) - if not self._should_lift_tensor_to_global(literal, external_trait): + if not self._should_lift_tensor_to_global(literal, external_trait, info): return None # If it is a tensor we haven't seen yet, materialize it @@ -304,8 +309,13 @@ def _lift_tensor_to_global( return materialized_global def _should_lift_tensor_to_global( - self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait] + self, + literal: torch.Tensor, + external_trait: Optional[ExternalTensorTrait], + info: InputInfo | None, ) -> bool: + if info is not None and info.store_producer_node: + return True if external_trait is not None: return True volume = math.prod(literal.shape) diff --git a/shark_turbine/aot/support/procedural/globals.py b/iree/turbine/aot/support/procedural/globals.py similarity index 100% rename from shark_turbine/aot/support/procedural/globals.py rename to iree/turbine/aot/support/procedural/globals.py diff --git a/shark_turbine/aot/support/procedural/iree_emitter.py b/iree/turbine/aot/support/procedural/iree_emitter.py similarity index 100% rename from shark_turbine/aot/support/procedural/iree_emitter.py rename to iree/turbine/aot/support/procedural/iree_emitter.py diff --git a/shark_turbine/aot/support/procedural/primitives.py b/iree/turbine/aot/support/procedural/primitives.py similarity index 100% rename from shark_turbine/aot/support/procedural/primitives.py rename to iree/turbine/aot/support/procedural/primitives.py diff --git a/shark_turbine/aot/support/procedural/tracer.py b/iree/turbine/aot/support/procedural/tracer.py similarity index 100% rename from shark_turbine/aot/support/procedural/tracer.py rename to iree/turbine/aot/support/procedural/tracer.py diff --git a/shark_turbine/aot/tensor_traits.py b/iree/turbine/aot/tensor_traits.py similarity index 100% rename from shark_turbine/aot/tensor_traits.py rename to iree/turbine/aot/tensor_traits.py diff --git a/shark_turbine/dynamo/__init__.py b/iree/turbine/dynamo/__init__.py similarity index 100% rename from shark_turbine/dynamo/__init__.py rename to iree/turbine/dynamo/__init__.py diff --git a/shark_turbine/dynamo/backends/cpu.py b/iree/turbine/dynamo/backends/cpu.py similarity index 100% rename from shark_turbine/dynamo/backends/cpu.py rename to iree/turbine/dynamo/backends/cpu.py diff --git a/shark_turbine/dynamo/decompositions.py b/iree/turbine/dynamo/decompositions.py similarity index 100% rename from shark_turbine/dynamo/decompositions.py rename to iree/turbine/dynamo/decompositions.py diff --git a/shark_turbine/dynamo/executor.py b/iree/turbine/dynamo/executor.py similarity index 100% rename from shark_turbine/dynamo/executor.py rename to iree/turbine/dynamo/executor.py diff --git a/shark_turbine/dynamo/passes.py b/iree/turbine/dynamo/passes.py similarity index 100% rename from shark_turbine/dynamo/passes.py rename to iree/turbine/dynamo/passes.py diff --git a/shark_turbine/dynamo/tensor.py b/iree/turbine/dynamo/tensor.py similarity index 99% rename from shark_turbine/dynamo/tensor.py rename to iree/turbine/dynamo/tensor.py index cd1de1ea..bdf1cb83 100644 --- a/shark_turbine/dynamo/tensor.py +++ b/iree/turbine/dynamo/tensor.py @@ -474,8 +474,8 @@ def _get_device_state() -> DeviceState: return DeviceState(driver="local-task") -# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 -# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py +# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/iree/turbine/aot/builtins/jittable.py#L212-L237 +# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/iree/turbine/dynamo/backends/cpu.py # TODO: Try to generalize for other devices. def compute_method(super_fn, *args, **kwargs): # Compute factory fns reserve the last arg as src_op diff --git a/shark_turbine/dynamo/type_conversion.py b/iree/turbine/dynamo/type_conversion.py similarity index 99% rename from shark_turbine/dynamo/type_conversion.py rename to iree/turbine/dynamo/type_conversion.py index 8206e10f..e829bafc 100644 --- a/shark_turbine/dynamo/type_conversion.py +++ b/iree/turbine/dynamo/type_conversion.py @@ -32,7 +32,7 @@ # 1. Local name (int, float, vtensor) # 2. Parameter block ("<...>"), including the delimitters # 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") +DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$") # Decomposes a vtensor parameter block into a dimension list and dtype. Groups: # 1. Dimension list diff --git a/shark_turbine/importers/README.md b/iree/turbine/importers/README.md similarity index 100% rename from shark_turbine/importers/README.md rename to iree/turbine/importers/README.md diff --git a/shark_turbine/importers/ir.py b/iree/turbine/importers/ir.py similarity index 100% rename from shark_turbine/importers/ir.py rename to iree/turbine/importers/ir.py diff --git a/shark_turbine/importers/utils.py b/iree/turbine/importers/utils.py similarity index 100% rename from shark_turbine/importers/utils.py rename to iree/turbine/importers/utils.py diff --git a/shark_turbine/kernel/__init__.py b/iree/turbine/kernel/__init__.py similarity index 100% rename from shark_turbine/kernel/__init__.py rename to iree/turbine/kernel/__init__.py diff --git a/shark_turbine/kernel/_support/context.py b/iree/turbine/kernel/_support/context.py similarity index 100% rename from shark_turbine/kernel/_support/context.py rename to iree/turbine/kernel/_support/context.py diff --git a/shark_turbine/kernel/_support/dtype.py b/iree/turbine/kernel/_support/dtype.py similarity index 100% rename from shark_turbine/kernel/_support/dtype.py rename to iree/turbine/kernel/_support/dtype.py diff --git a/shark_turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py similarity index 96% rename from shark_turbine/kernel/_support/indexing.py rename to iree/turbine/kernel/_support/indexing.py index 3f092278..b99d7b5b 100644 --- a/shark_turbine/kernel/_support/indexing.py +++ b/iree/turbine/kernel/_support/indexing.py @@ -99,6 +99,7 @@ class IndexingContext: __slots__ = [ "subs", + "special_subs", "shaped_bindings", "dyn_dims", "frozen_subs", @@ -109,6 +110,7 @@ class IndexingContext: def __init__(self): self.subs: dict[IndexSymbol, int] = {} + self.special_subs: dict[IndexSymbol, Any] = {} # Indexed by .instance self.shaped_bindings: dict[Any, _ShapedBinding] = {} self.dyn_dims: list[IndexSymbol] = [] @@ -245,6 +247,20 @@ def get_static_value(self, expr: IndexExpr | int) -> Optional[int]: except TypeError: return None + def iota(self, n: int) -> IndexExpr: + sym = index_symbol(f"$IOTA{n}") + if sym not in self.special_subs: + self.special_subs[sym] = tuple(range(n)) + + return sym + + def get_val(self, sym: IndexSymbol) -> Any: + res = self.subs.get(sym, None) + if res is None: + res = self.special_subs.get(sym, None) + + return res + ##### Context management. @staticmethod def current() -> "IndexingContext": diff --git a/shark_turbine/kernel/_support/regions.py b/iree/turbine/kernel/_support/regions.py similarity index 100% rename from shark_turbine/kernel/_support/regions.py rename to iree/turbine/kernel/_support/regions.py diff --git a/shark_turbine/kernel/_support/shaped_type.py b/iree/turbine/kernel/_support/shaped_type.py similarity index 100% rename from shark_turbine/kernel/_support/shaped_type.py rename to iree/turbine/kernel/_support/shaped_type.py diff --git a/shark_turbine/kernel/_support/tracing.py b/iree/turbine/kernel/_support/tracing.py similarity index 99% rename from shark_turbine/kernel/_support/tracing.py rename to iree/turbine/kernel/_support/tracing.py index 42424257..857cdb34 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/iree/turbine/kernel/_support/tracing.py @@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str): def get_subgraph(self, name: str) -> fx.Graph: return self.region_graph.subgraphs[name] + def add_subgraph(self, name: str, graph: fx.Graph): + self.region_graph.subgraphs[name] = graph + def get_root_graph(self) -> fx.Graph: return self.get_subgraph(self.root_graph) diff --git a/shark_turbine/kernel/compiler/base.py b/iree/turbine/kernel/compiler/base.py similarity index 100% rename from shark_turbine/kernel/compiler/base.py rename to iree/turbine/kernel/compiler/base.py diff --git a/shark_turbine/kernel/compiler/builder.py b/iree/turbine/kernel/compiler/builder.py similarity index 100% rename from shark_turbine/kernel/compiler/builder.py rename to iree/turbine/kernel/compiler/builder.py diff --git a/shark_turbine/kernel/compiler/dispatch_codegen.py b/iree/turbine/kernel/compiler/dispatch_codegen.py similarity index 77% rename from shark_turbine/kernel/compiler/dispatch_codegen.py rename to iree/turbine/kernel/compiler/dispatch_codegen.py index 0fccf39c..32dab88c 100644 --- a/shark_turbine/kernel/compiler/dispatch_codegen.py +++ b/iree/turbine/kernel/compiler/dispatch_codegen.py @@ -7,9 +7,7 @@ from typing import Any, Callable, Optional, Type -from .._support.indexing import ( - IndexingContext, -) +from .._support.indexing import IndexingContext, IndexSymbol, IndexExpr from .base import ( CodegenError, @@ -99,6 +97,7 @@ def define_entrypoint( grid: Grid, workgroup_size: list[int] = None, subgroup_size: int = None, + dynamic_symbols: list[IndexSymbol] = [], ) -> "DispatchEntrypoint": """Defines a dispatch function with a signature like: @@ -119,7 +118,6 @@ def define_entrypoint( The given name is not uniqued (must be unique as given by the caller). """ kb_input_bindings = sig.kernel_buffer_input_bindings - kb_temp_bindings = sig.kernel_buffer_temporary_bindings kb_output_bindings = sig.kernel_buffer_output_bindings # TODO: The way we are doing grid bindings is wrong. The Grid type # should be paramerized with special grid axis symbols which are @@ -127,18 +125,17 @@ def define_entrypoint( # just assuming that the grid dims can be resolved to constants , when # in reality, we should pass the workload and parameterize the grid # dims on the workloads. - workload_axis_bindings = [] + dynamic_dim_bindings = sig.dynamic_dim_bindings # Input bindings are always user specified. - # Grid/workgroup bindings are in the inputs section but are implied. - # Temp bindings are a special kind of output bindings. # Output bindings are the real outputs. - linear_bindings = ( - kb_input_bindings - + workload_axis_bindings - + kb_temp_bindings - + kb_output_bindings - ) + # Dynamic dim bindings are the dynamic dims of the input and output tensors. + linear_bindings = kb_input_bindings + dynamic_dim_bindings + kb_output_bindings + + dynamic_dim_indices = { + "begin": len(kb_input_bindings), + "end": len(linear_bindings) - len(kb_output_bindings), + } # TODO: This is sloppy. This assert will hit on some user errors for # unsupported type combinations and is just a last resort right now. @@ -177,7 +174,7 @@ def abi_type(binding: BindingDesc): with InsertionPoint.at_block_begin(self._exe_block): export_op = stream_d.ExecutableExportOp(name, name) export_block = export_op.workgroup_count.blocks.append( - *([b.as_mlir_type() for b in workload_axis_bindings]) + *([b.as_mlir_type() for b in dynamic_dim_bindings]) ) workgroup_builder = WorkgroupBuilder( @@ -185,12 +182,30 @@ def abi_type(binding: BindingDesc): ) # TODO: Support passing workload to the dispatch function. + from ..wave.codegen import gen_sympy_index + + # Map dynamic symbols to block arguments. + dynamic_symbols_mapping = { + k: v + for k, v in zip( + dynamic_symbols, workgroup_builder.entry_block.arguments + ) + } + with InsertionPoint(workgroup_builder.entry_block): result_type = IndexType.get() - workgroup_values = [ - arith_d.constant(result_type, IntegerAttr.get(result_type, dim)) - for dim in grid.dims - ] + workgroup_values = [] + for dim in grid.dims: + if isinstance(dim, IndexExpr): + workgroup_values.append( + gen_sympy_index(dynamic_symbols_mapping, dim) + ) + else: + workgroup_values.append( + arith_d.constant( + result_type, IntegerAttr.get(result_type, dim) + ) + ) while len(workgroup_values) < 3: workgroup_values.append( @@ -198,7 +213,20 @@ def abi_type(binding: BindingDesc): ) workgroup_builder.terminate(workgroup_values) - return DispatchEntrypoint(sig, def_func_block, linear_bindings) + # Map dynamic symbols to func arguments for dispatch entrypoint. + dynamic_symbols_mapping = { + k: v + for k, v in zip( + dynamic_symbols, + def_func_args[ + dynamic_dim_indices["begin"] : dynamic_dim_indices["end"] + ], + ) + } + + return DispatchEntrypoint( + sig, def_func_block, linear_bindings, dynamic_symbols_mapping + ) class WorkgroupBuilder: @@ -231,8 +259,10 @@ def __init__( sig: KernelSignature, entry_block: Block, linear_bindings: list[BindingDesc], + dynamic_symbols_mapping: dict[IndexSymbol, Value], ): super().__init__(sig, entry_block) + self.dynamic_symbols_mapping = dynamic_symbols_mapping self._abi_value_by_reference: dict[tuple[str, Any], Value] = { b.reference: value for value, b in zip(entry_block.arguments, linear_bindings) @@ -250,12 +280,15 @@ def resolve(self, binding: BindingDesc) -> Value: result_type = IndexType.get() zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0)) linear_arg_value = self._abi_value_by_reference[binding.reference] - # TODO: Need to also look up dynamic symbol values. return stream_d.binding_subspan( binding.as_mlir_type(), linear_arg_value, byte_offset=zero_value, - dynamic_dims=[], + dynamic_dims=[ + self.dynamic_symbols_mapping[dim] + for dim in binding.kernel_buffer_type.symbolic_shape + if dim in self.dynamic_symbols_mapping + ], ) raise ValidationError(f"Unhandled binding type: {binding}") diff --git a/shark_turbine/kernel/compiler/host_codegen.py b/iree/turbine/kernel/compiler/host_codegen.py similarity index 52% rename from shark_turbine/kernel/compiler/host_codegen.py rename to iree/turbine/kernel/compiler/host_codegen.py index 9225d831..d74af490 100644 --- a/shark_turbine/kernel/compiler/host_codegen.py +++ b/iree/turbine/kernel/compiler/host_codegen.py @@ -8,6 +8,7 @@ from .ir import ( Block, FunctionType, + IndexType, InsertionPoint, IrType, Location, @@ -19,6 +20,9 @@ func_d, ) +from .._support.indexing import IndexSymbol +from .kernel_codegen import BindingDesc + def memref_to_tensor(memrefs: list[IrType]): tensors = [] @@ -29,22 +33,47 @@ def memref_to_tensor(memrefs: list[IrType]): return tensors +def get_dynamic_dims(bindings: list[BindingDesc], dynamic_symbols: list[IndexSymbol]): + dynamic_dims: list[IndexSymbol] = [] + for b in bindings: + for dim in b.kernel_buffer_type.symbolic_shape: + if dim in dynamic_symbols: + dynamic_dims.append(dim) + return dynamic_dims + + def isolated_test_call( - mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str + mb: ModuleBuilder, + exe: StreamExecutable, + sig: KernelSignature, + entrypoint: str, + dynamic_symbols: list[IndexSymbol] = [], ): with InsertionPoint(mb.body_block), Location.unknown(): input_types = [b.as_mlir_type() for b in sig.kernel_buffer_input_bindings] input_tensors = memref_to_tensor(input_types) + argument_dims = get_dynamic_dims( + sig.kernel_buffer_input_bindings, dynamic_symbols + ) + input_tensors += [IndexType.get() for _ in argument_dims] + output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings] output_tensors = memref_to_tensor(output_types) + result_dims = get_dynamic_dims( + sig.kernel_buffer_output_bindings, dynamic_symbols + ) ftype = FunctionType.get(input_tensors, output_tensors) func_op = func_d.FuncOp("isolated_benchmark", ftype) arg_locs = [ (Location.name(b.name) if b.name is not None else Location.unknown()) - for b in sig.kernel_buffer_input_bindings + for b in sig.kernel_buffer_input_bindings + sig.dynamic_dim_bindings ] entry_block = func_op.add_entry_block(arg_locs) + offset = len(sig.kernel_buffer_input_bindings) + dynamic_argument_map = { + k: v for k, v in zip(dynamic_symbols, entry_block.arguments[offset:]) + } with InsertionPoint(entry_block): assert isinstance(entry_block, Block) # Create a flow.dispatch op to the kernel @@ -52,7 +81,12 @@ def isolated_test_call( entrypoints = ArrayAttr.get([dispatch]) out = flow_d.DispatchOp( - output_tensors, [], entrypoints, entry_block.arguments, [], [] + output_tensors, + [dynamic_argument_map[dim] for dim in dynamic_symbols], + entrypoints, + entry_block.arguments, + [dynamic_argument_map[dim] for dim in argument_dims], + [dynamic_argument_map[dim] for dim in result_dims], ) func_d.ReturnOp(out) diff --git a/shark_turbine/kernel/compiler/ir.py b/iree/turbine/kernel/compiler/ir.py similarity index 100% rename from shark_turbine/kernel/compiler/ir.py rename to iree/turbine/kernel/compiler/ir.py diff --git a/shark_turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py similarity index 95% rename from shark_turbine/kernel/compiler/kernel_codegen.py rename to iree/turbine/kernel/compiler/kernel_codegen.py index 0069630c..0ca1fa5a 100644 --- a/shark_turbine/kernel/compiler/kernel_codegen.py +++ b/iree/turbine/kernel/compiler/kernel_codegen.py @@ -177,6 +177,22 @@ def kernel_buffer_temporary_bindings(self) -> list[BindingDesc]: and b.kernel_buffer_type.usage == KernelBufferUsage.TEMPORARY ] + @property + def dynamic_dim_bindings(self) -> list[BindingDesc]: + """Gets all dynamic dimension bindings.""" + return [b for b in self.bindings if b.binding_type == BindingType.SYMBOL_VALUE] + + def add_from_dynamic_symbols(self, dynamic_symbols: list[IndexSymbol]): + for symbol in dynamic_symbols: + self.bindings.append( + BindingDesc( + ("symbol", symbol), + BindingType.SYMBOL_VALUE, + name=symbol.name, + symbol_type=symbol, + ) + ) + def add_from_graph_placeholders(self, graph: fx.Graph): # Extract all placeholder nodes. placeholder_nodes = filter_fx_graph(graph, is_placeholder) diff --git a/shark_turbine/kernel/compiler/op_matchers.py b/iree/turbine/kernel/compiler/op_matchers.py similarity index 100% rename from shark_turbine/kernel/compiler/op_matchers.py rename to iree/turbine/kernel/compiler/op_matchers.py diff --git a/shark_turbine/kernel/compiler/utils.py b/iree/turbine/kernel/compiler/utils.py similarity index 100% rename from shark_turbine/kernel/compiler/utils.py rename to iree/turbine/kernel/compiler/utils.py diff --git a/shark_turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/vector_codegen.py rename to iree/turbine/kernel/compiler/vector_codegen.py diff --git a/shark_turbine/kernel/gen/__init__.py b/iree/turbine/kernel/gen/__init__.py similarity index 100% rename from shark_turbine/kernel/gen/__init__.py rename to iree/turbine/kernel/gen/__init__.py diff --git a/shark_turbine/kernel/gen/kernel.py b/iree/turbine/kernel/gen/kernel.py similarity index 100% rename from shark_turbine/kernel/gen/kernel.py rename to iree/turbine/kernel/gen/kernel.py diff --git a/shark_turbine/kernel/gen/thread.py b/iree/turbine/kernel/gen/thread.py similarity index 100% rename from shark_turbine/kernel/gen/thread.py rename to iree/turbine/kernel/gen/thread.py diff --git a/shark_turbine/kernel/lang/__init__.py b/iree/turbine/kernel/lang/__init__.py similarity index 100% rename from shark_turbine/kernel/lang/__init__.py rename to iree/turbine/kernel/lang/__init__.py diff --git a/shark_turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py similarity index 100% rename from shark_turbine/kernel/lang/global_symbols.py rename to iree/turbine/kernel/lang/global_symbols.py diff --git a/shark_turbine/kernel/lang/grid.py b/iree/turbine/kernel/lang/grid.py similarity index 100% rename from shark_turbine/kernel/lang/grid.py rename to iree/turbine/kernel/lang/grid.py diff --git a/shark_turbine/kernel/lang/kernel_buffer.py b/iree/turbine/kernel/lang/kernel_buffer.py similarity index 100% rename from shark_turbine/kernel/lang/kernel_buffer.py rename to iree/turbine/kernel/lang/kernel_buffer.py diff --git a/shark_turbine/kernel/lang/prims.py b/iree/turbine/kernel/lang/prims.py similarity index 100% rename from shark_turbine/kernel/lang/prims.py rename to iree/turbine/kernel/lang/prims.py diff --git a/shark_turbine/kernel/lang/types.py b/iree/turbine/kernel/lang/types.py similarity index 100% rename from shark_turbine/kernel/lang/types.py rename to iree/turbine/kernel/lang/types.py diff --git a/shark_turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py similarity index 100% rename from shark_turbine/kernel/lang/wave_types.py rename to iree/turbine/kernel/lang/wave_types.py diff --git a/shark_turbine/kernel/ops/__init__.py b/iree/turbine/kernel/ops/__init__.py similarity index 100% rename from shark_turbine/kernel/ops/__init__.py rename to iree/turbine/kernel/ops/__init__.py diff --git a/shark_turbine/kernel/ops/base.py b/iree/turbine/kernel/ops/base.py similarity index 100% rename from shark_turbine/kernel/ops/base.py rename to iree/turbine/kernel/ops/base.py diff --git a/shark_turbine/kernel/ops/control_flow.py b/iree/turbine/kernel/ops/control_flow.py similarity index 100% rename from shark_turbine/kernel/ops/control_flow.py rename to iree/turbine/kernel/ops/control_flow.py diff --git a/shark_turbine/kernel/ops/core.py b/iree/turbine/kernel/ops/core.py similarity index 100% rename from shark_turbine/kernel/ops/core.py rename to iree/turbine/kernel/ops/core.py diff --git a/shark_turbine/kernel/ops/math.py b/iree/turbine/kernel/ops/math.py similarity index 100% rename from shark_turbine/kernel/ops/math.py rename to iree/turbine/kernel/ops/math.py diff --git a/shark_turbine/kernel/ops/memory.py b/iree/turbine/kernel/ops/memory.py similarity index 100% rename from shark_turbine/kernel/ops/memory.py rename to iree/turbine/kernel/ops/memory.py diff --git a/shark_turbine/kernel/ops/reduction.py b/iree/turbine/kernel/ops/reduction.py similarity index 100% rename from shark_turbine/kernel/ops/reduction.py rename to iree/turbine/kernel/ops/reduction.py diff --git a/shark_turbine/kernel/ops/shape_manipulation.py b/iree/turbine/kernel/ops/shape_manipulation.py similarity index 100% rename from shark_turbine/kernel/ops/shape_manipulation.py rename to iree/turbine/kernel/ops/shape_manipulation.py diff --git a/shark_turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py similarity index 86% rename from shark_turbine/kernel/ops/wave_ops.py rename to iree/turbine/kernel/ops/wave_ops.py index 0298a065..731c8678 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from ..wave.constraints import Constraint + from ..wave.scheduling.resources import Operation T = TypeVar("T", bound=Type[Any]) AccT = TypeVar("AccT") @@ -96,6 +97,12 @@ def maximum(lhs: "Register", rhs: "Register") -> "Register": ... +def broadcast( + arg: "Register", target_shape: Optional[IndexExpr | int] = None +) -> "Register": + ... + + def sum( src: "Register", acc: Optional["Register"] = None, @@ -365,12 +372,15 @@ def copy( new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None, arg_transform: Optional[Callable[[Any], Any]] = lambda x: x, + anchor: Optional[fx.Node] = None, ) -> Self: """Returns a duplicate of this node.""" graph = new_graph if new_graph is None: graph = self.graph - graph.inserting_after(self.fx_node) + if anchor is None: + anchor = self.fx_node + graph.inserting_after(anchor) new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name @@ -450,6 +460,8 @@ def index(self, value: Any): self.fx_node.index = {} for dim, key in value.items(): self.fx_node.index[dim] = key + elif isinstance(value, list): + self.fx_node.index = value else: raise ValueError("Index must be a dict") @@ -483,7 +495,6 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: pass -@define_py_op(operator.getitem) @define_py_op(operator.add) @define_py_op(operator.sub) @define_py_op(operator.mul) @@ -492,7 +503,12 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: @dataclass class BinaryPyOp(CustomOp, ABC): """ - Represents a binary python operator. + Represents an elementwise binary python operator. + + DTYPE requirement: lhs and rhs needs to have the same dtpye. + Shape requirement: lhs and rhs either have same shape or + their shape must be broadcastable to + one another. """ lhs: Any @@ -518,9 +534,16 @@ def type(self) -> Memory: lhs_type = get_custom(self.lhs).type rhs_type = get_custom(self.rhs).type has_same_type = has_same_custom_type(lhs_type, rhs_type) - if not has_same_type: - raise ValueError("Expected lhs and rhs to have same type post-expansion") - return lhs_type + if has_same_type: + return lhs_type + lhs_dim_set = set(lhs_type.symbolic_shape) + rhs_dim_set = set(rhs_type.symbolic_shape) + if lhs_dim_set.isdisjoint(rhs_dim_set): + raise ValueError( + "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." + ) + broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype + return broadcasted_type @define_interface_op("exp2") @@ -647,6 +670,13 @@ class IterArg(Placeholder): a reduction node. """ + def parent_op(self): + return get_custom(self.graph.parent_op) + + def get_iter_idx(self): + src_reduction = self.parent_op() + return src_reduction.iter_args(self.graph).index(self.fx_node) + # Ops modeling TKW operations in the kernel language @@ -689,11 +719,37 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool: prev_node, found_src = prev_node.prev, prev_node == src if not found_src: return False - while next_node and not found_dst: + while next_node.next.op != "root" and not found_dst: next_node, found_dst = next_node.next, next_node == dst return found_dst +@define_op("scheduling_barrier") +@dataclass +class SchedulingBarrier(CustomOp): + """ + Represents a scheduling barrier in the graph. + Takes in a list of operations that are allowed to cross + the barrier. + """ + + operations: list[Operation] + + +@define_op("scheduling_group_barrier") +@dataclass +class SchedulingGroupBarrier(CustomOp): + """ + Represents a scheduling group barrier in the graph. + The scheduling group barrier defines scheduling groups. + Each scheduling group contains different instructions in a specific order. + The sync_id identifies scheduling groups that need to be aware of each other. + """ + + instructions: dict[Operation, int] + sync_id: int + + @define_op("register") @dataclass class NewRegister(CustomOp): @@ -777,16 +833,6 @@ def custom_string(self, value_map: dict[str, str]) -> str: custom_str += f"acc={self.acc} (index = {self.acc_index}))" return custom_str - def post_expansion(self, constraints: list["Constraint"]) -> None: - """ - Once the arguments have been expanded, we set their indices, - ensuring that the LHS and RHS indices are consistent with their - corresponding address spaces. - """ - self.lhs.index = self.lhs_index - self.rhs.index = self.rhs_index - self.acc.index = self.acc_index - @define_op("read") @dataclass @@ -853,17 +899,29 @@ def wrapper(f): node._add_proxy_to_graph(graph) node.fx_node.node.tkw_op = cls node.fx_node.node.tkw_op_name = cls.tkw_op_name + graph.subgraphs[subgraph_name].parent_op = node.fx_node.node return node.fx_node return wrapper @property - def indexing_dims(self) -> list[IndexSymbol]: + def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) + return_node = [ + nested_node + for nested_node in self.graph.subgraphs[self.subgraph_name].nodes + if isinstance(get_custom(nested_node), Output) + ] + assert len(return_node) == 1 + return_vals = get_custom(return_node[0]).return_vals[0] + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for return_val in return_vals: + return_dims = get_custom(return_val).indexing_dims + reduced_dims = [dims for dims in return_dims if dims != self.axis] + expand_dims.append(reduced_dims) + if len(expand_dims) == 1: + expand_dims = expand_dims[0] return expand_dims def iter_args(self, graph: fx.Graph) -> list[fx.Node]: @@ -886,8 +944,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: return captured_vars @property - def type(self) -> list[Memory | Register]: - return [get_custom(x).type for x in self.init_args] + def type(self) -> Memory | Register | list[Memory | Register]: + res_types = [get_custom(x).type for x in self.init_args] + if len(res_types) == 1: + res_types = res_types[0] + return res_types def outputs(self, graph: fx.Graph) -> list[fx.Node]: for node in graph.nodes: @@ -907,6 +968,20 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]: else None ) + @index.setter + def index(self, value: Any): + CustomOp.index.fset(self, value) + + @property + def count(self) -> int: + if hasattr(self.fx_node, "count"): + return self.fx_node.count + return None + + @count.setter + def count(self, value: int): + self.fx_node.count = value + @define_op("write") @dataclass @@ -941,6 +1016,7 @@ def register_index(self) -> dict[IndexSymbol, IndexSequence]: return custom.index +@define_py_op(operator.getitem) @define_op("get_result") @dataclass class GetResult(CustomOp): @@ -949,16 +1025,24 @@ class GetResult(CustomOp): @property def type(self) -> "Memory": - return get_custom(self.value).type[self.res_idx] + src_type = get_custom(self.value).type + if isinstance(src_type, list): + return src_type[self.res_idx] + else: + return src_type @property - def indexing_dims(self) -> list[IndexSymbol]: - expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) - return expand_dims + def indexing_dims(self) -> list[IndexExpr]: + has_multiple_value = lambda x: all(isinstance(el, list) for el in x) + is_valid_indexing_dim = lambda x: isinstance(src_indexing, list) and all( + isinstance(el, IndexExpr) for el in x + ) + src_indexing = get_custom(self.value).indexing_dims + if has_multiple_value(src_indexing): + assert self.res_idx <= len(src_indexing) - 1 + src_indexing = src_indexing[self.res_idx] + assert is_valid_indexing_dim(src_indexing) + return src_indexing @property def index(self) -> dict[IndexSymbol, IndexSequence]: @@ -986,6 +1070,34 @@ def type(self) -> "Register": return get_custom(self.register_).type +@define_op("broadcast") +@dataclass +class Broadcast(CustomOp, ABC): + """ + Represents a Broadcast operation. + + arg: Source tensor/value to broadcast + target_shape: symbolic target broadcast shape. + """ + + arg: fx.Node + target_type: Sequence[IndexSymbol] = None + + @property + def target_shape(self): + return self.target_type.symbolic_shape + + @property + def indexing_dims(self) -> list[IndexSymbol]: + return self.target_shape + + @property + def type(self) -> Memory: + src_dtype = get_custom(self.arg).type.dtype + dst_type = Register[*self.target_shape, src_dtype] + return dst_type + + @define_interface_op("max") @define_interface_op("sum") @dataclass diff --git a/shark_turbine/kernel/wave/README.md b/iree/turbine/kernel/wave/README.md similarity index 100% rename from shark_turbine/kernel/wave/README.md rename to iree/turbine/kernel/wave/README.md diff --git a/shark_turbine/kernel/wave/__init__.py b/iree/turbine/kernel/wave/__init__.py similarity index 100% rename from shark_turbine/kernel/wave/__init__.py rename to iree/turbine/kernel/wave/__init__.py diff --git a/shark_turbine/kernel/wave/barriers.py b/iree/turbine/kernel/wave/barriers.py similarity index 100% rename from shark_turbine/kernel/wave/barriers.py rename to iree/turbine/kernel/wave/barriers.py diff --git a/shark_turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py similarity index 73% rename from shark_turbine/kernel/wave/codegen.py rename to iree/turbine/kernel/wave/codegen.py index e4a8cf72..bc6e54ed 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools import operator import sympy import math @@ -11,6 +12,7 @@ from dataclasses import dataclass import torch.fx as fx import torch.utils._pytree as pytree +from collections import namedtuple from ..compiler.ir import ( Attribute, @@ -37,13 +39,15 @@ stream_d, scf_d, vector_d, + llvm_d, ) -from shark_turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type +from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type # TK infrastructure imports. -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.lang.global_symbols import * from ..ops.wave_ops import ( write, + broadcast, register, mma, shuffle, @@ -57,6 +61,8 @@ shared_memory_barrier, extract_slice, CustomOp, + scheduling_barrier, + scheduling_group_barrier, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -77,10 +83,11 @@ WorkgroupConstraint, TilingConstraint, ) -from .utils import subs_idxc +from .utils import subs_idxc, find_index_bounds, get_hardware_vector_map # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence +from .scheduling.resources import get_scheduling_mask @dataclass @@ -90,6 +97,7 @@ class WaveEmitter: root_sig: BoundKernelSignature trace: CapturedTrace constraints: list[Constraint] + dynamic_symbols: list[IndexSymbol] ip: InsertionPoint = None OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {} _node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {} @@ -109,6 +117,11 @@ def emit_program_invariants(self): gpu_d.thread_id(gpu_d.Dimension.z), ] self.induction_vars: dict[IndexSymbol, Value] = {} + self.dynamic_dims: dict[IndexSymbol, Value] = {} + symbol_iterator = iter(self.dynamic_symbols) + for arg in self.root_sig.entry_block.arguments: + if arg.type == IndexType.get(): + self.dynamic_dims[next(symbol_iterator)] = arg def emit(self, graph: Optional[fx.Graph] = None): with self.ip, Location.unknown(): @@ -168,54 +181,17 @@ def get_type_or_element_type(operand_type: IrType): return operand_type -def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: - stack: list[OpResult] = [] - - def _process_mul_add_ops(term, is_mul): - args = [] - callables = [] - for _ in range(len(term.args)): - val = stack.pop() - if callable(val): - callables.append(val) - else: - args.append(val) - operation = None - for arg in args: - if operation is None: - operation = arg - continue - - if is_mul: - operation = arith_d.MulIOp(operation, arg) - else: - operation = arith_d.AddIOp(operation, arg) - - for arg in callables: - operation = arg(operation, is_mul) - - stack.append(operation) - - def _get_mul(numerator): - return lambda x: arith_d.MulIOp(x, numerator) - - def _get_add(numerator, denominator): - return lambda x: arith_d.AddIOp(arith_d.MulIOp(x, denominator), numerator) - - def _get_div(mul, add, denominator): - return lambda x, is_mul: arith_d.DivSIOp( - mul(x) if is_mul else add(x), denominator - ) - +def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]: induction_var_syms = [] induction_vars = [] - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint): - assert ( - constraint.dim in emitter.induction_vars - ), f"Could not find induction var for {constraint.dim} dimension" - induction_var_syms.append(constraint.induction_var) - induction_vars.append(emitter.induction_vars[constraint.dim]) + if emitter.induction_vars: + for constraint in emitter.constraints: + if isinstance(constraint, TilingConstraint): + assert ( + constraint.dim in emitter.induction_vars + ), f"Could not find induction var for {constraint.dim} dimension" + induction_var_syms.append(constraint.induction_var) + induction_vars.append(emitter.induction_vars[constraint.dim]) # TODO: factor this out all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars @@ -226,6 +202,150 @@ def _get_div(mul, add, denominator): all_symbols, ) ) + dynamics.update(emitter.dynamic_dims) + return dynamics + + +_Rational = namedtuple("_Rational", ["numerator", "denominator"]) + + +def gen_sympy_index(dynamics: dict[IndexSymbol, Any], expr: sympy.Expr) -> OpResult: + stack: list[OpResult] = [] + + def _get_ir_value(arg): + if not isinstance(arg, (Value, OpResult)): + arg = arg.result + + return arg + + def _check_vec_scalar(a, b): + return isinstance(a.type, VectorType) and a.type.element_type == b.type + + def _broadcast(a, b): + a = _get_ir_value(a) + b = _get_ir_value(b) + + if a.type == b.type: + return a, b + + if _check_vec_scalar(a, b): + b = vector_d.splat(a.type, b) + return a, b + + if _check_vec_scalar(b, a): + a = vector_d.splat(b.type, a) + return a, b + + raise CodegenError(f"Cannot broadcast {a.type} and {b.type}") + + def get_const_val(arg): + if isinstance(arg, OpResult): + arg = arg.owner.opview + + if isinstance(arg, arith_d.ConstantOp): + value = arg.attributes["value"] + if isinstance(value, IntegerAttr): + return int(value) + + return None + + def muli_fold(lhs, rhs): + if get_const_val(lhs) == 1: + return rhs + + if get_const_val(rhs) == 1: + return lhs + + return arith_d.muli(lhs, rhs) + + # `x + (a/b)` transformed into `(x*b + a) / b` + def _add(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.denominator, rhs)) + numerator = arith_d.addi(*_broadcast(numerator, lhs.numerator)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.denominator)) + numerator = arith_d.addi(*_broadcast(numerator, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + lhs_numerator = muli_fold(*_broadcast(lhs.numerator, rhs.denominator)) + rhs_numerator = muli_fold(*_broadcast(rhs.numerator, lhs.denominator)) + numerator = arith_d.addi(*_broadcast(lhs_numerator, rhs_numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return arith_d.addi(*_broadcast(lhs, rhs)) + + # `x * (a/b)` transformed into `(x * a) / b` + def _mul(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs.numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return muli_fold(*_broadcast(lhs, rhs)) + + def _floor(value): + if isinstance(value, _Rational): + value = arith_d.divsi(*_broadcast(value.numerator, value.denominator)) + + return value + + def _ceiling(value): + if isinstance(value, _Rational): + value = arith_d.ceildivsi(*_broadcast(value.numerator, value.denominator)) + + return value + + def _group_rationals(stack, count): + """Group rationals and non-rationals args into 2 contiguous sets. + + This allows to mul/add all non-rationals first, reducing total number of ops. + """ + rationals = [] + non_rationals = [] + for _ in range(count): + val = stack.pop() + if isinstance(val, _Rational): + rationals.append(val) + else: + non_rationals.append(val) + + return non_rationals + rationals + + def _apply(args, func): + assert len(args) > 0 + value = args[0] + for val in args[1:]: + value = func(value, val) + + return value + + def _enforce_non_rational(val, term): + if isinstance(val, _Rational): + raise CodegenError(f"Rational is not supported yet in '{type(term)}'") + + def _get_const(val): + if isinstance(val, int): + return arith_d.constant(IndexType.get(), val) + + if isinstance(val, (tuple, list)): + vec_type = VectorType.get([len(val)], IndexType.get()) + vals = [IntegerAttr.get(IndexType.get(), v) for v in val] + return arith_d.constant(vec_type, DenseElementsAttr.get(vals, vec_type)) + + raise CodegenError(f"Unsupported const val {val} : {type(val)}") idxc = IndexingContext.current() # Substitute in frozen vars to simplify expression. @@ -237,49 +357,58 @@ def _get_div(mul, add, denominator): for term in sympy.postorder_traversal(expr): match term: case sympy.Symbol(): - if term in idxc.subs.keys(): - cst = arith_d.constant(IndexType.get(), idxc.subs[term]) - stack.append(cst) + res = idxc.get_val(term) + if res is not None: + stack.append(_get_const(res)) elif term in dynamics.keys(): stack.append(dynamics[term]) else: raise CodegenError(f"Unknown symbol {term}") case sympy.Integer(): - stack.append(arith_d.constant(IndexType.get(), int(term))) + stack.append(_get_const(int(term))) case sympy.Mul(): - _process_mul_add_ops(term, is_mul=True) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _mul)) case sympy.Add(): - _process_mul_add_ops(term, is_mul=False) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _add)) case sympy.Mod(): rhs = stack.pop() lhs = stack.pop() - mod = arith_d.RemSIOp(lhs, rhs) + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + mod = arith_d.remsi(*_broadcast(lhs, rhs)) stack.append(mod) case sympy.floor(): - # TODO: Since divsi rounds to zero, this seems to work. - # But check whether floordivsi is needed. - stack.append(stack.pop()) + stack.append(_floor(stack.pop())) + case sympy.ceiling(): + stack.append(_ceiling(stack.pop())) case sympy.Rational(): - # `x * (a/b)` transformed into `(x * a) / b` - # `x + (a/b)` transformed into `(x*b + a) / b` - numerator = arith_d.constant(IndexType.get(), abs(term.p)) - denominator = arith_d.constant(IndexType.get(), abs(term.q)) - # Assumes that the negative term is always carried on the numerator - if abs(term.p) > term.p: - zero = arith_d.constant(IndexType.get(), int(0)) - numerator = arith_d.SubIOp(zero, numerator) - mul = lambda x: x - if abs(term.p) != 1: - mul = _get_mul(numerator) - add = _get_add(numerator, denominator) - operation = _get_div(mul, add, denominator) - stack.append(operation) + numerator = _get_const(term.p) + denominator = _get_const(term.q) + stack.append(_Rational(numerator, denominator)) + case sympy.StrictLessThan(): + rhs = stack.pop() + lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs)) + stack.append(res) + case sympy.And(): + rhs = stack.pop() + lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + res = arith_d.andi(*_broadcast(lhs, rhs)) + stack.append(res) case sympy.UnevaluatedExpr(): continue case _: - raise CodegenError(f"Can not handle {term} yet") - if len(stack) != 1: + raise CodegenError(f"Can not handle {type(term)} : {term}") + + if len(stack) != 1 or isinstance(stack[0], _Rational): raise CodegenError(f"Expected single result, got {len(stack)}") + return stack[0] @@ -312,8 +441,8 @@ def handle_register(emitter: WaveEmitter, node: fx.Node): shape, dtype, value = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - if hasattr(node, "thread_shape"): - shape = [node.thread_shape] + get_thread_shape = lambda index: max(x.size for x in index.values()) + shape = [get_thread_shape(get_custom(node).index)] vector_shape = cast_py_literal(emitter, shape) element_type = IrType.parse(dtype.ir_type_asm()) vector_type = VectorType.get(vector_shape, element_type) @@ -361,7 +490,10 @@ def _get_start_indices( def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr] ) -> list[OpResult]: - return [gen_sympy_index(emitter, i) for i in _get_start_indices(src_indices)] + return [ + gen_sympy_index(add_emitter_subs(emitter), i) + for i in _get_start_indices(src_indices) + ] def _compute_offset(indices: list[IndexExpr], strides: list[IndexExpr]) -> IndexExpr: @@ -392,44 +524,24 @@ def _is_identity_mapping( def _build_mask( emitter: WaveEmitter, index: Dict[IndexExpr, IndexExpr], elements_per_thread: int ) -> Optional[OpResult]: - bounds = [] - for constraint in emitter.constraints: - if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): - continue - - dim = constraint.dim - if dim not in index: - continue - - work_size = constraint.count * constraint.tile_size - if subs_idxc(work_size) == subs_idxc(dim): - continue - - bounds.append((dim, gen_sympy_index(emitter, dim))) - - if len(bounds) == 0: + bounds = find_index_bounds(emitter.constraints, index) + if bounds is None: return None - mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) - mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) - + idxc = IndexingContext.current() last_dim = tuple(index.keys())[-1] new_index = {k: _get_start_index(v) for k, v in index.items()} - for i in range(elements_per_thread): - cond = None - for dim, bound in bounds: - idx = gen_sympy_index(emitter, new_index[dim]) - lt = arith_d.cmpi(arith_d.CmpIPredicate.slt, idx, bound) - if cond is None: - cond = lt - else: - cond = arith_d.andi(cond, lt) + new_index[last_dim] = new_index[last_dim] + idxc.iota(elements_per_thread) - pos = arith_d.ConstantOp(IndexType.get(), i) - mask = vector_d.insertelement(cond, mask, position=pos) + mask_expr = functools.reduce( + lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds) + ) + mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) - new_index[last_dim] = new_index[last_dim] + 1 + mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) + if mask.type != mask_vec_type: + mask = vector_d.splat(mask_vec_type, mask) return mask @@ -503,7 +615,7 @@ def _construct_gather_scatter_indices( # arith ops and then `vector.insertelement` them into offsets vec. offset = int(offset) else: - dyn_offset = gen_sympy_index(emitter, offset) + dyn_offset = gen_sympy_index(add_emitter_subs(emitter), offset) dynamic_offsets.append((i, dyn_offset)) offset = 0 @@ -910,18 +1022,11 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node): flat_init_args, _ = pytree.tree_flatten((init_args)) flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args] - # Without scheduling, we assume that we always start at 0. start = arith_d.constant(IndexType.get(), int(0)) - count = None - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint) and constraint.dim == axis: - count = subs_idxc(constraint.count) - assert count is not None, "Could not find tiling constraint for reduction axis." - # For now, we assume that dimensions that have tiling constraints on them, # do not have any other constraints. - end = arith_d.constant(IndexType.get(), int(count)) + end = arith_d.constant(IndexType.get(), int(node.count)) # Since we divide the end by the tile size, we need to make sure that the # step is 1. @@ -970,6 +1075,38 @@ def handle_shared_memory_barrier(emitter: WaveEmitter, node: fx.Node): amdgpu_d.lds_barrier() +@handle_op(scheduling_barrier) +def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node): + try: + operations = node.args[0] + except ValueError as e: + raise ValidationError("Malformed arguments") from e + mask = 0 + for operation in operations: + mask |= get_scheduling_mask(operation) + + mask = arith_d.constant(IntegerType.get_signless(32), mask) + llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask]) + + +@handle_op(scheduling_group_barrier) +def handle_scheduling_group_barrier(emitter: WaveEmitter, node: fx.Node): + try: + instructions, sync_id = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + sync_id = arith_d.constant(IntegerType.get_signless(32), sync_id) + for instruction, counts in instructions.items(): + mask = get_scheduling_mask(instruction) + if mask is None: + continue + mask = arith_d.constant(IntegerType.get_signless(32), mask) + counts = arith_d.constant(IntegerType.get_signless(32), counts) + llvm_d.call_intrinsic( + None, "llvm.amdgcn.sched.group.barrier", [mask, counts, sync_id] + ) + + ############################################################################### # Slicing ops ############################################################################### @@ -995,6 +1132,42 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): emitter.bind_node_proxy(node, IRProxyValue(element)) +############################################################################### +# Reshape ops +############################################################################### + + +@handle_op(broadcast) +def handle_broadcast(emitter: WaveEmitter, node: fx.Node): + try: + register, target_type = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + # Get thread_shape/size for broadcast. + get_thread_shape = lambda index: max(x.size for x in index.values()) + bcast_dim_lane_dim_size = get_thread_shape(node.index) + + # Check MLIR shape + vector_src = cast_vector(emitter, register) + vector_type = vector_src.type + # Only support broadcasting vector<1xdtype> for now. + if not VectorType.isinstance(vector_type): + raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.") + assert vector_type.rank == 1 + assert vector_type.shape[0] == 1 + + # Extract and Splat + # If by chance broadcast size matches current size, we can return src. + if bcast_dim_lane_dim_size == vector_type.shape[0]: + emitter.bind_node_proxy(node, IRProxyValue(vector_src)) + + result_type = VectorType.get([bcast_dim_lane_dim_size], vector_type.element_type) + element = vector_d.extract(vector_src, static_position=[0], dynamic_position=[]) + splat = vector_d.splat(result_type, element) + emitter.bind_node_proxy(node, IRProxyValue(splat)) + + ############################################################################### # Miscellanous ops ############################################################################### diff --git a/shark_turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py similarity index 100% rename from shark_turbine/kernel/wave/constraints.py rename to iree/turbine/kernel/wave/constraints.py diff --git a/shark_turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py similarity index 94% rename from shark_turbine/kernel/wave/decompose_reduce_ops.py rename to iree/turbine/kernel/wave/decompose_reduce_ops.py index 1dac06cc..9916bb50 100644 --- a/shark_turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -20,9 +20,10 @@ ShuffleOp, CustomOp, ExtractSlice, + Reduction, ) -from .utils import DCE +from .utils import DCE, subs_idxc import torch.fx as fx import math from typing import Callable @@ -103,9 +104,11 @@ def decompose_reduce_ops( raise NotImplementedError( "Only implemented reduction on fastest dimension." ) - reduction_block_size = constraint_tile_size[reduction_dim] - reduction_size = reduction_block_size.subs(index_map) - local_reduction_size = reduction_size / subgroup_size + + get_thread_shape = lambda index: max( + subs_idxc(x.size) for x in index.values() + ) + local_reduction_size = get_thread_shape(get_custom(custom.arg).index) local_reduction = emit_local_reduction( binary_fn, reduction_src, custom.graph, local_reduction_size ) diff --git a/shark_turbine/kernel/wave/docs/gemm_example.md b/iree/turbine/kernel/wave/docs/gemm_example.md similarity index 100% rename from shark_turbine/kernel/wave/docs/gemm_example.md rename to iree/turbine/kernel/wave/docs/gemm_example.md diff --git a/shark_turbine/kernel/wave/docs/mlsys/.gitignore b/iree/turbine/kernel/wave/docs/mlsys/.gitignore similarity index 86% rename from shark_turbine/kernel/wave/docs/mlsys/.gitignore rename to iree/turbine/kernel/wave/docs/mlsys/.gitignore index f2e31fe2..b4c7d64b 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/.gitignore +++ b/iree/turbine/kernel/wave/docs/mlsys/.gitignore @@ -3,3 +3,4 @@ *.out *.pdf *.synctex.gz +*.blg diff --git a/shark_turbine/kernel/wave/docs/mlsys/algorithm.sty b/iree/turbine/kernel/wave/docs/mlsys/algorithm.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/algorithm.sty rename to iree/turbine/kernel/wave/docs/mlsys/algorithm.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/algorithmic.sty b/iree/turbine/kernel/wave/docs/mlsys/algorithmic.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/algorithmic.sty rename to iree/turbine/kernel/wave/docs/mlsys/algorithmic.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/fancyhdr.sty b/iree/turbine/kernel/wave/docs/mlsys/fancyhdr.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/fancyhdr.sty rename to iree/turbine/kernel/wave/docs/mlsys/fancyhdr.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/mlsys2024.bst b/iree/turbine/kernel/wave/docs/mlsys/mlsys2024.bst similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/mlsys2024.bst rename to iree/turbine/kernel/wave/docs/mlsys/mlsys2024.bst diff --git a/shark_turbine/kernel/wave/docs/mlsys/mlsys2024.sty b/iree/turbine/kernel/wave/docs/mlsys/mlsys2024.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/mlsys2024.sty rename to iree/turbine/kernel/wave/docs/mlsys/mlsys2024.sty diff --git a/iree/turbine/kernel/wave/docs/mlsys/tkw.bbl b/iree/turbine/kernel/wave/docs/mlsys/tkw.bbl new file mode 100644 index 00000000..5ca46234 --- /dev/null +++ b/iree/turbine/kernel/wave/docs/mlsys/tkw.bbl @@ -0,0 +1,208 @@ +\begin{thebibliography}{6} +\providecommand{\natexlab}[1]{#1} +\providecommand{\url}[1]{\texttt{#1}} +\expandafter\ifx\csname urlstyle\endcsname\relax + \providecommand{\doi}[1]{doi: #1}\else + \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi + +\bibitem[Chetlur et~al.(2014)Chetlur, Woolley, Vandermersch, Cohen, Tran, + Catanzaro, and Shelhamer]{chetlur_cudnn_2014} +Chetlur, S., Woolley, C., Vandermersch, P., Cohen, J., Tran, J., Catanzaro, B., + and Shelhamer, E. +\newblock {cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}, December + 2014. +\newblock URL \url{http://arxiv.org/abs/1410.0759}. +\newblock arXiv:1410.0759 [cs]. + +\bibitem[Dubey et~al.(2024)Dubey, Jauhri, Pandey, Kadian, Al-Dahle, Letman, + Mathur, Schelten, Yang, Fan, Goyal, Hartshorn, Yang, Mitra, Sravankumar, + Korenev, Hinsvark, Rao, Zhang, Rodriguez, Gregerson, Spataru, Roziere, Biron, + Tang, Chern, Caucheteux, Nayak, Bi, Marra, McConnell, Keller, Touret, Wu, + Wong, Ferrer, Nikolaidis, Allonsius, Song, Pintz, Livshits, Esiobu, + Choudhary, Mahajan, Garcia-Olano, Perino, Hupkes, Lakomkin, AlBadawy, + Lobanova, Dinan, Smith, Radenovic, Zhang, Synnaeve, Lee, Anderson, Nail, + Mialon, Pang, Cucurell, Nguyen, Korevaar, Xu, Touvron, Zarov, Ibarra, + Kloumann, Misra, Evtimov, Copet, Lee, Geffert, Vranes, Park, Mahadeokar, + Shah, van~der Linde, Billock, Hong, Lee, Fu, Chi, Huang, Liu, Wang, Yu, + Bitton, Spisak, Park, Rocca, Johnstun, Saxe, Jia, Alwala, Upasani, Plawiak, + Li, Heafield, Stone, El-Arini, Iyer, Malik, Chiu, Bhalla, Rantala-Yeary, + van~der Maaten, Chen, Tan, Jenkins, Martin, Madaan, Malo, Blecher, Landzaat, + de~Oliveira, Muzzi, Pasupuleti, Singh, Paluri, Kardas, Oldham, Rita, Pavlova, + Kambadur, Lewis, Si, Singh, Hassan, Goyal, Torabi, Bashlykov, Bogoychev, + Chatterji, Duchenne, Çelebi, Alrassy, Zhang, Li, Vasic, Weng, Bhargava, + Dubal, Krishnan, Koura, Xu, He, Dong, Srinivasan, Ganapathy, Calderer, + Cabral, Stojnic, Raileanu, Girdhar, Patel, Sauvestre, Polidoro, Sumbaly, + Taylor, Silva, Hou, Wang, Hosseini, Chennabasappa, Singh, Bell, Kim, Edunov, + Nie, Narang, Raparthy, Shen, Wan, Bhosale, Zhang, Vandenhende, Batra, + Whitman, Sootla, Collot, Gururangan, Borodinsky, Herman, Fowler, Sheasha, + Georgiou, Scialom, Speckbacher, Mihaylov, Xiao, Karn, Goswami, Gupta, + Ramanathan, Kerkez, Gonguet, Do, Vogeti, Petrovic, Chu, Xiong, Fu, Meers, + Martinet, Wang, Tan, Xie, Jia, Wang, Goldschlag, Gaur, Babaei, Wen, Song, + Zhang, Li, Mao, Coudert, Yan, Chen, Papakipos, Singh, Grattafiori, Jain, + Kelsey, Shajnfeld, Gangidi, Victoria, Goldstand, Menon, Sharma, Boesenberg, + Vaughan, Baevski, Feinstein, Kallet, Sangani, Yunus, Lupu, Alvarado, Caples, + Gu, Ho, Poulton, Ryan, Ramchandani, Franco, Saraf, Chowdhury, Gabriel, + Bharambe, Eisenman, Yazdan, James, Maurer, Leonhardi, Huang, Loyd, De~Paola, + Paranjape, Liu, Wu, Ni, Hancock, Wasti, Spence, Stojkovic, Gamido, Montalvo, + Parker, Burton, Mejia, Wang, Kim, Zhou, Hu, Chu, Cai, Tindal, Feichtenhofer, + Civin, Beaty, Kreymer, Li, Wyatt, Adkins, Xu, Testuggine, David, Parikh, + Liskovich, Foss, Wang, Le, Holland, Dowling, Jamil, Montgomery, Presani, + Hahn, Wood, Brinkman, Arcaute, Dunbar, Smothers, Sun, Kreuk, Tian, Ozgenel, + Caggioni, Guzmán, Kanayet, Seide, Florez, Schwarz, Badeer, Swee, Halpern, + Thattai, Herman, Sizov, Guangyi, Zhang, Lakshminarayanan, Shojanazeri, Zou, + Wang, Zha, Habeeb, Rudolph, Suk, Aspegren, Goldman, Damlaj, Molybog, Tufanov, + Veliche, Gat, Weissman, Geboski, Kohli, Asher, Gaya, Marcus, Tang, Chan, + Zhen, Reizenstein, Teboul, Zhong, Jin, Yang, Cummings, Carvill, Shepard, + McPhie, Torres, Ginsburg, Wang, Wu, U, Saxena, Prasad, Khandelwal, Zand, + Matosich, Veeraraghavan, Michelena, Li, Huang, Chawla, Lakhotia, Huang, Chen, + Garg, A, Silva, Bell, Zhang, Guo, Yu, Moshkovich, Wehrstedt, Khabsa, Avalani, + Bhatt, Tsimpoukelli, Mankus, Hasson, Lennie, Reso, Groshev, Naumov, Lathi, + Keneally, Seltzer, Valko, Restrepo, Patel, Vyatskov, Samvelyan, Clark, Macey, + Wang, Hermoso, Metanat, Rastegari, Bansal, Santhanam, Parks, White, Bawa, + Singhal, Egebo, Usunier, Laptev, Dong, Zhang, Cheng, Chernoguz, Hart, + Salpekar, Kalinli, Kent, Parekh, Saab, Balaji, Rittner, Bontrager, Roux, + Dollar, Zvyagina, Ratanchandani, Yuvraj, Liang, Alao, Rodriguez, Ayub, + Murthy, Nayani, Mitra, Li, Hogan, Battey, Wang, Maheswari, Howes, Rinott, + Bondu, Datta, Chugh, Hunt, Dhillon, Sidorov, Pan, Verma, Yamamoto, Ramaswamy, + Lindsay, Lindsay, Feng, Lin, Zha, Shankar, Zhang, Zhang, Wang, Agarwal, + Sajuyigbe, Chintala, Max, Chen, Kehoe, Satterfield, Govindaprasad, Gupta, + Cho, Virk, Subramanian, Choudhury, Goldman, Remez, Glaser, Best, Kohler, + Robinson, Li, Zhang, Matthews, Chou, Shaked, Vontimitta, Ajayi, Montanez, + Mohan, Kumar, Mangla, Albiero, Ionescu, Poenaru, Mihailescu, Ivanov, Li, + Wang, Jiang, Bouaziz, Constable, Tang, Wang, Wu, Wang, Xia, Wu, Gao, Chen, + Hu, Jia, Qi, Li, Zhang, Zhang, Adi, Nam, Yu, Wang, Hao, Qian, He, Rait, + DeVito, Rosnbrick, Wen, Yang, and Zhao]{dubey_llama_2024} +Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., + Mathur, A., Schelten, A., Yang, A., Fan, A., Goyal, A., Hartshorn, A., Yang, + A., Mitra, A., Sravankumar, A., Korenev, A., Hinsvark, A., Rao, A., Zhang, + A., Rodriguez, A., Gregerson, A., Spataru, A., Roziere, B., Biron, B., Tang, + B., Chern, B., Caucheteux, C., Nayak, C., Bi, C., Marra, C., McConnell, C., + Keller, C., Touret, C., Wu, C., Wong, C., Ferrer, C.~C., Nikolaidis, C., + Allonsius, D., Song, D., Pintz, D., Livshits, D., Esiobu, D., Choudhary, D., + Mahajan, D., Garcia-Olano, D., Perino, D., Hupkes, D., Lakomkin, E., + AlBadawy, E., Lobanova, E., Dinan, E., Smith, E.~M., Radenovic, F., Zhang, + F., Synnaeve, G., Lee, G., Anderson, G.~L., Nail, G., Mialon, G., Pang, G., + Cucurell, G., Nguyen, H., Korevaar, H., Xu, H., Touvron, H., Zarov, I., + Ibarra, I.~A., Kloumann, I., Misra, I., Evtimov, I., Copet, J., Lee, J., + Geffert, J., Vranes, J., Park, J., Mahadeokar, J., Shah, J., van~der Linde, + J., Billock, J., Hong, J., Lee, J., Fu, J., Chi, J., Huang, J., Liu, J., + Wang, J., Yu, J., Bitton, J., Spisak, J., Park, J., Rocca, J., Johnstun, J., + Saxe, J., Jia, J., Alwala, K.~V., Upasani, K., Plawiak, K., Li, K., Heafield, + K., Stone, K., El-Arini, K., Iyer, K., Malik, K., Chiu, K., Bhalla, K., + Rantala-Yeary, L., van~der Maaten, L., Chen, L., Tan, L., Jenkins, L., + Martin, L., Madaan, L., Malo, L., Blecher, L., Landzaat, L., de~Oliveira, L., + Muzzi, M., Pasupuleti, M., Singh, M., Paluri, M., Kardas, M., Oldham, M., + Rita, M., Pavlova, M., Kambadur, M., Lewis, M., Si, M., Singh, M.~K., Hassan, + M., Goyal, N., Torabi, N., Bashlykov, N., Bogoychev, N., Chatterji, N., + Duchenne, O., Çelebi, O., Alrassy, P., Zhang, P., Li, P., Vasic, P., Weng, + P., Bhargava, P., Dubal, P., Krishnan, P., Koura, P.~S., Xu, P., He, Q., + Dong, Q., Srinivasan, R., Ganapathy, R., Calderer, R., Cabral, R.~S., + Stojnic, R., Raileanu, R., Girdhar, R., Patel, R., Sauvestre, R., Polidoro, + R., Sumbaly, R., Taylor, R., Silva, R., Hou, R., Wang, R., Hosseini, S., + Chennabasappa, S., Singh, S., Bell, S., Kim, S.~S., Edunov, S., Nie, S., + Narang, S., Raparthy, S., Shen, S., Wan, S., Bhosale, S., Zhang, S., + Vandenhende, S., Batra, S., Whitman, S., Sootla, S., Collot, S., Gururangan, + S., Borodinsky, S., Herman, T., Fowler, T., Sheasha, T., Georgiou, T., + Scialom, T., Speckbacher, T., Mihaylov, T., Xiao, T., Karn, U., Goswami, V., + Gupta, V., Ramanathan, V., Kerkez, V., Gonguet, V., Do, V., Vogeti, V., + Petrovic, V., Chu, W., Xiong, W., Fu, W., Meers, W., Martinet, X., Wang, X., + Tan, X.~E., Xie, X., Jia, X., Wang, X., Goldschlag, Y., Gaur, Y., Babaei, Y., + Wen, Y., Song, Y., Zhang, Y., Li, Y., Mao, Y., Coudert, Z.~D., Yan, Z., Chen, + Z., Papakipos, Z., Singh, A., Grattafiori, A., Jain, A., Kelsey, A., + Shajnfeld, A., Gangidi, A., Victoria, A., Goldstand, A., Menon, A., Sharma, + A., Boesenberg, A., Vaughan, A., Baevski, A., Feinstein, A., Kallet, A., + Sangani, A., Yunus, A., Lupu, A., Alvarado, A., Caples, A., Gu, A., Ho, A., + Poulton, A., Ryan, A., Ramchandani, A., Franco, A., Saraf, A., Chowdhury, A., + Gabriel, A., Bharambe, A., Eisenman, A., Yazdan, A., James, B., Maurer, B., + Leonhardi, B., Huang, B., Loyd, B., De~Paola, B., Paranjape, B., Liu, B., Wu, + B., Ni, B., Hancock, B., Wasti, B., Spence, B., Stojkovic, B., Gamido, B., + Montalvo, B., Parker, C., Burton, C., Mejia, C., Wang, C., Kim, C., Zhou, C., + Hu, C., Chu, C.-H., Cai, C., Tindal, C., Feichtenhofer, C., Civin, D., Beaty, + D., Kreymer, D., Li, D., Wyatt, D., Adkins, D., Xu, D., Testuggine, D., + David, D., Parikh, D., Liskovich, D., Foss, D., Wang, D., Le, D., Holland, + D., Dowling, E., Jamil, E., Montgomery, E., Presani, E., Hahn, E., Wood, E., + Brinkman, E., Arcaute, E., Dunbar, E., Smothers, E., Sun, F., Kreuk, F., + Tian, F., Ozgenel, F., Caggioni, F., Guzmán, F., Kanayet, F., Seide, F., + Florez, G.~M., Schwarz, G., Badeer, G., Swee, G., Halpern, G., Thattai, G., + Herman, G., Sizov, G., Guangyi, Zhang, Lakshminarayanan, G., Shojanazeri, H., + Zou, H., Wang, H., Zha, H., Habeeb, H., Rudolph, H., Suk, H., Aspegren, H., + Goldman, H., Damlaj, I., Molybog, I., Tufanov, I., Veliche, I.-E., Gat, I., + Weissman, J., Geboski, J., Kohli, J., Asher, J., Gaya, J.-B., Marcus, J., + Tang, J., Chan, J., Zhen, J., Reizenstein, J., Teboul, J., Zhong, J., Jin, + J., Yang, J., Cummings, J., Carvill, J., Shepard, J., McPhie, J., Torres, J., + Ginsburg, J., Wang, J., Wu, K., U, K.~H., Saxena, K., Prasad, K., Khandelwal, + K., Zand, K., Matosich, K., Veeraraghavan, K., Michelena, K., Li, K., Huang, + K., Chawla, K., Lakhotia, K., Huang, K., Chen, L., Garg, L., A, L., Silva, + L., Bell, L., Zhang, L., Guo, L., Yu, L., Moshkovich, L., Wehrstedt, L., + Khabsa, M., Avalani, M., Bhatt, M., Tsimpoukelli, M., Mankus, M., Hasson, M., + Lennie, M., Reso, M., Groshev, M., Naumov, M., Lathi, M., Keneally, M., + Seltzer, M.~L., Valko, M., Restrepo, M., Patel, M., Vyatskov, M., Samvelyan, + M., Clark, M., Macey, M., Wang, M., Hermoso, M.~J., Metanat, M., Rastegari, + M., Bansal, M., Santhanam, N., Parks, N., White, N., Bawa, N., Singhal, N., + Egebo, N., Usunier, N., Laptev, N.~P., Dong, N., Zhang, N., Cheng, N., + Chernoguz, O., Hart, O., Salpekar, O., Kalinli, O., Kent, P., Parekh, P., + Saab, P., Balaji, P., Rittner, P., Bontrager, P., Roux, P., Dollar, P., + Zvyagina, P., Ratanchandani, P., Yuvraj, P., Liang, Q., Alao, R., Rodriguez, + R., Ayub, R., Murthy, R., Nayani, R., Mitra, R., Li, R., Hogan, R., Battey, + R., Wang, R., Maheswari, R., Howes, R., Rinott, R., Bondu, S.~J., Datta, S., + Chugh, S., Hunt, S., Dhillon, S., Sidorov, S., Pan, S., Verma, S., Yamamoto, + S., Ramaswamy, S., Lindsay, S., Lindsay, S., Feng, S., Lin, S., Zha, S.~C., + Shankar, S., Zhang, S., Zhang, S., Wang, S., Agarwal, S., Sajuyigbe, S., + Chintala, S., Max, S., Chen, S., Kehoe, S., Satterfield, S., Govindaprasad, + S., Gupta, S., Cho, S., Virk, S., Subramanian, S., Choudhury, S., Goldman, + S., Remez, T., Glaser, T., Best, T., Kohler, T., Robinson, T., Li, T., Zhang, + T., Matthews, T., Chou, T., Shaked, T., Vontimitta, V., Ajayi, V., Montanez, + V., Mohan, V., Kumar, V.~S., Mangla, V., Albiero, V., Ionescu, V., Poenaru, + V., Mihailescu, V.~T., Ivanov, V., Li, W., Wang, W., Jiang, W., Bouaziz, W., + Constable, W., Tang, X., Wang, X., Wu, X., Wang, X., Xia, X., Wu, X., Gao, + X., Chen, Y., Hu, Y., Jia, Y., Qi, Y., Li, Y., Zhang, Y., Zhang, Y., Adi, Y., + Nam, Y., Yu, Wang, Hao, Y., Qian, Y., He, Y., Rait, Z., DeVito, Z., + Rosnbrick, Z., Wen, Z., Yang, Z., and Zhao, Z. +\newblock The {Llama} 3 {Herd} of {Models}, August 2024. +\newblock URL \url{http://arxiv.org/abs/2407.21783}. +\newblock arXiv:2407.21783 [cs]. + +\bibitem[Paszke et~al.(2019)Paszke, Gross, Massa, Lerer, Bradbury, Chanan, + Killeen, Lin, Gimelshein, Antiga, Desmaison, Köpf, Yang, DeVito, Raison, + Tejani, Chilamkurthy, Steiner, Fang, Bai, and Chintala]{paszke_pytorch_2019} +Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, + T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., + DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., + Bai, J., and Chintala, S. +\newblock {PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} + {Learning} {Library}, December 2019. +\newblock URL \url{http://arxiv.org/abs/1912.01703}. +\newblock arXiv:1912.01703 [cs, stat]. + +\bibitem[Podell et~al.(2023)Podell, English, Lacey, Blattmann, Dockhorn, + Müller, Penna, and Rombach]{podell_sdxl_2023} +Podell, D., English, Z., Lacey, K., Blattmann, A., Dockhorn, T., Müller, J., + Penna, J., and Rombach, R. +\newblock {SDXL}: {Improving} {Latent} {Diffusion} {Models} for + {High}-{Resolution} {Image} {Synthesis}, July 2023. +\newblock URL \url{http://arxiv.org/abs/2307.01952}. +\newblock arXiv:2307.01952 [cs]. + +\bibitem[Sun et~al.(2023)Sun, Li, Geng, Stuijk, and + Corporaal]{sun_dissecting_2023} +Sun, W., Li, A., Geng, T., Stuijk, S., and Corporaal, H. +\newblock Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, + {Throughput} and {Numeric} {Behaviors}. +\newblock \emph{IEEE Transactions on Parallel and Distributed Systems}, + 34\penalty0 (1):\penalty0 246--261, January 2023. +\newblock ISSN 1045-9219, 1558-2183, 2161-9883. +\newblock \doi{10.1109/TPDS.2022.3217824}. +\newblock URL \url{https://ieeexplore.ieee.org/document/9931992/}. + +\bibitem[Tillet et~al.(2019)Tillet, Kung, and Cox]{tillet_triton_2019} +Tillet, P., Kung, H.~T., and Cox, D. +\newblock Triton: an intermediate language and compiler for tiled neural + network computations. +\newblock In \emph{Proceedings of the 3rd {ACM} {SIGPLAN} {International} + {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, pp.\ + 10--19, Phoenix AZ USA, June 2019. ACM. +\newblock ISBN 978-1-4503-6719-6. +\newblock \doi{10.1145/3315508.3329973}. +\newblock URL \url{https://dl.acm.org/doi/10.1145/3315508.3329973}. + +\end{thebibliography} diff --git a/iree/turbine/kernel/wave/docs/mlsys/tkw.bib b/iree/turbine/kernel/wave/docs/mlsys/tkw.bib new file mode 100644 index 00000000..da5bcf3d --- /dev/null +++ b/iree/turbine/kernel/wave/docs/mlsys/tkw.bib @@ -0,0 +1,111 @@ + +@inproceedings{tillet_triton_2019, + address = {Phoenix AZ USA}, + title = {Triton: an intermediate language and compiler for tiled neural network computations}, + isbn = {978-1-4503-6719-6}, + shorttitle = {Triton}, + url = {https://dl.acm.org/doi/10.1145/3315508.3329973}, + doi = {10.1145/3315508.3329973}, + abstract = {The validation and deployment of novel research ideas in the field of Deep Learning is often limited by the availability of efficient compute kernels for certain basic primitives. In particular, operations that cannot leverage existing vendor libraries (e.g., cuBLAS, cuDNN) are at risk of facing poor device utilization unless custom implementations are written by experts – usually at the expense of portability. For this reason, the development of new programming abstractions for specifying custom Deep Learning workloads at a minimal performance cost has become crucial.}, + language = {en}, + urldate = {2024-09-25}, + booktitle = {Proceedings of the 3rd {ACM} {SIGPLAN} {International} {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, + publisher = {ACM}, + author = {Tillet, Philippe and Kung, H. T. and Cox, David}, + month = jun, + year = {2019}, + pages = {10--19}, + file = {PDF:/Users/harsh/Zotero/storage/FMLLYK4M/Tillet et al. - 2019 - Triton an intermediate language and compiler for tiled neural network computations.pdf:application/pdf}, +} + +@misc{podell_sdxl_2023, + title = {{SDXL}: {Improving} {Latent} {Diffusion} {Models} for {High}-{Resolution} {Image} {Synthesis}}, + shorttitle = {{SDXL}}, + url = {http://arxiv.org/abs/2307.01952}, + abstract = {We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared to previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators. In the spirit of promoting open research and fostering transparency in large model training and evaluation, we provide access to code and model weights.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Podell, Dustin and English, Zion and Lacey, Kyle and Blattmann, Andreas and Dockhorn, Tim and Müller, Jonas and Penna, Joe and Rombach, Robin}, + month = jul, + year = {2023}, + note = {arXiv:2307.01952 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition}, + file = {PDF:/Users/harsh/Zotero/storage/ARJZQZ42/Podell et al. - 2023 - SDXL Improving Latent Diffusion Models for High-Resolution Image Synthesis.pdf:application/pdf}, +} + +@misc{dubey_llama_2024, + title = {The {Llama} 3 {Herd} of {Models}}, + url = {http://arxiv.org/abs/2407.21783}, + abstract = {Modern artificial intelligence (AI) systems are powered by foundation models. This paper presents a new set of foundation models, called Llama 3. It is a herd of language models that natively support multilinguality, coding, reasoning, and tool usage. Our largest model is a dense Transformer with 405B parameters and a context window of up to 128K tokens. This paper presents an extensive empirical evaluation of Llama 3. We find that Llama 3 delivers comparable quality to leading language models such as GPT-4 on a plethora of tasks. We publicly release Llama 3, including pre-trained and post-trained versions of the 405B parameter language model and our Llama Guard 3 model for input and output safety. The paper also presents the results of experiments in which we integrate image, video, and speech capabilities into Llama 3 via a compositional approach. We observe this approach performs competitively with the state-of-the-art on image, video, and speech recognition tasks. The resulting models are not yet being broadly released as they are still under development.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Dubey, Abhimanyu and Jauhri, Abhinav and Pandey, Abhinav and Kadian, Abhishek and Al-Dahle, Ahmad and Letman, Aiesha and Mathur, Akhil and Schelten, Alan and Yang, Amy and Fan, Angela and Goyal, Anirudh and Hartshorn, Anthony and Yang, Aobo and Mitra, Archi and Sravankumar, Archie and Korenev, Artem and Hinsvark, Arthur and Rao, Arun and Zhang, Aston and Rodriguez, Aurelien and Gregerson, Austen and Spataru, Ava and Roziere, Baptiste and Biron, Bethany and Tang, Binh and Chern, Bobbie and Caucheteux, Charlotte and Nayak, Chaya and Bi, Chloe and Marra, Chris and McConnell, Chris and Keller, Christian and Touret, Christophe and Wu, Chunyang and Wong, Corinne and Ferrer, Cristian Canton and Nikolaidis, Cyrus and Allonsius, Damien and Song, Daniel and Pintz, Danielle and Livshits, Danny and Esiobu, David and Choudhary, Dhruv and Mahajan, Dhruv and Garcia-Olano, Diego and Perino, Diego and Hupkes, Dieuwke and Lakomkin, Egor and AlBadawy, Ehab and Lobanova, Elina and Dinan, Emily and Smith, Eric Michael and Radenovic, Filip and Zhang, Frank and Synnaeve, Gabriel and Lee, Gabrielle and Anderson, Georgia Lewis and Nail, Graeme and Mialon, Gregoire and Pang, Guan and Cucurell, Guillem and Nguyen, Hailey and Korevaar, Hannah and Xu, Hu and Touvron, Hugo and Zarov, Iliyan and Ibarra, Imanol Arrieta and Kloumann, Isabel and Misra, Ishan and Evtimov, Ivan and Copet, Jade and Lee, Jaewon and Geffert, Jan and Vranes, Jana and Park, Jason and Mahadeokar, Jay and Shah, Jeet and van der Linde, Jelmer and Billock, Jennifer and Hong, Jenny and Lee, Jenya and Fu, Jeremy and Chi, Jianfeng and Huang, Jianyu and Liu, Jiawen and Wang, Jie and Yu, Jiecao and Bitton, Joanna and Spisak, Joe and Park, Jongsoo and Rocca, Joseph and Johnstun, Joshua and Saxe, Joshua and Jia, Junteng and Alwala, Kalyan Vasuden and Upasani, Kartikeya and Plawiak, Kate and Li, Ke and Heafield, Kenneth and Stone, Kevin and El-Arini, Khalid and Iyer, Krithika and Malik, Kshitiz and Chiu, Kuenley and Bhalla, Kunal and Rantala-Yeary, Lauren and van der Maaten, Laurens and Chen, Lawrence and Tan, Liang and Jenkins, Liz and Martin, Louis and Madaan, Lovish and Malo, Lubo and Blecher, Lukas and Landzaat, Lukas and de Oliveira, Luke and Muzzi, Madeline and Pasupuleti, Mahesh and Singh, Mannat and Paluri, Manohar and Kardas, Marcin and Oldham, Mathew and Rita, Mathieu and Pavlova, Maya and Kambadur, Melanie and Lewis, Mike and Si, Min and Singh, Mitesh Kumar and Hassan, Mona and Goyal, Naman and Torabi, Narjes and Bashlykov, Nikolay and Bogoychev, Nikolay and Chatterji, Niladri and Duchenne, Olivier and Çelebi, Onur and Alrassy, Patrick and Zhang, Pengchuan and Li, Pengwei and Vasic, Petar and Weng, Peter and Bhargava, Prajjwal and Dubal, Pratik and Krishnan, Praveen and Koura, Punit Singh and Xu, Puxin and He, Qing and Dong, Qingxiao and Srinivasan, Ragavan and Ganapathy, Raj and Calderer, Ramon and Cabral, Ricardo Silveira and Stojnic, Robert and Raileanu, Roberta and Girdhar, Rohit and Patel, Rohit and Sauvestre, Romain and Polidoro, Ronnie and Sumbaly, Roshan and Taylor, Ross and Silva, Ruan and Hou, Rui and Wang, Rui and Hosseini, Saghar and Chennabasappa, Sahana and Singh, Sanjay and Bell, Sean and Kim, Seohyun Sonia and Edunov, Sergey and Nie, Shaoliang and Narang, Sharan and Raparthy, Sharath and Shen, Sheng and Wan, Shengye and Bhosale, Shruti and Zhang, Shun and Vandenhende, Simon and Batra, Soumya and Whitman, Spencer and Sootla, Sten and Collot, Stephane and Gururangan, Suchin and Borodinsky, Sydney and Herman, Tamar and Fowler, Tara and Sheasha, Tarek and Georgiou, Thomas and Scialom, Thomas and Speckbacher, Tobias and Mihaylov, Todor and Xiao, Tong and Karn, Ujjwal and Goswami, Vedanuj and Gupta, Vibhor and Ramanathan, Vignesh and Kerkez, Viktor and Gonguet, Vincent and Do, Virginie and Vogeti, Vish and Petrovic, Vladan and Chu, Weiwei and Xiong, Wenhan and Fu, Wenyin and Meers, Whitney and Martinet, Xavier and Wang, Xiaodong and Tan, Xiaoqing Ellen and Xie, Xinfeng and Jia, Xuchao and Wang, Xuewei and Goldschlag, Yaelle and Gaur, Yashesh and Babaei, Yasmine and Wen, Yi and Song, Yiwen and Zhang, Yuchen and Li, Yue and Mao, Yuning and Coudert, Zacharie Delpierre and Yan, Zheng and Chen, Zhengxing and Papakipos, Zoe and Singh, Aaditya and Grattafiori, Aaron and Jain, Abha and Kelsey, Adam and Shajnfeld, Adam and Gangidi, Adithya and Victoria, Adolfo and Goldstand, Ahuva and Menon, Ajay and Sharma, Ajay and Boesenberg, Alex and Vaughan, Alex and Baevski, Alexei and Feinstein, Allie and Kallet, Amanda and Sangani, Amit and Yunus, Anam and Lupu, Andrei and Alvarado, Andres and Caples, Andrew and Gu, Andrew and Ho, Andrew and Poulton, Andrew and Ryan, Andrew and Ramchandani, Ankit and Franco, Annie and Saraf, Aparajita and Chowdhury, Arkabandhu and Gabriel, Ashley and Bharambe, Ashwin and Eisenman, Assaf and Yazdan, Azadeh and James, Beau and Maurer, Ben and Leonhardi, Benjamin and Huang, Bernie and Loyd, Beth and De Paola, Beto and Paranjape, Bhargavi and Liu, Bing and Wu, Bo and Ni, Boyu and Hancock, Braden and Wasti, Bram and Spence, Brandon and Stojkovic, Brani and Gamido, Brian and Montalvo, Britt and Parker, Carl and Burton, Carly and Mejia, Catalina and Wang, Changhan and Kim, Changkyu and Zhou, Chao and Hu, Chester and Chu, Ching-Hsiang and Cai, Chris and Tindal, Chris and Feichtenhofer, Christoph and Civin, Damon and Beaty, Dana and Kreymer, Daniel and Li, Daniel and Wyatt, Danny and Adkins, David and Xu, David and Testuggine, Davide and David, Delia and Parikh, Devi and Liskovich, Diana and Foss, Didem and Wang, Dingkang and Le, Duc and Holland, Dustin and Dowling, Edward and Jamil, Eissa and Montgomery, Elaine and Presani, Eleonora and Hahn, Emily and Wood, Emily and Brinkman, Erik and Arcaute, Esteban and Dunbar, Evan and Smothers, Evan and Sun, Fei and Kreuk, Felix and Tian, Feng and Ozgenel, Firat and Caggioni, Francesco and Guzmán, Francisco and Kanayet, Frank and Seide, Frank and Florez, Gabriela Medina and Schwarz, Gabriella and Badeer, Gada and Swee, Georgia and Halpern, Gil and Thattai, Govind and Herman, Grant and Sizov, Grigory and Guangyi and Zhang and Lakshminarayanan, Guna and Shojanazeri, Hamid and Zou, Han and Wang, Hannah and Zha, Hanwen and Habeeb, Haroun and Rudolph, Harrison and Suk, Helen and Aspegren, Henry and Goldman, Hunter and Damlaj, Ibrahim and Molybog, Igor and Tufanov, Igor and Veliche, Irina-Elena and Gat, Itai and Weissman, Jake and Geboski, James and Kohli, James and Asher, Japhet and Gaya, Jean-Baptiste and Marcus, Jeff and Tang, Jeff and Chan, Jennifer and Zhen, Jenny and Reizenstein, Jeremy and Teboul, Jeremy and Zhong, Jessica and Jin, Jian and Yang, Jingyi and Cummings, Joe and Carvill, Jon and Shepard, Jon and McPhie, Jonathan and Torres, Jonathan and Ginsburg, Josh and Wang, Junjie and Wu, Kai and U, Kam Hou and Saxena, Karan and Prasad, Karthik and Khandelwal, Kartikay and Zand, Katayoun and Matosich, Kathy and Veeraraghavan, Kaushik and Michelena, Kelly and Li, Keqian and Huang, Kun and Chawla, Kunal and Lakhotia, Kushal and Huang, Kyle and Chen, Lailin and Garg, Lakshya and A, Lavender and Silva, Leandro and Bell, Lee and Zhang, Lei and Guo, Liangpeng and Yu, Licheng and Moshkovich, Liron and Wehrstedt, Luca and Khabsa, Madian and Avalani, Manav and Bhatt, Manish and Tsimpoukelli, Maria and Mankus, Martynas and Hasson, Matan and Lennie, Matthew and Reso, Matthias and Groshev, Maxim and Naumov, Maxim and Lathi, Maya and Keneally, Meghan and Seltzer, Michael L. and Valko, Michal and Restrepo, Michelle and Patel, Mihir and Vyatskov, Mik and Samvelyan, Mikayel and Clark, Mike and Macey, Mike and Wang, Mike and Hermoso, Miquel Jubert and Metanat, Mo and Rastegari, Mohammad and Bansal, Munish and Santhanam, Nandhini and Parks, Natascha and White, Natasha and Bawa, Navyata and Singhal, Nayan and Egebo, Nick and Usunier, Nicolas and Laptev, Nikolay Pavlovich and Dong, Ning and Zhang, Ning and Cheng, Norman and Chernoguz, Oleg and Hart, Olivia and Salpekar, Omkar and Kalinli, Ozlem and Kent, Parkin and Parekh, Parth and Saab, Paul and Balaji, Pavan and Rittner, Pedro and Bontrager, Philip and Roux, Pierre and Dollar, Piotr and Zvyagina, Polina and Ratanchandani, Prashant and Yuvraj, Pritish and Liang, Qian and Alao, Rachad and Rodriguez, Rachel and Ayub, Rafi and Murthy, Raghotham and Nayani, Raghu and Mitra, Rahul and Li, Raymond and Hogan, Rebekkah and Battey, Robin and Wang, Rocky and Maheswari, Rohan and Howes, Russ and Rinott, Ruty and Bondu, Sai Jayesh and Datta, Samyak and Chugh, Sara and Hunt, Sara and Dhillon, Sargun and Sidorov, Sasha and Pan, Satadru and Verma, Saurabh and Yamamoto, Seiji and Ramaswamy, Sharadh and Lindsay, Shaun and Lindsay, Shaun and Feng, Sheng and Lin, Shenghao and Zha, Shengxin Cindy and Shankar, Shiva and Zhang, Shuqiang and Zhang, Shuqiang and Wang, Sinong and Agarwal, Sneha and Sajuyigbe, Soji and Chintala, Soumith and Max, Stephanie and Chen, Stephen and Kehoe, Steve and Satterfield, Steve and Govindaprasad, Sudarshan and Gupta, Sumit and Cho, Sungmin and Virk, Sunny and Subramanian, Suraj and Choudhury, Sy and Goldman, Sydney and Remez, Tal and Glaser, Tamar and Best, Tamara and Kohler, Thilo and Robinson, Thomas and Li, Tianhe and Zhang, Tianjun and Matthews, Tim and Chou, Timothy and Shaked, Tzook and Vontimitta, Varun and Ajayi, Victoria and Montanez, Victoria and Mohan, Vijai and Kumar, Vinay Satish and Mangla, Vishal and Albiero, Vítor and Ionescu, Vlad and Poenaru, Vlad and Mihailescu, Vlad Tiberiu and Ivanov, Vladimir and Li, Wei and Wang, Wenchen and Jiang, Wenwen and Bouaziz, Wes and Constable, Will and Tang, Xiaocheng and Wang, Xiaofang and Wu, Xiaojian and Wang, Xiaolan and Xia, Xide and Wu, Xilun and Gao, Xinbo and Chen, Yanjun and Hu, Ye and Jia, Ye and Qi, Ye and Li, Yenda and Zhang, Yilin and Zhang, Ying and Adi, Yossi and Nam, Youngjin and Yu and Wang and Hao, Yuchen and Qian, Yundi and He, Yuzi and Rait, Zach and DeVito, Zachary and Rosnbrick, Zef and Wen, Zhaoduo and Yang, Zhenyu and Zhao, Zhiwei}, + month = aug, + year = {2024}, + note = {arXiv:2407.21783 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition, Computer Science - Computation and Language}, + file = {PDF:/Users/harsh/Zotero/storage/BQKY8VZZ/Dubey et al. - 2024 - The Llama 3 Herd of Models.pdf:application/pdf}, +} + +@article{sun_dissecting_2023, + title = {Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, {Throughput} and {Numeric} {Behaviors}}, + volume = {34}, + copyright = {https://ieeexplore.ieee.org/Xplorehelp/downloads/license-information/IEEE.html}, + issn = {1045-9219, 1558-2183, 2161-9883}, + shorttitle = {Dissecting {Tensor} {Cores} via {Microbenchmarks}}, + url = {https://ieeexplore.ieee.org/document/9931992/}, + doi = {10.1109/TPDS.2022.3217824}, + abstract = {Tensor Cores have been an important unit to accelerate Fused Matrix Multiplication Accumulation (MMA) in all NVIDIA GPUs since Volta Architecture. To program Tensor Cores, users have to use either legacy wmma APIs or current mma APIs. Legacy wmma APIs are more easy-to-use but can only exploit limited features and power of Tensor Cores. Specifically, wmma APIs support fewer operand shapes and can not leverage the new sparse matrix multiplication feature of the newest Ampere Tensor Cores. However, the performance of current programming interface has not been well explored. Furthermore, the computation numeric behaviors of lowprecision floating points (TF32, BF16, and FP16) supported by the newest Ampere Tensor Cores are also mysterious. In this paper, we explore the throughput and latency of current programming APIs. We also intuitively study the numeric behaviors of Tensor Cores MMA and profile the intermediate operations including multiplication, addition of inner product, and accumulation. All codes used in this work can be found in https://github.com/sunlex0717/DissectingTensorCores.}, + language = {en}, + number = {1}, + urldate = {2024-09-25}, + journal = {IEEE Transactions on Parallel and Distributed Systems}, + author = {Sun, Wei and Li, Ang and Geng, Tong and Stuijk, Sander and Corporaal, Henk}, + month = jan, + year = {2023}, + pages = {246--261}, + file = {PDF:/Users/harsh/Zotero/storage/NZD3FJUB/Sun et al. - 2023 - Dissecting Tensor Cores via Microbenchmarks Latency, Throughput and Numeric Behaviors.pdf:application/pdf}, +} + +@misc{paszke_pytorch_2019, + title = {{PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} {Learning} {Library}}, + shorttitle = {{PyTorch}}, + url = {http://arxiv.org/abs/1912.01703}, + abstract = {Deep learning frameworks have often focused on either usability or speed, but not both. PyTorch is a machine learning library that shows that these two goals are in fact compatible: it provides an imperative and Pythonic programming style that supports code as a model, makes debugging easy and is consistent with other popular scientific computing libraries, while remaining efficient and supporting hardware accelerators such as GPUs.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and Köpf, Andreas and Yang, Edward and DeVito, Zach and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith}, + month = dec, + year = {2019}, + note = {arXiv:1912.01703 [cs, stat]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Statistics - Machine Learning}, + annote = {Comment: 12 pages, 3 figures, NeurIPS 2019}, + file = {PDF:/Users/harsh/Zotero/storage/D72HUVME/Paszke et al. - 2019 - PyTorch An Imperative Style, High-Performance Deep Learning Library.pdf:application/pdf}, +} + +@misc{chetlur_cudnn_2014, + title = {{cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}}, + shorttitle = {{cuDNN}}, + url = {http://arxiv.org/abs/1410.0759}, + doi = {10.48550/arXiv.1410.0759}, + abstract = {We present a library of efficient implementations of deep learning primitives. Deep learning workloads are computationally intensive, and optimizing their kernels is difficult and time-consuming. As parallel architectures evolve, kernels must be reoptimized, which makes maintaining codebases difficult over time. Similar issues have long been addressed in the HPC community by libraries such as the Basic Linear Algebra Subroutines (BLAS). However, there is no analogous library for deep learning. Without such a library, researchers implementing deep learning workloads on parallel processors must create and optimize their own implementations of the main computational kernels, and this work must be repeated as new parallel processors emerge. To address this problem, we have created a library similar in intent to BLAS, with optimized routines for deep learning workloads. Our implementation contains routines for GPUs, although similarly to the BLAS library, these routines could be implemented for other platforms. The library is easy to integrate into existing frameworks, and provides optimized performance and memory usage. For example, integrating cuDNN into Caffe, a popular framework for convolutional networks, improves performance by 36\% on a standard model while also reducing memory consumption.}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Chetlur, Sharan and Woolley, Cliff and Vandermersch, Philippe and Cohen, Jonathan and Tran, John and Catanzaro, Bryan and Shelhamer, Evan}, + month = dec, + year = {2014}, + note = {arXiv:1410.0759 [cs]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Computer Science - Neural and Evolutionary Computing}, +} + +@article{reed2022torch, + title={torch. fx: Practical program capture and transformation for deep learning in python}, + author={Reed, James and DeVito, Zachary and He, Horace and Ussery, Ansley and Ansel, Jason}, + journal={Proceedings of Machine Learning and Systems}, + volume={4}, + pages={638--651}, + year={2022} +} diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex b/iree/turbine/kernel/wave/docs/mlsys/tkw.tex similarity index 54% rename from shark_turbine/kernel/wave/docs/mlsys/tkw.tex rename to iree/turbine/kernel/wave/docs/mlsys/tkw.tex index cb56cab1..81cb5111 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex +++ b/iree/turbine/kernel/wave/docs/mlsys/tkw.tex @@ -20,6 +20,65 @@ % Use the following line for the initial blind version submitted for review: \usepackage{mlsys2024} +% For code listings +\usepackage{sourcecodepro} +\usepackage[T1]{fontenc} +\usepackage{listings} +\usepackage[dvipsnames]{xcolor} +\definecolor{commentgreen}{RGB}{2,112,10} +\definecolor{eminence}{RGB}{108,48,130} +\definecolor{weborange}{RGB}{255,165,0} +\definecolor{frenchplum}{RGB}{129,20,83} +%Define Colors +\definecolor{gray}{RGB}{102,102,102} %#666666 +\definecolor{lightblue}{RGB}{0,102,153} %#006699 +\definecolor{lightgreen}{RGB}{102,153,0} %#669900 +\definecolor{bluegreen}{RGB}{51,153,126} %#33997e +\definecolor{magenta}{RGB}{217,74,122} %#d94a7a +\definecolor{orange}{RGB}{226,102,26} %#e2661a +\definecolor{purple}{RGB}{125,71,147} %#7d4793 +\definecolor{green}{RGB}{113,138,98} %#718a62 + +\usepackage{tikz} +\usetikzlibrary{positioning, fit, backgrounds} + +\usepackage[framemethod=tikz]{mdframed} +\lstdefinelanguage{Wave}{ + language=Python, + classoffset=1, + morekeywords={WorkgroupConstraint, TilingConstraint, WaveConstraint, HardwareConstraint}, + keywordstyle=\color{lightblue}, + classoffset=2, + morekeywords={Memory, Register}, + keywordstyle=\color{lightgreen}, + classoffset=3, + morekeywords={reduction, read, write, mma}, + keywordstyle=\color{magenta}, + classoffset=4, + morekeywords={@wave, @reduction}, + keywordstyle=\color{orange}, + sensitive=false, % keywords are not case-sensitive +} + +\lstset{ + language={Wave}, + basicstyle={\scriptsize\ttfamily}, + identifierstyle={\scriptsize\ttfamily}, + commentstyle={\scriptsize\itshape\ttfamily}, + keywordstyle={\scriptsize\bfseries\ttfamily}, + ndkeywordstyle={\scriptsize\ttfamily}, + stringstyle={\scriptsize\ttfamily}, + frame={tb}, + breaklines=true, + columns=[l]{fullflexible}, + xrightmargin=0em, + xleftmargin=0em, + numberstyle={\scriptsize}, + stepnumber=1, + numbersep=1em, + lineskip=-0.5ex, +} + % If accepted, instead use the following line for the camera-ready submission: % \usepackage[accepted]{mlsys2024} @@ -30,7 +89,7 @@ \begin{document} \twocolumn[ -\mlsystitle{Submission and Formatting Instructions for MLSys 2024} +\mlsystitle{Wave : A Symbolic Python DSL and Compiler for High Performance Machine Learning} % It is OKAY to include author information, even for blind % submissions: the style file will automatically remove it for you @@ -94,6 +153,311 @@ %\printAffiliationsAndNotice{} % leave blank if no need to mention equal contribution \printAffiliationsAndNotice{\mlsysEqualContribution} % otherwise use the standard text. +\section{Introduction} +Generative models have seen tremendous success in a wide variety of +domains ranging from image generation to natural language processing and beyond. +\cite{podell_sdxl_2023,dubey_llama_2024}. Much of this success is being +driven by graphics processing units (GPUs) which while originally +designed for graphics, are now being optimized for machine learning. +Both datacenter and consumer grade GPUs feature powerful matrix multiplication hardware units +and specialized instructions to enable high performance inference and training \cite{sun_dissecting_2023}. +\\ \\ +Given the importance of GPUs in machine learning, significant +effort has been put into developing frameworks that allow developers to +write high performance machine learning models with a low barrier to entry. Frameworks such +as Pytorch \cite{paszke_pytorch_2019} have become extremely popular +because they expose a Python based approach to programming GPUs. Prior +to the advent of these frameworks, developers had to write CUDA or OpenCL +kernels by hand which required significant expertise to achieve +good performance and did not scale well to new operators. +\\ \\ +Under the hood, these machine learning frameworks rely heavily +on vendor-specific libraries such as cuDNN \cite{chetlur_cudnn_2014} to achieve high performance. +These libraries are performant but are black boxes consisting of +hand-written kernels and often do not support the full set of +operators encountered in machine learning models. +To address these limitations, recent work has focused on developing +Python domain specific languages (DSL) that allow developers to get high performance +while reducing the kernel complexity. Triton \cite{tillet_triton_2019}. +is a popular Python DSL that exposes a workgroup level programming +model and allows developers to author high performance kernels. +In the programmability versus performance tradeoff, Triton demonstrated that it is possible to +achieve high performance while maintaining a high level of programmability. However, +Triton kernels often get quite complex as the kernel complexity grows. Most of this complexity +comes from exposing a pointer based approach to access and manipulate memory. +\\ \\ +In this paper, we introduce Wave, a Python DSL and compiler for high performance machine learning. +Wave exposes a subgroup (wave or warp) level programming model and uses constraints +to specify the distribution strategy for the kernel. This allows for a separation between +the kernel and distribution strategy and results in simpler kernels. The language +and compiler make extensive use of symbolic data types to represent tensor shapes and memory access patterns +that make it easier to reason about the kernel. We demonstrate that Wave can achieve competitive performance +with Triton and hand-tuned kernels on core machine learning operators such as matrix multiplication, +convolutions and attention. +In summary, the contributions of this paper are as follows: +\begin{itemize} + \item \textbf{Wave language} (Section \ref{section:wave_language}): A Python DSL that exposes a subgroup programming model for GPUs. The language + defines constraints that separate distribution strategies from the description of the core computation. Tensor shapes and address spaces + are represented using symbolic types (using sympy). + \item \textbf{Wave compiler} (Section \ref{section:wave_compiler}): A Python compiler that uses symbolic types + to represent memory access patterns and reason about them. The compiler uses torch.fx to trace the kernel, + then runs a series of compiler optimization passes and finally lowers the computation graph to MLIR and LLVM. + \item \textbf{Numerical Experiments} (Section \ref{section:numerical_experiments}): Numerical experiments on + matrix multiplication, convolutions and attention that demonstrate the performance of Wave kernels and show + that it is on par with existing DSLs and hand-tuned libraries. + +\end{itemize} + +\section{Wave Language} +In this section, we will go through the Wave language and its features using matrix multiplication as an example. See Listing \ref{lst:gemm} for the full code listing. + +\label{section:wave_language} +\subsection{Wave Programming Model} +Wave programs follow the single-program multiple data (SPMD) programming model where the kernel is written +at the level of execution of a single wave or warp. While waves are core to how programs are executed on +GPUs, most GPU languages do not expose them directly to the developer. CUDA, for example, allows workgroup +and thread level programming but does not expose the wave level while Triton only exposes workgroup level programming. +The advantages of the wave programming model are that it allows developers to write kernels at the same level of +abstraction as the native hardware matrix multiply accumulate (MMA) instructions which operate at the granularity of waves +giving them low-level control from a high-level abstraction. +% Should we compare to the Triton programming model here? +% Possibly we could have roughly something like figure 3 of the Triton paper + +\subsection{Syntax \& Semantics} +The Wave language partitions programs into two distinct regions as can be seen in Listing \ref{lst:gemm}. +The first part of the program consists of constraints which are new constructs introduced by the language. + +\subsubsection{Constraints} +Constraints are used to represent the distribution strategy of a kernel. Each constraint operates on a particular +dimension and specifies how that dimension is to be distributed. In the matrix multiplication example of Listing \ref{lst:gemm}, +the \texttt{WorkgroupConstraint} on symbolic dimension \texttt{M/N} states that the \texttt{M/N} dimension is distributed among work group dimension 0/1 +with a tile size of \texttt{BLOCK\_M/BLOCK\_N}. The \texttt{WaveConstraint} on the \texttt{M/N} dimension states that the $M/N$ dimension is then further distributed among waves +with a tile size of \texttt{BLOCK\_M / 2 / BLOCK\_N / 2}. The \texttt{TilingConstraint} on \texttt{K} specifies that the \texttt{K} dimension is tiled with a tile size of \texttt{BLOCK\_K} +in a sequential for loop. Finally, the \texttt{HardwareConstraint} specifies hardware specific parameters such as the number of threads per wave, the number of waves per block and +the canonical shape of the program. +\\ \\ +The canonical shape of the program specifies the minimum granularity of the operations in the program. In the matrix multiplication example and for programs using MMA instructions, the canonical shape is +the shape of the MMA instruction which is \texttt{M = 16, N = 16, K = 16}. For programs that do not use MMA instructions, users can +explicitly specify the canonical shape of the program in the \texttt{HardwareConstraint} by using the \texttt{vector\_shapes} keyword. For more examples of this, +see Appendix \ref{appendix:samples}. +\\ \\ +The constraints of the program serve multiple purposes. First, they separate out the distribution strategy from the kernel. +Second, they result in much simpler kernels because kernel authors do not need to keep track the offsets to different memory locations +as the compiler takes care of this. + +\subsubsection{Kernel} +The kernel is the second part of the program and consists of the core computation. It is annotated with the \texttt{@wave} decorator which +is used by the compiler to trace the kernel. In the matrix multiplication example, the inputs to the kernel are of type \texttt{Memory} which +represents a memory buffer with a symbolic shape and address space (shared memory or global memory). In the kernel, +even though the inputs are specified as type \texttt{Memory} with shape \texttt{[M, K], [N, K], [M, N]}, the actual shape of the memory +is determined by the constraints. In order to simplify the kernel, we write the kernel using the original symbolic shapes and let the compiler +determine the actual shapes of the memory buffers. +% Note: The question that comes up with this phrasing: Could these shapes then just be left out. In reality they are at least important to figure out the indexing dimensions. + +\begin{lstlisting}[language=Wave, frame=single, breaklines, caption={Mixed-precision $C = A \times B^{T}$ expressed in Wave.}, captionpos=b, label={lst:gemm}] +constraints = [WorkgroupConstraint(M, BLOCK_M, 0)] +constraints += [WorkgroupConstraint(N, BLOCK_N, 1)] +constraints += [TilingConstraint(K, BLOCK_K)] +constraints += [WaveConstraint(M, BLOCK_M / 2)] +constraints += [WaveConstraint(N, BLOCK_N / 2)] +constraints += [ + HardwareConstraint(threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=MMAType.F32_16x16x16_F16 +] + +@wave(constraints) +def gemm( + a: Memory[M, K, ADDRESS_SPACE, f16], + b: Memory[N, K, ADDRESS_SPACE, f16], + c: Memory[M, N, GLOBAL_ADDRESS_SPACE, f32], +): + c_reg = Register[M, N, f32](0.0) + + @reduction(K, init_args=[c_reg]) + def loop(acc: Register[M, N, f32]) -> Register[M, N, f32]: + a_reg = read(a, elements_per_thread=ELEMS) + b_reg = read(b, elements_per_thread=ELEMS) + acc = mma(a_reg, b_reg, acc) + return acc + + write(loop, c, elements_per_thread=ELEMS) +\end{lstlisting} + +\newpage + +\section{Wave Compiler} +\label{section:wave_compiler} +The Wave compiler is a Python-based compiler designed to process and optimize kernels written in the Wave language. It leverages symbolic types to represent memory access patterns and perform reasoning about them. The compilation process involves several key steps: + +\begin{tikzpicture}[node distance=0.5cm, auto] + % Define styles + \tikzstyle{block} = [rectangle, draw, fill=blue!20, + text width=5em, text centered, rounded corners, minimum height=3em] + \tikzstyle{mlir} = [rectangle, draw, fill=gray!20, + text width=20em, text centered, minimum height=7em] + \tikzstyle{wave_compiler} = [rectangle, draw, fill=gray!20, + text width=20em, text centered, minimum height=19em] + \tikzstyle{dialect} = [rectangle, draw, fill=yellow!20, + text width=5em, text centered, rounded corners, minimum height=2em] + \tikzstyle{line} = [draw, ->] + + % Place nodes + + \node [text width=10em, text centered] (wave) {Wave Kernel}; + + \node [wave_compiler, below=of wave] (wave_compiler) {}; + \node [anchor=north west, inner sep=0.5em] at (wave_compiler.north west) {Wave Compiler}; + + \node [block, below=1cm of wave] (tracing) {Tracing (torch.fx)}; + \node [block, below=of tracing] (symbolic) {Symbolic Analysis}; + \node [block, below=of symbolic] (optimization) {Optimization Passes}; + \node [block, below=of optimization] (mlir_codegen) {MLIR Codegen}; + + % MLIR box with dialects + \node [mlir, below=of mlir_codegen] (mlir) {}; + \node [anchor=north west, inner sep=0.5em] at (mlir.north west) {MLIR}; + \node [dialect, left=3em of mlir.center] (dialect1) {amdgpu Dialect}; + \node [dialect, at=(mlir.center)] (gpu_dialect) {GPU Dialect}; + \node [dialect, right=3em of mlir.center] (llvm_dialect) {LLVM Dialect}; + \node [block, below=of mlir] (llvm) {LLVM IR}; + + % Draw edges + \path [line] (wave) -- (tracing); + \path [line] (tracing) -- (symbolic); + \path [line] (symbolic) -- (optimization); + \path [line] (optimization) -- (mlir_codegen); + \path [line] (mlir_codegen) -- (dialect1); + \path [line] (mlir_codegen) -- (gpu_dialect); + \path [line] (mlir_codegen) -- (llvm_dialect); + + \path [line] (mlir) -- (llvm); + + % Add a note about constraints + \node [text width=5em, text centered, below right=-0.25cm and 0.5cm of symbolic] (constraints) {Constraint Processing}; + \draw [dashed, ->] (constraints) -- (symbolic); + \draw [dashed, ->] (constraints) -- (optimization); +\end{tikzpicture} + +\subsection{Tracing with torch.fx} + +The Wave compiler utilizes torch.fx~\cite{reed2022torch} for symbolic tracing of the kernel. This process involves executing the Python kernel program with special \emph{Proxy} objects that act as placeholders for actual values. As the program runs, torch.fx records all definitions and function calls, effectively capturing the computational structure of the kernel. The result is a comprehensive representation of the kernel's logic in the form of the torch.fx graph IR. By leveraging torch.fx, Wave strikes a balance between the flexibility of a Python-embedded DSL and the power of a custom compiler, allowing for specialized optimizations while maintaining an accessible and extensible programming model. + +\subsection{Intermediate Representation} +We extended the torch.fx intermediate representation with custom primitives and types to accurately model the Wave programming model. Our enhancements include custom primitives representing Wave-specific constructs such as wave-level operations and memory accesses. Representing them explicitly in the IR simplifies type inference and enables easier transformations on the IR. At the same time we keep compatibility with torch.fx IR in order to reuse existing tooling for e.g. visualization. + + +\subsection{Lowering Wave programming to thread programming} +% Note: A description of the programming model will already be in the previous chapter. So I can take that as a given. +A wave kernel is expressed following an SPMD programming model in the granularity of a single wave. While the target programming model for the output IR follows SPMD as well, it operates on the granularity of a single thread with explicit data movement and synchronization. +In consequence the computation graph needs to be expanded according to the input sizes and the distribution of data to threads to model the instructions for each thread. +\smallskip +As each thread has to load different data depending on its position in the launch grid %or do we call this differently? +we prepend the IR with operations to get the \texttt{\footnotesize thread\_id} and \texttt{\footnotesize workgroup\_id} for the relevant dimensions. + +First, we determine to which level each dimension has to expand according to the input sizes and constraints. +% small example? +We start at the leaf nodes of the kernel and follow def-uses upward until we reach the kernel inputs. We each node we reach we: +\begin{enumerate} + \item Determine the dimensions this node needs to index in + \item \ldots +\end{enumerate} +% TODO: Preliminary, maybe better to just express this as an algorihm? + +% TODO: Can we produce a figure (or graph) of pre-expansion and post-expansion? Maybe only expansion in a single dimension if this gets too large otherwise? + + +% Thought: +% We name this programming model \emph{SIWT} (Single Instruction, Wave of Threads) denoting that a single instruction is executed by a wave of threads. This is inspired by the SIMT execution model where instructions are executed by all threads in lockstep. + + +\subsection{Instruction Scheduling} +% Describe the deep type of decisions we can take already on this level with the vast information we still have available + + +\subsection{Optimization Passes} +After tracing, the compiler executes a series of optimization passes on the computational graph, such as: +\textbf{barrier insertion} % We could give more specifics here on when we insert barriers +\ldots + +\subsection{Lowering to MLIR} + +% TODO: mention symbolic types, their optimization and lowering here +% Also mention for which sizes we use them: Tensor shapes, ... +% basically the sympy walker we have. + +The final stage of the compilation process involves lowering the optimized computational graph to several MLIR dialects: % TODO: possibly present this differently, for now a list is fine. +%We target the amdgpu and gpu dialects to model GPU + +% Possibly simplify if this is too specific +\begin{itemize} + \item Intrinsics used in the kernel are directly mapped to operations of the \texttt{\footnotesize amdgpu} dialect + \item The \texttt{\footnotesize gpu} dialect is used to model general GPU concepts, such as \texttt{\footnotesize thread\_id} + \item The \texttt{\footnotesize scf} dialect is used to model loops + \item The \texttt{\footnotesize llvm} dialect is used to emit scheduling barriers that preserve our instruction scheduling decisions + \item We use the \texttt{\footnotesize scf}, \texttt{\footnotesize arith}, and \texttt{\footnotesize math} dialects to model loops and arithmetic. + \item Furthermore we use the \texttt{\footnotesize memref} and \texttt{\footnotesize func} dialects +\end{itemize} + + + +% mention LLVM scheduling barriers + +\subsection{Integration} +TODO: briefly describe integration into Pytorch (compile to vmfb + call with torch tensors) \& IREE (\ldots)? + + + +\bigskip +In summary, the Wave compiler combines symbolic computation, multi-stage lowering, and GPU-specific optimizations to translate high-level Wave kernels into efficient, low-level code suitable for execution on GPUs. Its use of symbolic types and constraint-based programming model allows for powerful optimizations while maintaining a high level of abstraction for kernel developers. + +\section{Numerical Experiments} +\label{section:numerical_experiments} + +\section{Related Work} +\label{section:related_work} + +\section{Conclusions \& Future Work} +\label{section:conclusions} + +\section{Acknowledgements} +\label{section:acknowledgements} + +\section{Appendix: Sample Wave Programs} +\label{section:samples} + + +\iffalse +It has a Python based compiler that uses torch.fx tracing to define +and trace operators written in the language. The torch.fx graphs are then run through a series of optimization passes +on the computation graph and are finally lowered to MLIR and subsequently LLVM. This code generation flow allows compiler writers +to blend high productivity in Python with high performance from the MLIR and LLVM +code generation flow. +\\ \\ + + +\section{Memory Access Patterns} +We represent memory access patterns in the language using the standard +triplet notation consisting of an offset, number of elements, and absolute stride and associate +a triplet with each tensor dimension. The memory access pattern for a given operation +is determined by the access patterns of the operands of the operation as well as +the user-specified constraints. For example, the memory access pattern for the output +of an elementwise operation is determined from the access patterns of the inputs, +whereas for a matrix-multiply accumulate operation, the memory access patterns of the operands are specified by +the hardware constraint. +\\ \\ +One of the advantages of the dimension based specification is that it obviates +the need for any propagation of memory access patterns through the computation graph, +as is commonly done in other frameworks. When setting the access pattern for a specific +dimension of a tensor, the access pattern is taken to be the union of all possible +access patterns with the determination of which access pattern to use based on +the minimization of an appropriate metric across the entire graph (see Section 3). + +\fi + +\newpage + +\iffalse \section{Electronic Submission} \label{submission} @@ -527,8 +891,9 @@ \section*{Acknowledgements} % In the unusual situation where you want a paper to appear in the % references without citing it in the main text, use \nocite \nocite{langley00} +\fi -\bibliography{example_paper} +\bibliography{tkw} \bibliographystyle{mlsys2024} diff --git a/shark_turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py similarity index 84% rename from shark_turbine/kernel/wave/expansion.py rename to iree/turbine/kernel/wave/expansion.py index b96f778c..75cc5551 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -19,15 +19,15 @@ from .._support.indexing import IndexingContext, IndexSequence from ...support.logging import get_logger from .._support.tracing import CapturedTrace -from .utils import get_mma_dimensional_mapping +from .utils import get_mma_dimensional_mapping, specialize_index_sequence from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") -# This represents a mapping of a node + indexing into the dimensions to the -# corresponding expanded node in these specific dimensions. An example for a -# record in this map is (read_0_0_0, ((M,0),(N,0),(K,1)) -> read_0_0_1 +# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results) +# of node into the dimensions to the corresponding expanded node in these specific dimensions. +# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1. ExpandedNodeMap: TypeAlias = dict[ - tuple[CustomOp, tuple[tuple[IndexSymbol, int], ...]], CustomOp + tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp ] @@ -81,6 +81,11 @@ def get_indexed_dims( """ if isinstance(nodeOrDims, CustomOp): nodeOrDims = nodeOrDims.indexing_dims + # Flatten dims for node with multiple values or expanded Reduction. + if all(isinstance(el, Sequence) for el in nodeOrDims): + flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) + flatten_dims_set = dict.fromkeys(flattened_dims) + nodeOrDims = list(flatten_dims_set) return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) @@ -141,6 +146,7 @@ def compute_stride( def set_node_index( constraints: Sequence[Constraint], mma_index: dict[IndexSymbol, int], + mma_slices: dict[IndexSymbol, list[fx.Node]], dim_tile_size: dict[IndexSymbol, int], custom: CustomOp, dim_scaling: dict[IndexSymbol, int], @@ -171,11 +177,7 @@ def set_node_index( for dim in custom.indexing_dims: index_seq = None for constraint in sorted_constraints: - mma_check = ( - isinstance(constraint, HardwareConstraint) - and dim in mma_index - and isinstance(custom, MMA) - ) + mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index vector_check = ( isinstance(constraint, HardwareConstraint) @@ -217,6 +219,8 @@ def set_node_index( index_seq = constraint.apply( constraint_index, dim, elements_per_thread, stride ) + if mma_index: + index_seq = specialize_index_sequence(index_seq, mma_slices, custom) else: if index_seq is None: @@ -246,10 +250,10 @@ def expand_graph( dim_scaling = constraints_or_scaling node_index_setter = lambda *args: None else: - mma_index = get_mma_dimensional_mapping(trace) + mma_index, mma_slices = get_mma_dimensional_mapping(trace) dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) node_index_setter = partial( - set_node_index, constraints_or_scaling, mma_index, dim_tile_size + set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size ) # Start from the back and expand in the corresponding indexing dimensions of a node @@ -298,6 +302,7 @@ def _expand_node( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" if isinstance(node, list): @@ -305,23 +310,31 @@ def _expand_node( for elem in node: expanded_nodes.append( _expand_node( - elem, trace, dim_query, dim_scaling, node_index_setter, context + elem, + trace, + dim_query, + dim_scaling, + node_index_setter, + context, + res_idx, ).fx_node ) return expanded_nodes # If we expanded a node in the same dimensions before, we can reuse it - if (node, get_indexed_dims(dim_query, node)) in context: + if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") - return context[(node, get_indexed_dims(dim_query, node))] + return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): return _expand_reduction( - node, trace, dim_query, dim_scaling, node_index_setter, context + node, trace, dim_query, dim_scaling, node_index_setter, context, res_idx ) - elif isinstance(node, GetResult): + elif isinstance(node, Getitem): + res_idx = node.res_idx + elif isinstance(node, GetResult) and not isinstance(node, Getitem): # The presence of a GetResult node indicates that the reduction has already # been expanded. Simply return the corresponding node. reduction = get_custom(node.value) - return context[(reduction, get_indexed_dims(dim_query, reduction))] + return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)] elif isinstance(node, Allocate): # Allocate nodes are not expanded. return node @@ -329,14 +342,28 @@ def _expand_node( # Filter out the dimensions that are not indexed by the node restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims) logger.debug(f"Expanding node: {node} in {restricted_dims}") + + # For iter args, we want to insert + if not hasattr(_expand_node, "last_expanded_iter_arg"): + _expand_node.last_expanded_iter_arg = None + # Clone the node for the new expansion. The original node is reused for the # case of all dimensions being zero. if expansion_needed(restricted_dims, node.indexing_dims): - new_node = node.copy() + new_node = node.copy( + anchor=( + _expand_node.last_expanded_iter_arg + if isinstance(node, IterArg) + else None + ) + ) else: new_node = node logger.debug(f"did not clone node: {node} in {restricted_dims}") + if isinstance(node, IterArg): + _expand_node.last_expanded_iter_arg = new_node.fx_node + new_node.fx_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) node_index_setter(new_node, restricted_dims) @@ -353,12 +380,13 @@ def _expand_node( dim_scaling, node_index_setter, context, + res_idx, ) new_node.update_arg(i, new_arg) new_node.post_expansion(constraints) - context[(node, get_indexed_dims(restricted_dims, node))] = new_node + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -369,6 +397,7 @@ def _expand_reduction( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" # Determine the dimensions to expand the reduction from the indexing of its users @@ -391,32 +420,41 @@ def _expand_reduction( new_output_args = [] new_init_args = [] for dim_vals in get_dim_combinations(dim_scaling, expand_dims): - for arg_idx, arg in output.node_args.items(): - dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + return_vals = output.return_vals[0] + dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for arg_idx, arg in enumerate(return_vals): + arg = get_custom(arg) # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) new_node.add_to_graph(reduction.graph) new_node.fx_node.name = get_expanded_name(new_node, dims) - context[(reduction, get_indexed_dims(dims, expand_dims))] = new_node + context[ + (reduction, get_indexed_dims(dims, expand_dims), arg_idx) + ] = new_node # Proceed with expansion inside the reduction new_output_args.append( - _expand_node(arg, trace, dims, dim_scaling, node_index_setter, context) + _expand_node( + arg, trace, dims, dim_scaling, node_index_setter, context, res_idx + ) ) - # Proceed with expansion outside the reduction - for init_arg in reduction.init_args: - new_init_args.append( - _expand_node( - get_custom(init_arg), - trace, - dims, - dim_scaling, - node_index_setter, - context, - ) + # Proceed with expansion outside the reduction + for init_arg in reduction.init_args: + new_init_args.append( + _expand_node( + get_custom(init_arg), + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) + ) # Update init_args and return values reduction.update_arg( @@ -424,11 +462,17 @@ def _expand_reduction( ) output.update_arg("return_vals", [node.fx_node for node in new_output_args]) _handle_reduction_dim( - reduction, output, trace, dim_scaling, node_index_setter, context + reduction, + output, + trace, + dim_scaling, + node_index_setter, + context, + res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return # the node corresponding to the original query - return context[(reduction, get_indexed_dims(dim_query, expand_dims))] + return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: @@ -518,6 +562,7 @@ def _handle_reduction_dim( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int, ): # Rediscover iter args # TODO: Register iter args with the reduction initially so accessing them is easier @@ -554,7 +599,13 @@ def _handle_reduction_dim( saved_arg = user.node_args[index] user.update_arg(index, dummy) new_node = _expand_node( - user, trace, dims, dim_scaling, node_index_setter, context + user, + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) # This expansion always happens, user should never be reused diff --git a/shark_turbine/kernel/wave/hoisting.py b/iree/turbine/kernel/wave/hoisting.py similarity index 95% rename from shark_turbine/kernel/wave/hoisting.py rename to iree/turbine/kernel/wave/hoisting.py index df68c753..5a4773d7 100644 --- a/shark_turbine/kernel/wave/hoisting.py +++ b/iree/turbine/kernel/wave/hoisting.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ...support.logging import get_logger -from shark_turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py similarity index 98% rename from shark_turbine/kernel/wave/index_sequence_analysis.py rename to iree/turbine/kernel/wave/index_sequence_analysis.py index cec8b60b..b9212f01 100644 --- a/shark_turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -24,7 +24,7 @@ def get_vector_shape( hardware_constraint: HardwareConstraint, symbolic_shape: list[IndexSymbol], ) -> list[int]: - mma_indices = get_mma_dimensional_mapping(trace) + mma_indices, _ = get_mma_dimensional_mapping(trace) return [ get_hardware_vector_size(dim, hardware_constraint, mma_indices) for dim in symbolic_shape diff --git a/shark_turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py similarity index 57% rename from shark_turbine/kernel/wave/iree_utils.py rename to iree/turbine/kernel/wave/iree_utils.py index 6d612c91..39f67404 100644 --- a/shark_turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +from typing import Any from .utils import compile_and_invoke from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM @@ -23,6 +24,23 @@ def get_mmt_asm(lhs_type: str, rhs_type: str, acc_type: str) -> str: return matmul_function +def get_conv_asm( + conv_type: str, lhs_type: str, rhs_type: str, res_type: str, stride: int +) -> str: + res_dtype = res_type.split("x")[-1] + return f""" + func.func @conv_{conv_type}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{res_type}> {{ + %c0 = arith.constant 0.0 : {res_dtype} + %init = tensor.empty() : tensor<{res_type}> + %inital_result = linalg.fill ins(%c0 : {res_dtype}) outs(%init : tensor<{res_type}>) -> tensor<{res_type}> + %result = linalg.conv_{conv_type} + {{dilations = dense<1> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} + ins(%lhs, %rhs : tensor<{lhs_type}>, tensor<{rhs_type}>) + outs(%inital_result : tensor<{res_type}>) -> tensor<{res_type}> + return %result : tensor<{res_type}> + }}""" + + def dtype_str(dtype: torch.dtype) -> str: dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None) if dtype_str is None: @@ -39,20 +57,36 @@ def generate_iree_ref( kernel_inputs: list[torch.Tensor], kernel_outputs: list[torch.Tensor], config: dict[str, str], + **kwargs: dict[str, Any], ): """ Generate a reference output for the given kernel type and arguments. """ asm = None + conv_str = "conv_" if kernel_type == "mmt": lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_mmt_asm(lhs_type, rhs_type, acc_type) + elif kernel_type.startswith(conv_str): + lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) + rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) + acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + conv_type = kernel_type[len(conv_str) :] + asm = get_conv_asm( + conv_type, lhs_type, rhs_type, acc_type, int(kwargs["stride"]) + ) else: raise ValueError(f"Unknown kernel type: {kernel_type}") compile_and_invoke( - asm, kernel_type, config, kernel_inputs, kernel_outputs, True, False + asm, + kernel_type, + config, + kernel_inputs, + kernel_outputs, + run=True, + run_bench=kwargs.get("run_bench", False), ) diff --git a/shark_turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py similarity index 97% rename from shark_turbine/kernel/wave/minimize_global_loads.py rename to iree/turbine/kernel/wave/minimize_global_loads.py index 3ea1a3d0..17971354 100644 --- a/shark_turbine/kernel/wave/minimize_global_loads.py +++ b/iree/turbine/kernel/wave/minimize_global_loads.py @@ -63,12 +63,11 @@ def materialize_shape( constraint_tile_size: dict[IndexSymbol, int], symbolic_shape: list[IndexSymbol] ) -> list[int]: materialized_shape = [] - idxc = IndexingContext.current() for dim in symbolic_shape: if dim in constraint_tile_size: - materialized_shape.append(constraint_tile_size[dim].subs(idxc.subs)) + materialized_shape.append(subs_idxc(constraint_tile_size[dim])) else: - materialized_shape.append(dim.subs(idxc.subs)) + materialized_shape.append(subs_idxc(dim)) return materialized_shape diff --git a/iree/turbine/kernel/wave/packaging/build_package.py b/iree/turbine/kernel/wave/packaging/build_package.py new file mode 100644 index 00000000..6da47254 --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/build_package.py @@ -0,0 +1,49 @@ +import pathlib +from typing import Any +import jinja2 +import shutil + + +def build_folders(kernel_info: dict[str, Any], output_dir: str): + package_path = pathlib.Path(output_dir) / kernel_info["package_name"] + package_path.mkdir(parents=True, exist_ok=True) + subfolder = package_path / kernel_info["package_name"] + subfolder.mkdir(parents=True, exist_ok=True) + init_file = subfolder / "__init__.py" + with open(init_file, "w") as f: + f.write(f"from .main import {kernel_info['kernel_name']}\n") + return subfolder + + +def copy_artifacts(kernel_info: dict[str, Any], output_dir: str): + shutil.copy(kernel_info["vmfb_path"], output_dir) + + +def render_templates(kernel_info: dict[str, Any], output_dir: str): + parent_dir = pathlib.Path(__file__).resolve().parent + template_loader = jinja2.FileSystemLoader(searchpath=parent_dir / "templates") + template_env = jinja2.Environment(loader=template_loader) + main_template = template_env.get_template("main.py.j2") + updated_template = main_template.render( + kernel_function_name=kernel_info["kernel_name"], + kernel_num_inputs=kernel_info["num_inputs"], + kernel_dispatch_name=kernel_info["dispatch_name"], + vmfb_path=pathlib.Path(kernel_info["vmfb_path"]).name, + ) + with open(output_dir / "main.py", "w") as f: + f.write(updated_template) + setup_template = template_env.get_template("setup.py.j2") + updated_template = setup_template.render( + kernel_package_name=kernel_info["package_name"], + kernel_version=kernel_info["kernel_version"], + ) + with open(output_dir.parents[0] / "setup.py", "w") as f: + f.write(updated_template) + + +def create_pip_package(kernel_info: dict[str, Any], output_dir: str): + """Builds a pip package from the current directory.""" + + subfolder = build_folders(kernel_info, output_dir) + copy_artifacts(kernel_info, subfolder) + render_templates(kernel_info, subfolder) diff --git a/iree/turbine/kernel/wave/packaging/templates/main.py.j2 b/iree/turbine/kernel/wave/packaging/templates/main.py.j2 new file mode 100644 index 00000000..20f6add2 --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/templates/main.py.j2 @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +# Do not modify this file. +# This file is automatically generated from a template in iree/turbine/kernel/wave/packaging/templates/main.py. +# ========================================================================================== + +from functools import lru_cache +import iree.runtime as rt +from typing import Callable, Optional, Union +from threading import local, Lock +import warnings +from iree.runtime import ( + BufferUsage, + HalBufferView, + HalDevice, + HalDriver, + MemoryType, + VmInstance, + VmModule, + create_hal_module, + get_driver, +) +import torch + +_CURRENT_THREAD = local() +_CONFIG_LOCK = Lock() +_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None + + +class MismatchedDeviceSetClearError(AssertionError): + def __init__(self): + super().__init__("Calls to Device.set()/clear() are mismatched or unbalanced.") + + +class UnsupportedTorchDeviceError(Exception): + def __init__(self, torch_device): + super().__init__( + f"Attempt to use turbine with a torch.device that is not supported by this build: {torch_device}" + ) + + +class NoCurrentDeviceError(Exception): + def __init__(self): + super().__init__( + "You accessed a method which requires a current device but none was set on this thread. " + "Either pass an explicit 'device=' or set a current device via " + "`with device:`" + ) + + +def get_vm_instance() -> VmInstance: + global _GLOBAL_VM_INSTANCE + if not _GLOBAL_VM_INSTANCE: + with _CONFIG_LOCK: + if not _GLOBAL_VM_INSTANCE: + _GLOBAL_VM_INSTANCE = VmInstance() + return _GLOBAL_VM_INSTANCE + + +class DeviceState: + """State for an instantiated HAL device. + + Note that the IREE runtime internally manages a global cache of drivers for + standard named-access (not custom-constructed) drivers. + """ + + __slots__ = [ + "device", + "driver", + "instance", + "enumerated_info", + "torch_device", + "dlpack_device_type_code", + ] + + def __init__( + self, + *, + driver: Union[str, HalDriver], + device: Optional[HalDevice] = None, + vm_instance: Optional[VmInstance] = None, + enumerated_info: Optional[dict] = None, + torch_device: Optional[torch.device] = None, + dlpack_device_type_code: int = 0, + ): + self.instance = vm_instance or get_vm_instance() + self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) + self.device = device if device else self.driver.create_default_device() + self.enumerated_info = enumerated_info or {} + self.torch_device = torch_device + self.dlpack_device_type_code = dlpack_device_type_code + + @property + def enumerated_device_id(self) -> int: + try: + return self.enumerated_info["device_id"] + except KeyError as e: + raise RuntimeError("No enumerated device_id for device") from e + + @property + def enumerated_path(self) -> str: + try: + return self.enumerated_info["path"] + except KeyError as e: + raise RuntimeError("No enumerated path for device") from e + + @property + def enumerated_name(self) -> str: + try: + return self.enumerated_info["name"] + except KeyError as e: + raise RuntimeError("No enumerated name for device") from e + + @staticmethod + @lru_cache(maxsize=None) + def from_uri(uri: str) -> "DeviceState": + driver = get_driver(uri) + return DeviceState(driver=driver, device=driver.create_device_by_uri(uri)) + + +class Device: + """Represents a low-level device (HalDriver/HalDevice) and scheduling data. + + This is the type that user's interact with as a 'Device'. Devices can be handled + loose-leaf or bound to a thread with a context manager. + """ + + __slots__ = [ + "_s", + "_main_timeline", + "_main_timepoint", + "_tx_timeline", + "_tx_timepoint", + "_fence_capacity", + "compile_target_flags", + "driver_id", + "export_torch_tensor", + "import_torch_tensor", + "instance_cache_key", + "type_cache_key", + ] + + _s: DeviceState + + # Each device will have a function attached to import a torch.tensor + # *that is already on that device* directly from device memory. + # This is unsafe and relatively unchecked. If criss-crossing devices, + # it is undefined behavior. + import_torch_tensor: Callable[[torch.Tensor], HalBufferView] + + # Devices can also export a torch tensor from a HalBufferView, given + # a meta tensor that describes it. + export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor] + + # Unique name of the IREE runtime driver associated with this device. + driver_id: str + + # Cache key that uniquely identifies this device. + instance_cache_key: str + + # Cache key that uniquely identifies this type of device (currently + # based on its driver). + type_cache_key: str + + # Compiler flags to use to target this device. + # TODO: We should replace this with a target attribute but need an API + # to derive that. + compile_target_flags: tuple[str, ...] + + def __new__( + cls, + uri: Optional[str] = None, + *, + device_state: Optional[DeviceState] = None, + ): + if uri is not None: + # Construction by URI is cached on the thread. + assert not device_state, "device_state= cannot be given with explicit URI" + try: + existing = _CURRENT_THREAD.device_by_uri[uri] + except (AttributeError, KeyError): + ... + else: + return existing + + # New instance. + device_state = DeviceState.from_uri(uri) + new_inst = super().__new__(cls) + new_inst._s = device_state + try: + _CURRENT_THREAD.device_by_uri[uri] = new_inst + except AttributeError: + _CURRENT_THREAD.device_by_uri = {uri: new_inst} + new_inst._initialize() + return new_inst + else: + # Explicit construction with a device_state is assumed that you know what you + # are doing and an uncached instance will be returned. This will be unsychronized + # relative to any cached instance. + assert device_state, "device_state= must be given if URI ommitted" + new_inst = super().__new__(cls) + new_inst._s = device_state + new_inst._initialize() + return new_inst + + def _initialize(self): + d = self._s.device + self._main_timeline = d.create_semaphore(0) + self._main_timepoint = 0 + self._tx_timeline = d.create_semaphore(0) + self._tx_timepoint = 0 + # Maximum number of semaphores the device uses. Can be increased if doing out of the + # ordinary scheduling. + self._fence_capacity = 2 + + # Perform driver specific augmentations. + # TODO: Add a HalDriver.id property to get the driver name instead of parsing + # the device repr. + driver_id = repr(d) + colon_pos = driver_id.find(":") + if colon_pos >= 0: + driver_id = driver_id[0:colon_pos] + self.driver_id = driver_id + try: + import_fn = TORCH_TENSOR_IMPORTERS[driver_id] + export_fn = TORCH_TENSOR_EXPORTERS[driver_id] + self.import_torch_tensor = lambda t: import_fn(self, t) + self.export_torch_tensor = lambda bv, t: export_fn(self, bv, t) + self.compile_target_flags = () + except KeyError as e: + raise AssertionError( + f"Unsupported TORCH_TENSOR_IMPORTERS for iree driver '{driver_id}'" + ) from e + + # Cache keys. + # TODO: The type cache key should actually be based on the driver id + # and device characteristics hash. + self.instance_cache_key = repr(d) + self._recompute_target_keys() + + def _recompute_target_keys(self): + self.type_cache_key = f"{self.driver_id}:{';'.join(self.compile_target_flags)}" + + @property + def hal_device(self) -> HalDevice: + return self._s.device + + @property + def vm_instance(self) -> VmInstance: + return self._s.instance + + def create_hal_module(self) -> VmModule: + s = self._s + return create_hal_module(s.instance, s.device) + + @staticmethod + def current() -> "Device": + try: + return _CURRENT_THREAD.stack[-1] + except (AttributeError, IndexError): + raise NoCurrentDeviceError() + + def set(self) -> "Device": + """Sets this device as the current device without a context manager.""" + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + return self + + def clear(self): + """Clears the current device without a context manager.""" + try: + c = _CURRENT_THREAD.stack[-1] + if _CURRENT_THREAD.stack[-1] is self: + _CURRENT_THREAD.stack.pop() + return + except (AttributeError, IndexError): + ... + raise MismatchedDeviceSetClearError() + + def dump_device_info(self) -> str: + return self._s.driver.dump_device_info(self._s.enumerated_device_id) + + def __repr__(self): + return f"" + + def __enter__(self): + try: + _CURRENT_THREAD.stack.append(self) + except AttributeError: + _CURRENT_THREAD.stack = [self] + + def __exit__(self, type, value, traceback): + _CURRENT_THREAD.stack.pop() + + +################################################################################ +# CUDA and HIP import/export +################################################################################ + + +def _device_import_torch_tensor_cuda_hip( + device: Device, t: torch.Tensor +) -> HalBufferView: + # We currently only support contiguous, so ensure that. + if not t.is_contiguous(): + t = t.contiguous() + # TODO: The 'None' here tells the producer to synchronize on the default + # stream. For async, we should advance our timeline and signal when an + # event is raised on Torch's stream at the current position. + capsule = t.__dlpack__(None) + bv = device.hal_device.from_dlpack_capsule(capsule) + return bv + + +def _device_export_torch_tensor_cuda_hip( + device: Device, bv: HalBufferView, like: torch.Tensor +) -> torch.Tensor: + state = device._s + device_type_code = state.dlpack_device_type_code + assert device_type_code > 0 + torch_device = state.torch_device + assert torch_device is not None + device_index = torch_device.index + t = torch.from_dlpack( + device.hal_device.create_dlpack_capsule(bv, device_type_code, device_index) + ) + if t.dtype != like.dtype: + t = t.view(like.dtype) + # TODO: For async, we should enqueue an event on Torch's stream which will + # signal when this tensor is produced (i.e. at the current point in our + # timeline). + return t + + +# Mapping of torch tensor importers keyed by driver name. +TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { + "cuda": _device_import_torch_tensor_cuda_hip, + "hip": _device_import_torch_tensor_cuda_hip, +} + +TORCH_TENSOR_EXPORTERS: dict[ + str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] +] = { + "cuda": _device_export_torch_tensor_cuda_hip, + "hip": _device_export_torch_tensor_cuda_hip, +} + +############################################################################### +# torch.device to Device mapping +############################################################################### + + +def lookup_device_from_torch( + torch_device: torch.device, *, create: bool = True +) -> Optional[Device]: + """Gets a shared Device corresponding to the given torch.device. + + This will return None if the device is wholly unsupported or if + create=False. Otherwise, faults in setting up the device are + reported as an appropriate exception. + """ + try: + mapping = _CURRENT_THREAD.device_by_torch_device + except AttributeError: + _CURRENT_THREAD.device_by_torch_device = mapping = {} + device = mapping.get(torch_device) + if device is not None or not create: + return device + device = _create_device_from_torch(torch_device) + if device is not None: + mapping[torch_device] = device + return device + + +def get_device_from_torch(torch_device: torch.device) -> Device: + """Gets a shared Device corresponding to the given torch.device. + + Raises an exception if the device cannot be created. + """ + device = lookup_device_from_torch(torch_device) + if device is None: + raise UnsupportedTorchDeviceError(torch_device) + return device + + +def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: + torch_type = torch_device.type + if torch_type == "cuda": + # Fork based on HIP or real CUDA. + props = torch.cuda.get_device_properties(torch_device) + if not hasattr(props, "gcnArchName"): + # Real CUDA. + return _create_cuda_device(torch_device, props) + else: + # HIP as CUDA. + return _create_hip_device(torch_device, props) + + return None + + +def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for real CUDA ROCM is 2. + device = _create_cuda_like_device(torch_device, props, "hip", 2) + if device: + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", + ) + device._recompute_target_keys() + return device + + +def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: + # Note that the dlpack device type code for ROCM is 10. + device = _create_cuda_like_device(torch_device, props, "hip", 10) + # The gcnArchName comes back like gfx90a:sramecc+:xnack- for a fully + # specified target. However the IREE target-chip flag only expects the + # prefix. See: https://github.com/iree-org/iree/issues/17402 + # This should be changed to tunnel through target information unmolested. + gcn_arch_name: str = props.gcnArchName + colon_pos = gcn_arch_name.find(":") + if colon_pos >= 0: + gcn_arch_name = gcn_arch_name[0:colon_pos] + if device: + gcn_arch_name = gcn_arch_name + device.compile_target_flags = device.compile_target_flags + ( + f"--iree-rocm-target-chip={gcn_arch_name}", + ) + device._recompute_target_keys() + return device + + +def _create_cuda_like_device( + torch_device: torch.device, props, driver_name: str, dlpack_device_type_code: int +) -> Optional[Device]: + if torch.cuda.device_count() > 1: + warnings.warn( + f"Multiple {driver_name} devices detected: Turbine does not yet " + f"guarantee stable device mapping" + ) + + requested_index = torch_device.index + driver = get_driver(driver_name) + available_infos = driver.query_available_devices() + if requested_index >= len(available_infos): + return None + device_info = available_infos[requested_index] + hal_device = driver.create_device(device_info) + device_state = DeviceState( + driver=driver, + device=hal_device, + vm_instance=get_vm_instance(), + enumerated_info=device_info, + torch_device=torch_device, + dlpack_device_type_code=dlpack_device_type_code, + ) + device = Device(device_state=device_state) + return device + + +@lru_cache(maxsize=None) +def module(device: Device): + return rt.VmModule.mmap(device._s.instance, "{{vmfb_path}}") + +@lru_cache(maxsize=None) +def context(device: Device): + return rt.VmContext( + device._s.instance, + (create_hal_module(device._s.instance, device._s.device), module(device)), + ) + + +@lru_cache(maxsize=None) +def func(device, name): + return module(device).lookup_function(name) + + +def {{kernel_function_name}}(*args): + arg_list, ret_list = [], [] + device = None + num_inputs = {{kernel_num_inputs}} + num_outputs = len(args) - num_inputs + for i, arg in enumerate(args): + if device is None: + device = get_device_from_torch(args[0].device) + assert device is not None, "Device not found" + if i < num_inputs: + arg_list.append(TORCH_TENSOR_IMPORTERS[arg.dtype](device, arg)) + else: + ret_list.append(TORCH_TENSOR_IMPORTERS[arg.dtype](device, arg)) + context(device).vm_context.invoke( + func(device, {{kernel_dispatch_name}}), arg_list, ret_list + ) + return_values = [] + for ret in ret_list: + return_values.append(TORCH_TENSOR_EXPORTERS[ret.dtype](device, ret)) + return return_values[0] if num_outputs == 1 else return_values diff --git a/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 b/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 new file mode 100644 index 00000000..62365cf6 --- /dev/null +++ b/iree/turbine/kernel/wave/packaging/templates/setup.py.j2 @@ -0,0 +1,12 @@ +from setuptools import setup, find_packages + +setup( + name="{{kernel_package_name}}", + version="{{kernel_version}}", + packages=find_packages(), + include_package_data=True, + package_data={"": ["*.vmfb"]}, + install_requires=[ + "iree-runtime==20240918.1020", + ], +) diff --git a/shark_turbine/kernel/wave/promotion.py b/iree/turbine/kernel/wave/promotion.py similarity index 75% rename from shark_turbine/kernel/wave/promotion.py rename to iree/turbine/kernel/wave/promotion.py index fd1aa541..3711436f 100644 --- a/shark_turbine/kernel/wave/promotion.py +++ b/iree/turbine/kernel/wave/promotion.py @@ -15,6 +15,25 @@ logger = get_logger("turbine.wave.promotion") +def apply_padding( + shape: tuple[IndexSymbol | int], dtype: DataType +) -> tuple[IndexSymbol | int]: + """ + When accessing shared memory, we need to be cognizant of bank conflicts + that can have a significant impact on performance. One way to mitigate + these conflicts is by applying padding to the shared memory allocation. + This function applies padding of 64 bits to the shared memory allocation. + While this approach accomplishes the goal of reducing bank conflicts, it + is inefficient in terms of memory usage. A more sophisticated approach + would involve swizzling of the shared memory access patterns. + """ + padding = 64 // dtype.bitwidth() + return tuple( + value + padding if i == len(shape) - 1 else value + for i, value in enumerate(shape) + ) + + def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): match custom_node: case Read(memory, elements_per_thread) if get_custom( @@ -47,9 +66,10 @@ def promote_node( assert isinstance(node, Read) or isinstance(node, Write) with node.graph.inserting_before(node.fx_node.next): constrained_shape = get_constrained_shape(node.type.symbolic_shape, constraints) + padded_shape = apply_padding(constrained_shape, node.type.dtype) allocate_node = Allocate( node.type.symbolic_shape, - constrained_shape, + padded_shape, node.type.dtype, address_space, ) diff --git a/shark_turbine/kernel/wave/scheduling/__init__.py b/iree/turbine/kernel/wave/scheduling/__init__.py similarity index 90% rename from shark_turbine/kernel/wave/scheduling/__init__.py rename to iree/turbine/kernel/wave/scheduling/__init__.py index 19879f4b..65b7ec28 100644 --- a/shark_turbine/kernel/wave/scheduling/__init__.py +++ b/iree/turbine/kernel/wave/scheduling/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .schedule import * +from .resources import * diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py similarity index 98% rename from shark_turbine/kernel/wave/scheduling/graph_utils.py rename to iree/turbine/kernel/wave/scheduling/graph_utils.py index e625b666..af398af3 100644 --- a/shark_turbine/kernel/wave/scheduling/graph_utils.py +++ b/iree/turbine/kernel/wave/scheduling/graph_utils.py @@ -213,12 +213,13 @@ def topological_sort_nodes( Perform a topological sort on the nodes in the strongly connected component that have an edge in edges, excluding certain nodes. """ - scc_nodes = set(scc) - set(exclude) + scc_nodes = set(scc) filtered_nodes = set() for edge in edges: if edge._from in scc_nodes and edge._to in scc_nodes: filtered_nodes.add(edge._to) filtered_nodes.add(edge._from) + filtered_nodes -= set(exclude) if exclude is not None else set() sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f) return sorted_nodes diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py new file mode 100644 index 00000000..db456ec9 --- /dev/null +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -0,0 +1,557 @@ +from ..constraints import Constraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import ( + Reduction, + IterArg, + Placeholder, + Output, + GetResult, + get_custom, + SchedulingGroupBarrier, +) +from .modulo_scheduling import ModuloScheduler +from ..utils import ( + get_induction_variable, + replace_uses_in, +) +import torch.fx as fx +from collections import deque, defaultdict +from ..visualization import visualize_mapped_graphs, visualize_graph +from ....support.logging import get_logger +from .loop_reconstruction_utils import ( + ArgumentContext, + create_fill_stage_schedule, + create_drain_stage_schedule, + liveness_analysis, + partition_graph_by_stage, + interleave_instructions, +) +from .resources import get_custom_operation_type +from enum import Enum + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction") + + +class PipelineStage(Enum): + PROLOGUE = 0 + KERNEL = 1 + EPILOGUE = 2 + + +def add_nodes_by_schedule( + reduction_graph: fx.Graph, + partitioned_graph: list[dict[int, fx.Node]], + arg_context: ArgumentContext, + stages: list[int], + initiation_interval: int, + induction_variable: IndexSymbol, + current_induction_variables: list[int], + rotating_registers: dict[fx.Node, list[fx.Node]], + pipelining_stage: PipelineStage = PipelineStage.KERNEL, + use_scheduling_barriers: bool = False, +): + """ + Interleave the instructions in the partitioned graph by stage + for a single initiation interval, updating the argument maps + per stage starting at the provided start times and indices. + """ + fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE] + + for cycle in range(initiation_interval): + logger.debug(f"Cycle: {cycle}") + # Interleave the instructions that are scheduled at the same cycle. + interleaved_instructions = [] + for iteration, stage in enumerate(stages): + if stage is None: + continue + if cycle not in partitioned_graph[stage]: + continue + for node in partitioned_graph[stage][cycle]: + interleaved_instructions.append((iteration, stage, node)) + interleave_instructions(interleaved_instructions) + + instructions = defaultdict(int) + for iteration, stage, node in interleaved_instructions: + logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}") + custom_node = get_custom(node) + logger.debug(f"Node args: {node.args}") + for arg in node.args: + if arg_context.contains_in_iteration(iteration, arg): + logger.debug( + f"Found arg: {arg} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}." + ) + continue + new_node = custom_node.copy( + new_graph=reduction_graph, + arg_transform=lambda x: ( + arg_context.get_from_iteration(iteration, x) + if arg_context.contains_in_iteration(iteration, x) + else x + ), + ) + instructions[get_custom_operation_type(new_node)] += 1 + # Update the argument context. + arg_context[(iteration, stage, node)] = new_node.fx_node + logger.debug( + f"Copying Node: {node}, Stage: {stage}, Iteration: {iteration} -> {new_node.fx_node}" + ) + # Set the index for the new node by substituting the induction variable + # for the current iteration. + new_node.index = node.index + for dim in new_node.index: + new_node.index[dim] = new_node.index[dim].subs( + {induction_variable: current_induction_variables[iteration]} + ) + # Add scheduling parameters for debugging. + new_node.scheduling_parameters = node.scheduling_parameters + # Update the rotating registers and argument context for the current node (if applicable). + if node in rotating_registers: + rotating_registers[node].append(new_node.fx_node) + rotating_registers[node].popleft() + # If draining, then override the rotating registers and update the argument context. + if fill_or_drain: + for next_stage in range(stage + 1, len(stages)): + arg_context[(iteration, next_stage, node)] = new_node.fx_node + + # Update the init args in the argument context whenever a result is computed. + if node in arg_context.results: + if ( + pipelining_stage == PipelineStage.KERNEL + or pipelining_stage == PipelineStage.EPILOGUE + ): + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_iter_arg[node], new_node.fx_node + ) + if pipelining_stage == PipelineStage.PROLOGUE: + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_init_arg[node], new_node.fx_node + ) + + if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers: + SchedulingGroupBarrier(instructions, 0).add_to_graph(reduction_graph) + + +def push_placeholders( + implicit_captures: list[fx.Node], + reduction_subgraph: fx.Node, + arg_context: ArgumentContext, +): + """ + Push placeholders into the argument context for the reduction graph. + """ + for node in reduction_subgraph.nodes: + custom = get_custom(node) + if isinstance(custom, Placeholder) and not isinstance(custom, IterArg): + root_node = [x for x in implicit_captures if x.name == node.name][0] + assert root_node is not None + arg_context.map_arg_all(node, root_node) + + +def construct_prologue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], +): + """ + Construct the prologue of the pipelined loop. + For this, we need to copy nodes from the reduction_graph and insert them + before the reduction operator in the root graph in the appropriate order. + We also need to initialize the rotating registers and update the indices + of the nodes to use the appropriate values of the induction variable. + """ + logger.debug("=====================================") + logger.debug("Constructing prologue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + # Map iter args to init args in the prologue. + for iter_arg, init_arg in zip( + reduction.iter_args(reduction_subgraph), reduction.init_args + ): + arg_context.map_arg_all(iter_arg, init_arg) + + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + with reduction.graph.inserting_before(reduction.fx_node): + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.PROLOGUE, + ) + + # During the prologue, we may have computed results that need to be passed as init args + # to the kernel. + new_init_args: list[fx.Node] = [] + for init_arg in reduction.init_args: + mapped_init_arg = arg_context.lookup(init_arg) + if mapped_init_arg is None: + mapped_init_arg = init_arg + new_init_args.append(mapped_init_arg) + reduction.init_args = new_init_args + + +def flatten_dict_values( + rotating_registers: dict[fx.Node, list[fx.Node]] +) -> list[fx.Node]: + """ + Flatten the values of the rotating registers into a list. + """ + return [ + register for registers in rotating_registers.values() for register in registers + ] + + +def unflatten_dict_values( + rotating_registers_shapes: dict[fx.Node, int], values: list[fx.Node] +) -> dict[fx.Node, list[fx.Node]]: + """ + Unflatten the values of the rotating registers into a dictionary + using the provided shapes. + """ + rotating_registers = {} + count = 0 + for node, shape in rotating_registers_shapes.items(): + rotating_registers[node] = deque(values[count : count + shape]) + count += shape + assert count == sum(rotating_registers_shapes.values()) + return rotating_registers + + +def push_rotating_registers( + arg_context: ArgumentContext, + rotating_registers: dict[fx.Node, list[fx.Node]], + graph: fx.Graph, + node_map: dict[fx.Node, fx.Node], + create_new_nodes: bool = False, +) -> dict[fx.Node, deque[fx.Node]]: + """ + Pushes the rotating registers into the argument map + at the appropriate stages. Create new nodes in the + specified graph if requested. + + For each rotating register, + we evaluate which stage it belongs to and update the argument + context for the next stage and n - 1 stages after it, where + n is the total number of rotating registers. + If var a has [a, b, c] as rotating registers, then in a 3-stage schedule + a is used in stage 2, (iteration 0) + b in stage 1, (iteration 1) + c in stage 0. (iteration 2) + """ + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = {} + count = 0 + for node, registers in rotating_registers.items(): + new_registers: deque[fx.Node] = deque() + custom = get_custom(node) + stage = custom.scheduling_parameters["stage"] + iteration = arg_context.get_kernel_iteration(stage) + arg_context[(iteration, stage, node)] = registers[-1] + for i, register in enumerate(registers): + mapped_stage = stage + len(registers) - i + mapped_iteration = arg_context.get_kernel_iteration(mapped_stage) + if create_new_nodes: + iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg + new_registers.append(iter_arg) + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}." + ) + else: + arg_context[(mapped_iteration, mapped_stage, node)] = register + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage}." + ) + count += 1 + if new_registers: + new_rotating_registers[node] = new_registers + return new_rotating_registers + + +def construct_kernel( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, + use_scheduling_barriers: bool = False, +) -> tuple[Reduction, fx.Graph]: + """ + Construct the kernel of the pipelined loop. + First, we construct a new reduction op with an empty graph. + Then, we set the init args, construct the iter args and add the ops. + Finally, we create the output node with the return values. + The iter args/results of the pipelined reduction are always: + [results0, result1, ..., resultN, rotating_reg0, rotating_reg1, ..., rotating_regN] + """ + logger.debug("=====================================") + logger.debug("Constructing kernel.") + logger.debug("=====================================") + + with reduction.graph.inserting_before(reduction.fx_node): + pipelined_reduction = Reduction( + reduction.axis, + init_args=reduction.init_args + flatten_dict_values(rotating_registers), + subgraph_name="pipelined_reduction", + implicit_captures=reduction.implicit_captures, + ).add_to_graph(reduction.graph) + pipelined_reduction.index = reduction.index + pipelined_reduction_graph = fx.Graph() + reduction.graph.subgraphs["pipelined_reduction"] = pipelined_reduction_graph + + # Update the argument map for the new reduction. + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + # For the original iter args, we just map the old ones to the new ones. + # Do this for all stages, since the original iter args are "dummy" nodes + # during scheduling. + for node in arg_context.iter_args: + iter_arg = IterArg(node.name).add_to_graph(pipelined_reduction_graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context.map_arg_all(node, iter_arg) + + # Push the rotating registers into the argument context. + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = push_rotating_registers( + arg_context, + rotating_registers, + pipelined_reduction_graph, + node_map, + create_new_nodes=True, + ) + + add_nodes_by_schedule( + pipelined_reduction_graph, + partitioned_graph, + arg_context, + list(reversed(range(scheduler.num_stages))), + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + new_rotating_registers, + PipelineStage.KERNEL, + use_scheduling_barriers, + ) + + # Create output node (last node in the graph). + return_vals: list[fx.Node] = arg_context.get_kernel_results() + for registers in new_rotating_registers.values(): + return_vals.extend(registers) + + Output(return_vals).add_to_graph(pipelined_reduction_graph) + reduction.replace_all_uses_with(pipelined_reduction) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction_graph, + new_rotating_registers, + arg_context.argument_map, + "kernel.png", + ) + + return pipelined_reduction, pipelined_reduction_graph + + +def construct_epilogue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + pipelined_reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], + num_rotating_registers: dict[fx.Node, int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, +): + """ + Construct the epilogue of the pipelined loop. + The difference from the prologue is that we need to map the results + of the pipelined reduction to the remaining stages. (In the prologue, + no iteration is every completed and so we don't compute the final results) + We emit GetResult nodes for the rotating registers and map them to + the different epilogue stages. + """ + logger.debug("=====================================") + logger.debug("Constructing epilogue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + existing_get_results: list[GetResult] = sorted( + [x for x in pipelined_reduction.users if isinstance(x, GetResult)], + key=lambda x: x.res_idx, + ) + existing_users = {x: x.users for x in existing_get_results} + + # Map the results from the kernel to the init args (for stages). + for iter_arg, get_result in zip( + reduction.iter_args(reduction_subgraph), existing_get_results + ): + arg_context.map_arg_all(iter_arg, get_result.fx_node) + + with pipelined_reduction.graph.inserting_before( + existing_get_results[0].fx_node.next + ): + # Add get result nodes for the rotating registers and update the + # argument map with them. + rotating_registers_get_results = [] + offset = len(existing_get_results) + for i in range(len(flatten_dict_values(rotating_registers))): + rotating_registers_get_results.append( + GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph( + pipelined_reduction.graph + ) + ) + rotating_registers = unflatten_dict_values( + num_rotating_registers, rotating_registers_get_results + ) + + # Push the rotating registers onto the argument map. + push_rotating_registers(arg_context, rotating_registers, None, node_map, False) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + pipelined_reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.EPILOGUE, + ) + + # Replace the existing uses with the new results. + new_results = arg_context.get_mapped_results(existing_get_results) + assert len(new_results) == len(existing_get_results) + for i, get_result in enumerate(existing_get_results): + replace_uses_in(existing_users, get_result, new_results[i]) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction.graph, + rotating_registers, + arg_context.argument_map, + "epilogue.png", + ) + + +def construct_pipelined_loop( + trace: CapturedTrace, + reduction: Reduction, + graph: fx.Graph, + constraints: list[Constraint], + scheduler: ModuloScheduler, + node_map: dict[fx.Node, fx.Node], + max_induction_variable: int, + visualize: bool = False, + use_scheduling_barriers: bool = False, +) -> fx.Node: + """ + Given a graph annotated with scheduling parameters, construct a pipelined loop + with a prologue, kernel and epilogue. + """ + induction_variable = get_induction_variable(reduction, constraints) + num_rotating_registers = liveness_analysis(graph, constraints, scheduler) + rotating_registers: dict[fx.Node, deque[fx.Node]] = { + k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items() + } + partitioned_graph = partition_graph_by_stage(graph, scheduler) + # Construct prologue. + construct_prologue( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + list(range(scheduler.num_stages)), + create_fill_stage_schedule(scheduler.num_stages), + ) + # Construct kernel. + pipelined_reduction, pipelined_reduction_graph = construct_kernel( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [induction_variable + i for i in range(scheduler.num_stages)], + node_map, + visualize, + use_scheduling_barriers, + ) + trace.add_subgraph( + get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph + ) + # Construct epilogue. + construct_epilogue( + graph, + reduction, + get_custom(pipelined_reduction), + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [ + max_induction_variable - scheduler.num_stages + i + for i in range(scheduler.num_stages) + ], + create_drain_stage_schedule(scheduler.num_stages), + num_rotating_registers, + node_map, + visualize, + ) + + # Remove the unpipelined reduction. + reduction.graph.erase_node(reduction.fx_node) + + if visualize: + visualize_graph(pipelined_reduction.graph, "pipelined.png") + + return pipelined_reduction diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py new file mode 100644 index 00000000..b6993a21 --- /dev/null +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py @@ -0,0 +1,285 @@ +from ..constraints import Constraint, TilingConstraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import Reduction, IterArg, Output, Write, GetResult, get_custom +from .modulo_scheduling import ModuloScheduler +from ..utils import graph_copy, erase_graph +from ..utils import subs_idxc +import torch.fx as fx +import math +from collections import defaultdict, deque, ChainMap +from ..visualization import visualize_mapped_graphs +from ....support.logging import get_logger +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import random +from typing import Optional + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction_utils") + + +class ArgumentContext: + """ + The argument context is used to store the mapping of arguments + for each modulo pipelining stage. + """ + + def __init__( + self, + results: list[fx.Node], + iter_args: list[fx.Node], + init_args: list[fx.Node], + num_stages: int, + ) -> None: + self.argument_map: list[list[dict[fx.Node, fx.Node]]] = [ + [{} for _ in range(num_stages)] for _ in range(num_stages) + ] + self.results = results + self.iter_args = iter_args + self.init_args = init_args + self.num_stages = num_stages + self.num_iterations = num_stages + self.result_to_iter_arg: dict[fx.Node, fx.Node] = {} + self.result_to_init_arg: dict[fx.Node, fx.Node] = {} + + for result, iter_arg in zip(results, iter_args): + self.result_to_iter_arg[result] = iter_arg + for result, init_arg in zip(results, init_args): + self.result_to_init_arg[result] = init_arg + + def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + self.argument_map[iteration][stage][from_] = to_ + + def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + self.argument_map[iteration][stage][from_] = to_ + + def get_mapped_results(self, get_results: list[GetResult]) -> list[fx.Node]: + """ + Gets the mapped results from the last iteration. If the result is not + in the last iteration, then get it from the get result nodes. + """ + mapped_results = [] + for result, get_result in zip(self.results, get_results): + stage = result.scheduling_parameters["stage"] + if result not in self.argument_map[self.num_iterations - 1][stage]: + mapped_results.append(get_result.fx_node) + else: + mapped_results.append( + self.argument_map[self.num_iterations - 1][stage][result] + ) + return mapped_results + + def get_kernel_iteration(self, stage: int) -> int: + """ + Get the iteration from the stage for the kernel. + """ + return self.num_stages - 1 - stage + + def get_kernel_results(self) -> list[fx.Node]: + """ + Gets the mapped results for the kernel. Here there + exists a fixed relationship between the iteration and stage. + """ + mapped_results = [] + for result in self.results: + stage = result.scheduling_parameters["stage"] + iteration = self.get_kernel_iteration(stage) + mapped_results.append(self.argument_map[iteration][stage][result]) + return mapped_results + + def __setitem__(self, key: tuple[int, fx.Node], value: fx.Node) -> None: + """ + Sets the argument mapping for the given stage. + """ + assert isinstance(key, tuple), "Argument context key must be a tuple" + iteration, stage, from_ = key + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + self.argument_map[iteration][stage][from_] = value + + def __getitem__(self, value: tuple[int, fx.Node]) -> fx.Node: + """ + Gets the argument mapping for the given stage. + """ + assert isinstance(value, tuple), "Argument context key must be a tuple" + iteration, stage, key = value + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + return self.argument_map[iteration][stage].get(key, None) + + def __contains__(self, key: fx.Node | tuple[int, fx.Node]) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration and stage or at all iterations and stages. + """ + if isinstance(key, tuple): + iteration, stage, key = key + return key in self.argument_map[iteration][stage] + return any( + key in self.argument_map[iteration][stage] + for iteration in range(self.num_iterations) + for stage in range(self.num_stages) + ) + + def lookup(self, key: fx.Node) -> Optional[fx.Node]: + """ + Looks up the argument mapping for the given node. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration. + """ + return any( + key in self.argument_map[iteration][stage] + for stage in range(self.num_stages) + ) + + def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node: + """ + Gets the argument mapping for the given iteration. + """ + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def dump(self): + """ + Dump the argument context to the logger. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + logger.debug(f"Iteration: {iteration}, Stage: {stage}") + for key, value in self.argument_map[iteration][stage].items(): + logger.debug(f" {key} -> {value}") + + +def create_fill_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the prologue (fill). + This looks like: + [0 None None None] + [1 0 None None] + [2 1 0 None] + """ + schedule = [] + for i in range(n - 1): + row = list(range(i, -1, -1)) + row.extend([None] * (n - i - 1)) + schedule.append(row) + return schedule + + +def create_drain_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the epilogue (drain). + This looks like: + [None 3 2 1] + [None None 3 2] + [None None None 3] + """ + schedule = [] + for i in range(n - 1): + row = [None] * (i + 1) + row.extend(range(n - 1, i, -1)) + schedule.append(row) + return schedule + + +def liveness_analysis( + graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler +) -> dict[fx.Node, int]: + """ + Perform liveness analysis on the graph to determine the live ranges of + variables and use that to deduce how many rotating registers we need. + """ + lifetime: dict[fx.Node, int] = {} + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if node not in lifetime: + lifetime[node] = 0 + for user in custom.users: + if user.scheduling_parameters is None: + continue + logger.debug( + f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}" + ) + lifetime[node] = max( + user.scheduling_parameters["stage"] + - custom.scheduling_parameters["stage"], + lifetime[node], + ) + + # Determine how many copies we need for each node. If the lifetime of a node + # is l clocks and the initiation interval is T, then only ceil(l/T) values + # of the node can be live at the same time. We need to create copies of only + # those nodes that are live at more than one stage. + num_rotating_registers: dict[fx.Node, int] = {} + for node, l in lifetime.items(): + if node in num_rotating_registers: + continue + custom = get_custom(node) + if ( + isinstance(custom, Write) + and custom.memory_type.address_space == SHARED_ADDRESS_SPACE + ): + continue + if l > 0: + num_rotating_registers[node] = l + + return num_rotating_registers + + +def partition_graph_by_stage( + graph: fx.Graph, scheduler: ModuloScheduler +) -> list[dict[int, list[fx.Node]]]: + """ + Partition the graph into stages based on the scheduling parameters. + """ + partitioned_graph: list[dict[int, list[fx.Node]]] = [ + defaultdict(list) for _ in range(scheduler.num_stages) + ] + for stage in range(scheduler.num_stages): + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if isinstance(custom, IterArg): + continue + if custom.scheduling_parameters["stage"] == stage: + cycle = custom.scheduling_parameters["cycle"] + partitioned_graph[stage][cycle].append(node) + return partitioned_graph + + +def interleave_instructions(instructions: list[tuple[int, int, fx.Node]]): + """ + Interleave the instructions that are scheduled in the same cycle. + Currently, we just randomly shuffle them, but we could also sort + them based on some criteria. + """ + rng = random.Random(0) + # rng.shuffle(instructions) diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py similarity index 97% rename from shark_turbine/kernel/wave/scheduling/modulo_scheduling.py rename to iree/turbine/kernel/wave/scheduling/modulo_scheduling.py index f2abbd13..82940113 100644 --- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -18,6 +18,7 @@ ) from typing import Callable import numpy as np +import math logger = get_logger("turbine.wave.modulo_scheduling") @@ -263,3 +264,11 @@ def resource_reservations(self) -> np.array: Returns the resource reservations of the schedule. """ return self.RT + + @property + def num_stages(self) -> int: + """ + Returns the number of stages in the kernel of the pipelined loop. + """ + max_cycle = max([t for t in self.schedule.values()]) + return math.ceil(max_cycle / self.initiation_interval) diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py similarity index 64% rename from shark_turbine/kernel/wave/scheduling/resources.py rename to iree/turbine/kernel/wave/scheduling/resources.py index 13e80687..346833f2 100644 --- a/shark_turbine/kernel/wave/scheduling/resources.py +++ b/iree/turbine/kernel/wave/scheduling/resources.py @@ -6,7 +6,15 @@ from ...lang.global_symbols import * from ..utils import subs_idxc -from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom +from ...ops.wave_ops import ( + Read, + Write, + MMA, + IterArg, + Output, + get_custom, + CustomOp, +) import torch.fx as fx from enum import Enum import numpy as np @@ -24,6 +32,9 @@ class Operation(Enum): READ_GLOBAL = "read_global" WRITE_GLOBAL = "write_global" MMA = "mma" + ALU = "alu" + VALU = "valu" + SALU = "salu" NOOP = "noop" @@ -49,6 +60,29 @@ class Operation(Enum): } +def get_custom_operation_type(custom: CustomOp) -> Operation: + if isinstance(custom, Read): + return ( + Operation.READ_GLOBAL + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else Operation.READ_SHARED + ) + elif isinstance(custom, Write): + return ( + Operation.WRITE_GLOBAL + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else Operation.WRITE_SHARED + ) + elif isinstance(custom, MMA): + return Operation.MMA + elif isinstance(custom, IterArg): + return Operation.NOOP + elif isinstance(custom, Output): + return Operation.NOOP + else: + return None + + def annotate_resource_usage( graph: fx.Graph, ) -> tuple[set[fx.Node], list[fx.Node], fx.Node]: @@ -79,3 +113,27 @@ def annotate_resource_usage( else: ignore_nodes.add(node) return ignore_nodes, iter_args, output + + +def get_scheduling_mask(operation: Operation) -> int: + """ + Returns the scheduling mask for the given operation. + """ + match operation: + case Operation.READ_GLOBAL: + return int("0x20", 0) + case Operation.WRITE_GLOBAL: + return int("0x40", 0) + case Operation.READ_SHARED: + return int("0x100", 0) + case Operation.WRITE_SHARED: + return int("0x200", 0) + case Operation.MMA: + return int("0x8", 0) + case Operation.ALU: + return int("0x1", 0) + case Operation.VALU: + return int("0x2", 0) + case Operation.SALU: + return int("0x4", 0) + return None diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py similarity index 64% rename from shark_turbine/kernel/wave/scheduling/schedule.py rename to iree/turbine/kernel/wave/scheduling/schedule.py index a03ad082..9cf6eb19 100644 --- a/shark_turbine/kernel/wave/scheduling/schedule.py +++ b/iree/turbine/kernel/wave/scheduling/schedule.py @@ -11,8 +11,12 @@ from .graph_utils import create_scheduling_edges, Edge from .resources import get_available_resources, annotate_resource_usage from ..visualization import visualize_edges, visualize_graph, visualize_schedule -from ..utils import subs_idxc, graph_copy, erase_graph +from .loop_reconstruction import construct_pipelined_loop +from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc import torch.fx as fx +from ....support.logging import get_logger + +logger = get_logger("turbine.wave.scheduling.schedule") def visualize_scheduling_graph(edges: list[Edge]): @@ -20,7 +24,10 @@ def visualize_scheduling_graph(edges: list[Edge]): def schedule_reduction( - reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] + reduction: Reduction, + trace: CapturedTrace, + constraints: list[Constraint], + use_scheduling_barriers: bool = False, ): """ Clones the reduction graph and does the following: @@ -68,8 +75,39 @@ def schedule_reduction( erase_graph(graph) + # After scheduling has completed, we have enough information to decide + # whether to pipeline the loop. For pipelining to be possible, we need + # to have atleast N iterations of the loop where N > num_stages - 1 (because + # we will be peeling off num_stages iterations from the loop). + tiling_constraint = get_tiling_constraint(reduction, constraints) + max_induction_variable = int( + subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size) + ) + if max_induction_variable <= scheduler.num_stages - 1: + logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.") + return {} + + new_reduction = construct_pipelined_loop( + trace, + reduction, + reduction_graph, + constraints, + scheduler, + node_map, + max_induction_variable, + visualize, + use_scheduling_barriers, + ) + + # Update new reduction count. + new_reduction.count = max_induction_variable - (scheduler.num_stages - 1) -def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): + +def schedule_graph( + trace: CapturedTrace, + constraints: list[Constraint], + use_scheduling_barriers: bool = False, +): """ Given a graph, pipelines the reductions in the graph. """ @@ -82,4 +120,6 @@ def is_reduction(node: fx.Node) -> bool: return for reduction_node in reduction_nodes: - schedule_reduction(get_custom(reduction_node), trace, constraints) + schedule_reduction( + get_custom(reduction_node), trace, constraints, use_scheduling_barriers + ) diff --git a/shark_turbine/kernel/wave/shared_memory_indexing.py b/iree/turbine/kernel/wave/shared_memory_indexing.py similarity index 100% rename from shark_turbine/kernel/wave/shared_memory_indexing.py rename to iree/turbine/kernel/wave/shared_memory_indexing.py diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py new file mode 100644 index 00000000..927bd363 --- /dev/null +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -0,0 +1,181 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...support.logging import get_logger +from iree.turbine.kernel._support.tracing import CapturedTrace +import torch.fx as fx +from ..ops.wave_ops import * +from ..lang.global_symbols import * +from .utils import capture_forward_slice, capture_backward_slice, subs_idxc + +logger = get_logger("turbine.wave.thread_shape_analysis") + + +@dataclass(order=True) +class DimSize: + dim: IndexSymbol + size: int + + def __hash__(self): + return hash((self.dim, self.size)) + + +def get_dim_sizes(indices: list[IndexSequence]): + dims = frozenset( + [DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()] + ) + return dims + + +def get_custom_dim_sizes(custom: CustomOp): + return get_dim_sizes(custom.index) + + +def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): + for target in target_dim_sizes: + if target.dim not in custom.index: + raise NotImplementedError( + "NYI: Handle when source target index size is not found in target/user index." + ) + custom.index[target.dim].size = target.size + + +def handle_binaryop_conflict(custom_node: CustomOp): + # Analyze if we can resolve conflict with broadcast. + lhs = get_custom(custom_node.lhs) + rhs = get_custom(custom_node.rhs) + lhs_dim_set = set(lhs.type.symbolic_shape) + rhs_dim_set = set(rhs.type.symbolic_shape) + if lhs_dim_set == rhs_dim_set: + raise ValueError("Cannot broadcast if lhs and rhs is already same.") + if lhs_dim_set.isdisjoint(rhs_dim_set): + raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.") + # Determine the correct indexSize for binaryOp and insert broadcasting. + dst_op = lhs if lhs_dim_set > rhs_dim_set else rhs + broadcast_idx, broadcast_src = (1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs) + broadcast = Broadcast(broadcast_src.fx_node, dst_op.type) + with custom_node.graph.inserting_before(custom_node.fx_node): + broadcast.add_to_graph(custom_node.graph) + setattr(broadcast.fx_node, "index", dst_op.index) + custom_node.index = dst_op.index + custom_node.update_arg(broadcast_idx, broadcast.fx_node) + return True + + +# Returns True iff all conflicts are handled succesfully. +def handle_conflicts(conflicted_ops: set[CustomOp]): + for conflict in conflicted_ops: + custom = get_custom(conflict) + if isinstance(custom, BinaryPyOp): + handle_binaryop_conflict(custom) + else: + return False + return True + + +def determine_thread_shapes(trace: CapturedTrace): + """ + This function does analysis and propagation of thread shape. It does by such: + 1. Look for "anchor" ops who has information of it's elem_per_thread. + 2. Do a forward/backward slice on these anchor ops to get ops that + who's shapes depends on these anchor ops. + 3. We bucket these ops to Variadic(Index->elem_per_thread) mapping. + 4. At every bucket of (index -> elem_per_thread), we apply these information + by updating their indexSequence size. + + We stored the buckets above in a variable/dict called `thread_size_to_ops`. + + `thread_size_to_ops` is a dict that uses thread_shapes as key and for every + key/thread_shape will map to a set of fx.nodes that needs to have that + thread_shape in it's indexSequence. + + `thread_shapes` is used to store thread_size at every dimension that the op + cares about. We use a frozenset[DimSize] to represent it, where DimSize + is essentially a pair. we are using + frozen_set since we do not care about the order of dims for the shape/size + propagation. + + We use sets[CustomOp] to represent the values of `thread_size_ops` S.T we can + easily find any conflicting of index using set operations and handle/resolve it + if required. + + For better illustration, here's an example: + Kernel: + imm = tkw.mul(x, y) + lhs = tkw.neg(imm) + a = tkw.mma(lhs, rhs, acc) + b = tkw.exp2(a) + Anchors: + mma.lhs: {IndexSize(index=M, size=1), IndexSize(index=K, size=4)} + mma.rhs: {IndexSize(index=K, size=4), IndexSize(index=N, size=1)} + mma.acc: {IndexSize(index=M, size=4), IndexSize(index=N, size=1)} + Bucket Entry: + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}] = set(lhs, imm, x, y) + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=4), IndexSize(index=N, size=1)}] = set(acc, exp2_0) + thread_sizes_to_ops[frozenset({IndexSize(index=K, size=4), IndexSize(index=N, size=1)}] = set(rhs, ...) + + """ + + # Anchor ops are ops who's thread shape are predetermined. + anchorOpTypes = (Read, Write, MMA, ReduceOp) + noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) + nonPropagatableTypes = anchorOpTypes + noHandleTypes + + def is_anchor_op(node: fx.Node): + return isinstance(get_custom(node), anchorOpTypes) + + def propagatable_op(node: fx.Node): + return not isinstance(get_custom(node), nonPropagatableTypes) + + anchor_ops = trace.walk(is_anchor_op) + thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} + for anchor_op in anchor_ops: + custom = get_custom(anchor_op) + index_sizes = get_custom_dim_sizes(custom) + if isinstance(custom, (Read, ReduceOp)): + fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(fwd_slice) + elif isinstance(custom, Write): + bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(bwd_slice) + elif isinstance(custom, MMA): + lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op) + rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op) + acc_slice = capture_forward_slice(custom.acc, propagatable_op) + acc_slice = acc_slice.union( + capture_backward_slice(custom.acc, propagatable_op) + ) + acc_index = get_dim_sizes(custom.acc_index) + lhs_index = get_dim_sizes(custom.lhs_index) + rhs_index = get_dim_sizes(custom.rhs_index) + thread_size_to_ops[acc_index] = thread_size_to_ops.get( + acc_index, set([]) + ).union(acc_slice) + thread_size_to_ops[lhs_index] = thread_size_to_ops.get( + lhs_index, set([]) + ).union(lhs_bwd_slice) + thread_size_to_ops[rhs_index] = thread_size_to_ops.get( + rhs_index, set([]) + ).union(rhs_bwd_slice) + + # Go through each index-size buckets, and apply the index-size to ops in the bucket. + cummulative_set = set() + for target_index_size, target_ops in thread_size_to_ops.items(): + # Try to handle conflicts and remove from target set if successfully handled. + if not cummulative_set.isdisjoint(target_ops): + conflicted_ops = cummulative_set.intersection(target_ops) + if handle_conflicts(conflicted_ops) == False: + raise NotImplementedError("Failed to handle conflicting thread shape.") + target_ops = target_ops.difference(conflicted_ops) + cummulative_set = cummulative_set.union(target_ops) + # Set target ops's indexSize to be the determined from analysis. + for user in target_ops: + custom_user = get_custom(user) + set_index_size(custom_user, target_index_size) diff --git a/shark_turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py similarity index 54% rename from shark_turbine/kernel/wave/utils.py rename to iree/turbine/kernel/wave/utils.py index affd5fef..11b3cbe2 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -1,5 +1,4 @@ # Copyright 2024 The IREE Authors -# # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -11,16 +10,34 @@ Operation, transform_d, UnitAttr, + Value, ) -from typing import Callable, Any, List, Tuple +from typing import Optional, Callable, Any, List, Tuple from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * -from ..ops.wave_ops import get_custom, Output, Write, MMA -from .constraints import Constraint, HardwareConstraint, TilingConstraint +from ..ops.wave_ops import ( + get_custom, + Output, + Write, + MMA, + CustomOp, + Reduction, + GetResult, + IterArg, +) +from .constraints import ( + Constraint, + WorkgroupConstraint, + HardwareConstraint, + TilingConstraint, +) import torch.fx as fx -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl + +import tempfile +from ...support.conversions import TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM from iree.compiler.dialects.transform import ( interpreter as transform_interpreter, any_op_t, @@ -90,6 +107,21 @@ def print_trace(trace: CapturedTrace, custom_print: bool = True): print(get_custom(node)) +def print_subgraph(trace: CapturedTrace, subgraph_name: str, custom_print: bool = True): + """ + Prints a specific subgraphs of a trace. + The graphs are printed first in the torch printing format and + then using our custom node format. + """ + # The root graph is at the back so we print the subgraphs in reverse order + for name, subgraph in trace.region_graph.subgraphs.items(): + if name == subgraph_name: + print(subgraph) + if custom_print: + for node in subgraph.nodes: + print(get_custom(node)) + + def DCE(trace: CapturedTrace): """ Removes all operators that are not used in the graph, @@ -115,6 +147,19 @@ def is_removable_operator(node: fx.Node) -> bool: get_custom(node).graph.erase_node(node) +def remove_chained_getresult(trace: CapturedTrace): + def is_chained_getresult(node: fx.Node) -> bool: + custom = get_custom(node) + return isinstance(custom, GetResult) and isinstance( + get_custom(custom.value), GetResult + ) + + while removable_nodes := trace.walk(is_chained_getresult): + for node in removable_nodes: + get_custom(node).replace_all_uses_with(get_custom(node).value) + get_custom(node).graph.erase_node(node) + + def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]: """ Delinearizes a 1D index into a multi-dimensional index @@ -145,7 +190,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) -def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]: +def get_mma_dimensional_mapping( + trace: CapturedTrace, +) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have @@ -159,7 +206,8 @@ def is_mma(node): return isinstance(get_custom(node), MMA) mapping: dict[IndexSymbol, int] = {} - for node in trace.walk(is_mma): + mma_nodes = trace.walk(is_mma) + for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] lhs_shape = custom.lhs_type.symbolic_shape @@ -170,7 +218,7 @@ def is_mma(node): mapping[n] = 1 mapping[k] = 2 - return mapping + return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) def get_hardware_vector_size( @@ -223,9 +271,14 @@ def _invoke(vm_context, device, entry_function, inputs, outputs): ret_list = rt.VmVariantList(len(outputs)) for input in inputs: - input_cpu = input.cpu().contiguous() - device_array = rt.asdevicearray(device, input_cpu) - arg_list.push_ref(device_array._buffer_view) + if isinstance(input, torch.Tensor): + input_cpu = input.cpu().contiguous() + device_array = rt.asdevicearray(device, input_cpu) + arg_list.push_ref(device_array._buffer_view) + elif isinstance(input, int): + arg_list.push_int(input) + else: + raise ValueError(f"Unsupported input type: {type(input)}") vm_context.invoke(entry_function, arg_list, ret_list) @@ -281,6 +334,16 @@ def compile_and_invoke( if config.get("print_ir_after_all", False): flags.append("--mlir-print-ir-after-all") + preprocessing_pipeline = config.get("iree_preprocessing_pass_pipeline", None) + if preprocessing_pipeline is not None: + flags.append(f"--iree-preprocessing-pass-pipeline={preprocessing_pipeline}") + + if "dump_intermediates" in config: + intermediates_path = config.get("dump_intermediates") + flags.append( + f"--iree-hal-dump-executable-intermediates-to={intermediates_path}" + ) + if run_bench: bench_batch_size = config.get("benchmark_batch_size", None) bench_repetitions = config.get("benchmark_repetitions", None) @@ -322,7 +385,24 @@ def compile_and_invoke( _invoke(ctx.vm_context, device, func, kernel_inputs, kernel_outputs) if run_bench: - inputs = [inp.numpy() for inp in kernel_inputs] + bench_with_constant_weights = config.get("bench_with_constant_weights", False) + tempfiles = [] + inputs = [] + if bench_with_constant_weights: + for inp in kernel_inputs: + inputs.append( + "x".join( + [str(x) for x in inp.shape] + + [TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM[inp.dtype]] + ) + ) + else: + for inp in kernel_inputs: + tf = tempfile.NamedTemporaryFile(suffix=".npy") + numpy.save(tf, inp.numpy()) + tempfiles.append(tf) + inputs.append("@" + tf.name) + benchmark_results = bench.benchmark_module( mod, entry_function=func_name, @@ -378,3 +458,204 @@ def erase_graph(graph: fx.Graph): for user in node.users: graph.erase_node(user) graph.erase_node(node) + + +def get_users( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the users of a node, propagating through reductions. + """ + users = [] + for user in node.users: + custom = get_custom(user) + if isinstance(custom, Reduction): + # Map init arg to iter arg + reduction = custom + init_arg_idx = custom.init_args.index(node) + users.append(custom.iter_args[init_arg_idx]) + continue + if isinstance(custom, Output) and reduction: + # Map output to get result + return_vals = custom.return_vals[0] + get_results = sorted( + [x for x in reduction.users if isinstance(get_custom(x), GetResult)], + lambda x: get_custom(x).res_idx, + ) + if isinstance(return_vals, list): + output_idx = return_vals.index(node) + users.append(get_results[output_idx]) + else: + users.append(get_results[0]) + continue + users.append(user) + return users, reduction + + +def get_inputs( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the inputs of a node, propagating through reductions. + """ + inputs = [] + custom = get_custom(node) + if isinstance(custom, IterArg): + # Map iter args to init args + local_reduction = reduction + if reduction is None: + local_reduction = custom.parent_op() + iter_arg_idx = custom.get_iter_idx() + inputs.append(local_reduction.init_args[iter_arg_idx]) + elif isinstance(custom, GetResult): + reduction = get_custom(custom.value) + assert isinstance( + get_custom(reduction), Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] + inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + else: + # Default handling for other ops. + for input in node.all_input_nodes: + inputs.append(input) + return inputs, reduction + + +def bfs( + node: fx.Node, + get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], + filter_fn: Callable[[fx.node], bool], +) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + visited: set[fx.Node] = set() + queue: list[fx.Node] = [] + visited.add(node) + queue.append(node) + reduction = None + while queue: + s = queue.pop(0) + neighbors, reduction = get_neighbors(s, reduction) + for neighbor in neighbors: + if neighbor not in visited and filter_fn(neighbor): + visited.add(neighbor) + queue.append(neighbor) + return visited + + +def capture_forward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + return bfs(node, lambda x, y: get_users(x, y), filter_fn) + + +def capture_backward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: + """ + Capture backward slice from a node and return the tree. + Assumes graph is directed. + """ + return bfs(node, lambda x, y: get_inputs(x, y), filter_fn) + + +def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + """ + mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} + for mma in mma_nodes: + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) + mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + return mma_slices + + +def specialize_index_sequence( + index_seq: IndexSequence, + mma_slices: dict[IndexSymbol, list[fx.Node]], + custom: CustomOp, +) -> IndexSequence: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + If the node is not used as any of the operands, return the original index sequence + with all the MMA symbols zeroed out. + """ + if isinstance(custom, MMA): + return index_seq + operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} + for key in mma_slices: + if custom.fx_node in mma_slices[key]: + operand_map[key] = 1 + return index_seq.subs(operand_map) + return index_seq.subs(operand_map) + + +def find_index_bounds( + constraints: list[Constraint], index: dict[IndexExpr, IndexExpr] +) -> Optional[list[IndexExpr]]: + bounds = [] + for constraint in constraints: + if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): + continue + + dim = constraint.dim + if dim not in index: + continue + + work_size = constraint.count * constraint.tile_size + if subs_idxc(work_size) == subs_idxc(dim): + continue + + bounds.append(dim) + + if len(bounds) == 0: + return None + + return bounds + + +def get_induction_variable( + reduction: Reduction, constraints: list[Constraint] +) -> IndexSymbol: + induction_var = None + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + induction_var = constraint.induction_var + break + else: + raise ValueError(f"Could not find induction variable for reduction {reduction}") + return induction_var + + +def get_tiling_constraint( + reduction: Reduction, constraints: list[Constraint] +) -> TilingConstraint: + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + return constraint + else: + raise ValueError(f"Could not find tiling constraint for reduction {reduction}") + + +def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node): + """ + Replace all uses of `old` with `new` in the list of users. + """ + for user in users[old]: + for i, arg in enumerate(user.fx_node.args): + if arg == old.fx_node: + user.update_arg(i, new) diff --git a/iree/turbine/kernel/wave/visualization.py b/iree/turbine/kernel/wave/visualization.py new file mode 100644 index 00000000..d6438bfc --- /dev/null +++ b/iree/turbine/kernel/wave/visualization.py @@ -0,0 +1,190 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +graphviz_disabled = False +try: + import pygraphviz as pgv +except: + graphviz_disabled = True +from torch import fx +from .scheduling.graph_utils import Edge +from ..ops.wave_ops import Output, Placeholder, IterArg, get_custom +from collections import ChainMap +import math + + +def number_nodes(graph: fx.Graph) -> dict[int, int]: + return {id(node): i for i, node in enumerate(graph.nodes)} + + +def visualize_graph(graph: fx.Graph, file_name: str): + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + node_numbering = number_nodes(graph) + G = pgv.AGraph(directed=True) + for node in graph.nodes: + G.add_node(node_numbering[id(node)], label=node.name) + for node in graph.nodes: + for user in node.users.keys(): + # Handle scenario where nodes are shared across graphs. + if user not in graph.nodes: + continue + G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) + G.layout(prog="dot") + G.draw(file_name) + + +def visualize_edges(edges: list[Edge], file_name: str): + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + G = pgv.AGraph(directed=True) + node_map = {} + count = 0 + for edge in edges: + if edge._from not in node_map: + node_map[edge._from] = count + count += 1 + G.add_node(node_map[edge._from], label=f"{edge._from}") + if edge._to not in node_map: + node_map[edge._to] = count + count += 1 + G.add_node(node_map[edge._to], label=f"{edge._to}") + G.add_edge( + node_map[edge._from], + node_map[edge._to], + label=f"({edge.weight.iteration_difference}, {edge.weight.delay})", + ) + G.layout(prog="dot") + G.draw(file_name) + + +def visualize_schedule( + schedule: dict[fx.Graph, int], initiation_interval: int, file_name: str +): + import pandas as pd + + max_time = max(schedule.values()) + max_stage = math.ceil(max_time / initiation_interval) + rows = max_time + 1 + max_stage * initiation_interval + cols = max_stage + + table = [["" for _ in range(cols)] for _ in range(rows)] + for stage in range(max_stage): + for key, value in schedule.items(): + table[value + stage * initiation_interval][stage] += f"{key}
" + + df = pd.DataFrame(table, columns=[f"Iteration {i}" for i in range(cols)]) + s = df.style.set_properties(**{"text-align": "center"}) + s = s.set_table_styles( + [ + {"selector": "", "props": [("border", "1px solid grey")]}, + {"selector": "tbody td", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("min-width", "300px")]}, + ] + ) + output = s.apply( + lambda x: [ + ( + "background: lightgreen" + if int(x.name) >= (max_stage - 1) * initiation_interval + and int(x.name) < max_stage * initiation_interval + else "" + ) + for _ in x + ], + axis=1, + ).to_html() + with open(f"{file_name}", "w") as f: + f.write(output) + + +def visualize_mapped_graphs( + second: fx.Graph, + rotating_registers: dict[fx.Node, list[fx.Node]], + mappings: list[list[dict[fx.Node, fx.Node]]], + file_name: str, +): + """ + Given the pipelined graph and a list of mappings of nodes from the original + graph to the pipelined graph (per stage), visualize the pipelined graph (with their original labels) + + """ + + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + second_numbering = number_nodes(second) + + flat_inverse_map: dict[fx.Node, fx.Node] = {} + flat_map: dict[fx.Node, fx.Node] = {} + for iteration_mapping in mappings: + for mapping in iteration_mapping: + flat_inverse_map.update({v: k for k, v in mapping.items()}) + flat_map.update(mapping) + flat_inverse_map = ChainMap(flat_inverse_map) + flat_map = ChainMap(flat_map) + + # Draw nodes and edges in the pipelined graph. + G = pgv.AGraph(directed=True) + G0 = G.add_subgraph(name="pipelined") + stage: dict[fx.Node, int] = {} + for node in second.nodes: + if hasattr(node, "scheduling_parameters"): + if node in flat_inverse_map: + name = flat_inverse_map[node].name + else: + name = node.name + else: + name = node.name + G0.add_node( + second_numbering[id(node)], + label=name, + color="lightblue", + style="filled", + ) + for user in node.users.keys(): + if user not in second.nodes: + continue + if isinstance(get_custom(user), Output): + continue + G0.add_edge( + second_numbering[id(node)], + second_numbering[id(user)], + color="black", + ) + + # Draw nodes and edges in the original graph. + colors = ["red", "green", "orange", "purple", "orange", "cyan", "magenta"] + max_stage = len(mappings) + for node, mapped_node in flat_map.items(): + for user in node.users.keys(): + if user not in flat_map: + continue + mapped_user = flat_map[user] + if mapped_user not in second.nodes or mapped_node not in second.nodes: + continue + stage = "" + if hasattr(user, "scheduling_parameters"): + stage = user.scheduling_parameters["stage"] + G.add_edge( + second_numbering[id(mapped_node)], + second_numbering[id(mapped_user)], + label=f"{stage}", + color=colors[stage % max_stage], + ) + + # Draw edges between rotating registers for the same variable. + for node in rotating_registers: + all_registers = [k for k, v in flat_inverse_map.items() if v == node] + for second, first in zip(all_registers[:-1], all_registers[1:]): + G.add_edge( + second_numbering[id(first)], + second_numbering[id(second)], + color="blue", + ) + + G.layout(prog="dot") + G.draw(file_name) diff --git a/shark_turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py similarity index 80% rename from shark_turbine/kernel/wave/wave.py rename to iree/turbine/kernel/wave/wave.py index eb6003de..177d9867 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -23,7 +23,13 @@ from .expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_allocs -from .utils import canonicalize_module, compile_and_invoke, safe_subs +from .utils import ( + canonicalize_module, + compile_and_invoke, + safe_subs, + remove_chained_getresult, + subs_idxc, +) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops from .barriers import add_shared_memory_barriers @@ -33,10 +39,10 @@ from ..ops.wave_ops import Reduction, CustomOp, get_custom from .index_sequence_analysis import partition_strided_operators from .shared_memory_indexing import apply_shared_memory_indexing_corrections -from .register_analysis import determine_register_shape +from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, CompiledContext, @@ -179,6 +185,18 @@ def initialize_wave_constraints(self, trace: CapturedTrace) -> None: / hardware_constraint.threads_per_wave ) + def initialize_reductions(self, trace: CapturedTrace) -> None: + """ + For each reduction, initializes the reduction count by looking at the + tiling constraints associated with the reduction. + + """ + is_reduction = lambda node: isinstance(get_custom(node), Reduction) + for reduction in trace.walk(is_reduction): + for tiling_constraint in self.tiling_constraints: + if tiling_constraint.dim == get_custom(reduction).axis: + reduction.count = subs_idxc(tiling_constraint.count) + def _trace_and_get_kernel_signature( self, args, @@ -191,6 +209,7 @@ def _trace_and_get_kernel_signature( self.create_induction_vars(graph) self.initialize_wave_constraints(graph) + self.initialize_reductions(graph) idxc = IndexingContext.current() idxc.finalize() @@ -205,8 +224,8 @@ def _trace_and_get_kernel_signature( # Expansion expand_graph(graph, self.constraints) - # Register analysis to determine register shapes. - determine_register_shape(graph, self.constraints) + # Clean up chains of GetResults + remove_chained_getresult(graph) # Optimizations. minimize_global_loads(graph, self.constraints) @@ -217,12 +236,25 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) + # Analyze Thread Shapes per Op. + determine_thread_shapes(graph) + # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs) # Schedule the reduction ops. + # Scheduling should always be used with use_scheduling_barriers=True, + # as this is the only way we can ensure that LLVM enforces our desired schedule. + # However, due a bug in LLVM, you will need to patch your local LLVM repo + # with the following PR: https://github.com/kerbowa/llvm-project/commit/ee52732cddae42deed2e3387a83b20ec05860b4e + # Specifically: + # git remote add sched_fixes https://github.com/kerbowa/llvm-project.git + # git fetch sched_fixes + # git cherry-pick ee52732cddae42deed2e3387a83b20ec05860b4e + # [Manually resolve conflicts consistent with the PR] if kwargs.get("schedule", False): - schedule_graph(graph, self.constraints) + use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False) + schedule_graph(graph, self.constraints, use_scheduling_barriers) # Add shared memory barriers. add_shared_memory_barriers(graph) @@ -238,6 +270,8 @@ def _trace_and_get_kernel_signature( root_graph = graph.get_root_graph() kernel_sig = kernel_codegen.KernelSignature() kernel_sig.add_from_graph_placeholders(root_graph) + dynamic_symbols = kwargs.get("dynamic_symbols", []) + kernel_sig.add_from_dynamic_symbols(dynamic_symbols) kernel_sig.add_grid(self.grid_type) kernel_sig.determine_input_output_buffers(root_graph) @@ -247,10 +281,17 @@ def _trace_and_get_kernel_signature( workgroup_size = self.hardware_constraints[0].threads_per_block subgroup_size = self.hardware_constraints[0].threads_per_wave dispatch_entrypoint = exe.define_entrypoint( - entrypoint_name, kernel_sig, grid, workgroup_size, subgroup_size + entrypoint_name, + kernel_sig, + grid, + workgroup_size, + subgroup_size, + dynamic_symbols, ) - emitter = WaveEmitter(dispatch_entrypoint, graph, self.constraints) + emitter = WaveEmitter( + dispatch_entrypoint, graph, self.constraints, dynamic_symbols + ) emitter.emit(graph.get_root_graph()) emitter.finish() @@ -272,7 +313,10 @@ def test_execute(self, args, kwargs): run_bench = kwargs.get("run_bench", False) if run or run_bench: # TODO: cache compiled code - host_codegen.isolated_test_call(mb, exe, kernel_sig, entrypoint_name) + dynamic_symbols = kwargs.get("dynamic_symbols", []) + host_codegen.isolated_test_call( + mb, exe, kernel_sig, entrypoint_name, dynamic_symbols + ) asm = mb.module_op.get_asm() kernel_inputs = [] @@ -285,6 +329,10 @@ def test_execute(self, args, kwargs): if usage == kernel_codegen.KernelBufferUsage.OUTPUT: kernel_outputs.append(arg) + dynamic_symbols_map = kwargs.get("dynamic_symbols_map", {}) + if dynamic_symbols: + kernel_inputs += [dynamic_symbols_map[sym] for sym in dynamic_symbols] + config = kwargs.get("run_config", None) if not config: raise ValueError("no config provided") diff --git a/shark_turbine/kernel/wave/wave_sim.py b/iree/turbine/kernel/wave/wave_sim.py similarity index 100% rename from shark_turbine/kernel/wave/wave_sim.py rename to iree/turbine/kernel/wave/wave_sim.py diff --git a/shark_turbine/ops/__init__.py b/iree/turbine/ops/__init__.py similarity index 100% rename from shark_turbine/ops/__init__.py rename to iree/turbine/ops/__init__.py diff --git a/shark_turbine/ops/_jinja_test_ops.py b/iree/turbine/ops/_jinja_test_ops.py similarity index 100% rename from shark_turbine/ops/_jinja_test_ops.py rename to iree/turbine/ops/_jinja_test_ops.py diff --git a/shark_turbine/ops/_str_format_test_ops.py b/iree/turbine/ops/_str_format_test_ops.py similarity index 100% rename from shark_turbine/ops/_str_format_test_ops.py rename to iree/turbine/ops/_str_format_test_ops.py diff --git a/shark_turbine/ops/iree.py b/iree/turbine/ops/iree.py similarity index 97% rename from shark_turbine/ops/iree.py rename to iree/turbine/ops/iree.py index 1609db2b..b4d79aee 100644 --- a/shark_turbine/ops/iree.py +++ b/iree/turbine/ops/iree.py @@ -83,7 +83,8 @@ class transfer_to_logical_device(CustomOp): def select(self, ksel: KernelSelection): ksel.attr_str(0) ta = ksel.arg_tensor(1) - ksel.return_tensor(ta.t) + ta.specialize_all_dims() + ksel.return_tensor(ta.t).specialize_all_dims() def eager_execute(self, device_moniker, tensor): return tensor diff --git a/shark_turbine/ops/templates/test_add_jinja.mlir b/iree/turbine/ops/templates/test_add_jinja.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_jinja.mlir rename to iree/turbine/ops/templates/test_add_jinja.mlir diff --git a/shark_turbine/ops/templates/test_add_strformat.mlir b/iree/turbine/ops/templates/test_add_strformat.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_strformat.mlir rename to iree/turbine/ops/templates/test_add_strformat.mlir diff --git a/shark_turbine/ops/templates/test_syntax_error.mlir b/iree/turbine/ops/templates/test_syntax_error.mlir similarity index 100% rename from shark_turbine/ops/templates/test_syntax_error.mlir rename to iree/turbine/ops/templates/test_syntax_error.mlir diff --git a/shark_turbine/runtime/__init__.py b/iree/turbine/runtime/__init__.py similarity index 100% rename from shark_turbine/runtime/__init__.py rename to iree/turbine/runtime/__init__.py diff --git a/shark_turbine/runtime/device.py b/iree/turbine/runtime/device.py similarity index 100% rename from shark_turbine/runtime/device.py rename to iree/turbine/runtime/device.py diff --git a/shark_turbine/runtime/launch.py b/iree/turbine/runtime/launch.py similarity index 100% rename from shark_turbine/runtime/launch.py rename to iree/turbine/runtime/launch.py diff --git a/shark_turbine/runtime/op_reg/__init__.py b/iree/turbine/runtime/op_reg/__init__.py similarity index 100% rename from shark_turbine/runtime/op_reg/__init__.py rename to iree/turbine/runtime/op_reg/__init__.py diff --git a/shark_turbine/runtime/op_reg/base.py b/iree/turbine/runtime/op_reg/base.py similarity index 100% rename from shark_turbine/runtime/op_reg/base.py rename to iree/turbine/runtime/op_reg/base.py diff --git a/shark_turbine/runtime/op_reg/compiler.py b/iree/turbine/runtime/op_reg/compiler.py similarity index 100% rename from shark_turbine/runtime/op_reg/compiler.py rename to iree/turbine/runtime/op_reg/compiler.py diff --git a/shark_turbine/runtime/op_reg/eager.py b/iree/turbine/runtime/op_reg/eager.py similarity index 100% rename from shark_turbine/runtime/op_reg/eager.py rename to iree/turbine/runtime/op_reg/eager.py diff --git a/shark_turbine/runtime/op_reg/impl_helper.py b/iree/turbine/runtime/op_reg/impl_helper.py similarity index 100% rename from shark_turbine/runtime/op_reg/impl_helper.py rename to iree/turbine/runtime/op_reg/impl_helper.py diff --git a/shark_turbine/runtime/tracing.py b/iree/turbine/runtime/tracing.py similarity index 100% rename from shark_turbine/runtime/tracing.py rename to iree/turbine/runtime/tracing.py diff --git a/shark_turbine/support/__init__.py b/iree/turbine/support/__init__.py similarity index 100% rename from shark_turbine/support/__init__.py rename to iree/turbine/support/__init__.py diff --git a/shark_turbine/support/conversions.py b/iree/turbine/support/conversions.py similarity index 100% rename from shark_turbine/support/conversions.py rename to iree/turbine/support/conversions.py diff --git a/shark_turbine/support/debugging.py b/iree/turbine/support/debugging.py similarity index 100% rename from shark_turbine/support/debugging.py rename to iree/turbine/support/debugging.py diff --git a/shark_turbine/support/exceptions.py b/iree/turbine/support/exceptions.py similarity index 100% rename from shark_turbine/support/exceptions.py rename to iree/turbine/support/exceptions.py diff --git a/shark_turbine/support/ir_imports.py b/iree/turbine/support/ir_imports.py similarity index 100% rename from shark_turbine/support/ir_imports.py rename to iree/turbine/support/ir_imports.py diff --git a/shark_turbine/support/logging.py b/iree/turbine/support/logging.py similarity index 100% rename from shark_turbine/support/logging.py rename to iree/turbine/support/logging.py diff --git a/shark_turbine/tools/__init__.py b/iree/turbine/tools/__init__.py similarity index 100% rename from shark_turbine/tools/__init__.py rename to iree/turbine/tools/__init__.py diff --git a/shark_turbine/tools/interpreter.py b/iree/turbine/tools/interpreter.py similarity index 66% rename from shark_turbine/tools/interpreter.py rename to iree/turbine/tools/interpreter.py index 5e4a0b15..5022933d 100644 --- a/shark_turbine/tools/interpreter.py +++ b/iree/turbine/tools/interpreter.py @@ -4,32 +4,33 @@ import re from typing import Callable from collections import namedtuple +import numpy as np logger = get_logger("turbine.wave.interpreter") from ..kernel.compiler.ir import ( - amdgpu_d, - builtin_d, Context, + F16Type, + F32Type, IndexType, - Value, - VectorType, + IntegerAttr, + IntegerType, Module, Operation, + Value, + VectorType, + amdgpu_d, + arith_d, + builtin_d, flow_d, func_d, gpu_d, llvm_d, - scf_d, - vector_d, memref_d, - IntegerAttr, - IndexType, - arith_d, + scf_d, stream_d, - F32Type, - F16Type, + vector_d, ) @@ -53,17 +54,12 @@ def get_dtype(self, dtype): return torch.float32 if type(dtype) == F16Type: return torch.float16 + if type(dtype) == IndexType: + return torch.int64 + if dtype == IntegerType.get_signless(1): + return torch.bool raise NotImplementedError(f"Unsupported dtype: {dtype}") - def create_tensor(self, shape: list[int], dtype, value) -> torch.Tensor: - """ - Creates a constant tensor with the given shape, dtype and value. - The tensor is filled with ones. - """ - if type(dtype) == F32Type or type(dtype) == F16Type: - value = float(value) - return torch.ones(*shape, dtype=self.get_dtype(dtype)) * value - def callback(self, op: Operation) -> None: if ( op.operation.parent.name == "func.func" @@ -80,11 +76,13 @@ def callback(self, op: Operation) -> None: elif vtype == VectorType: shape = op.value.type.shape dtype = op.value.type.element_type - value = self.create_tensor( - shape, - dtype, - op.attributes["value"].get_splat_value(), - ) + val = op.attributes["value"] + dtype = self.get_dtype(dtype) + if val.is_splat: + val = val.get_splat_value().value + value = torch.full(shape, val, dtype=dtype) + else: + value = torch.from_numpy(np.array(val)).type(dtype=dtype) else: raise NotImplementedError(f"Unsupported constant type: {vtype}") case arith_d.MulIOp: @@ -112,6 +110,21 @@ def callback(self, op: Operation) -> None: self.symbol_table[op.operands[0]] // self.symbol_table[op.operands[1]] ) + case arith_d.AndIOp: + value = ( + self.symbol_table[op.operands[0]] + & self.symbol_table[op.operands[1]] + ) + case arith_d.CmpIOp: + lhs = self.symbol_table[op.lhs] + rhs = self.symbol_table[op.rhs] + pred = int(op.predicate) + if pred == int(arith_d.CmpIPredicate.slt): + value = lhs < rhs + else: + raise NotImplementedError( + f"Unsupported predicate: {op.predicate}" + ) case amdgpu_d.LDSBarrierOp: return case amdgpu_d.MFMAOp: @@ -136,11 +149,10 @@ def callback(self, op: Operation) -> None: ) # Row-major load offset = [0 for _ in range(len(load_indices))] - offset[-1] += 1 for i in range(*result_shape): - value[i] = memref[ - *[int(x) + y for x, y in zip(load_indices, offset)] - ] + ind = [int(x) + y for x, y in zip(load_indices, offset)] + value[i] = memref[*ind] + offset[-1] += 1 case vector_d.ExtractStridedSliceOp: vector = self.symbol_table[op.vector] value = vector[[int(x) for x in op.offsets]] @@ -154,11 +166,69 @@ def callback(self, op: Operation) -> None: result_shape = vector.shape # Row-major store offset = [0 for _ in range(len(store_indices))] - offset[-1] += 1 for i in range(*result_shape): memref[ *[int(x) + y for x, y in zip(store_indices, offset)] ] = vector[i] + offset[-1] += 1 + case vector_d.MaskedStoreOp: + store_indices = [] + for index in op.indices: + store_indices.append(self.symbol_table[index]) + vector = self.symbol_table[op.valueToStore] + memref = self.symbol_table[op.base] + mask = self.symbol_table[op.mask] + result_type = vector.type + result_shape = vector.shape + # Row-major store + offset = [0 for _ in range(len(store_indices))] + for i in range(*result_shape): + if mask[i]: + ind = [int(x) + y for x, y in zip(store_indices, offset)] + memref[*ind] = vector[i] + + offset[-1] += 1 + case vector_d.ConstantMaskOp: + shape = op.result.type.shape + value = torch.ones(shape, dtype=torch.bool) + case vector_d.GatherOp: + load_indices = [] + for index in op.indices: + load_indices.append(self.symbol_table[index]) + logger.debug("Gather indices:", load_indices) + memref = self.symbol_table[op.base] + mask = self.symbol_table[op.mask] + index_vec = self.symbol_table[op.index_vec] + pass_thru = self.symbol_table[op.pass_thru] + result_type = op.result.type + result_shape = result_type.shape + result_dtype = result_type.element_type + value = torch.zeros( + *result_shape, dtype=self.get_dtype(result_dtype) + ) + # Row-major load + offset = [0 for _ in range(len(load_indices))] + for i in range(*result_shape): + if mask[i]: + off = [ + slice(int(x) + y, None) + for x, y in zip(load_indices, offset) + ] + m = memref[off].flatten() + value[i] = m[index_vec[i]] + else: + value[i] = pass_thru[i] + case vector_d.InsertElementOp: + source = self.symbol_table[op.source] + value = self.symbol_table[op.dest].clone() + position = self.symbol_table[op.position] + value[int(position[0])] = source + case vector_d.SplatOp: + mtype = op.result.type + shape = mtype.shape + dtype = mtype.element_type + input = self.symbol_table[op.input][0] + value = torch.full(shape, input, dtype=self.get_dtype(dtype)) case stream_d.DispatchWorkgroupIDOp: index = int(op.attributes["dimension"]) value = self.workgroup_ids[index] @@ -214,7 +284,7 @@ def callback(self, op: Operation) -> None: case _: raise NotImplementedError(f"Unsupported operation: {op}") - if type(op) != vector_d.StoreOp: + if type(op) not in (vector_d.StoreOp, vector_d.MaskedStoreOp): self.symbol_table[op.result] = value def walk_operations(self, operation: Operation, callback: Callable) -> None: @@ -237,6 +307,14 @@ def interpret(self, asm: str) -> None: operation = module.operation self.walk_operations(operation, self.callback) + @staticmethod + def interpret_ndrange( + asm: str, workgroup_count: list[int], workgroup_size: list[int] + ): + for wg in np.ndindex(*workgroup_count): + for t in np.ndindex(*workgroup_size): + Interpreter([*wg], [*t]).interpret(asm) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MLIR Interpreter") diff --git a/shark_turbine/transforms/builder.py b/iree/turbine/transforms/builder.py similarity index 100% rename from shark_turbine/transforms/builder.py rename to iree/turbine/transforms/builder.py diff --git a/shark_turbine/transforms/general/add_metadata.py b/iree/turbine/transforms/general/add_metadata.py similarity index 97% rename from shark_turbine/transforms/general/add_metadata.py rename to iree/turbine/transforms/general/add_metadata.py index 44aa2413..340169ec 100644 --- a/shark_turbine/transforms/general/add_metadata.py +++ b/iree/turbine/transforms/general/add_metadata.py @@ -12,7 +12,7 @@ import re -from shark_turbine.support.ir_imports import * +from iree.turbine.support.ir_imports import * from ..rewriter import * from iree.compiler.ir import Context, DictAttr diff --git a/shark_turbine/transforms/general/custom_op_expansion.py b/iree/turbine/transforms/general/custom_op_expansion.py similarity index 100% rename from shark_turbine/transforms/general/custom_op_expansion.py rename to iree/turbine/transforms/general/custom_op_expansion.py diff --git a/shark_turbine/transforms/general/rename_parameters.py b/iree/turbine/transforms/general/rename_parameters.py similarity index 100% rename from shark_turbine/transforms/general/rename_parameters.py rename to iree/turbine/transforms/general/rename_parameters.py diff --git a/shark_turbine/transforms/merger.py b/iree/turbine/transforms/merger.py similarity index 100% rename from shark_turbine/transforms/merger.py rename to iree/turbine/transforms/merger.py diff --git a/shark_turbine/transforms/quantization/mm_group_quant.py b/iree/turbine/transforms/quantization/mm_group_quant.py similarity index 100% rename from shark_turbine/transforms/quantization/mm_group_quant.py rename to iree/turbine/transforms/quantization/mm_group_quant.py diff --git a/shark_turbine/transforms/rewriter.py b/iree/turbine/transforms/rewriter.py similarity index 100% rename from shark_turbine/transforms/rewriter.py rename to iree/turbine/transforms/rewriter.py diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 1b446dc0..14eb2e60 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -3,18 +3,18 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: @@ -98,7 +98,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read_0_1 # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_shared_0_0 # CHECK-SAME: (%read_0_0, %allocate, 4, None) # CHECK-NEXT: %write_shared_1_1 @@ -182,9 +182,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 @@ -207,9 +207,9 @@ def test_gemm(): # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 # CHECK-NEXT: %read_0_0_1 diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 0bba2384..b9fc81ed 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -2,11 +2,11 @@ import pytest from typing import Callable -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test import torch M = tkl.sym.M @@ -21,18 +21,24 @@ ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 -def codegen_test_context(canonicalize: bool = False): +def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): + bindings = { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + } + + # Remove dynamic symbols from the bindings. + for sym in dynamic_symbols: + if sym in bindings: + del bindings[sym] + return tk.gen.TestLaunchContext( - { - M: 16, - N: 16, - K: 16, - BLOCK_M: 16, - BLOCK_N: 16, - BLOCK_K: 16, - ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, - }, - canonicalize=canonicalize, + bindings, canonicalize=canonicalize, dynamic_symbols=dynamic_symbols ) @@ -231,51 +237,40 @@ def test( print(test(a, b).module_op) # CHECK: func.func @test(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding) - # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16> - # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index - # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index - # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index - # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index - # CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index - # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index - # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index - # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index - # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x - # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index - # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C4]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index - # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index - # CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index - # CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index - # CHECK: %[[D9:.+]] = vector.constant_mask [4] : vector<4xi1> - # CHECK: %[[D10:.+]] = arith.cmpi slt, %[[D5]], %[[C1]] : index - # CHECK: %[[D11:.+]] = arith.cmpi slt, %[[D8]], %[[C3]] : index - # CHECK: %[[D12:.+]] = arith.andi %[[D10]], %[[D11]] : i1 - # CHECK: %[[D13:.+]] = vector.insertelement %[[D12]], %[[D9]][%[[C0]] : index] : vector<4xi1> - # CHECK: %[[D14:.+]] = arith.addi %[[D8]], %[[C1]] : index - # CHECK: %[[D15:.+]] = arith.cmpi slt, %[[D14]], %[[C3]] : index - # CHECK: %[[D16:.+]] = arith.andi %[[D10]], %[[D15]] : i1 - # CHECK: %[[D17:.+]] = vector.insertelement %[[D16]], %[[D13]][%[[C1]] : index] : vector<4xi1> - # CHECK: %[[D18:.+]] = arith.addi %[[D8]], %[[C2]] : index - # CHECK: %[[D19:.+]] = arith.cmpi slt, %[[D18]], %[[C3]] : index - # CHECK: %[[D20:.+]] = arith.andi %[[D10]], %[[D19]] : i1 - # CHECK: %[[D21:.+]] = vector.insertelement %[[D20]], %[[D17]][%[[C2]] : index] : vector<4xi1> - # CHECK: %[[D22:.+]] = arith.addi %[[D8]], %[[C3]] : index - # CHECK: %[[D23:.+]] = arith.cmpi slt, %[[D22]], %[[C3]] : index - # CHECK: %[[D24:.+]] = arith.andi %[[D10]], %[[D23]] : i1 - # CHECK: %[[D25:.+]] = vector.insertelement %[[D24]], %[[D21]][%[[C3]] : index] : vector<4xi1> - # CHECK: %[[D26:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D25]], %[[CST]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> - # CHECK: %[[D27:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: vector.maskedstore %[[D27]][%[[D5]], %[[D8]]], %[[D25]], %[[D26]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> + # CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16> + # CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<3> : vector<4xindex> + # CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + # CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + # CHECK: %[[WORKGROUP_ID_0:.*]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.*]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.*]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y + # CHECK: %[[D0:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: %[[D1:.*]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index + # CHECK: %[[D2:.*]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.*]] = arith.muli %[[D2]], %[[C4]] : index + # CHECK: %[[D4:.*]] = arith.addi %[[D3]], %[[D1]] : index + # CHECK: %[[D5:.*]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index + # CHECK: %[[D6:.*]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index + # CHECK: %[[D7:.*]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index + # CHECK: %[[D8:.*]] = arith.addi %[[D7]], %[[D6]] : index + # CHECK: %[[D9:.*]] = vector.splat %[[D8]] : vector<4xindex> + # CHECK: %[[D10:.*]] = arith.addi %[[D9]], %[[CST_1]] : vector<4xindex> + # CHECK: %[[D11:.*]] = arith.cmpi slt, %[[D10]], %[[CST_0]] : vector<4xindex> + # CHECK: %[[D12:.*]] = arith.cmpi slt, %[[D5]], %[[C1]] : index + # CHECK: %[[D13:.*]] = vector.splat %[[D12]] : vector<4xi1> + # CHECK: %[[D14:.*]] = arith.andi %[[D11]], %[[D13]] : vector<4xi1> + # CHECK: %[[D15:.*]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D14]], %[[CST]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> + # CHECK: %[[D16:.*]] = stream.binding.subspan %arg1[%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: vector.maskedstore %[[D16]][%[[D5]], %[[D8]]], %[[D14]], %[[D15]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> @run_test @@ -339,6 +334,72 @@ def test( # CHECK-SAME: strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> +@run_test +def test_dynamic_copy(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): + b = tkw.read(a, elements_per_thread=16) + tkw.write(b, a, elements_per_thread=16) + + with codegen_test_context(canonicalize=True, dynamic_symbols=[M, N]): + a = torch.randn(16, 16, dtype=torch.float16) + print(test(a).module_op) + + # CHECK: stream.executable.export public @test workgroups(%[[ARG0:.*]]: index, %[[ARG1:.*]]: + # CHECK-SAME: index) -> (index, index, index) { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK: %[[D0:.+]] = arith.ceildivsi %[[ARG0]], %[[C16]] : index + # CHECK: %[[D1:.+]] = arith.ceildivsi %[[ARG1]], %[[C16]] : index + # CHECK: stream.return %[[D0]], %[[D1]], %[[C1]] : index, index, index + # CHECK: } + # CHECK: func.func @test(%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK-SAME: attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<16xf16> + # CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : + # CHECK-SAME: vector<16xindex> + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + # CHECK-DAG: %[[C16]] = arith.constant 16 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y + # CHECK: %[[D0]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref{%[[ARG1]], + # CHECK-SAME: %[[ARG2]]} + # CHECK: %[[D1]] = arith.muli %[[WORKGROUP_ID_0]], %[[C16]] : index + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index + # CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C16]] : index + # CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] : index + # CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index + # CHECK: %[[D9:.+]] = vector.splat %[[D8]] : vector<16xindex> + # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[CST_0]] : vector<16xindex> + # CHECK: %[[D11:.+]] = vector.splat %[[ARG2]] : vector<16xindex> + # CHECK: %[[D12:.+]] = arith.cmpi slt, %[[D10]], %[[D11]] : vector<16xindex> + # CHECK: %[[D13:.+]] = arith.cmpi slt, %[[D5]], %[[ARG1]] : index + # CHECK: %[[D14:.+]] = vector.splat %[[D13]] : vector<16xi1> + # CHECK: %[[D15:.+]] = arith.andi %[[D12]], %[[D14]] : vector<16xi1> + # CHECK: %[[D16:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[CST]] : memref, + # CHECK-SAME: vector<16xi1>, vector<16xf16> into vector<16xf16> + # CHECK: vector.maskedstore %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[D16]] : memref, vector<16xi1>, + # CHECK-SAME: vector<16xf16> + # CHECK: return + + @run_test def test_mma(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] @@ -386,7 +447,7 @@ def mma( print(mma(a, b, c).module_op) # CHECK: func.func @mma(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -405,60 +466,63 @@ def mma( # CHECK: %[[D1:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index # CHECK: %[[D2:.+]] = arith.muli %[[D1]], %[[C16]] : index # CHECK: %[[D3:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index - # CHECK: %[[D5:.+]] = vector.load %[[D0]][%[[D4]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, - # CHECK-SAME: vector<4xf16> - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D2]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D4:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[D3]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D2]] : index + # CHECK: %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D8:.+]] = arith.divsi %[[D7]], %[[C16]] : index + # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index + # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> + # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D6:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D2]] : index - # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index - # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index - # CHECK: %[[D11:.+]] = vector.load %[[ALLOC]][%[[D7]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, + # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> - # CHECK: %[[D13:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D13]] : index - # CHECK: %[[D16:.+]] = vector.load %[[D12]][%[[D15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: + # CHECK: %[[D14:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D15:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D4]], %[[D15]] : index + # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index + # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D13]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D17:.+]] = arith.addi %[[D6]], %[[D13]] : index - # CHECK: %[[D18:.+]] = vector.load %[[ALLOC_0]][%[[D17]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D19:.+]] = amdgpu.mfma %[[D11]] * %[[D18]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D22:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D22:.+]] = arith.addi %[[D4]], %[[D10]] : index - # CHECK: %[[D23:.+]] = arith.addi %[[D6]], %[[D14]] : index - # CHECK: %[[D24:.+]] = arith.addi %[[D23]], %[[D13]] : index - # CHECK: vector.store %[[D20]], %[[D21]][%[[D22]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D24:.+]] = arith.addi %[[D3]], %[[D2]] : index + # CHECK: %[[D25:.+]] = arith.addi %[[D24]], %[[D9]] : index + # CHECK: vector.store %[[D22]], %[[D23]][%[[D25]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D25:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D26:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D26:.+]] = arith.addi %[[D22]], %[[C1]] : index - # CHECK: vector.store %[[D25]], %[[D21]][%[[D26]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D27:.+]] = arith.addi %[[D25]], %[[C1]] : index + # CHECK: vector.store %[[D26]], %[[D23]][%[[D27]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D27:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D28:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D28:.+]] = arith.addi %[[D22]], %[[C2]] : index - # CHECK: vector.store %[[D27]], %[[D21]][%[[D28]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D29:.+]] = arith.addi %[[D25]], %[[C2]] : index + # CHECK: vector.store %[[D28]], %[[D23]][%[[D29]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D29:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D30:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D30:.+]] = arith.addi %[[D22]], %[[C3]] : index - # CHECK: vector.store %[[D29]], %[[D21]][%[[D30]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D31:.+]] = arith.addi %[[D25]], %[[C3]] : index + # CHECK: vector.store %[[D30]], %[[D23]][%[[D31]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: return @run_test @@ -515,7 +579,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm(a, b, c).module_op) # CHECK: func.func @gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index @@ -529,83 +593,254 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: %[[D22:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D24:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D25:.+]] = arith.muli %[[D24]], %[[C16]] : index - # CHECK: %[[D26:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D25]] : index - # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D25]] : index - # CHECK: %[[D32:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D33:.+]] = arith.divsi %[[D32]], %[[C16]] : index - # CHECK: %[[D34:.+]] = arith.muli %[[D33]], %[[C4]] : index - # CHECK: %[[D36:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D37:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D38:.+]] = arith.addi %[[D37]], %[[D36]] : index - # CHECK: %[[D40:.+]] = arith.addi %[[D30]], %[[D36]] : index - # CHECK: %[[D0:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D4]] : index + # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D3]] : index + # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index + # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index + # CHECK: %[[D11:.+]] = arith.addi %[[D5]], %[[D3]] : index + # CHECK: %[[D12:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D13:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D14:.+]] = arith.addi %[[D5]], %[[D13]] : index + # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D12]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D5]], %[[D12]] : index + # CHECK: %[[D17:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<4xf32>) { - # CHECK: %[[D28:.+]] = arith.muli %[[ARG3]], %[[C16]] : index - # CHECK: %[[D29:.+]] = vector.load %[[D22]][%[[D27]], %[[D28]]] : memref<64x64xf16, strided<[64, 1], - # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D25]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D39:.+]] = arith.muli %[[ARG3]], %[[C16]] : index + # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index + # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D35:.+]] = vector.load %[[ALLOC]][%[[D31]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D39:.+]] = vector.load %[[D23]][%[[D38]], %[[D28]]] : memref<128x64xf16, strided<[64, 1], + # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D36]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D41:.+]] = vector.load %[[ALLOC_0]][%[[D40]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D42:.+]] = amdgpu.mfma %[[D35]] * %[[D41]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: scf.yield %[[D42]] : vector<4xf32> + # CHECK: scf.yield %[[D45]] : vector<4xf32> # CHECK: } - # CHECK: %[[D1:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D19:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D3:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D4:.+]] = arith.divsi %[[D3]], %[[C16]] : index - # CHECK: %[[D5:.+]] = arith.muli %[[D4]], %[[C4]] : index - # CHECK: %[[D6:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[D6]], %[[C16]] : index - # CHECK: %[[D8:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D9:.+]] = arith.addi %[[D8]], %[[D7]] : index - # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[D5]] : index - # CHECK: %[[D11:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D12:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D13:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.addi %[[D13]], %[[D12]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D11]] : index - # CHECK: vector.store %[[D1]], %[[D2]][%[[D10]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D20:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D21:.+]] = arith.divsi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = arith.muli %[[D21]], %[[C4]] : index + # CHECK: %[[D23:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D24:.+]] = arith.muli %[[D23]], %[[C16]] : index + # CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index + # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D22]] : index + # CHECK: %[[D28:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D29:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index + # CHECK: %[[D32:.+]] = arith.addi %[[D31]], %[[D28]] : index + # CHECK: vector.store %[[D18]], %[[D19]][%[[D27]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D16:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D33:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D17:.+]] = arith.addi %[[D10]], %[[C1]] : index - # CHECK: vector.store %[[D16]], %[[D2]][%[[D17]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D34:.+]] = arith.addi %[[D27]], %[[C1]] : index + # CHECK: vector.store %[[D33]], %[[D19]][%[[D34]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D35:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D19:.+]] = arith.addi %[[D10]], %[[C2]] : index - # CHECK: vector.store %[[D18]], %[[D2]][%[[D19]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D36:.+]] = arith.addi %[[D27]], %[[C2]] : index + # CHECK: vector.store %[[D35]], %[[D19]][%[[D36]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D37:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = arith.addi %[[D10]], %[[C3]] : index - # CHECK: vector.store %[[D20]], %[[D2]][%[[D21]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D38:.+]] = arith.addi %[[D27]], %[[C3]] : index + # CHECK: vector.store %[[D37]], %[[D19]][%[[D38]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> # CHECK: return +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + }, + canonicalize=True, + schedule=True, + use_scheduling_barriers=True, + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, 128, dtype=torch.float32) + print(gemm_pipelined(a, b, c).module_op) + + # CHECK: func.func @gemm_pipelined + # CHECK-COUNT-2: vector.load + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-10: vector.load + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-1: scf.for + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-6: vector.load + # CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" + # CHECK-COUNT-4: vector.load + # CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" + # CHECK-COUNT-1: scf.yield + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-8: vector.load + # CHECK-COUNT-8: amdgpu.mfma + + +# This test is used to check two things +# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works +# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape). +@run_test +def test_gemm_and_reduce(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.f16], + d: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(K, init_args=[init_max, c_reg]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], acc: tkl.Register[M, N, tkl.f32] + ) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + partial_max = tkw.max(a_reg, partial_max, dim=K) + acc = tkw.mma(a_reg, b_reg, acc) + return partial_max, acc + + res_max, res_mm = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_mm, d, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, dtype=torch.float16) + d = torch.zeros(64, 128, dtype=torch.float32) + print(gemm(a, b, c, d).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + + # Tile Reduction Loop + # Note: Shape is 32x20 instead of 32x16 because of padding to avoid bank conflicts + # CHECK: %{{.*}}:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}}) + # CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space>, vector<4xf16> + # CHECK-COUNT-6: gpu.shuffle xor + # CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + # CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]] + # CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32> + + @run_test def test_add_float(): constraints: list[tkw.Constraint] = [ @@ -756,6 +991,94 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# This test is to ensure that the propagation of indexing_dims between reduction and operations +# outside the reduction is working properly. +@run_test +def test_reduction_and_elemwise(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(N, init_args=[init_max]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + return partial_max + + result = repeat + repeat + tkw.write(result, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT]], %[[ACC1:.+]] = %[[INIT]]) -> (vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_1:.+]] = arith.maximumf %[[ACC1]], %{{.*}} + + # CHECK: scf.yield %[[ACC_REDUCE_0]], %[[ACC_REDUCE_1]] : vector<1xf16>, vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_0:.+]] = arith.addf %[[TILED]]#0, %[[TILED]]#0 : vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_1:.+]] = arith.addf %[[TILED]]#1, %[[TILED]]#1 : vector<1xf16> + # CHECK: vector.store %[[POST_TILE_ELEMWISE_0:.+]], %{{.*}} + # CHECK: vector.store %[[POST_TILE_ELEMWISE_1:.+]], %{{.*}} + + @run_test def test_tiled_reduce_max(): M = tkl.sym.M @@ -851,6 +1174,183 @@ def repeat( # CHECK: scf.yield %[[ACC_REDUCE]] : vector<1xf16> +# This test is to ensure that the we can handle multiple IV in reduction properly. +@run_test +def test_multiple_reduction_iv(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + d: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + init_sum = tkl.Register[M, tkl.f16](0) + + @tkw.reduction(N, init_args=[init_max, init_sum]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + partial_sum: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + partial_sum = tkw.sum(lhs, partial_sum, dim=N) + return partial_max, partial_sum + + res_max, res_sum = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_sum, d, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + d = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT_MAX:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + # CHECK-DAG: %[[INIT_SUM:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]]) + # CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}} + + # 1st Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 1st Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}} + + # 2nd Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 2nd Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}} + + # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] + + +@run_test +def test_broadcast_add(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + lhs = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + rhs = tkw.read(b, elements_per_thread=1) + res = lhs + rhs + tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 128) + a = torch.ones(shape, dtype=torch.float16) + b = torch.ones(shape[0], dtype=torch.float16) + c = torch.zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + LOAD_ELEMS_PER_THREAD: 2, + STORE_ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + run=False, + run_config=config, + ): + print(test(a, b, c).module_op) + # CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %{{.+}}: !stream.binding) + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + + # Slicing LHS + # CHECK: %[[LHS:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<256x128xf16 + # CHECK: %[[LHS_0:.+]] = vector.load %[[LHS]][%[[X_SLICE_0:.+]], %[[Y_SLICE:.+]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16> + # CHECK: %[[X_SLICE_1:.+]] = arith.addi %[[X_SLICE_0]], %c1 : index + # CHECK: %[[LHS_1:.+]] = vector.load %[[LHS]][%[[X_SLICE_1]], %[[Y_SLICE]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16> + + # Slicing RHS + # CHECK: %[[RHS:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<256xf16 + # CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> + # CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> + + # 1st Broadcast-ADD RHS + # CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16> + # CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16> + # CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16> + + # 2nd Broadcast-ADD RHS + # CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16> + # CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16> + # CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16> + + @run_test def test_binary_lowerings(): constraints: list[tkw.Constraint] = [ diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index a20965f3..efcdd582 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -2,13 +2,13 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test, print_trace import sympy # Input sizes @@ -243,31 +243,31 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 @@ -305,13 +305,13 @@ def test_gemm(): # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0) # CHECK-NEXT: %mma_0_1_1 # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0) - # CHECK-NEXT: return [mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1] + # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] # Custom format: # CHECK-NEXT: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -346,7 +346,7 @@ def test_gemm(): # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: rhs=read_0_1_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: acc=mma_0_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],)) + # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) # CHECK-NEXT: ----- @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index d49ee3b2..f0149b70 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -2,22 +2,22 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) -from shark_turbine.kernel.wave.index_sequence_analysis import ( +from iree.turbine.kernel.wave.index_sequence_analysis import ( partition_strided_operators, ) @@ -98,9 +98,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 @@ -182,13 +182,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -210,16 +210,16 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_4, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_5, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_6, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_7, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_8, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 32}) @@ -234,22 +234,22 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_12, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_13, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_14, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_15, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -303,9 +303,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 310e9ef4..329a9ccf 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -3,21 +3,21 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) @@ -103,9 +103,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 @@ -131,13 +131,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -146,19 +146,19 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -215,9 +215,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 01db88cc..c3836f4f 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -2,16 +2,16 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: @@ -74,7 +74,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -123,7 +123,7 @@ def test_read_write_equal_sizes_different_address_spaces(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -181,9 +181,9 @@ def test_gemm(): # CHECK-NEXT: %c # CHECK-NEXT: %register # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-NEXT: %write # CHECK-SAME: (%reduction, %c, 4, None) diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py new file mode 100644 index 00000000..55de465e --- /dev/null +++ b/lit_tests/kernel/wave/scheduling.py @@ -0,0 +1,237 @@ +# RUN: python %s | FileCheck %s + +import logging +import unittest +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_subgraph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( + apply_shared_memory_indexing_corrections, +) +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K + +# Workgroup tile sizes +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K + +# Address space +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + + +@tkw.wave_trace_only() +def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + with tk.gen.TestLaunchContext( + { + M: 128, + N: 256, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 2, + GLOBAL_MEMORY_UNITS: 2, + MMA_UNITS: 2, + } + ): + trace: CapturedTrace = gemm_pipelined() + IndexingContext.current().finalize() + promote_placeholders(trace, constraints) + hoist_allocs(trace) + expand_graph(trace, constraints) + minimize_global_loads(trace, constraints) + apply_shared_memory_indexing_corrections(trace, constraints) + schedule_graph(trace, constraints, True) + + print_subgraph(trace, "pipelined_reduction", False) + # CHECK: %acc_0_0_0 + # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 + # CHECK-NEXT: %rotating_reg_0 + # CHECK-NEXT: %rotating_reg_1 + # CHECK-NEXT: %rotating_reg_2 + # CHECK-NEXT: %rotating_reg_3 + # CHECK-NEXT: %rotating_reg_4 + # CHECK-NEXT: %rotating_reg_5 + # CHECK-NEXT: %rotating_reg_6 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6) + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2, Operation.READ_GLOBAL: 2}, 0) + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0) + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.READ_SHARED: 2, Operation.MMA: 2}, 0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0) + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 2, Operation.WRITE_SHARED: 2}, 0) + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 2, Operation.READ_SHARED: 2}, 0) + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1) + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2}, 0) + # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0] + + print_subgraph(trace, "region_1", False) + # CHECK: %a + # CHECK-NEXT: %b + # CHECK-NEXT: %c + # CHECK-NEXT: %register_0_0_0 + # CHECK-NEXT: %register_1_1_0 + # CHECK-NEXT: %register_1_0_0 + # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %allocate + # CHECK-NEXT: %allocate_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %read_6 + # CHECK-NEXT: %read_7 + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0) + # CHECK-NEXT: %write_4 + # CHECK-NEXT: %write_5 + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_2 + # CHECK-NEXT: %read_shared_0_1_3 + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0) + # CHECK-NEXT: %read_shared_0_0_4 + # CHECK-NEXT: %read_shared_0_0_5 + # CHECK-NEXT: %reduction_1 + # CHECK-NEXT: %getresult_1_1_0 + # CHECK-NEXT: %getresult_1_0_0 + # CHECK-NEXT: %getresult_0_1_0 + # CHECK-NEXT: %getresult_0_0_0 + # CHECK-NEXT: %get_result_4 + # CHECK-NEXT: %get_result_5 + # CHECK-NEXT: %get_result_6 + # CHECK-NEXT: %get_result_7 + # CHECK-NEXT: %get_result_8 + # CHECK-NEXT: %get_result_9 + # CHECK-NEXT: %get_result_10 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10) + # CHECK-NEXT: %read_shared_0_0_6 + # CHECK-NEXT: %read_shared_0_0_7 + # CHECK-NEXT: %read_shared_1_0_2 + # CHECK-NEXT: %read_shared_1_0_3 + # CHECK-NEXT: %mma_0_0_2 + # CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0) + # CHECK-NEXT: %mma_0_1_2 + # CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0) + # CHECK-NEXT: %mma_0_0_3 + # CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2) + # CHECK-NEXT: %mma_1_0_2 + # CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0) + # CHECK-NEXT: %mma_1_0_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2) + # CHECK-NEXT: %mma_0_1_3 + # CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2) + # CHECK-NEXT: %mma_1_1_2 + # CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1) + # CHECK-NEXT: %mma_1_1_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2) + # CHECK-NEXT: %write_0_0_0 + # CHECK-NEXT: %write_1_1_0 + # CHECK-NEXT: %write_1_0_0 + # CHECK-NEXT: %write_0_1_0 + # CHECK-NEXT: return None + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/lit_tests/kernel/wave/tracing.py b/lit_tests/kernel/wave/tracing.py index 283b6436..f6c9306b 100644 --- a/lit_tests/kernel/wave/tracing.py +++ b/lit_tests/kernel/wave/tracing.py @@ -1,11 +1,11 @@ # RUN: python %s | FileCheck %s from typing import Callable -from shark_turbine.kernel._support.tracing import CapturedTrace -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.ops.wave_ops import get_custom, Read, Write -from shark_turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel._support.tracing import CapturedTrace +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.ops.wave_ops import get_custom, Read, Write +from iree.turbine.kernel.wave.utils import run_test, print_trace M = tkl.sym.M N = tkl.sym.N diff --git a/lit_tests/lit.cfg.py b/lit_tests/lit.cfg.py index 5b40c7eb..614383fc 100644 --- a/lit_tests/lit.cfg.py +++ b/lit_tests/lit.cfg.py @@ -7,7 +7,7 @@ import lit.llvm -from shark_turbine.support.logging import get_logger +from iree.turbine.support.logging import get_logger logger = get_logger("turbine.lit_tests") diff --git a/mypy.ini b/mypy.ini index 29c35b65..5638faef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ explicit_package_bases = True mypy_path = $MYPY_CONFIG_FILE_DIR -packages = shark_turbine +packages = iree.turbine # Missing typing stubs for iree.compiler. [mypy-iree.compiler.*] @@ -13,11 +13,15 @@ ignore_missing_imports = True ignore_missing_imports = True # fx_importer needs to be fixed upstream. -[mypy-shark_turbine.importers.fx_importer.*] +[mypy-iree.turbine.importers.fx_importer.*] ignore_errors = True # TODO: Fix all typing errors in TK. -[mypy-shark_turbine.kernel.*] +[mypy-iree.turbine.kernel.*] +ignore_errors = True + +# TODO: Some pytorch errors. +[mypy-iree.turbine.tools.interpreter] ignore_errors = True # Ignore all typing errors in tests/tools (these depend on TK). diff --git a/setup.py b/setup.py index 63a028cb..c73c3532 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,7 @@ REPO_DIR = THIS_DIR VERSION_INFO_FILE = os.path.join(REPO_DIR, "version_info.json") -# Transitional as we migrate from shark-turbine -> iree-turbine. -TURBINE_PACKAGE_NAME = os.getenv("TURBINE_PACKAGE_NAME", "shark-turbine") +TURBINE_PACKAGE_NAME = "iree-turbine" with open( os.path.join( @@ -81,12 +80,12 @@ def initialize_options(self): setup( name=f"{TURBINE_PACKAGE_NAME}", version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK Turbine Machine Learning Deployment Tools", + author="IREE Authors", + author_email="iree-technical-discussion@lists.lfaidata.foundation", + description="IREE Turbine Machine Learning Deployment Tools", long_description=README, long_description_content_type="text/markdown", - url="https://github.com/nod-ai/SHARK-Turbine", + url="https://github.com/iree-org/iree-turbine/", license="Apache-2.0", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -96,11 +95,11 @@ def initialize_options(self): packages=packages, include_package_data=True, package_data={ - "shark_turbine": ["ops/templates/*.mlir"], # Include MLIR templates + "iree.turbine": ["ops/templates/*.mlir"], # Include MLIR templates }, entry_points={ "torch_dynamo_backends": [ - "turbine_cpu = shark_turbine.dynamo.backends.cpu:backend", + "turbine_cpu = iree.turbine.dynamo.backends.cpu:backend", ], }, install_requires=[ diff --git a/shark_turbine/__init__.py b/shark_turbine/__init__.py new file mode 100644 index 00000000..f1e1c318 --- /dev/null +++ b/shark_turbine/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +# Temp redirect from old shark_turbine namespace. +from iree.turbine import aot +from iree.turbine import dynamo +from iree.turbine import kernel +from iree.turbine import ops +from iree.turbine import runtime diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl b/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl deleted file mode 100644 index 1295a6d4..00000000 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl +++ /dev/null @@ -1,54 +0,0 @@ -\begin{thebibliography}{8} -\providecommand{\natexlab}[1]{#1} -\providecommand{\url}[1]{\texttt{#1}} -\expandafter\ifx\csname urlstyle\endcsname\relax - \providecommand{\doi}[1]{doi: #1}\else - \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi - -\bibitem[Author(2018)]{anonymous} -Author, N.~N. -\newblock Suppressed for anonymity, 2018. - -\bibitem[Duda et~al.(2000)Duda, Hart, and Stork]{DudaHart2nd} -Duda, R.~O., Hart, P.~E., and Stork, D.~G. -\newblock \emph{Pattern Classification}. -\newblock John Wiley and Sons, 2nd edition, 2000. - -\bibitem[Kearns(1989)]{kearns89} -Kearns, M.~J. -\newblock \emph{Computational Complexity of Machine Learning}. -\newblock PhD thesis, Department of Computer Science, Harvard University, 1989. - -\bibitem[Langley(2000)]{langley00} -Langley, P. -\newblock Crafting papers on machine learning. -\newblock In Langley, P. (ed.), \emph{Proceedings of the 17th International - Conference on Machine Learning (ICML 2000)}, pp.\ 1207--1216, Stanford, CA, - 2000. Morgan Kaufmann. - -\bibitem[Michalski et~al.(1983)Michalski, Carbonell, and - Mitchell]{MachineLearningI} -Michalski, R.~S., Carbonell, J.~G., and Mitchell, T.~M. (eds.). -\newblock \emph{Machine Learning: An Artificial Intelligence Approach, Vol. I}. -\newblock Tioga, Palo Alto, CA, 1983. - -\bibitem[Mitchell(1980)]{mitchell80} -Mitchell, T.~M. -\newblock The need for biases in learning generalizations. -\newblock Technical report, Computer Science Department, Rutgers University, - New Brunswick, MA, 1980. - -\bibitem[Newell \& Rosenbloom(1981)Newell and Rosenbloom]{Newell81} -Newell, A. and Rosenbloom, P.~S. -\newblock Mechanisms of skill acquisition and the law of practice. -\newblock In Anderson, J.~R. (ed.), \emph{Cognitive Skills and Their - Acquisition}, chapter~1, pp.\ 1--51. Lawrence Erlbaum Associates, Inc., - Hillsdale, NJ, 1981. - -\bibitem[Samuel(1959)]{Samuel59} -Samuel, A.~L. -\newblock Some studies in machine learning using the game of checkers. -\newblock \emph{IBM Journal of Research and Development}, 3\penalty0 - (3):\penalty0 211--229, 1959. - -\end{thebibliography} diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib b/shark_turbine/kernel/wave/docs/mlsys/tkw.bib deleted file mode 100644 index 6bd0e3ee..00000000 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib +++ /dev/null @@ -1,75 +0,0 @@ -@inproceedings{langley00, - author = {P. Langley}, - title = {Crafting Papers on Machine Learning}, - year = {2000}, - pages = {1207--1216}, - editor = {Pat Langley}, - booktitle = {Proceedings of the 17th International Conference - on Machine Learning (ICML 2000)}, - address = {Stanford, CA}, - publisher = {Morgan Kaufmann} -} - -@TechReport{mitchell80, - author = "T. M. Mitchell", - title = "The Need for Biases in Learning Generalizations", - institution = "Computer Science Department, Rutgers University", - year = "1980", - address = "New Brunswick, MA", -} - -@phdthesis{kearns89, - author = {M. J. Kearns}, - title = {Computational Complexity of Machine Learning}, - school = {Department of Computer Science, Harvard University}, - year = {1989} -} - -@Book{MachineLearningI, - editor = "R. S. Michalski and J. G. Carbonell and T. - M. Mitchell", - title = "Machine Learning: An Artificial Intelligence - Approach, Vol. I", - publisher = "Tioga", - year = "1983", - address = "Palo Alto, CA" -} - -@Book{DudaHart2nd, - author = "R. O. Duda and P. E. Hart and D. G. Stork", - title = "Pattern Classification", - publisher = "John Wiley and Sons", - edition = "2nd", - year = "2000" -} - -@misc{anonymous, - title= {Suppressed for Anonymity}, - author= {Author, N. N.}, - year= {2018} -} - -@InCollection{Newell81, - author = "A. Newell and P. S. Rosenbloom", - title = "Mechanisms of Skill Acquisition and the Law of - Practice", - booktitle = "Cognitive Skills and Their Acquisition", - pages = "1--51", - publisher = "Lawrence Erlbaum Associates, Inc.", - year = "1981", - editor = "J. R. Anderson", - chapter = "1", - address = "Hillsdale, NJ" -} - - -@Article{Samuel59, - author = "A. L. Samuel", - title = "Some Studies in Machine Learning Using the Game of - Checkers", - journal = "IBM Journal of Research and Development", - year = "1959", - volume = "3", - number = "3", - pages = "211--229" -} diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg b/shark_turbine/kernel/wave/docs/mlsys/tkw.blg deleted file mode 100644 index ef864a1b..00000000 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg +++ /dev/null @@ -1,46 +0,0 @@ -This is BibTeX, Version 0.99d (TeX Live 2020) -Capacity: max_strings=200000, hash_size=200000, hash_prime=170003 -The top-level auxiliary file: example_paper.aux -The style file: mlsys2024.bst -Database file #1: example_paper.bib -You've used 8 entries, - 2773 wiz_defined-function locations, - 645 strings with 5916 characters, -and the built_in function-call counts, 3248 in all, are: -= -- 293 -> -- 140 -< -- 9 -+ -- 49 -- -- 41 -* -- 223 -:= -- 507 -add.period$ -- 25 -call.type$ -- 8 -change.case$ -- 36 -chr.to.int$ -- 8 -cite$ -- 16 -duplicate$ -- 174 -empty$ -- 295 -format.name$ -- 51 -if$ -- 691 -int.to.chr$ -- 1 -int.to.str$ -- 1 -missing$ -- 6 -newline$ -- 47 -num.names$ -- 37 -pop$ -- 81 -preamble$ -- 1 -purify$ -- 29 -quote$ -- 0 -skip$ -- 127 -stack$ -- 0 -substring$ -- 100 -swap$ -- 24 -text.length$ -- 3 -text.prefix$ -- 0 -top$ -- 0 -type$ -- 78 -warning$ -- 0 -while$ -- 34 -width$ -- 0 -write$ -- 113 diff --git a/shark_turbine/kernel/wave/register_analysis.py b/shark_turbine/kernel/wave/register_analysis.py deleted file mode 100644 index cbd42fbe..00000000 --- a/shark_turbine/kernel/wave/register_analysis.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from ..wave.constraints import Constraint -from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr -from .._support.tracing import CapturedTrace -from ...support.logging import get_logger -from ..ops.wave_ops import get_custom, NewRegister, CustomOp, MMA, Reduction, ReduceOp -from .utils import get_hardware_vector_map -import torch.fx as fx - -logger = get_logger("turbine.wave.register_analysis") - - -def set_register_shape( - trace: CapturedTrace, custom: CustomOp, vector_map: dict[IndexSymbol, int] -) -> None: - for custom_user in custom.users: - if isinstance(custom_user, MMA): - arg_index = custom_user.fx_node.args.index(custom.fx_node) - get_thread_shape = lambda index: max(x.size for x in index.values()) - match arg_index: - case 0: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.lhs_index - ) - case 1: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.rhs_index - ) - case 2: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.acc_index - ) - break - - elif isinstance(custom_user, Reduction): - idx = custom_user.init_args.index(custom.fx_node) - iter_arg = get_custom( - custom_user.iter_args(trace.get_subgraph(custom_user.subgraph_name))[ - idx - ] - ) - set_register_shape(trace, iter_arg, vector_map) - custom.fx_node.thread_shape = iter_arg.fx_node.thread_shape - break - elif isinstance(custom_user, ReduceOp): - # Check that dim is non-reduction && in hw_constraint.vector_shape. - is_parallel_dim = lambda dim: dim != custom_user.dim and dim in vector_map - # TODO: Modify num_reduction_dims once we add support for multi-dim reduction. - num_reduction_dims = 1 - register_shape = [ - vector_map[dim] - for dim in custom_user.type.symbolic_shape - if is_parallel_dim(dim) - ] - expected_result_rank = ( - len(custom_user.type.symbolic_shape) - custom_user.num_reduction_dims - ) - # If rank do not match => some dims not found in hw_constraint.vector_shape. - if len(register_shape) != expected_result_rank: - raise NotImplementedError( - "NYI: Handling of dim not in vector_shapes during register analysis." - ) - non_unit_dims = sum(1 for dim in register_shape if dim > 1) - if non_unit_dims > 1: - raise NotImplementedError( - "NYI: Currently Register semantic only support 0-D vector." - ) - custom.fx_node.thread_shape = max(register_shape) - else: - raise NotImplementedError( - f"Register shape propagation not implemented for {custom_user}" - ) - - -def determine_register_shape( - trace: CapturedTrace | fx.Graph, constraints: list[Constraint] -) -> None: - """ - Each register op is annotated with the wave shape of the register. This - function determines the thread shape of the register based on the uses - of the register in the graph. - """ - register_nodes = trace.walk(lambda node: isinstance(get_custom(node), NewRegister)) - if not register_nodes: - return - vector_map = get_hardware_vector_map(constraints) - for node in register_nodes: - set_register_shape(trace, get_custom(node), vector_map) diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py deleted file mode 100644 index 924c36bd..00000000 --- a/shark_turbine/kernel/wave/visualization.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -graphviz_disabled = False -try: - import pygraphviz as pgv -except: - graphviz_disabled = True -from torch import fx -from .scheduling.graph_utils import Edge -import math - - -def number_nodes(graph: fx.Graph) -> dict[int, int]: - return {id(node): i for i, node in enumerate(graph.nodes)} - - -def visualize_graph(graph: fx.Graph, file_name: str): - if graphviz_disabled: - raise ImportError("pygraphviz not installed, cannot visualize graph") - node_numbering = number_nodes(graph) - G = pgv.AGraph(directed=True) - for node in graph.nodes: - G.add_node(node_numbering[id(node)], label=node.name) - for node in graph.nodes: - for user in node.users.keys(): - G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) - G.layout(prog="dot") - G.draw(file_name) - - -def visualize_edges(edges: list[Edge], file_name: str): - if graphviz_disabled: - raise ImportError("pygraphviz not installed, cannot visualize graph") - G = pgv.AGraph(directed=True) - node_map = {} - count = 0 - for edge in edges: - if edge._from not in node_map: - node_map[edge._from] = count - count += 1 - G.add_node(node_map[edge._from], label=f"{edge._from}") - if edge._to not in node_map: - node_map[edge._to] = count - count += 1 - G.add_node(node_map[edge._to], label=f"{edge._to}") - G.add_edge( - node_map[edge._from], - node_map[edge._to], - label=f"({edge.weight.iteration_difference}, {edge.weight.delay})", - ) - G.layout(prog="dot") - G.draw(file_name) - - -def visualize_schedule( - schedule: dict[fx.Graph, int], initiation_interval: int, file_name: str -): - import pandas as pd - - max_time = max(schedule.values()) - max_stage = math.ceil(max_time / initiation_interval) - rows = max_time + 1 + max_stage * initiation_interval - cols = max_stage - - table = [["" for _ in range(cols)] for _ in range(rows)] - for stage in range(max_stage): - for key, value in schedule.items(): - table[value + stage * initiation_interval][stage] += f"{key}
" - - df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)]) - s = df.style.set_properties(**{"text-align": "center"}) - s = s.set_table_styles( - [ - {"selector": "", "props": [("border", "1px solid grey")]}, - {"selector": "tbody td", "props": [("border", "1px solid grey")]}, - {"selector": "th", "props": [("border", "1px solid grey")]}, - {"selector": "th", "props": [("min-width", "300px")]}, - ] - ) - output = s.apply( - lambda x: [ - ( - "background: lightgreen" - if int(x.name) >= (max_stage - 1) * initiation_interval - and int(x.name) < max_stage * initiation_interval - else "" - ) - for _ in x - ], - axis=1, - ).to_html() - with open(f"{file_name}", "w") as f: - f.write(output) diff --git a/tests/aot/api_test.py b/tests/aot/api_test.py index e038704d..0d5f4215 100644 --- a/tests/aot/api_test.py +++ b/tests/aot/api_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn diff --git a/tests/aot/args_test.py b/tests/aot/args_test.py index d7ec458d..efbce489 100644 --- a/tests/aot/args_test.py +++ b/tests/aot/args_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class ArgsTest(unittest.TestCase): diff --git a/tests/aot/compiled_exported_program_test.py b/tests/aot/compiled_exported_program_test.py index baaeb9bb..6b86b185 100644 --- a/tests/aot/compiled_exported_program_test.py +++ b/tests/aot/compiled_exported_program_test.py @@ -14,8 +14,8 @@ Context, ) -from shark_turbine.aot import * -from shark_turbine.aot.builtins import * +from iree.turbine.aot import * +from iree.turbine.aot.builtins import * class TorchExportTests(unittest.TestCase): diff --git a/tests/aot/decompositions_test.py b/tests/aot/decompositions_test.py index baf96604..f186cf12 100644 --- a/tests/aot/decompositions_test.py +++ b/tests/aot/decompositions_test.py @@ -9,7 +9,7 @@ import logging import unittest -from shark_turbine.aot import decompositions +from iree.turbine.aot import decompositions class DecompTest(unittest.TestCase): diff --git a/tests/aot/dynamic_shape_export_test.py b/tests/aot/dynamic_shape_export_test.py new file mode 100644 index 00000000..8f53df27 --- /dev/null +++ b/tests/aot/dynamic_shape_export_test.py @@ -0,0 +1,50 @@ +import torch + +import pytest + +from iree.turbine.aot import * + + +@pytest.mark.parametrize( + "import_symbolic_shape_expressions", + [ + True, + False, + ], +) +def test_exported_program_dynamic_shapes(import_symbolic_shape_expressions): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) + + example_args = (torch.randn(32, 64), torch.randn(32, 128)) + + # Create a dynamic batch size + batch = torch.export.Dim("batch") + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + output = export( + M(), + args=example_args, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + output.print_readable() + asm = str(output.mlir_module) + + if import_symbolic_shape_expressions: + assert "bind_symbolic_shape" in asm + else: + assert "bind_symbolic_shape" not in asm diff --git a/tests/aot/functionalize_test.py b/tests/aot/functionalize_test.py index 0cad8e93..2a2ea309 100644 --- a/tests/aot/functionalize_test.py +++ b/tests/aot/functionalize_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class FunctionalizeTests(unittest.TestCase): diff --git a/tests/aot/fx_programs_test.py b/tests/aot/fx_programs_test.py index c54f1851..f2c70456 100644 --- a/tests/aot/fx_programs_test.py +++ b/tests/aot/fx_programs_test.py @@ -10,7 +10,7 @@ import pytest import torch -from shark_turbine.aot import ( +from iree.turbine.aot import ( FxPrograms, FxProgramsBuilder, ) diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 26bab1a6..7a250531 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self): export_global(AbstractF32, external=True, uninitialized=True) +class SimpleCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + + return cache + + +class ReadWriteReadCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + cache_value_0 = self.cache[2].clone() + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + cache_value_1 = cache[2].clone() + return cache, cache_value_0, cache_value_1 + + +class BufferTest(unittest.TestCase): + def testMutableBuffer(self): + max_size = 10 + simple_cache = SimpleCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + def testReadWriteReadMutableBuffer(self): + max_size = 10 + simple_cache = ReadWriteReadCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index 9f479921..251c8f12 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class CompiledModuleAPI(unittest.TestCase): diff --git a/tests/aot/jittable_test.py b/tests/aot/jittable_test.py index 9c87fb11..d19988bc 100644 --- a/tests/aot/jittable_test.py +++ b/tests/aot/jittable_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class JittableTests(unittest.TestCase): diff --git a/tests/aot/non_strict_export_test.py b/tests/aot/non_strict_export_test.py index ece961dc..2ed1b603 100644 --- a/tests/aot/non_strict_export_test.py +++ b/tests/aot/non_strict_export_test.py @@ -3,7 +3,7 @@ from torch import nn import torch -from shark_turbine.aot import * +from iree.turbine.aot import * logger = logging.getLogger(__file__) diff --git a/tests/aot/params_test.py b/tests/aot/params_test.py index a1d64206..895cb2b9 100644 --- a/tests/aot/params_test.py +++ b/tests/aot/params_test.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from shark_turbine.aot import ( +from iree.turbine.aot import ( export, externalize_module_parameters, save_module_parameters, diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ccbd0088 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runperf", action="store_true", default=False, help="run performance tests" + ) + parser.addoption( + "--dump-perf-files-path", + action="store", + default=None, + help="save performance info into provided directory, filename based on current test name", + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "perf_only: performance test, runs only with '--runperf'" + ) + config.addinivalue_line( + "markers", "validate_only: validation test, never runs with '--runperf'" + ) + + +def _has_marker(item, marker): + return next(item.iter_markers(marker), None) is not None + + +def pytest_collection_modifyitems(config, items): + run_perf = config.getoption("--runperf") + for item in items: + is_validate_only = _has_marker(item, "validate_only") + is_perf_only = _has_marker(item, "perf_only") + if run_perf: + if not is_perf_only or is_validate_only: + item.add_marker(pytest.mark.skip("skip non-perf test")) + else: + if is_perf_only: + item.add_marker(pytest.mark.skip("skip perf test")) diff --git a/tests/dynamo/importer_dynamic_test.py b/tests/dynamo/importer_dynamic_test.py index 72ff4f82..682aa140 100644 --- a/tests/dynamo/importer_dynamic_test.py +++ b/tests/dynamo/importer_dynamic_test.py @@ -14,7 +14,7 @@ # from torch._export.constraints import constrain_as_size, constrain_as_value from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index fcd40660..0562c071 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -12,8 +12,8 @@ import torch # Public API imports. -from shark_turbine.runtime import Device -from shark_turbine.dynamo import TurbineMode, DeviceTensor +from iree.turbine.runtime import Device +from iree.turbine.dynamo import TurbineMode, DeviceTensor class TensorTest(unittest.TestCase): diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index dfc3de25..70375efb 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -12,7 +12,7 @@ Type as IrType, ) -import shark_turbine.dynamo.type_conversion as tc +import iree.turbine.dynamo.type_conversion as tc class TypeConversionTest(unittest.TestCase): @@ -32,6 +32,7 @@ def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") + self._compareNative("!torch.vtensor<[],complex>", "tensor>") def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: diff --git a/tests/generated/evaluate.py b/tests/generated/evaluate.py index 3184930d..a971e23c 100644 --- a/tests/generated/evaluate.py +++ b/tests/generated/evaluate.py @@ -2,7 +2,7 @@ import logging from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/kernel/aot_kernel_test.py b/tests/kernel/aot_kernel_test.py index 690e366a..16363048 100644 --- a/tests/kernel/aot_kernel_test.py +++ b/tests/kernel/aot_kernel_test.py @@ -8,9 +8,9 @@ import unittest import torch -from shark_turbine.aot import export -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +from iree.turbine.aot import export +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl def export_softmax_kernel(): diff --git a/tests/kernel/arith_test.py b/tests/kernel/arith_test.py index 1631454c..ce9e659e 100644 --- a/tests/kernel/arith_test.py +++ b/tests/kernel/arith_test.py @@ -8,15 +8,15 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/compiler/utils_test.py b/tests/kernel/compiler/utils_test.py index be084613..6f2db310 100644 --- a/tests/kernel/compiler/utils_test.py +++ b/tests/kernel/compiler/utils_test.py @@ -1,9 +1,9 @@ import logging import pytest import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel._support.indexing import IndexSymbol, IndexingContext -from shark_turbine.kernel.compiler.utils import strides_from_symbolic_shape +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel._support.indexing import IndexSymbol, IndexingContext +from iree.turbine.kernel.compiler.utils import strides_from_symbolic_shape class UtilsTest(unittest.TestCase): diff --git a/tests/kernel/dispatch_codegen_test.py b/tests/kernel/dispatch_codegen_test.py index be17a86d..b76ed2e1 100644 --- a/tests/kernel/dispatch_codegen_test.py +++ b/tests/kernel/dispatch_codegen_test.py @@ -8,16 +8,16 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, dispatch_codegen, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/fused_attention_test.py b/tests/kernel/fused_attention_test.py index 89883780..abc9d7ad 100644 --- a/tests/kernel/fused_attention_test.py +++ b/tests/kernel/fused_attention_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl BATCH = tkl.sym.BATCH N_HEADS = tkl.sym.N_HEADS diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index 8bc27c50..677bbf09 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -9,8 +9,8 @@ import torch -from shark_turbine.kernel._support.indexing import * -from shark_turbine.kernel.lang import * +from iree.turbine.kernel._support.indexing import * +from iree.turbine.kernel.lang import * M = sym.M N = sym.N diff --git a/tests/kernel/simple_kernel_test.py b/tests/kernel/simple_kernel_test.py index 87cf3ed2..bffe723c 100644 --- a/tests/kernel/simple_kernel_test.py +++ b/tests/kernel/simple_kernel_test.py @@ -9,8 +9,8 @@ import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K @@ -37,9 +37,9 @@ def iota_kernel(out: tk.lang.KernelBuffer[M, tkl.index]): print(iota_kernel._trace().region_graph) # Prints: # .graph(): - # %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_global_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) + # %out : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_global_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) # return None def testSoftmax(self): @@ -76,17 +76,17 @@ def softmax(x): print(softmax_kernel._trace().region_graph) # Prints: # graph(): - # %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] - # %output : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %input_1 : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] + # %output : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) # %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%input_1, (%program_id, slice(None, None, None))), kwargs = {}) # %max_1 : [num_users=1] = call_function[target=torch.max](args = (%getitem,), kwargs = {}) # %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem, %max_1), kwargs = {}) # %exp : [num_users=2] = call_function[target=torch.exp](args = (%sub,), kwargs = {}) # %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%exp,), kwargs = {}) # %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%exp, %sum_1), kwargs = {}) - # %program_id_1 : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_kernel_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) + # %program_id_1 : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_kernel_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) # return None diff --git a/tests/kernel/types_test.py b/tests/kernel/types_test.py index 87dc6536..e355db31 100644 --- a/tests/kernel/types_test.py +++ b/tests/kernel/types_test.py @@ -7,7 +7,7 @@ import logging import unittest -from shark_turbine.kernel.lang import ( +from iree.turbine.kernel.lang import ( Index, ) diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index fcd33462..696852c0 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K diff --git a/tests/kernel/wave/constraints_test.py b/tests/kernel/wave/constraints_test.py index 418c3c8b..f2915ac4 100644 --- a/tests/kernel/wave/constraints_test.py +++ b/tests/kernel/wave/constraints_test.py @@ -8,8 +8,8 @@ import pytest import unittest from sympy import ceiling -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.constraints import ( +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.constraints import ( WorkgroupConstraint, get_grid_shape, TilingConstraint, diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 93d9cb6c..bb7cbc25 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -6,32 +6,32 @@ import unittest import logging -from shark_turbine.kernel.wave.scheduling.modulo_scheduling import ( +from iree.turbine.kernel.wave.scheduling.modulo_scheduling import ( ModuloScheduler, EdgeWeight, Edge, ) import torch.fx as fx import numpy as np -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.scheduling.graph_utils import ( +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.scheduling.graph_utils import ( find_strongly_connected_components, find_cycles_in_scc, all_pairs_longest_paths, evaluate_all_pairs_longest_paths, ) -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph -from shark_turbine.kernel.ops.wave_ops import get_custom +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.ops.wave_ops import get_custom class SchedulingTest(unittest.TestCase): diff --git a/tests/kernel/wave/types_test.py b/tests/kernel/wave/types_test.py index d27c4c47..cdb05c3e 100644 --- a/tests/kernel/wave/types_test.py +++ b/tests/kernel/wave/types_test.py @@ -9,9 +9,9 @@ import sympy import unittest -from shark_turbine.kernel.lang import Memory, Register, sym, f16 -from shark_turbine.kernel.lang.wave_types import AddressSpace -from shark_turbine.kernel.lang.kernel_buffer import KernelBufferUsage +from iree.turbine.kernel.lang import Memory, Register, sym, f16 +from iree.turbine.kernel.lang.wave_types import AddressSpace +from iree.turbine.kernel.lang.kernel_buffer import KernelBufferUsage M = sym.M N = sym.N diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py index 17cce11c..ebe6a75f 100644 --- a/tests/kernel/wave/visualization_test.py +++ b/tests/kernel/wave/visualization_test.py @@ -9,15 +9,15 @@ import unittest import os import pytest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import get_custom -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.visualization import visualize_graph +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import get_custom +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.visualization import visualize_graph def run(func: Callable[[], None]) -> Callable[[], None]: diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 4c2f04db..a3effd46 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -4,11 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim -from shark_turbine.kernel.lang.global_symbols import * +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import torch from numpy.testing import assert_allclose, assert_equal import pytest @@ -44,9 +45,19 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: return default_test_shapes +def xfail_unaligned(func): + def wrapper(shape): + if shape[-1] % 2 != 0: + pytest.xfail("Unaligned shape is not expected to work on this test yet.") + func(shape) + + return wrapper + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) -def test_copy(shape): +def test_copy(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -57,7 +68,8 @@ def test_copy(shape): # elements. wave_size = 64 BLOCK_M = 1 - BLOCK_N = sympy.Max(sympy.Min(N, 256), wave_size) + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) ELEMS_PER_THREAD = BLOCK_N / wave_size constraints: list[tkw.Constraint] = [ @@ -92,6 +104,64 @@ def test( }, canonicalize=True, run=True, + run_bench=run_bench, + run_config=config, + ): + test(a, b) + assert_allclose(a, b) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +def test_dynamic_copy(shape, request): + run_bench = request.config.getoption("--runperf") + M = tkl.sym.M + N = tkl.sym.N + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + ELEMS_PER_THREAD = BLOCK_N / wave_size + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + waves_per_block=(1, 1, 1), + vector_shapes={M: BLOCK_M, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + a = torch.randn(shape, dtype=torch.float16) + b = torch.zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + dynamic_symbols=(M, N), + dynamic_symbols_map={M: shape[0], N: shape[1]}, + canonicalize=True, + run=True, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -100,7 +170,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_transpose_read")) -def test_transpose_read(shape): +def test_transpose_read(shape, request): + run_bench = request.config.getoption("--runperf") shape = shape[::-1] M = tkl.sym.M N = tkl.sym.N @@ -149,6 +220,7 @@ def test( }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -157,7 +229,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_transpose_write")) -def test_transpose_write(shape): +def test_transpose_write(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -205,6 +278,7 @@ def test( }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -213,7 +287,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_reduce_sum")) -def test_reduce_sum(shape): +def test_reduce_sum(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N wave_size = 64 @@ -261,6 +336,7 @@ def test( }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): test(a, b, c) @@ -269,13 +345,14 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_tiled_reduce_max")) -def test_tiled_reduce_max(shape): +@xfail_unaligned +def test_toy_online_softmax(shape): M = tkl.sym.M N = tkl.sym.N wave_size = 64 BLOCK_M = 1 BLOCK_N = tkl.sym.BLOCK_N - ELEMS_PER_THREAD = BLOCK_N / wave_size + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE constraints: list[tkw.Constraint] = [ @@ -293,35 +370,44 @@ def test_tiled_reduce_max(shape): @tkw.wave(constraints) def test( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f32], ): - init_max = tkl.Register[M, tkl.f16](-1e6) + init_max = tkl.Register[M, tkl.f32](-1e6) + init_sum = tkl.Register[M, tkl.f32](0) - @tkw.reduction(N, init_args=[init_max]) + @tkw.reduction(N, init_args=[init_max, init_sum]) def repeat( - partial_max: tkl.Register[M, tkl.f16], - ) -> tkl.Register[M, tkl.f16]: + partial_max: tkl.Register[M, tkl.f32], + partial_sum: tkl.Register[M, tkl.f32], + ) -> tkl.Register[M, tkl.f32]: lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) res = lhs * rhs partial_max = tkw.max(res, partial_max, dim=N) - return partial_max + partial_sum = tkw.sum(res, partial_sum, dim=N) + return partial_max, partial_sum - tkw.write(repeat, c, elements_per_thread=1) + res_max, res_sum = repeat + result = res_max / res_sum + tkw.write(result, c, elements_per_thread=1) config = {"backend": "rocm", "device": "hip", "target": "gfx942"} - a = torch.randn(shape, dtype=torch.float16) - b = torch.randn(shape, dtype=torch.float16) - c = torch.zeros((shape[0],), dtype=torch.float16) - ref = torch.max((a * b), dim=-1) + torch.manual_seed(1) + a = torch.randn(shape, dtype=torch.float32) + b = torch.randn(shape, dtype=torch.float32) + c = torch.zeros((shape[0],), dtype=torch.float32) + ref_max = torch.max((a * b), dim=-1).values + ref_sum = torch.sum((a * b), dim=-1) + ref = ref_max / ref_sum with tk.gen.TestLaunchContext( { M: shape[0], N: shape[1], BLOCK_N: min(128, shape[1]), + ELEMS_PER_THREAD: min(128, shape[1]) // wave_size, ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, }, canonicalize=True, @@ -332,11 +418,12 @@ def repeat( # Assert equal does cast to boolean on torch.Tensor # which causes issues, hence we cast to numpy before # checking. - assert_equal(c, ref.values.numpy()) + assert_allclose(ref, c, atol=0.015) @require_e2e -def test_im2col(): +def test_im2col(request): + run_bench = request.config.getoption("--runperf") # TODO: we don't support unaligned access at the moment so all sizes must # be aligned to WG/Wave sizes, c * hw * wf == 8 and number of windows == 64. n, c, h, w = 1, 2, 9, 9 # Image. @@ -430,6 +517,7 @@ def test( }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -437,7 +525,8 @@ def test( @require_e2e -def test_im2col_mma(): +def test_im2col_mma(request): + run_bench = request.config.getoption("--runperf") # igemm without final col2im n, c, h, w = 1, 4, 9, 9 # Image. nf, cf, hf, wf = 64, c, 2, 2 # Filters. @@ -560,20 +649,85 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): gpu_func(x, we, out) assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05) +_igemm_cases = [ + (1, 5, 5, 10, 2, 2, 2, 2), + (2, 5, 5, 3, 2, 2, 1, 1), + (4, 5, 5, 10, 2, 2, 2, 1), + (2, 5, 5, 10, 2, 2, 1, 1), + (2, 5, 5, 10, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 16, 1), + (1, 5, 5, 10, 2, 2, 1, 2), + (1, 5, 5, 4, 2, 2, 2, 1), + (4, 5, 5, 10, 2, 2, 2, 3), + (4, 5, 5, 10, 2, 2, 1, 3), + (4, 5, 5, 10, 2, 2, 16, 2), + (1, 5, 5, 3, 2, 2, 2, 2), + (4, 5, 5, 10, 2, 2, 16, 1), + (4, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 16, 3), + (4, 5, 5, 4, 2, 2, 16, 2), + (4, 5, 5, 10, 2, 2, 2, 1), + (4, 5, 5, 3, 2, 2, 1, 1), + (4, 5, 5, 4, 2, 2, 2, 1), + (4, 5, 5, 3, 2, 2, 2, 1), + (2, 5, 5, 1, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 2, 1), + (2, 5, 5, 10, 2, 2, 16, 1), + (1, 5, 5, 1, 3, 3, 1, 1), +] + +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) +validation_test = lambda *a: pytest.param(*a, marks=pytest.mark.validate_only) + +_igemm_cases += [ + perf_test(2, 128, 128, 16, 3, 3, 320, 1), + perf_test(2, 128, 128, 320, 1, 1, 640, 1), + perf_test(2, 128, 128, 320, 1, 1, 960, 1), + perf_test(2, 128, 128, 320, 3, 3, 16, 1), + perf_test(2, 128, 128, 320, 3, 3, 320, 1), + perf_test(2, 32, 32, 1280, 1, 1, 1920, 1), + perf_test(2, 32, 32, 1280, 1, 1, 2560, 1), + perf_test(2, 32, 32, 1280, 1, 1, 640, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1280, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1920, 1), + perf_test(2, 32, 32, 1280, 3, 3, 2560, 1), + perf_test(2, 32, 32, 1280, 3, 3, 640, 1), + perf_test(2, 32, 32, 640, 3, 3, 640, 1), + perf_test(2, 64, 64, 320, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 1280, 1), + perf_test(2, 64, 64, 640, 1, 1, 1920, 1), + perf_test(2, 64, 64, 640, 1, 1, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 960, 1), + perf_test(2, 64, 64, 640, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 3, 3, 640, 1), +] + +_mem_spaces = [ + pytest.param(GLOBAL_ADDRESS_SPACE, id="global", marks=pytest.mark.validate_only), + pytest.param(SHARED_ADDRESS_SPACE, id="shared"), +] + +_layouts = [ + pytest.param("nchw_fchw", marks=pytest.mark.validate_only), + pytest.param("nhwc_hwcf"), +] + + @require_e2e -@pytest.mark.parametrize("n", [1, 2, 4]) -@pytest.mark.parametrize("c", [1, 3, 4, 10]) -@pytest.mark.parametrize("nf", [1, 2, 16]) -@pytest.mark.parametrize("stride", [1, 2, 3]) -def test_igemm_conv(n, c, nf, stride): - h, w = 5, 5 # Image. - cf, hf, wf = c, 2, 2 # Filters. +@pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) +@pytest.mark.parametrize("mem_space", _mem_spaces) +@pytest.mark.parametrize("layout", _layouts) +def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request): + cf = c padding = 0 # TODO: only pad=0 is supported for now torch.manual_seed(1) @@ -633,6 +787,20 @@ def test_igemm_conv(n, c, nf, stride): # Other hyperparameters ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + if layout == "nchw_fchw": + x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32] + elif layout == "nhwc_hwcf": + x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16] + we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16] + out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32] + x = torch.permute(x, (0, 2, 3, 1)).contiguous() + we = torch.permute(we, (2, 3, 1, 0)).contiguous() + out_ref = torch.permute(out_ref, (0, 2, 3, 1)).contiguous() + else: + raise ValueError(f"Invalid layout: {layout}") + # Expose user-constraints constraints: list[tkw.Constraint] = [] constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] @@ -650,9 +818,9 @@ def test_igemm_conv(n, c, nf, stride): @tkw.wave(constraints) def conv( - x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16], - we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16], - out: tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32], + x: x_type, + we: we_type, + out: out_type, ): c_reg = tkl.Register[M, NF, tkl.f32](0.0) @@ -675,10 +843,20 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD ) - out = torch.zeros_like(out_ref) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + with tk.gen.TestLaunchContext( { N: n, @@ -691,11 +869,32 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: BLOCK_M: 16, BLOCK_N: 16, ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE: mem_space, }, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): + out = torch.zeros_like(out_ref) conv(x, we, out) assert_allclose(out, out_ref, rtol=1e-03, atol=1e-03) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + + config[ + "iree_preprocessing_pass_pipeline" + ] = "builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)" + iree_ref = torch.zeros_like(out_ref) + generate_iree_ref( + "conv_2d_" + layout, + [x, we], + [iree_ref], + config, + stride=stride, + run_bench=True, + ) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 344032a4..78d26e3b 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -8,21 +8,32 @@ import pytest import torch import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.iree_utils import generate_iree_ref +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import os import json +from torch.testing import assert_close _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) +# Whether to use scheduling group barriers (needs LLVM fix). +enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0)) default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)] +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) + +default_test_shapes += [ + perf_test((1024, 5120, 640)), + perf_test((2048, 10240, 1280)), + perf_test((4096, 20480, 2560)), +] + user_specified_test_shapes = "" test_params_path = os.environ.get("TEST_PARAMS_PATH", None) @@ -40,8 +51,10 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) -def testGemm(shape: tuple[int]): - +@pytest.mark.parametrize("enable_scheduling", [False, True]) +def testGemm(shape: tuple[int], enable_scheduling: bool, request): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") # Input sizes M = tkl.sym.M N = tkl.sym.N @@ -106,10 +119,33 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: M: shape[0], N: shape[1], K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + with tk.gen.TestLaunchContext( - hyperparams, canonicalize=True, run=True, run_config=config + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, ): a = torch.randn(shape[0], shape[2], dtype=torch.float16) b = torch.randn(shape[1], shape[2], dtype=torch.float16) @@ -121,11 +157,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with open(filename, "w") as f: f.write(mb.module_op.get_asm()) + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32) - generate_iree_ref("mmt", [a, b], [iree_ref], config) - assert torch.equal(c, iree_ref) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref) diff --git a/tests/kernel/wave/wave_packaging_test.py b/tests/kernel/wave/wave_packaging_test.py new file mode 100644 index 00000000..4c4f71f2 --- /dev/null +++ b/tests/kernel/wave/wave_packaging_test.py @@ -0,0 +1,140 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import torch +import unittest +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.packaging.build_package import create_pip_package +import os +import json +from torch.testing import assert_close + + +def packageTest(): + shape = (2048, 1280, 1280) + enable_scheduling = True + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + # a_reg: tkw.Register[M, K, tkl.f16] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[N, K, tkl.f16] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[M, N, tkl.f32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = { + "backend": "rocm", + "device": "hip", + "target": "gfx942", + "dump_vmfb_file": "artifacts/kernel.vmfb", + } + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=False, + run_config=config, + schedule=enable_scheduling, + ): + a = torch.randn(shape[0], shape[2], dtype=torch.float16) + b = torch.randn(shape[1], shape[2], dtype=torch.float16) + c = torch.zeros(shape[0], shape[1], dtype=torch.float32) + gemm(a, b, c) + + # Create the pip package + kernel_info = { + "package_name": "libtkw", + "kernel_name": "gemm_f32_2048x1280x1280_f16", + "num_inputs": 2, + "dispatch_name": "isolated_benchmark", + "vmfb_path": "artifacts/kernel.vmfb", + "kernel_version": "0.0.1", + } + create_pip_package(output_dir="pip_package/", kernel_info=kernel_info) + # Run python setup.py bdist_wheel in pip_package/ to build the wheel. + # Once the wheel is built, it can be installed using + # pip install .whl --find-links https://iree.dev/pip-release-links.html + # The kernel can then be invoked from Python as follows: + # import libtkw + # import torch + # a = torch.randn(2048, 1280, dtype=torch.float16, device="cuda") + # b = torch.randn(1280, 1280, dtype=torch.float16, device="cuda") + # c = torch.empty(2048, 1280, dtype=torch.float32, device="cuda") + # libtkw.gemm_f32_2048x1280x1280_f16(a, b, c) + assert os.path.exists("pip_package/libtkw-0.0.1-py3-none-any.whl") + + +packageTest() diff --git a/tests/kernel/wave/wave_sim_test.py b/tests/kernel/wave/wave_sim_test.py index 5fa5695a..58ec1255 100644 --- a/tests/kernel/wave/wave_sim_test.py +++ b/tests/kernel/wave/wave_sim_test.py @@ -6,9 +6,9 @@ import pytest import torch -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim from numpy.testing import assert_allclose diff --git a/tests/kernel/wave/wave_utils_test.py b/tests/kernel/wave/wave_utils_test.py index ec1198fd..bce6de9f 100644 --- a/tests/kernel/wave/wave_utils_test.py +++ b/tests/kernel/wave/wave_utils_test.py @@ -6,8 +6,8 @@ import logging import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.utils import delinearize_index +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.utils import delinearize_index import numpy as np M = sym.M diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index b06a7910..facbf545 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops # See runtime/op_reg/kernel_aot_test.py for additional tests of the trace diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index e78aff8e..89cd83c8 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -14,17 +14,17 @@ from iree.runtime import HalElementType # Public API imports. -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Device, ) # Internals. -from shark_turbine.runtime.device import ( +from iree.turbine.runtime.device import ( _CURRENT_THREAD, get_device_from_torch, ) -from shark_turbine.support.exceptions import * +from iree.turbine.support.exceptions import * class DeviceTest(unittest.TestCase): @@ -151,7 +151,7 @@ def testFromTorchDevice(self): print(device.dump_device_info()) def testJit(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") result = test_ops.test_add(t, t) @@ -161,7 +161,7 @@ def testJit(self): class TorchCPUInterop(unittest.TestCase): def testJitStrFormat(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) @@ -169,7 +169,7 @@ def testJitStrFormat(self): torch.testing.assert_close(result, expected) def testJitJinja(self): - from shark_turbine.ops import _jinja_test_ops as test_ops + from iree.turbine.ops import _jinja_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) diff --git a/tests/runtime/launch_test.py b/tests/runtime/launch_test.py index 1a142161..ad12b2e3 100644 --- a/tests/runtime/launch_test.py +++ b/tests/runtime/launch_test.py @@ -8,11 +8,11 @@ import torch import unittest -from shark_turbine.aot.params import ( +from iree.turbine.aot.params import ( ParameterArchiveBuilder, ) -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/tests/runtime/op_reg/impl_helper_test.py b/tests/runtime/op_reg/impl_helper_test.py index b0797c2d..2661dc5b 100644 --- a/tests/runtime/op_reg/impl_helper_test.py +++ b/tests/runtime/op_reg/impl_helper_test.py @@ -9,7 +9,7 @@ import torch -from shark_turbine.ops import _str_format_test_ops +from iree.turbine.ops import _str_format_test_ops class KernelRegTest(unittest.TestCase): diff --git a/tests/runtime/op_reg/kernel_aot_test.py b/tests/runtime/op_reg/kernel_aot_test.py index 4aa04857..4533326a 100644 --- a/tests/runtime/op_reg/kernel_aot_test.py +++ b/tests/runtime/op_reg/kernel_aot_test.py @@ -10,10 +10,10 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass class MLP(nn.Module): diff --git a/tests/runtime/op_reg/kernel_reg_test.py b/tests/runtime/op_reg/kernel_reg_test.py index 75554b04..dfc88d83 100644 --- a/tests/runtime/op_reg/kernel_reg_test.py +++ b/tests/runtime/op_reg/kernel_reg_test.py @@ -9,9 +9,9 @@ import torch -from shark_turbine.runtime.op_reg import * +from iree.turbine.runtime.op_reg import * -from shark_turbine.runtime.op_reg.compiler import _testing_get_cache_size +from iree.turbine.runtime.op_reg.compiler import _testing_get_cache_size class KernelRegTest(unittest.TestCase): diff --git a/tests/tools/interpreter_test.py b/tests/tools/interpreter_test.py index 0513b10b..2152c701 100644 --- a/tests/tools/interpreter_test.py +++ b/tests/tools/interpreter_test.py @@ -1,8 +1,8 @@ -from shark_turbine.tools.interpreter import Interpreter -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.tools.interpreter import Interpreter +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * import torch diff --git a/tests/top_level_package_test.py b/tests/top_level_package_test.py index 52ea796b..b2c04cdd 100644 --- a/tests/top_level_package_test.py +++ b/tests/top_level_package_test.py @@ -11,8 +11,8 @@ class TopLevelPackageTest(unittest.TestCase): def testIreeTurbineRedirect(self): # We have a temporary redirect of the top-level API to the - # iree.turbine namespace. - from iree.turbine import aot, dynamo, kernel, ops, runtime + # shark-turbine namespace. + from shark_turbine import aot, dynamo, kernel, ops, runtime if __name__ == "__main__": diff --git a/tests/transforms/general/add_metadata_test.py b/tests/transforms/general/add_metadata_test.py index 8055fa26..da5d0207 100644 --- a/tests/transforms/general/add_metadata_test.py +++ b/tests/transforms/general/add_metadata_test.py @@ -11,7 +11,7 @@ from iree.compiler.ir import Context, Operation, Module -from shark_turbine.transforms.general import add_metadata +from iree.turbine.transforms.general import add_metadata SIMPLE_FUNC_ASM = r""" func.func @list_func(%arg0 : !iree_input.list) -> !iree_input.list { diff --git a/tests/transforms/general/custom_op_expansion_test.py b/tests/transforms/general/custom_op_expansion_test.py index b94e2750..f621320d 100644 --- a/tests/transforms/general/custom_op_expansion_test.py +++ b/tests/transforms/general/custom_op_expansion_test.py @@ -9,15 +9,15 @@ import torch import unittest -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass -from shark_turbine.runtime.op_reg import ( +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.runtime.op_reg import ( def_library, CustomOp, KernelBuilder, KernelSelection, ) -from shark_turbine.support.ir_imports import ( +from iree.turbine.support.ir_imports import ( Context, Module, ) diff --git a/tests/transforms/general/rename_parameters_test.py b/tests/transforms/general/rename_parameters_test.py index 74fc6753..a14dbcbd 100644 --- a/tests/transforms/general/rename_parameters_test.py +++ b/tests/transforms/general/rename_parameters_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.general import rename_parameters +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.general import rename_parameters SIMPLE_GLOBALS_ASM = r""" module { diff --git a/tests/transforms/quantization/mm_group_quant_test.py b/tests/transforms/quantization/mm_group_quant_test.py index c6870d2c..b465301b 100644 --- a/tests/transforms/quantization/mm_group_quant_test.py +++ b/tests/transforms/quantization/mm_group_quant_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.quantization import mm_group_quant +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.quantization import mm_group_quant MM_F32_TO_INT4_CONTENTS = ( Path(__file__).resolve().parent / "mm_f32_to_int4.mlir"