diff --git a/.github/Dockerfile.base b/.github/Dockerfile.base index c0a01e6d6..e6fc33757 100644 --- a/.github/Dockerfile.base +++ b/.github/Dockerfile.base @@ -28,7 +28,11 @@ RUN apt-get update && apt-get install -y \ graphviz \ patchelf \ libyaml-cpp-dev \ - libboost-all-dev + libboost-all-dev \ + curl \ + jq \ + sudo \ + gh # Install clang 17 RUN wget https://apt.llvm.org/llvm.sh && \ diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index ade377c06..8ec0c93dc 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -47,8 +47,9 @@ jobs: fail-fast: false matrix: build: [ - {runs-on: ubuntu-latest, enable_perf: OFF, name: "run", ttrt_flags: ""}, - {runs-on: ubuntu-latest, enable_perf: ON, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: OFF, name: "run", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: ON, enable_op_model: OFF, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} ] name: Build tt-mlir @@ -66,11 +67,22 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "Build tt-mlir (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.enable_op_model }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -78,7 +90,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} # Build project @@ -97,6 +109,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=ON \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build @@ -143,18 +156,19 @@ jobs: run: | source env/activate cmake --build ${{ steps.strings.outputs.build-output-dir }} -- check-ttmlir + cp build/test/report.xml ${{ steps.strings.outputs.test_report_path }} - name: Upload Test Report uses: actions/upload-artifact@v4 with: - name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }} - path: build/test/report.xml + name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }} + path: ${{ steps.strings.outputs.test_report_path }} - name: Show Test Report uses: mikepenz/action-junit-report@v4 if: success() || failure() with: - report_paths: build/test/report.xml + report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: MLIR Tests # Build and upload ttrt @@ -214,6 +228,7 @@ jobs: {runs-on: n300, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"}, {runs-on: n300, enable_perf: ON, name: "perf"}, ] + name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" runs-on: - in-service @@ -237,11 +252,23 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -303,19 +330,27 @@ jobs: run: | source env/activate ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/perf_unit + cp ttrt_report.xml ${{ steps.strings.outputs.test_report_path }} - - name: Upload ttrt test report + - name: Upload ttrt test report json if: always() uses: actions/upload-artifact@v4 with: name: ${{ matrix.build.runs-on }}_${{ matrix.build.name }}_results.json path: ${{ matrix.build.name }}_results.json + - name: Upload Test Report xml + uses: actions/upload-artifact@v4 + if: success() || failure() + with: + name: test-reports-${{ matrix.build.runs-on }}-${{ matrix.test_group_id }} + path: ${{ steps.strings.outputs.test_report_path }} + - name: Show Test Report uses: mikepenz/action-junit-report@v4 if: success() || failure() with: - report_paths: ttrt_report.xml + report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: TTRT ${{ matrix.build.runs-on }} ${{ matrix.build.name }} Tests run-ttrt-tests: @@ -346,6 +381,7 @@ jobs: - /opt/tt_metal_infra/provisioning/provisioning_env:/opt/tt_metal_infra/provisioning/provisioning_env steps: + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -353,11 +389,22 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "${{ github.job }} (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -411,31 +458,22 @@ jobs: shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_read.py - - - name: ttrt query tests - shell: bash - run: | - source env/activate - pytest -ssv runtime/tools/python/test/test_query.py - - - name: ttrt check tests - shell: bash - run: | - source env/activate - pytest -ssv runtime/tools/python/test/test_check.py + pytest -ssv runtime/tools/python/test \ + --junit-xml=${{ steps.strings.outputs.test_report_path }} - - name: ttrt run tests - shell: bash - run: | - source env/activate - pytest -ssv runtime/tools/python/test/test_run.py + - name: Upload Test Report + uses: actions/upload-artifact@v4 + if: success() || failure() + with: + name: test-reports-${{ matrix.build.runs-on }}-${{ matrix.build.name }} + path: ${{ steps.strings.outputs.test_report_path }} - - name: ttrt perf tests - shell: bash - run: | - source env/activate - pytest -ssv runtime/tools/python/test/test_perf.py + - name: Show Test Report + uses: mikepenz/action-junit-report@v4 + if: success() || failure() + with: + report_paths: ${{ steps.strings.outputs.test_report_path }} + check_name: Run ttrt tests build-and-test-explorer: needs: build-image @@ -472,6 +510,7 @@ jobs: run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -480,7 +519,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} - name: Configure CMake shell: bash @@ -496,6 +535,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=OFF \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=OFF \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build tt-explorer @@ -509,3 +549,4 @@ jobs: run: | source env/activate pytest tools/explorer/test/run_tests.py + # collect results diff --git a/.github/workflows/issue-last-updated.yml b/.github/workflows/issue-last-updated.yml index 61a235aff..f79d16c2c 100644 --- a/.github/workflows/issue-last-updated.yml +++ b/.github/workflows/issue-last-updated.yml @@ -21,6 +21,7 @@ jobs: echo "project_id=PVT_kwDOA9MHEM4AjeTl" >> $GITHUB_ENV echo "field_id=PVTF_lADOA9MHEM4AjeTlzgiiU18" >> $GITHUB_ENV + - name: Get Issue ID id: get_issue_id run: | @@ -31,18 +32,94 @@ jobs: - name: Get Item ID for Issue - id: get_item_by_issue_id + id: get_item_id_by_issue_id run: | - ITEM_ID=$(curl -X POST -H "Authorization: Bearer $GITHUB_TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "query": "query($projectId: ID!) { node(id: $projectId) { ... on ProjectV2 { items(first: 100) { nodes { id content { ... on Issue { id } } } } } } }", - "variables": { - "projectId": "'"${{ env.project_id }}"'" - } - }' \ - https://api.github.com/graphql | jq -r '.data.node.items.nodes[] | select(.content.id=="'"${{ env.issue_id }}"'") | .id') - echo "ITEM_ID=$ITEM_ID" >> $GITHUB_ENV + # Initialize variables + CURSOR=null + ITEM_ID="" + + + # Define the GraphQL query as a string + QUERY='query($projectId: ID!, $cursor: String) { + node(id: $projectId) { + ... on ProjectV2 { + items(first: 100, after: $cursor) { + nodes { + id + content { + ... on Issue { + id + } + } + } + pageInfo { + hasNextPage + endCursor + } + } + } + } + }' + + + while : ; do + # Construct JSON payload using jq for proper formatting + JSON_PAYLOAD=$(jq -n \ + --arg query "$QUERY" \ + --arg projectId "${{ env.project_id }}" \ + --arg cursor "$CURSOR" \ + '{ query: $query, variables: { projectId: $projectId, cursor: $cursor }}') + + + # Make the GraphQL request + RESPONSE=$(curl -s -X POST -H "Authorization: Bearer $GITHUB_TOKEN" \ + -H "Content-Type: application/json" \ + -d "$JSON_PAYLOAD" \ + https://api.github.com/graphql) + + + # Debug: print entire response + echo "RESPONSE: $RESPONSE" + + + # Check if the response contains `items` data + ITEMS_DATA=$(echo "$RESPONSE" | jq -r '.data.node.items.nodes' 2>/dev/null) + if [[ "$ITEMS_DATA" == "null" ]]; then + echo "Error: Items data not found. Please check your PROJECT_ID and GITHUB_TOKEN permissions." + break + fi + + + # Parse the item ID if it matches the issue_id + ITEM_ID=$(echo "$RESPONSE" | jq -r --arg issue_id "$issue_id" \ + '.data.node.items.nodes[] | select(.content.id==$issue_id) | .id') + + + # If ITEM_ID is found, output it and stop the loop + if [[ -n "$ITEM_ID" && "$ITEM_ID" != "null" ]]; then + echo "Found ITEM_ID: $ITEM_ID" + echo "ITEM_ID=$ITEM_ID" >> $GITHUB_ENV # Save ITEM_ID to environment for future steps + break + fi + + + # Extract pagination information + HAS_NEXT_PAGE=$(echo "$RESPONSE" | jq -r '.data.node.items.pageInfo.hasNextPage') + CURSOR=$(echo "$RESPONSE" | jq -r '.data.node.items.pageInfo.endCursor') + + + # If no more pages, exit loop + if [[ "$HAS_NEXT_PAGE" != "true" ]]; then + echo "Issue not found in project items." + break + fi + done + + + - name: Use Found ITEM_ID + if: env.ITEM_ID # Only runs if ITEM_ID was set + run: echo "The ITEM_ID is ${{ env.ITEM_ID }}" + - name: Update Project Field run: | diff --git a/.github/workflows/nightly-uplift.yml b/.github/workflows/nightly-uplift.yml index a0f6eb534..54dd758ae 100644 --- a/.github/workflows/nightly-uplift.yml +++ b/.github/workflows/nightly-uplift.yml @@ -5,7 +5,7 @@ name: Nighty Uplift on: schedule: - - cron: '0 8 * * *' # Runs at 08:00 UTC every day + - cron: '0 6 * * *' # Runs at 06:00 UTC every day workflow_dispatch: # Manual trigger jobs: @@ -13,25 +13,30 @@ jobs: runs-on: ubuntu-latest env: - SUBMODULE_PATH: third_party/tt-metal - TT_METAL_VERSION: origin/main + TT_METAL_SUBMODULE_PATH: third_party/tt-metal steps: - - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + ref: main - - name: Set env variable + - name: Set env variable for today's date run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - - name: Update tt-metal reference + - name: Fetch latest SHA of tt-metal submodule env: GH_TOKEN: ${{ github.token }} run: | - # Fetch the latest SHA using GitHub CLI - LATEST_SHA=$(gh api repos/tenstorrent/tt-metal/commits/main --jq '.sha') - # Update the third_party/CMakeLists.txt file with the new SHA - sed -i "s/set(TT_METAL_VERSION \".*\")/set(TT_METAL_VERSION \"${LATEST_SHA}\")/" third_party/CMakeLists.txt + LATEST_TT_METAL_VERSION=$(gh api repos/tenstorrent/tt-metal/commits/main --jq '.sha') + echo "LATEST_TT_METAL_VERSION=$LATEST_TT_METAL_VERSION" >> $GITHUB_ENV + + - name: Update tt-metal reference in third_party/CMakeLists.txt + run: | + echo "Updating tt-metal to SHA: ${{ env.LATEST_TT_METAL_VERSION }}" + sed -i "s/set(TT_METAL_VERSION \".*\")/set(TT_METAL_VERSION \"${{ env.LATEST_TT_METAL_VERSION }}\")/" third_party/CMakeLists.txt - name: Create Pull Request uses: peter-evans/create-pull-request@v7 @@ -41,9 +46,9 @@ jobs: committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> author: ${{ github.actor }} <${{ github.actor_id }}+${{ github.actor }}@users.noreply.github.com> base: main - commit-message: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}" - title: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}" - body: "This PR uplifts the ${{ env.SUBMODULE_PATH }} to the ${{ env.SUBMODULE_VERSION }}" + commit-message: "Uplift ${{ env.TT_METAL_SUBMODULE_PATH }} to ${{ env.LATEST_TT_METAL_VERSION }} ${{ env.TODAY }}" + title: "Uplift ${{ env.TT_METAL_SUBMODULE_PATH }} to ${{ env.LATEST_TT_METAL_VERSION }} ${{ env.TODAY }}" + body: "This PR uplifts the ${{ env.TT_METAL_SUBMODULE_PATH }} to the ${{ env.LATEST_TT_METAL_VERSION }}" labels: uplift delete-branch: true token: ${{ secrets.GH_TOKEN }} @@ -57,8 +62,11 @@ jobs: echo "Pull Request URL - ${{ steps.create-pr.outputs.pull-request-url }}" gh pr review ${{ steps.create-pr.outputs.pull-request-number }} --approve - - name: Enable Pull Request Automerge - if: ${{ steps.create-pr.outputs.pull-request-number }} - run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" - env: - GH_TOKEN: ${{ secrets.GH_TOKEN }} + # Note: Dissable auto-merge for now until we are more confident + # that uplift won't break the downstream projects + # + # - name: Enable Pull Request Automerge + # if: ${{ steps.create-pr.outputs.pull-request-number }} + # run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" + # env: + # GH_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/.github/workflows/produce_data.yml b/.github/workflows/produce_data.yml new file mode 100644 index 000000000..e53ccc0f6 --- /dev/null +++ b/.github/workflows/produce_data.yml @@ -0,0 +1,28 @@ +name: "[internal] Collect workflow data" + +on: + workflow_run: + workflows: # List workflow that we want to collect data for + - "On PR" + - "On push" + - "Build on macos-latest" + - "Build and Test" + types: + - completed + +jobs: + produce-cicd-data: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + steps: + - name: Collect CI/CD data + uses: tenstorrent/tt-github-actions/.github/actions/collect_data@main + if: ${{ github.event_name == 'workflow_run' }} + with: + repository: ${{ github.repository }} + run_id: ${{ github.event.workflow_run.id }} + run_attempt: ${{ github.event.workflow_run.run_attempt }} + sftp_host: ${{ secrets.SFTP_CICD_WRITER_HOSTNAME }} + sftp_user: ${{ secrets.SFTP_CICD_WRITER_USERNAME }} + ssh-private-key: ${{ secrets.SFTP_CICD_WRITER_KEY }} diff --git a/.gitignore b/.gitignore index 8663a2ff0..274c39c1f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,8 @@ ttrt-artifacts/* query_results.json run_results.json ttrt_report.xml +cluster_descriptor.yaml + +# TTNN and TTMetal flatbuffers +*.ttnn +*.ttm diff --git a/CMakeLists.txt b/CMakeLists.txt index 54fcc89d4..2927fb560 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ endif() option(TT_RUNTIME_ENABLE_PERF_TRACE "Enable performance mode" OFF) option(TTMLIR_ENABLE_RUNTIME "Enable runtime" OFF) option(TTMLIR_ENABLE_STABLEHLO "Enable StableHLO support" OFF) +option(TTMLIR_ENABLE_OP_MODEL "Enable OpModel support" OFF) if (TTMLIR_ENABLE_STABLEHLO) add_compile_definitions(TTMLIR_ENABLE_STABLEHLO) @@ -20,6 +21,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(TTMLIR_ENABLE_BINDINGS_PYTHON ON CACHE BOOL "Enable Python bindings") +if (APPLE) + set(TTMLIR_ENABLE_OP_MODEL OFF) + message(WARNING "TTNNOpModelLib is disabled on Apple platforms. Optimizer will not get true performance.") +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake/modules) if (TT_RUNTIME_ENABLE_PERF_TRACE) diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index beeb35883..41ca83528 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -5,7 +5,7 @@ # User Guide - [Building](./build.md) - - [Internal Build Notes / IRD](./internal-build.md) + - [Docker Notes](./docker-notes.md) - [Tools](./tools.md) - [ttmlir-opt](./ttmlir-opt.md) - [ttmlir-translate](./ttmlir-translate.md) diff --git a/docs/src/dialects-overview.md b/docs/src/dialects-overview.md index e886fb90c..0dbf5fbed 100644 --- a/docs/src/dialects-overview.md +++ b/docs/src/dialects-overview.md @@ -3,7 +3,7 @@ Here is a brief overview of the dialects in the project, please refer to the individual dialect documentation for more details.: -- `tt`: Common types such as, `tt.tile`, `tt.layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. +- `tt`: Common types such as, `tt.tile`, `tt.metal_layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. - `ttir`: A high level dialect that models the tensor compute graph on tenstorrent devices. Accepts `tosa` and `linalg` input. - `ttir.generic`: Generically describe compute work. - `ttir.to_layout`: Convert between different tensor memory layouts and transfer between different memory spaces. diff --git a/docs/src/internal-build.md b/docs/src/docker-notes.md similarity index 72% rename from docs/src/internal-build.md rename to docs/src/docker-notes.md index 11d2fb864..1674bf2ef 100644 --- a/docs/src/internal-build.md +++ b/docs/src/docker-notes.md @@ -1,21 +1,11 @@ -# Internal Build Notes / IRD - -- When building the runtime we must use Ubuntu 22.04 docker image - - When making an IRD reservation use `--docker-image - yyz-gitlab.local.tenstorrent.com:5005/tenstorrent/infra/ird-ubuntu-22-04-amd64:latest` -- You'll have to manaully install a newer version of cmake, at least 3.22, the easiest way to do this is to `pip install cmake` and make sure this one is in your path -- You'll want LLVM installation to persist IRD reservations, you can achieve this by: - - mkdir /localdev/$USER/ttmlir-toolchain - - When requesting an IRD use `--volumes /localdev/$USER/ttmlir-toolchain:/opt/ttmlir-toolchain` - -## Working with Docker Images +# Working with Docker Images Components: - Dockerfile - Workflow for building Docker image - Project build using Docker image -### Overview +## Overview We use docker images to prepare project enviroment, install dependancies, tooling and prebuild toolchain. Project builds four docker images: @@ -29,11 +19,11 @@ Base image starts with a supported base image (Ubuntu 22.04) and installs depend During the CI Docker build, the project is built and tests are run to ensure that everything is set up correctly. If any dependencies are missing, the Docker build will fail. -### Building the Docker Image using GitHub Actions +## Building the Docker Image using GitHub Actions The GitHub Actions workflow [Build and Publish Docker Image](.github/workflows/build-image.yml) builds the Docker images and uploads them to GitHub Packages at https://github.com/orgs/tenstorrent/packages?repo_name=tt-mlir. We use the git SHA we build from as the tag. -### Building the Docker Image Locally +## Building the Docker Image Locally To test the changes and build the image locally, use the following command: ```bash @@ -43,7 +33,7 @@ docker build -f .github/Dockerfile.ird -build-args FROM_IMAGE=base -t ghcr.io/te docker build -f .github/Dockerfile.ird -build-args FROM_IMAGE=ci -t ghcr.io/tenstorrent/tt-mlir/tt-mlir-ird-ubuntu-22-04:latest . ``` -### Using the Image in GitHub Actions Jobs +## Using the Image in GitHub Actions Jobs The GitHub Actions workflow [Build in Docker](.github/workflows/docker-build.yml) uses a Docker container for building: ```yaml diff --git a/docs/src/specs/device.md b/docs/src/specs/device.md index ae72fe638..64bc91cfa 100644 --- a/docs/src/specs/device.md +++ b/docs/src/specs/device.md @@ -135,7 +135,7 @@ the logical device grid: ```mlir tensor<16x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<8x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -170,7 +170,7 @@ the logical device grid: ```mlir tensor<256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x16>, memref<2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -205,7 +205,7 @@ We can consider the following tensor to map onto this grid: ```mlir tensor<64x256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x4x16>, memref<32x2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> diff --git a/docs/src/specs/tensor-layout.md b/docs/src/specs/tensor-layout.md index d523f51ed..52c693189 100644 --- a/docs/src/specs/tensor-layout.md +++ b/docs/src/specs/tensor-layout.md @@ -33,7 +33,7 @@ been used by the TT dialect to encode the tensor's layout. This looks like: ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, @@ -76,7 +76,7 @@ topics: ### Dimension Collapsing -Probably the most important concept in `tt.layout` is dimension collapsing. +Probably the most important concept in `tt.metal_layout` is dimension collapsing. This is captured by the affine map `linear` property which provides a mapping from tensor dim space to a reduced physical dimensional space. This single-handedly touches on most of the tensor layout goals mentioned at the @@ -106,7 +106,7 @@ to get our remapped offset: This remapped offset `(262, 100)` corresponds to the row and column index of the collapsed physical memory. -By default, the dim range `[0, -1)` is collapsed, but the `tt.layout` contructor +By default, the dim range `[0, -1)` is collapsed, but the `tt.metal_layout` contructor can actually take a programmable range called `collapseIntervals`. `collapseIntervals` is a list of pairs, where each pair is a dim range interval, left inclusive, right exclusive. Let's consider a few examples: @@ -137,7 +137,7 @@ Let's consider the original example again, but on a larger grid than `1x1`, say ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, @@ -173,7 +173,7 @@ Here's a few more example mlir snippets: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -181,7 +181,7 @@ tensor<8x300xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -189,7 +189,7 @@ tensor<8x96x32xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -197,7 +197,7 @@ tensor<8x96x32xf32, > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -226,7 +226,7 @@ A tilized tensor is one with a memref that has a tile element type. Given some tensor with scalar layout: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -238,7 +238,7 @@ tensor<3x64x128xf32, After tilizing we'll have: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -256,7 +256,7 @@ intact. Padding can be a bit of an overloaded term, but in this context it refers to an out of bounds area in the physical memory allocation that has no real tensor data in it. The contents of this area is tracked by `oob_val` and the padding -area can be automatically derived from the attributes of `tt.layout`. +area can be automatically derived from the attributes of `tt.metal_layout`. Padding is a necessary evil that arises when a tensor is not evenly divisible by a grid shape or tile shape. It can also arise due to minimum Noc addressing @@ -265,7 +265,7 @@ requirements. Example of non-divisible grid: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -284,7 +284,7 @@ cores and 1 scalar column of padding on the last column of cores. Taking the above example a step further, we could tilize it: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -308,7 +308,7 @@ stride between dimensions. Consider tensor (w/ batch dim `2`): ```mlir tensor<2x8x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 8 + d1, d2), undef, <1x2>, @@ -356,7 +356,7 @@ consider the following example with a 3d grid and `collapseIntervals=[(1, -1)]`. ```mlir tensor<2x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<1x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -387,7 +387,7 @@ under the same grid primitive that also divides tensor rows and columns. ## Concerns -- `tt.layout` is deliberately flexible and tries to capture as many problematic +- `tt.metal_layout` is deliberately flexible and tries to capture as many problematic use-cases we've ran into in the past in a single, succinct representation. This flexibility will need to be further constrained by backends to avoid unsupported programming of this attribute. diff --git a/docs/src/ttmlir-translate.md b/docs/src/ttmlir-translate.md index c82f7ee8f..ba9c69b3c 100644 --- a/docs/src/ttmlir-translate.md +++ b/docs/src/ttmlir-translate.md @@ -5,15 +5,15 @@ The `ttmlir-translate` translation utility. Unlike `ttmlir-opt` tool which is us ```bash # First, let's run `ttmlir-opt` to convert to proper dialect -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn --convert-ttnn-to-emitc test/ttmlir/Dialect/TTNN/simple_multiply.mlir -o c.mlir +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Dialect/TTNN/simple_multiply.mlir -o c.mlir # Now run `ttmlir-translate` to produce C++ code -./build/bin/ttmlir-translate -mlir-to-cpp c.mlir -allow-unregistered-dialect +./build/bin/ttmlir-translate --mlir-to-cpp c.mlir ``` Bonus: These two commands can be piped, to avoid writing a `mlir` file to disk, like so: ```bash -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn --convert-ttnn-to-emitc test/ttmlir/Dialect/TTNN/simple_multiply.mlir | ./build/bin/ttmlir-translate -mlir-to-cpp -allow-unregistered-dialect +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Dialect/TTNN/simple_multiply.mlir | ./build/bin/ttmlir-translate -mlir-to-cpp ``` ## Generate flatbuffer file from MLIR diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index fbbe8de4b..263cd1d8e 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -50,9 +50,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet( size_t chipCoordsSize, MlirAttribute *chipChannels, size_t chipChannelsSize); -MLIR_CAPI_EXPORTED MlirAttribute -ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, unsigned oobVal, - MlirAttribute grid, MlirType memref, unsigned memLayout); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMetalLayoutAttrGet( + MlirContext ctx, MlirAffineMap linear, unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, uint32_t memorySpace); @@ -84,6 +84,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( MlirAttribute *dram, size_t dramSize, MlirAttribute *eth, size_t ethSize, MlirAttribute *eth_inactive, size_t eth_inactiveSize); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx, + int64_t y, int64_t x); + #ifdef __cplusplus } #endif diff --git a/include/ttmlir-c/TTNNAttrs.h b/include/ttmlir-c/TTNNAttrs.h index a7f5a8170..ea3e333c2 100644 --- a/include/ttmlir-c/TTNNAttrs.h +++ b/include/ttmlir-c/TTNNAttrs.h @@ -5,6 +5,7 @@ #ifndef TTMLIR_C_TTNNATTRS_H #define TTMLIR_C_TTNNATTRS_H +#include "mlir-c/AffineMap.h" #include "ttmlir-c/Dialects.h" #ifdef __cplusplus @@ -44,6 +45,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y, int64_t x); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNTTNNLayoutAttrGet( + MlirContext ctx, MlirAffineMap linear, MlirAttribute grid, MlirType memref, + unsigned memLayout); + #ifdef __cplusplus } #endif diff --git a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h index acd5373c9..5f1feb08b 100644 --- a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h +++ b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h @@ -7,11 +7,15 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir::tt { +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter); + std::unique_ptr> createConvertTosaToTTIRPass(); } // namespace mlir::tt -#endif +#endif // TTMLIR_CONVERSION_TOSATOTTIR_TOSATOTTIR_H diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td index b82c71c3f..aee19f63c 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td @@ -137,6 +137,7 @@ def TT_OperandConstraintSingleBank : I32BitEnumAttrCaseBit<"SingleBank", 7, "sin def TT_OperandConstraintHeightSharded : I32BitEnumAttrCaseBit<"HeightSharded", 8, "height_sharded">; def TT_OperandConstraintWidthSharded : I32BitEnumAttrCaseBit<"WidthSharded", 9, "width_sharded">; def TT_OperandConstraintBlockSharded : I32BitEnumAttrCaseBit<"BlockSharded", 10, "block_sharded">; +def TT_OperandConstraintSystemScalar : I32BitEnumAttrCaseGroup<"SystemScalar", [TT_OperandConstraintSystem, TT_OperandConstraintScalar], "system_scalar">; def TT_OperandConstraintAnyLayout : I32BitEnumAttrCaseGroup<"AnyLayout", [TT_OperandConstraintNone, TT_OperandConstraintInterleaved, TT_OperandConstraintSingleBank, TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded], "any_layout">; def TT_OperandConstraintAny : I32BitEnumAttrCaseGroup<"Any", [TT_OperandConstraintSystem, TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any">; def TT_OperandConstraintAnyDevice : I32BitEnumAttrCaseGroup<"AnyDevice", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any_device">; @@ -155,6 +156,7 @@ def TT_OperandConstraint : I32BitEnumAttr<"OperandConstraint", "TT Operand Const TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded, + TT_OperandConstraintSystemScalar, TT_OperandConstraintAnyLayout, TT_OperandConstraintAny, TT_OperandConstraintAnyDevice, @@ -189,6 +191,54 @@ def TT_BufferAccess : I32BitEnumAttr<"BufferAccess", "TT Buffer Access", let cppNamespace = "::mlir::tt"; } +def TT_ReduceType_Sum : I32EnumAttrCase<"Sum", 0, "sum">; +def TT_ReduceType_Mean : I32EnumAttrCase<"Mean", 1, "mean">; +def TT_ReduceType_Max : I32EnumAttrCase<"Max", 2, "max">; +def TT_ReduceType_Min : I32EnumAttrCase<"Min", 3, "min">; +def TT_ReduceType_Std : I32EnumAttrCase<"Std", 4, "std">; +def TT_ReduceType_Var : I32EnumAttrCase<"Var", 5, "var">; + +def TT_ReduceType: I32EnumAttr<"ReduceType", "TT Reduce Type", + [ + TT_ReduceType_Sum, + TT_ReduceType_Mean, + TT_ReduceType_Max, + TT_ReduceType_Min, + TT_ReduceType_Std, + TT_ReduceType_Var, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + +def TT_MeshShardDirection_FullToShard : I32EnumAttrCase<"FullToShard", 0, "full_to_shard">; +def TT_MeshShardDirection_ShardToFull : I32EnumAttrCase<"ShardToFull", 1, "shard_to_full">; + +def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirection", + [ + TT_MeshShardDirection_FullToShard, + TT_MeshShardDirection_ShardToFull, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + +def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">; +def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">; +def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">; +def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">; + +def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType", + [ + TT_MeshShardType_Manual, + TT_MeshShardType_Replicate, + TT_MeshShardType_Maximal, + TT_MeshShardType_Devices, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + def TT_CPURoleHost : I32EnumAttrCase<"Host", 0, "host">; def TT_CPURoleDevice : I32EnumAttrCase<"Device", 1, "device">; diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index d9ff13164..d5dc22e28 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -214,7 +214,7 @@ def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> { }]; } -def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { +def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> { let summary = "Tensor layout attribute"; let description = [{ The tensor layout attribute captures how tensor data is sharded across a grid of devices, cores, and @@ -241,7 +241,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { Examples: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -249,7 +249,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -257,7 +257,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -265,7 +265,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -284,7 +284,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref (`,` $mem_layout^)? `>`"; let extraClassDeclaration = [{ - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace = MemorySpace::System, @@ -292,28 +292,28 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace = MemorySpace::System, GridAttr grid = {}, ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, Type elementType, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withGrid(::mlir::MLIRContext *context, + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); - LayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); - LayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); - LayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); + MetalLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); + MetalLayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); + MetalLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); + MetalLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); uint64_t getMemrefSizeBytes() const; MemorySpace getMemorySpace() const; @@ -400,7 +400,7 @@ def TT_DeviceAttr : TT_Attr<"Device", "device", []> { // - DeviceL1: This ends up being exactly the shard size // - DeviceDRAM: Is more nuanced because the whole tensor size gets paged and interleaved between all dram channels, // due to paging and rounding the footprint ends up being close to: the_whole_tensor / num_dram_channels - uint64_t getLayoutSizeBytes(ArrayRef tensorShape, LayoutAttr layout, MemorySpace memorySpace) const; + uint64_t getLayoutSizeBytes(ArrayRef tensorShape, MetalLayoutAttr layout, MemorySpace memorySpace) const; // Returns the footprint size in bytes of the tensor distributed across the given memory space. // Forwards to getLayoutSizeBytes, see comment there for more info. @@ -443,6 +443,20 @@ def TT_ArgumentAllocationAttr : TT_Attr<"ArgumentAllocation", "arg_alloc", []> { let assemblyFormat = "`<` $address `,` $size `,` $memorySpace `>`"; } +def TT_ReduceTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TT_ReduceTypeArrayAttr : TypedArrayAttrBase; + +def TT_MeshShardDirectionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TT_MeshShardTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // TT type definitions //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h index 16fafe551..4a44e883d 100644 --- a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h +++ b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h @@ -27,18 +27,23 @@ struct MemoryLayoutAnalysisPolicyTypeParser return false; } - static void print(llvm::raw_ostream &os, - const MemoryLayoutAnalysisPolicyType &value) { - llvm::StringRef policy; + static std::string toString(const MemoryLayoutAnalysisPolicyType &value) { + std::string res; switch (value) { case MemoryLayoutAnalysisPolicyType::DFSharding: - policy = "DFSharding"; + res += "DFSharding"; break; case MemoryLayoutAnalysisPolicyType::L1Interleaved: - policy = "L1Interleaved"; + res += "L1Interleaved"; break; } - os << "memory-layout-analysis-policy=" << policy << "\n"; + return res; + } + + static void print(llvm::raw_ostream &os, + const MemoryLayoutAnalysisPolicyType &value) { + os << "memory-layout-analysis-policy=" + << MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n"; } }; diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 506d07ce0..9afe67dac 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -114,8 +114,8 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI - Some combination of the above ```llvm - #layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> - #layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> + #layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> + #layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> %1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1> ``` }]; @@ -172,8 +172,12 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> { // TTIR top level named ops //===----------------------------------------------------------------------===// +def TwoOperands : ParamNativeOpTrait<"NOperands", "2">; +def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">; +def FourOperands : ParamNativeOpTrait<"NOperands", "4">; + class TTIR_ElementwiseOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let description = [{ Base class for elementwise operations. Elementwise operations can take inputs with different shape, @@ -187,7 +191,7 @@ class TTIR_ElementwiseOp traits = []> : } class TTIR_ElementwiseTernaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise ternary op."; let description = [{ Eltwise ternary op. @@ -210,7 +214,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { } class TTIR_ElementwiseUnaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise unary op."; let description = [{ Eltwise unary op. @@ -424,7 +428,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { } class TTIR_ElementwiseBinaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise binary op."; let description = [{ Eltwise binary op. @@ -502,18 +506,6 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> { }]; } -def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> { - let summary = "Eltwise maximum OP."; - let description = [{ - Calculates maximum of input tensors' values element-wise and stores result in output tensor. - - Example: - %lhs: [[3, 2, 7], [1, 4, 4]] - %rhs: [[1, 4, 2], [1, 2, 3]] - "ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]] - }]; -} - def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> { let summary = "Eltwise minimum OP."; let description = [{ @@ -719,27 +711,6 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { }]; } -// CCL ops -def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { - let summary = "All gather operation."; - let description = [{ - All gather op. - }]; - - let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - SI32Attr:$dim, - TT_OperandConstraintArrayAttr:$operand_constraints); - - let results = (outs AnyRankedTensor:$result); - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - - let hasVerifier = 1; -} - def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let summary = "Conv2d operation."; let description = [{ @@ -927,6 +898,33 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> { let hasVerifier = 1; } +def TTIR_SelectOp: TTIR_DPSOp<"select"> { + let summary = "Select op."; + let description = [{ + Extracts a sub-tensor (slice) from the input tensor along a specified dimension in few steps defined by the + `begin`, `length`, and `stride` attributes. + The `begin` specifies the start index for the selected dimension of the tensor. + The `length` specifies the number of elements to extract from the input tensor along the selected dimension. + The `stride` specifies the step size for the start index. The default value is 0. 0 means no stride. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + SI32Attr:$begin, + SI32Attr:$length, + DefaultValuedOptionalAttr:$stride, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: decomposing_an_op_index_ttir def TTIR_IndexOp: TTIR_DPSOp<"index"> { let summary = "Index op."; @@ -1023,6 +1021,48 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { let hasVerifier = 1; } +def TTIR_ArangeOp : TTIR_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a tensor with values from `start` to `end` (exclusive) with a step size of `step`, along the dimension specified by `arange_dimension`. + + Examples: + %0 = "ttir.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5xi64> + // %0: [0, 1, 2, 3, 4] + + %1 = "ttir.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64, arange_dimension = 0 : i64} : () -> tensor<5xf32> + // %1: [0.0, 2.0, 4.0, 6.0, 8.0] + + %2 = "ttir.arange"() {start = 0 : i64, end = 5 : i64, step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5x3xi64> + // %2: [ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4] + ] + + %3 = "ttir.arange"() {start = 0 : i64, end = 3 : i64, step = 1 : i64, arange_dimension = 1 : i64} : () -> tensor<5x3xi64> + // %3: [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + }]; + + let arguments = (ins SI64Attr:$start, + SI64Attr:$end, + SI64Attr:$step, + I64Attr:$arange_dimension); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; @@ -1066,6 +1106,34 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> { }]; } +def TTIR_LinearOp : TTIR_DPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + %a = tensor.empty() : () -> tensor<10x64x32xbf16> + %b = tensor.empty() : () -> tensor<32x128xbf16> + %bias = tensor.empty() : () -> tensor<128xbf16> + %output = tensor.empty() : () -> tensor<10x64x128xbf16> + %0 = "ttir.linear"(%a, %b, %bias, %output) : (tensor<10x64x32xbf16>, tensor<32x128xbf16>, tensor<128xbf16>, tensor<10x64x128xbf16>) -> tensor<10x64x128xbf16> + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: adding_an_op_matmul_ttir def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let summary = "Matrix multiply operation."; @@ -1101,11 +1169,10 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(getNumOperands() == 2 && "Input and output operand must have the same rank"); - assert(sameRank(getOperands()) && - "Elementwise unary op must have only one input and one output operand."); + assert(sameRank(getOperation()->getOperands()) && + "Input and output operand must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( @@ -1114,19 +1181,6 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } @@ -1146,29 +1200,16 @@ class TTIR_GenericElementwiseBinaryOp traits = []> void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(sameRank(getOperands()) && + assert(sameRank(getOperation()->getOperands()) && "For now all operands must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); - SmallVector indexingMaps(getNumOperands(), + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); + SmallVector indexingMaps(getOperation()->getNumOperands(), builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, builder.getAttr(IteratorType::Parallel)); return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } @@ -1193,6 +1234,53 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> { }]; } +def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> { + let summary = "Eltwise maximum."; + let description = [{ + Calculates maximum of input tensors' values element-wise and stores result in output tensor. + + Example: + %lhs: [[3, 2, 7], [1, 4, 4]] + %rhs: [[1, 4, 2], [1, 2, 3]] + "ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]] + }]; +} + +//===----------------------------------------------------------------------===// + +def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> { + let summary = "Scatter operation"; + let description = [{ + Produces a 'result' tensor which are equal to `input` tensor except that + several slices specified by `scatter_indices` are updated with the values + `updates`. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$scatter_indices, + AnyRankedTensor:$update, + DenseI32ArrayAttr:$update_window_dims, + DenseI32ArrayAttr:$inserted_window_dims, + DenseI32ArrayAttr:$input_batching_dims, + DenseI32ArrayAttr:$scatter_indices_batching_dims, + DenseI32ArrayAttr:$scatter_dims_to_operand_dims, + I32Attr:$index_vector_dim, + BoolAttr:$indices_are_sorted, + BoolAttr:$unique_indices, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let regions = (region SizedRegion<1>:$update_computation); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; +} + //===----------------------------------------------------------------------===// // TTIR region ops (ops that may appear inside of ttir.generic region) //===----------------------------------------------------------------------===// @@ -1222,4 +1310,102 @@ def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator]> { let arguments = (ins Variadic:$values); } +//===----------------------------------------------------------------------===// +// TTIR ccl ops +//===----------------------------------------------------------------------===// + +def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { + let summary = "All gather operation."; + let description = [{ + All gather op. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + +def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> { + let summary = "AllReduce operation."; + let description = [{ + AllReduce op. + }]; + + let arguments = (ins + Variadic:$inputs, + AnyRankedTensor:$output, + I64ElementsAttr:$replica_groups, + SI32Attr:$dim, + OptionalAttr:$channel_handle, + UnitAttr:$use_global_device_ids, + TT_ReduceTypeAttr:$reduce_type, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs Variadic:$results); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + +def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> { + let summary = "Mesh shard operation."; + let description = [{ + MeshShard op shards the inputs (FullToShard) or concatnates the outputs (ShardToFull) for ccl ops. + + shard_direction attribute determines whether to shard or concat. + + shard_type attribute determines how to shard or concat. + manual: no sharding + replicate: all devices have identical data + maximal: only one device contains full data + devices: shard_shape determines sharded dimensions + + For example, on 2x4 mesh hardware, following op shards arg0 to 8 slices, row divided by 2 + and col divided by 4. + + %1 = "ttir.mesh_shard"(%arg0, %0) < + {... shard_direction = #tt.shard_direction, + shard_shape = #tt.grid<2x4>, + shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, ...) -> tensor<4096x196xf32> + + On the other hand, this op concatnates %4 to single tensor by concatnating + one of the top row tensor with one of the bottom row tensor. + + %6 = "ttir.mesh_shard"(%4, %5) < + {..., shard_direction = #tt.shard_direction, + shard_shape = #tt.grid<2x1>, + shard_type = #tt.shard_type}> : (tensor<4096x16384xf32>, ...) -> tensor<8192x16384xf32> + }]; + + let arguments = (ins + AnyRankedTensor:$input, + AnyRankedTensor:$output, + TT_MeshShardTypeAttr:$shard_type, + TT_MeshShardDirectionAttr:$shard_direction, + TT_GridAttr:$shard_shape, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + #endif diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h index 1d88e8a65..01b677297 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h @@ -12,7 +12,7 @@ namespace mlir { namespace tt { namespace ttir { namespace detail { -mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op); +mlir::LogicalResult verifyBroadcastable(mlir::Operation *op); } // namespace detail } // namespace ttir } // namespace tt diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index cbc005673..a130332f0 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -64,11 +64,13 @@ def TTIROpInterface : OpInterface<"TTIROp"> { ]; } -def TTIR_ElementwiseOpInterface : OpInterface<"ElementwiseOp"> { +def TTIR_Broadcastable : OpInterface<"Broadcastable"> { let cppNamespace = "::mlir::tt::ttir"; + let dependentTraits = [AttrSizedOperandSegments]; + let verify = [{ - return detail::verifyElementwiseOp($_op); + return detail::verifyBroadcastable($_op); }]; } @@ -105,6 +107,20 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> { /*methodBody=*/"", /*defaultImplementation=*/"" >, + StaticInterfaceMethod< + /*desc=*/[{ + Return if the given operands have the same rank. + }], + /*retTy=*/"bool", + /*methodName=*/"sameRank", + /*args=*/(ins "::mlir::OperandRange":$operands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::all_equal(llvm::map_range(operands, [](Value operand) { + return mlir::cast(operand.getType()).getRank(); + })); + }] + > ]; } diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 63ccb0d28..b6269f715 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,4 +112,17 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { + let summary = "Broadcast operation is folded to all the consumers."; + let description = [{ + This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops. + Example: + %1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32> + %2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> + + This above broadcast is folded as: + %1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32> + }]; +} + #endif diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td index 4b6da4b68..ed70d7da6 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td @@ -180,6 +180,15 @@ def TTKernel_MulOp : TTKernel_Op<"mul"> { let arguments = (ins I32:$dst_index); } +def TTKernel_MaxOp : TTKernel_Op<"max"> { + let summary = "Max operation"; + let description = [{ + Max operation + }]; + + let arguments = (ins I32:$dst_index); +} + def TTKernel_MatmulOp : TTKernel_Op<"matmul"> { let summary = "Matmul operation"; let description = [{ @@ -333,6 +342,31 @@ def TTKernel_ReduceTileOp : TTKernel_Op<"reduce_tile"> { TTKernel_ReduceDimAttr:$reduce_dim); } +//===----------------------------------------------------------------------===// +// TTKernel SFPU operations +//===----------------------------------------------------------------------===// + +def TTKernel_MaxTilesInitOp : TTKernel_Op<"max_tile_init"> { + let summary = "Short init function"; + let description = [{ + Must be run before max_tile. + }]; + + let arguments = (ins); +} + +def TTKernel_MaxTilesOp : TTKernel_Op<"max_tile"> { + let summary = "Max operation"; + let description = [{ + Performs element-wise computation of maximum operation + DST[dst0_index] <- max(DST[dst0_index], DST[dst1_index]) + on DST register operands. The DST register buffer must be in + acquired state via *tile_regs_acquire* call. + }]; + + let arguments = (ins I32:$dst0_index, I32:$dst1_index); +} + //===----------------------------------------------------------------------===// // TTKernel CB operations //===----------------------------------------------------------------------===// @@ -503,6 +537,68 @@ def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> { }]; } +//===----------------------------------------------------------------------===// +// TTKernel Multicast NoC operations +//===----------------------------------------------------------------------===// + +def TTKernel_GetNocMulticastAddrOp : TTKernel_Op<"get_noc_multicast_addr"> { + let summary = "GetNocMulticastAddr"; + let description = [{ + GetNocMulticastAddr + }]; + + let arguments = (ins I32:$noc_x_start, I32:$noc_y_start, I32:$noc_x_end, I32:$noc_y_end, I32:$addr, Optional:$noc); + let results = (outs TTKernel_NocAddr:$mcastNocAddr); +} + +def TTKernel_NocAsyncWriteMulticastOnePacketOp : TTKernel_Op<"noc_async_write_multicast_one_packet"> { + let summary = "NocAsyncWriteMulticastOnePacket"; + let description = [{ + NocAsyncWriteMulticastOnePacket + this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + +def TTKernel_NocAsyncWriteMulticastOp : TTKernel_Op<"noc_async_write_multicast"> { + let summary = "NocAsyncWriteMulticast"; + let description = [{ + Initiates an asynchronous write from a source address in L1 memory on the + Tensix core executing this function call to a rectangular destination grid. + The destinations are specified using a uint64_t encoding referencing an + on-chip grid of nodes located at NOC coordinate range + (x_start,y_start,x_end,y_end) and a local address created using + *get_noc_multicast_addr* function. Also, *see noc_async_write_barrier*. + + The destination nodes can only be a set of Tensix cores + L1 memory address. + The destination nodes must form a rectangular grid. The destination L1 + memory address must be the same on all destination nodes. + + With this API, the multicast sender cannot be part of the multicast + destinations. If the multicast sender has to be in the multicast + destinations (i.e. must perform a local L1 write), the other API variant + *noc_async_write_multicast_loopback_src* can be used. + + Note: The number of destinations needs to be non-zero. Besides that, + there is no restriction on the number of destinations, i.e. the + multicast destinations can span the full chip. However, as mentioned + previously, the multicast source cannot be part of the destinations. So, the + maximum number of destinations is 119. + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + +def TTKernel_NocAsyncWriteMulticastLoopbackSrcOp : TTKernel_Op<"noc_async_write_multicast_loopback_src"> { + let summary = "NocAsyncWriteMulticastLoopbackSrc"; + let description = [{ + NocAsyncWriteMulticastLoopbackSrc + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + //===----------------------------------------------------------------------===// // TTKernel Misc operations //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 910ed7dfd..57383c007 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -636,6 +636,34 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { let hasVerifier = 1; } +def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + // %a = [[1., 2.]], [2., 1.]] + // %b = [[0., 1.], [1., 0.]] + // %bias = [[1.]] + "ttnn.linear"(%a, %b, %bias, %result) : (tensor<2x2xf16>, tensor<2x2xf16>, tensor<1xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> + // %result = [[3., 2.], [2., 3.]] + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output); + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + + // ANCHOR: adding_an_op_matmul_ttnn def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, @@ -759,6 +787,32 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { let hasVerifier = 1; } +def TTNN_ArangeOp : TTNN_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a (1, 1, 1, N)-shaped tensor with values from `start` to `end` (exclusive) with a step size of `step`. + + Examples: + %0 = "ttnn.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64} : () -> tensor<1x1x1x5xi64> + // %0: [[[[0, 1, 2, 3, 4]]]] + + %1 = "ttnn.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64} : () -> tensor<1x1x1x5xf32> + // %1: [[[[0.0, 2.0, 4.0, 6.0, 8.0]]]] + }]; + + let arguments = (ins I64Attr:$start, + I64Attr:$end, + I64Attr:$step, + OptionalAttr:$dtype, + Optional:$device, + OptionalAttr:$memory_config); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTNN_FullOp : TTNN_Op<"full"> { let summary = "Full op."; let description = [{ @@ -806,6 +860,13 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { let hasVerifier = 1; } +def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> { + let summary = "Scatter op."; + let description = [{ + Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor. + }]; +} + def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { let summary = "Reduce scatter op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index bba7fe6f2..e45fba003 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -109,6 +109,13 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { let summary = "Tensor encoding attribute used for types in ttnn"; let description = [{ Layout attribute in ttnn. This attribute is used to encode different information about tensor memory layout. + Here is how tensor will look like after layout tensor<32x32x64xf32, #ttnn.ttnn_layout> + Lets break down what each parameter means: + - linear: An affine map that defines how the logical tensor dimensions map to physical space. + - grid: The grid shape (of tensix cores) where tensor is divided onto. + - memref: A memref is used to describe shard size and memory space. Shard size is calculated by dividing the tensor size by grid size. + - mem_layout: The layout of the tensor in memory. For tensor on host it should be None. For tensor on device + it can be interleaved or sharded. }]; let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear, @@ -142,15 +149,15 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { bool hasShardedL1TensorMemoryLayout() const; bool hasInterleavedL1TensorMemoryLayout() const; bool isTiled() const; + Layout getLayout() const; Type getElementType() const; - DataType getDataTypeFromMemRef() const; + DataType getDataType() const; uint64_t getElementSizeBytes() const; int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; llvm::SmallVector getStride(ArrayRef logicalShape) const; - llvm::SmallVector getPhysicalShape(ArrayRef logicalShape) const; - llvm::SmallVector getShardShape(bool convertTileToScalar = true) const; + llvm::SmallVector getShardShape() const; + llvm::SmallVector getScalarShardShape() const; AffineMap replaceMemoryMapSymbolsWithShardShape(AffineMap physicalMemoryMap) const; - AffineMap projectOnto(AffineMap linearMap, AffineMap physicalMemoryMap) const; AffineMap getIdentityTileLinearMap() const; llvm::SmallVector getTiledShape(ArrayRef logicalTensorShape) const; }]; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 48c723e1c..636d5f623 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -6,7 +6,8 @@ #define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H #include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" -#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "mlir/Pass/PassOptions.h" diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index db24eeb28..c474106e3 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -5,50 +5,98 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H #define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H -#include - -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" +#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" namespace mlir::tt::ttnn { -struct OutputLayoutOverrideParams { - SmallVector grid; - BufferType bufferType; - TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - Layout memoryLayout; // ROW_MAJOR / TILE - tt::DataType dataType; -}; +class OptimizerOverridesHandler { +public: + OptimizerOverridesHandler() {}; + ~OptimizerOverridesHandler() {}; -struct InputLayoutOverrideParams { - SmallVector operandIdxes; -}; + // Setters for the overrides + // These are used to enable/disable the optimizer passes + void setEnableOptimizer(bool); + // These are used to enable/disable the memory configurations + void setMemoryReconfig(bool); + void setEnableMemoryLayoutAnalysis(bool); + void setEnableMemoryLayoutAnalysisPolicy(bool); + void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); + // These are used to set the input/output layout overrides + void setInputLayoutOverrides(llvm::StringMap &); + void setOutputLayoutOverrides(llvm::StringMap &); + // These are used to add system descriptor path + void setSystemDescPath(std::string); + // These are used to set the maximum number of legal layouts for grid analysis + void setMaxLegalLayouts(int64_t); + // These are used to set the mesh shape + void setMeshShape(std::vector); -struct OutputLayoutOverrideParser - : public llvm::cl::parser> { -public: - OutputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} + // Getters for the overrides + // These are used to get the current state of the optimizer passes + bool getEnableOptimizer() const; + // These are used to get the current state of the memory configurations + bool getMemoryReconfig() const; + bool getEnableMemoryLayoutAnalysis() const; + bool getEnableMemoryLayoutAnalysisPolicy() const; + MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; + // These are used to get the current input/output layout overrides + llvm::StringMap getInputLayoutOverrides() const; + llvm::StringMap getOutputLayoutOverrides() const; + // These are used to get the current system descriptor path + std::string getSystemDescPath() const; + // These are used to get the current maximum number of legal layouts for grid + // analysis + int64_t getMaxLegalLayouts() const; + // These are used to get the current mesh shape + std::vector getMeshShape() const; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Method that converts the overrides to a string + std::string toString() const; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; + // Fill input/output layout overrides maps. + // This is used from tt-forge frontend where we define and compile the models. + void addInputLayoutOverride(StringRef, InputLayoutOverrideParams); + void addInputLayoutOverride(StringRef, SmallVector &); + void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams); + void addOutputLayoutOverride(StringRef, SmallVector &, BufferType, + TensorMemoryLayout, tt::ttnn::Layout, + tt::DataType); -struct InputLayoutOverrideParser - : public llvm::cl::parser> { -public: - InputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} +private: + // Options for the TTIR to TTNN backend pipeline, + // we use them to extract the names and the deafulat values. + TTIRToTTNNBackendPipelineOptions pipelineOptions; + + // Flags for enabling/disabling the optimizer passes + bool enableOptimizer = false; + + // Flags for enabling/disabling the memory configurations + bool enableMemoryReconfig = true; + bool enableMemoryLayoutAnalysis = false; + + // Input layout overrides + llvm::StringMap inputLayoutOverrides; + + // Output layout overrides + llvm::StringMap outputLayoutOverrides; + + // Memory layout analysis policy + bool enableMemoryLayoutAnalysisPolicy = false; + MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy; + + // System descriptor path + std::string systemDescPath; + + // Maximum number of legal layouts for grid analysis + int64_t maxLegalLayouts = 0; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Mesh shape + std::vector meshShape; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; +}; // class OptimizerOverridesHandler } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h new file mode 100644 index 000000000..09e587c9c --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H +#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H + +#include + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + +namespace mlir::tt::ttnn { + +struct OutputLayoutOverrideParams { + + SmallVector grid; + BufferType bufferType; + TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + Layout memoryLayout; // ROW_MAJOR / TILE + mlir::tt::DataType dataType; + + bool operator==(const OutputLayoutOverrideParams rhs) const { + return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] && + bufferType == rhs.bufferType && + tensorMemoryLayout == rhs.tensorMemoryLayout && + memoryLayout == rhs.memoryLayout && dataType == rhs.dataType; + } + + bool operator!=(const OutputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct InputLayoutOverrideParams { + + SmallVector operandIdxes; + + bool operator==(const InputLayoutOverrideParams &rhs) const { + if (operandIdxes.size() != rhs.operandIdxes.size()) { + return false; + } + for (std::size_t i = 0; i < operandIdxes.size(); i++) { + if (operandIdxes[i] != rhs.operandIdxes[i]) { + return false; + } + } + return true; + } + + bool operator!=(const InputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct OutputLayoutOverrideParser + : public llvm::cl::parser> { +public: + OutputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +struct InputLayoutOverrideParser + : public llvm::cl::parser> { +public: + InputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +} // namespace mlir::tt::ttnn + +#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index a6e10c099..d7d8fbdd3 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -5,6 +5,8 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H #define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H +#include + #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" @@ -31,10 +33,6 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( mlir::tt::MemorySpace toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType); -DataType getDataTypeFromMemRef(mlir::MemRefType memref); - -Layout getLayoutFromMemRef(mlir::MemRefType memref); - mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype); diff --git a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h new file mode 100644 index 000000000..31ac14984 --- /dev/null +++ b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include + +namespace mlir::tt::op_model::ttnn { + +struct ReluOpInterface { + static bool isLegal(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); + + static std::tuple + getOpL1Usage(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; + +} // namespace mlir::tt::op_model::ttnn +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 2d67ee1d1..3e7ed425f 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -11,67 +11,67 @@ struct Dim2dRange { } enum Arch: uint { - Grayskull = 0, - Wormhole_b0 = 1, - Blackhole = 2, + Grayskull, + Wormhole_b0, + Blackhole } enum DataType: uint16 { - Float32 = 0, - Float16 = 1, - BFloat16 = 2, - BFP_Float8 = 3, - BFP_BFloat8 = 4, - BFP_Float4 = 5, - BFP_BFloat4 = 6, - BFP_Float2 = 7, - BFP_BFloat2 = 8, - UInt32 = 9, - UInt16 = 10, - UInt8 = 11, + Float32, + Float16, + BFloat16, + BFP_Float8, + BFP_BFloat8, + BFP_Float4, + BFP_BFloat4, + BFP_Float2, + BFP_BFloat2, + UInt32, + UInt16, + UInt8, } enum OOBVal: ushort { - Undef = 0, - Zero = 1, - One = 2, - Inf = 3, - NegInf = 4, + Undef, + Zero, + One, + Inf, + NegInf, } enum MemorySpace: ushort { - System = 0, - SystemMMIO = 1, - DeviceDRAM = 2, - DeviceL1 = 3, + System, + SystemMMIO, + DeviceDRAM, + DeviceL1, } enum ChipCapability: uint32 (bit_flags) { - PCIE = 0, - HostMMIO = 1, + PCIE, + HostMMIO, } enum TensorMemoryLayout: ushort { - None = 0, - Interleaved = 1, - SingleBank = 2, - HeightSharded = 3, - WidthSharded = 4, - BlockSharded = 5, + None, + Interleaved, + SingleBank, + HeightSharded, + WidthSharded, + BlockSharded, } enum TensorLayout: ushort { - RowMajor = 0, - Tile = 1, - Invalid = 2, + RowMajor, + Tile, + Invalid, } enum BufferType: ushort { - DRAM = 0, - L1 = 1, - SystemMemory = 2, - L1Small = 3, - Trace = 4, + DRAM, + L1, + SystemMemory, + L1Small, + Trace, } // TODO (#620): Add other fields like core_ranges, shard orientation etc. @@ -197,8 +197,8 @@ table ChipPhysicalCores { enum CPURole: uint8 { - Host = 0, - Device = 1, + Host, + Device, } table CPUDesc { @@ -223,9 +223,11 @@ table EventRef { global_id: uint32; } +// Explicit non-sequential enumeration copied over from tt-metal definition of +// `enum class MathFidelity`. enum MathFidelity : uint8 { - LoFi = 0, - HiFi2 = 2, - HiFi3 = 3, - HiFi4 = 4, + LoFi = 0, + HiFi2 = 2, + HiFi3 = 3, + HiFi4 = 4, } diff --git a/include/ttmlir/Target/TTMetal/program.fbs b/include/ttmlir/Target/TTMetal/program.fbs index 4fcf96602..52451234b 100644 --- a/include/ttmlir/Target/TTMetal/program.fbs +++ b/include/ttmlir/Target/TTMetal/program.fbs @@ -3,18 +3,18 @@ include "Common/types.fbs"; namespace tt.target.metal; enum NocIndex : ushort { - Noc0 = 0, - Noc1 = 1, + Noc0, + Noc1, } enum EthType : ushort { - Sender = 0, - Receiver = 1, + Sender, + Receiver, } enum UnpackToDestMode : uint8 { - UnpackToDestFp32 = 0, - Default = 1, + UnpackToDestFp32, + Default, } table NocConfig { @@ -45,17 +45,17 @@ table KernelSource { } enum BinaryType : ushort { - BRISC = 0, - NCRISC = 1, - TRISC0 = 2, - TRISC1 = 3, - TRISC2 = 4, - ERISC = 5, + BRISC, + NCRISC, + TRISC0, + TRISC1, + TRISC2, + ERISC, } enum CoreType : ushort { - WORKER = 0, - ETH = 1, + WORKER, + ETH, } table KernelBinary { diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index ec493e649..19b1dbc92 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -61,47 +61,58 @@ table FullOp { out: tt.target.TensorRef; } +table ArangeOp { + start: float; + end: float; + step: float; + dtype: tt.target.DataType = null; // optional + device: tt.target.DeviceRef; // optional + memcfg: tt.target.MemoryConfigDesc; // optional + out: tt.target.TensorRef; +} + enum EltwiseOpType: uint32 { - Add = 0, - Multiply = 1, - Subtract = 2, - Relu = 3, - GreaterEqual = 4, - Sqrt = 5, - Div = 6, - Sigmoid = 7, - Reciprocal = 8, - Exp = 9, - Maximum = 10, - Abs = 11, - Neg = 12, - Rsqrt = 13, - Typecast = 14, - Equal = 15, - NotEqual = 16, - LessEqual = 17, - LessThan = 18, - GreaterThan = 19, - LogicalAnd = 20, - LogicalOr = 21, - LogicalNot = 22, - Cbrt = 23, - Minimum = 24, - Ceil = 25, - Sin = 26, - Cos = 27, - Log = 28, - Log1p = 29, - Expm1 = 30, - Sign = 31, - Remainder = 32, - IsFinite = 33, - Floor = 34, - Where = 35, - Gelu = 36, - LogicalXor = 37, - Clamp = 38, - LeakyRelu = 39, + Add, + Multiply, + Subtract, + Relu, + GreaterEqual, + Sqrt, + Div, + Sigmoid, + Reciprocal, + Exp, + Maximum, + Abs, + Neg, + Rsqrt, + Typecast, + Equal, + NotEqual, + LessEqual, + LessThan, + GreaterThan, + LogicalAnd, + LogicalOr, + LogicalNot, + Cbrt, + Minimum, + Ceil, + Sin, + Cos, + Log, + Log1p, + Expm1, + Sign, + Remainder, + IsFinite, + Floor, + Where, + Gelu, + LogicalXor, + Clamp, + LeakyRelu, + Scatter } table ClampOpParams { @@ -126,9 +137,9 @@ table EltwiseOp { } enum ReductionOpType: uint32 { - Sum = 0, - Mean = 1, - Max = 2, + Sum, + Mean, + Max, } table ReductionOp { @@ -178,6 +189,13 @@ table SliceOp { step: [int64]; } +table LinearOp { + in0: tt.target.TensorRef; + in1: tt.target.TensorRef; + bias: tt.target.TensorRef; + out: tt.target.TensorRef; +} + // ANCHOR: adding_an_op_matmul_fbs table MatmulOp { in0: tt.target.TensorRef; @@ -249,6 +267,7 @@ union OpType { EmptyOp, FullOp, EltwiseOp, + LinearOp, MatmulOp, ReductionOp, EmbeddingOp, @@ -261,11 +280,13 @@ union OpType { MaxPool2dOp, DeallocateOp, AllGatherOp, + ArangeOp, } table Operation { type: OpType; debug_info: string; + loc_info: string; } table Program { diff --git a/include/ttmlir/Target/Utils/FuncOpToProgram.h b/include/ttmlir/Target/Utils/FuncOpToProgram.h index d9e8d9820..a28f2f5e9 100644 --- a/include/ttmlir/Target/Utils/FuncOpToProgram.h +++ b/include/ttmlir/Target/Utils/FuncOpToProgram.h @@ -31,6 +31,13 @@ inline std::string getOpDebugString(mlir::Operation *op, return str; }; +inline std::string getOpLocInfo(mlir::Operation *op) { + std::string str; + llvm::raw_string_ostream os(str); + op->getLoc().print(os); + return str; +} + inline Value getOperandThroughDPSOps(Value value) { auto *op = value.getDefiningOp(); if (!op) { @@ -76,7 +83,8 @@ Program funcOpToProgram(FlatbufferObjectCache &cache, func::FuncOp entry, } } else { std::string debugStr = getOpDebugString(op, printFlags); - program.ops.push_back(fn(cache, op, debugStr)); + std::string locInfo = getOpLocInfo(op); + program.ops.push_back(fn(cache, op, debugStr, locInfo)); } }); diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index ac23b9bb0..cb9439d97 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -18,8 +18,9 @@ namespace mlir::tt { flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr attr, - ArrayRef logicalShape, DeviceAttr deviceAttr); +metalLayoutAttrToFlatbuffer(FlatbufferObjectCache &cache, MetalLayoutAttr attr, + ArrayRef logicalShape, + DeviceAttr deviceAttr); flatbuffers::Offset<::tt::target::LayoutDesc> ttnnLayoutAttrToFlatbuffer( FlatbufferObjectCache &cache, ttnn::TTNNLayoutAttr attr, @@ -438,9 +439,9 @@ toFlatbuffer(FlatbufferObjectCache &cache, ElementsAttr elementsAttr) { inline flatbuffers::Offset<::tt::target::LayoutDesc> encodingToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr, ArrayRef logicalShape, DeviceAttr deviceAttr) { - if (isa(attr)) { - return layoutAttrToFlatbuffer(cache, cast(attr), logicalShape, - deviceAttr); + if (isa(attr)) { + return metalLayoutAttrToFlatbuffer(cache, cast(attr), + logicalShape, deviceAttr); } assert(isa(attr) && "unsupported layout attr"); @@ -478,7 +479,11 @@ toDebugInfo(::flatbuffers::FlatBufferBuilder &fbb, std::string const &name, ModuleOp module) { std::string source; llvm::raw_string_ostream os(source); - module->print(os); + + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); // Enable the loc dumping + module->print(os, flags); + return ::tt::target::CreateMLIRDirect(fbb, name.c_str(), source.c_str()); } } // namespace mlir::tt diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index 40a3ada6f..c329f41d5 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -119,15 +119,15 @@ MlirAttribute ttmlirTTSystemDescAttrGet( chipCapabilitiesUnwrapped, chipCoordsUnwrapped, chipChannelsUnwrapped)); } -MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, - unsigned oobVal, MlirAttribute grid, - MlirType memref, unsigned memLayout) { +MlirAttribute ttmlirTTMetalLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, + unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout) { mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr); - return wrap(LayoutAttr::get(unwrap(ctx), affineMap, - static_cast(oobVal), - mlir::cast(unwrap(grid)), - mlir::cast(unwrap(memref)), - static_cast(memLayout))); + return wrap(MetalLayoutAttr::get(unwrap(ctx), affineMap, + static_cast(oobVal), + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), + static_cast(memLayout))); } MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, @@ -219,4 +219,8 @@ MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( ethVec, ethInactiveVec)); } +MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx, int64_t y, int64_t x) { + return wrap(CoreCoordAttr::get(unwrap(ctx), y, x)); +} + } // namespace mlir::tt diff --git a/lib/CAPI/TTNNAttrs.cpp b/lib/CAPI/TTNNAttrs.cpp index 0fb1066cb..677f22fb4 100644 --- a/lib/CAPI/TTNNAttrs.cpp +++ b/lib/CAPI/TTNNAttrs.cpp @@ -69,4 +69,14 @@ MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y, return wrap(MeshShapeAttr::get(unwrap(ctx), y, x)); } +MlirAttribute ttmlirTTNNTTNNLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, + MlirAttribute grid, MlirType memref, + unsigned memLayout) { + mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr); + return wrap(TTNNLayoutAttr::get(unwrap(ctx), affineMap, + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), + static_cast(memLayout))); +} + } // namespace mlir::tt::ttnn diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c3dc3a4b7..881d6545d 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo) include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build) +add_subdirectory(OpModel) add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 9db51169c..015702b24 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -119,12 +120,10 @@ class StableHLOToTTIRReduceOpConversionPattern tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - mlir::ArrayAttr dimArg = - adaptor.getDimensionsAttr().size() > 0 - ? rewriter.getArrayAttr(SmallVector( - 1, - rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))) - : 0; + mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector( + 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0 + ? adaptor.getDimensionsAttr()[0] + : 1))); // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without @@ -283,30 +282,81 @@ class StableHLOToTTIRDotGeneralOpConversionPattern ::mlir::stablehlo::DotDimensionNumbersAttr dimensions = adaptor.getDotDimensionNumbers(); - if (dimensions.getLhsContractingDimensions().empty() || - dimensions.getRhsContractingDimensions().empty()) { - return rewriter.notifyMatchFailure(srcOp, - "Contracting dimension is missing."); + if (dimensions.getLhsContractingDimensions().size() != 1 || + dimensions.getRhsContractingDimensions().size() != 1) { + return rewriter.notifyMatchFailure( + srcOp, + "LHS and RHS must have exactly 1 contracting dimension each. " + "Received LHS contracting dims: " + + std::to_string(dimensions.getLhsContractingDimensions().size()) + + ", RHS contracting dims: " + + std::to_string(dimensions.getRhsContractingDimensions().size())); + } + + // Use negative indexing to determine if this is a valid matmul since math + // is done over the final two dimensions. + int64_t lhsContractingDim = dimensions.getLhsContractingDimensions()[0] - + srcOp.getLhs().getType().getRank(); + int64_t rhsContractingDim = dimensions.getRhsContractingDimensions()[0] - + srcOp.getRhs().getType().getRank(); + + if (lhsContractingDim != -1) { + return rewriter.notifyMatchFailure( + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. LHS contracting dimension must be " + + std::to_string(srcOp.getLhs().getType().getRank() - 1) + + ". Got " + std::to_string(lhsContractingDim)); } - if (dimensions.getLhsContractingDimensions()[0] != 1) { + if (rhsContractingDim != -2) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. RHS contracting dimension must be " + + std::to_string(srcOp.getRhs().getType().getRank() - 2) + + ". Got " + std::to_string(rhsContractingDim)); } - if (dimensions.getRhsContractingDimensions()[0] != 0) { + if (dimensions.getLhsBatchingDimensions() != + dimensions.getRhsBatchingDimensions()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "LHS and RHS must have same batching dimensions."); } - if (!dimensions.getLhsBatchingDimensions().empty()) { + // For the RHS, all dimensions which are not the row and column dimensions + // must be 1 OR they must be equal to the corresponding dimension in the + // LHS. If the RHS has less dimensions than the LHS we will assume that the + // missing dimensions are 1. + + auto lhsShape = srcOp.getLhs().getType().getShape().vec(); + auto rhsShape = srcOp.getRhs().getType().getShape().vec(); + + if (rhsShape.size() > lhsShape.size()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "RHS must not be a higher rank than LHS."); + } + + while (rhsShape.size() < lhsShape.size()) { + rhsShape.insert(rhsShape.begin(), 1); + } + + // Need only to check dims to the left of dim -2 on the RHS + bool allOnes = true; + bool mismatchedDims = false; + for (int32_t i = rhsShape.size() - 3; i >= 0; i--) { + if (rhsShape[i] != 1) { + allOnes = false; + } + + if (rhsShape[i] != lhsShape[i]) { + mismatchedDims = true; + } } - if (!dimensions.getRhsBatchingDimensions().empty()) { + if (mismatchedDims && !allOnes) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "All dimensions in the RHS that are not the row and column " + "dimensions must be 1 OR they must all be equal to the " + "corresponding dimensions in the LHS."); } return success(); @@ -803,7 +853,7 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern llvm::SmallVector broadcastedShape; auto srcType = - getTypeConverter()->convertType(srcOp.getOperand().getType()); + getTypeConverter()->convertType(adaptor.getOperand().getType()); auto inputShape = mlir::cast(srcType).getShape(); auto outputShape = mlir::cast(srcType).getShape(); @@ -949,8 +999,8 @@ class StableHLOToTTIRConcatOpConversionPattern "ConcatOp dimension is too large."); } - auto rankedTensorType = - mlir::dyn_cast(srcOp.getOperand(0).getType()); + auto rankedTensorType = mlir::dyn_cast( + adaptor.getOperands()[0].getType()); if (static_cast(adaptor.getDimension()) >= rankedTensorType.getRank()) { return rewriter.notifyMatchFailure(srcOp, @@ -1010,6 +1060,440 @@ class StableHLOToTTIROpLogicalOpConversionPattern } }; +template +LogicalResult getReduceType(SrcOpTy &srcOp, ReduceType &reduceType) { + if constexpr (!std::is_same::value) { + return failure(); + } + // Check operations in the first block and determine reduce type for now + // TODO(wooseoklee): This pattern matching mechanism may need to be updated as + // we see complicated patterns of reduce block in the future. + auto &block = srcOp.getRegion().front(); + for (Operation &op : block) { + if (isa(op)) { + reduceType = ReduceType::Sum; + return success(); + } + if (isa(op)) { + reduceType = ReduceType::Max; + return success(); + } + if (isa(op)) { + reduceType = ReduceType::Min; + return success(); + } + } + // Other reduce types are currently not supported + return failure(); +} + +// StalbeHLO spec.md defines following channel type for ccl ops +enum StableHLOChannelType { + // CHANNEL_TYPE_INVALID = 0 : Invalid primitive type to serve as + // default. + kChannelTypeInvalid = 0, + // DEVICE_TO_DEVICE = 1 : A channel for sending data between + // devices. + kChannelTypeDeviceToDevice = 1, + // DEVICE_TO_HOST = 2 : A channel for sending data from the + // device to the host. Can only be used with a Send operation. + kChannelTypeDeviceToHost = 2, + // HOST_TO_DEVICE = 3 : A channel for sending data from the host to + // the device. Can only be used with a Recv operation. + kChannelTypeHostToDevice = 3, +}; + +class StableHLOToTTIRAllReduceOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::AllReduceOp srcOp, + mlir::stablehlo::AllReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check legality of the operation + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (failed(err)) { + return err; + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector ttirTypes; + if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), + ttirTypes))) { + return failure(); + } + + auto ttirOperands = srcOp.getOperandsMutable(); + ttirOperands.append(ValueRange(outputTensor)); + + SmallVector srcAttrs = to_vector(srcOp->getAttrs()); + SmallVector ttirAttrs; + for (auto srcAttr : srcAttrs) { + StringAttr srcName = srcAttr.getName(); + if (srcName == "channel_handle") { + auto srcChannelHandleAttr = + dyn_cast(srcAttr.getValue()); + if (!srcChannelHandleAttr) { + return failure(); + } + + // channelType is supposed to be DEVICE_TO_DEVICE for CCL ops. + // Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton. + // Consider preserving this information in the future if the attribute + // is non-DEVICE_TO_DEVICE values. + auto channelType = static_cast(srcChannelHandleAttr.getType()); + if (channelType != kChannelTypeDeviceToDevice) { + return failure(); + } + + IntegerAttr channelHandleAttr = rewriter.getSI32IntegerAttr( + static_cast(srcChannelHandleAttr.getHandle())); + if (!channelHandleAttr) { + return failure(); + } + ttirAttrs.push_back({srcName, channelHandleAttr}); + } else { + ttirAttrs.push_back(srcAttr); + } + } + + // Algorithm here is to search for the first non-one working dimension + auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); + size_t dim = 0; + for (auto s : replicaGroupsShape) { + if (s != 1) { + break; + } + ++dim; + } + if (dim > replicaGroupsShape.size()) { + // all one shape, then select the fastest dim + dim = replicaGroupsShape.size(); + } + StringAttr dimName = StringAttr::get(this->getContext(), "dim"); + IntegerAttr dimAttr = + rewriter.getSI32IntegerAttr(static_cast(dim)); + ttirAttrs.push_back({dimName, dimAttr}); + + // Parse computation in region and add it to ttirAttrs + ReduceType reduceType; + if (failed(getReduceType(srcOp, reduceType))) { + return rewriter.notifyMatchFailure( + srcOp, "AllReduceOp cannot specify reduce type."); + } + StringAttr reduceTypeAttrName = + StringAttr::get(this->getContext(), "reduce_type"); + Attribute reduceTypeAttr = rewriter.getAttr(reduceType); + ttirAttrs.push_back({reduceTypeAttrName, reduceTypeAttr}); + + StringAttr operationConstraintAttrName = + StringAttr::get(this->getContext(), "operand_constraints"); + Attribute operationConstraintAttr = rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + ttirAttrs.push_back({operationConstraintAttrName, operationConstraintAttr}); + + auto ttirAllReduceOp = rewriter.create( + srcOp.getLoc(), ttirTypes, ValueRange(ttirOperands.getAsOperandRange()), + ttirAttrs); + + rewriter.replaceOp(srcOp, ttirAllReduceOp); + + return success(); + } + +private: + LogicalResult + checkBasicLegality(mlir::stablehlo::AllReduceOp &srcOp, + mlir::stablehlo::AllReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (srcOp.getOperands().empty() || srcOp.getOperands().size() > 1) { + return rewriter.notifyMatchFailure( + srcOp, "AllReduceOp must have one input/output for now."); + } + + return success(); + } +}; // namespace + +class StableHLOToTTIRCustomCallOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::CustomCallOp srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check legality of the operation + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (failed(err)) { + return err; + } + + const std::string kShardingTarget = "Sharding"; + const std::string kSPMDFullToShardShapeTarget = "SPMDFullToShardShape"; + const std::string kSPMDShardToFullShapeTarget = "SPMDShardToFullShape"; + + auto callTargetName = adaptor.getCallTargetNameAttr(); + + // Currently stablehlo.custom_call with following functions from + // jax/openxla are supported + if (callTargetName != kShardingTarget && + callTargetName != kSPMDFullToShardShapeTarget && + callTargetName != kSPMDShardToFullShapeTarget) { + return failure(); + } + + auto shardingAttr = dyn_cast_or_null( + adaptor.getAttributes().get("mhlo.sharding")); + if (!shardingAttr) { + return failure(); + } + StringRef shardingStr = shardingAttr.getValue(); + if (!shardingStr.consume_front("{") || !shardingStr.consume_back("}")) { + return failure(); + } + SmallVector shardingStrAttrs; + shardingStr.split(shardingStrAttrs, " "); + struct ShardAttrValue shardAttrValue; + if (failed(parseShardingAttr(rewriter, shardingStrAttrs, shardAttrValue))) { + return failure(); + } + + if (callTargetName == kSPMDFullToShardShapeTarget) { + Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); + if (!shardingOp) { + return rewriter.notifyMatchFailure( + srcOp, "requires operand to be defined by an op"); + } + + // TODO(wooseoklee): a bit rough approach here to match output dim + shardingOp->getResult(0).setType(srcOp->getResult(0).getType()); + srcOp.getResult(0).replaceAllUsesWith(shardingOp->getResult(0)); + rewriter.eraseOp(srcOp); + } else if (callTargetName == kSPMDShardToFullShapeTarget) { + Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); + if (!shardingOp) { + return rewriter.notifyMatchFailure( + srcOp, "requires operand to be defined by an op"); + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp->getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector outputTypes; + if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), + outputTypes))) { + return failure(); + } + + shardAttrValue.shardDirection = mlir::tt::MeshShardDirection::ShardToFull; + if (failed(createMeshShardOp(srcOp, adaptor, outputTensor, outputTypes, + shardAttrValue, rewriter))) { + return failure(); + } + } else if (callTargetName == kShardingTarget) { + if (shardAttrValue.shardType == mlir::tt::MeshShardType::Manual) { + // "manual" sharding indicates match between input/output tensor shape + // and no sharding is required. + srcOp.getResult(0).replaceAllUsesWith(srcOp->getOperand(0)); + rewriter.eraseOp(srcOp); + } else { + auto *user = *srcOp.getResult(0).user_begin(); + auto userOp = dyn_cast_or_null(user); + if (!userOp) { + return failure(); + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(userOp->getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector outputTypes; + if (failed(this->getTypeConverter()->convertTypes( + userOp->getResultTypes(), outputTypes))) { + return failure(); + } + + shardAttrValue.shardDirection = + mlir::tt::MeshShardDirection::FullToShard; + if (failed(createMeshShardOp(srcOp, adaptor, outputTensor, outputTypes, + shardAttrValue, rewriter))) { + return failure(); + } + } + } + return success(); + } + +private: + struct ShardAttrValue { + mlir::tt::MeshShardDirection shardDirection; + mlir::tt::MeshShardType shardType; + bool lastTileDimReplicate; + std::vector shardShape; + }; + + // OpenXLA has its own lexer, but we will use simple string-based parser here + // This parsing is mainly based on "Sharding Attribute" section in + // https://github.com/sdasgup3/stablehlo/blob/80082431d1af0933e6202ecc8a6f8801e039235b/docs/spec.md + LogicalResult parseShardingAttr(ConversionPatternRewriter &rewriter, + SmallVector shardingStrAttrs, + struct ShardAttrValue &shardAttrValue) const { + MeshShardType shardType = mlir::tt::MeshShardType::Manual; + bool lastTileDimReplicate = false; + for (auto str : shardingStrAttrs) { + if (str.contains("replicated")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // replicated: all devices have whole data + shardType = mlir::tt::MeshShardType::Replicate; + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("maximal")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // maximal: one device has whole data + shardType = mlir::tt::MeshShardType::Maximal; + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("device=")) { + // maximal should followed by "device" to put data on + assert(shardType == mlir::tt::MeshShardType::Maximal && + "Fail to parse sharding info."); + int64_t d; + if (!str.consume_front("device=")) { + return failure(); + } + if (str.getAsInteger(10, d)) { + return failure(); + } + shardAttrValue.shardShape.push_back(d); + } else if (str.contains("manual")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // manual: already sharded, so no action is needed + assert(!lastTileDimReplicate && + "last time dim duplicate option shouldn't be set here."); + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("devices=")) { + // other: "devices" detail sharding plan + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + shardType = mlir::tt::MeshShardType::Devices; + if (!str.consume_front("devices=")) { + return failure(); + } + auto [devicesStr, restStr] = str.split("<="); + // parse devices ex. [4,2,1] + if (!devicesStr.consume_front("[") || !devicesStr.consume_back("]")) { + return failure(); + } + SmallVector dimsStr; + devicesStr.split(dimsStr, ","); + for (auto dim : dimsStr) { + int64_t d; + if (dim.getAsInteger(10, d)) { + return failure(); + } + shardAttrValue.shardShape.push_back(d); + } + } else if (str.contains("last_tile_dim_replicate")) { + assert(shardType == mlir::tt::MeshShardType::Devices && + "Fail to parse sharding info."); + // other: replicate last tile dim + lastTileDimReplicate = true; + } + } + shardAttrValue.shardType = shardType; + shardAttrValue.lastTileDimReplicate = lastTileDimReplicate; + return success(); + } + + LogicalResult + createMeshShardOp(mlir::stablehlo::CustomCallOp &srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + tensor::EmptyOp &outputTensor, + SmallVector &outputTypes, + ShardAttrValue &shardAttrValue, + ConversionPatternRewriter &rewriter) const { + + auto meshShardOperands = srcOp.getInputsMutable(); + meshShardOperands.append(ValueRange(outputTensor)); + SmallVector meshShardAttrs; + + StringAttr shardTypeAttrName = rewriter.getStringAttr("shard_type"); + Attribute shardTypeAttr = + rewriter.getAttr(shardAttrValue.shardType); + meshShardAttrs.push_back({shardTypeAttrName, shardTypeAttr}); + + StringAttr shardDirectionAttrName = + rewriter.getStringAttr("shard_direction"); + Attribute shardDirectionAttr = + rewriter.getAttr(shardAttrValue.shardDirection); + meshShardAttrs.push_back({shardDirectionAttrName, shardDirectionAttr}); + + StringAttr shardShapeAttrName = rewriter.getStringAttr("shard_shape"); + if (shardAttrValue.lastTileDimReplicate) { + shardAttrValue.shardShape.pop_back(); + } + GridAttr shardShape = + GridAttr::get(this->getContext(), shardAttrValue.shardShape); + meshShardAttrs.push_back({shardShapeAttrName, shardShape}); + + StringAttr operationConstraintAttrName = + StringAttr::get(this->getContext(), "operand_constraints"); + Attribute operationConstraintAttr = rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::SystemScalar))); + meshShardAttrs.push_back( + {operationConstraintAttrName, operationConstraintAttr}); + + auto meshShardOp = rewriter.create( + srcOp.getLoc(), outputTypes, + ValueRange(meshShardOperands.getAsOperandRange()), meshShardAttrs); + rewriter.replaceOp(srcOp, meshShardOp); + + return success(); + } + + LogicalResult + checkBasicLegality(mlir::stablehlo::CustomCallOp &srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Expect single input/output, otherwise do not convert + if (adaptor.getInputs().size() != 1 && srcOp->getResults().size() != 1) { + return failure(); + } + + return success(); + } +}; // namespace + class StableHLOToTTIRSliceOpConversionPattern : public OpConversionPattern { @@ -1138,8 +1622,8 @@ class StableHLOToTTIRGatherOpConversionPattern auto dimensionNumbers = srcOp.getDimensionNumbers(); rewriter.replaceOpWithNewOp( - srcOp, outputType, srcOp.getOperands()[0], - srcOp.getOperands()[1], // Start indices + srcOp, outputType, adaptor.getOperands()[0], + adaptor.getOperands()[1], // Start indices Value(outputTensor), dimensionNumbers.getOffsetDims(), dimensionNumbers.getCollapsedSliceDims(), dimensionNumbers.getOperandBatchingDims(), @@ -1154,6 +1638,167 @@ class StableHLOToTTIRGatherOpConversionPattern } }; +template +class StableHLOToTTIROpIotaOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcIotaOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + rewriter.replaceOpWithNewOp( + srcOp, outputType, 0, outputType.getDimSize(adaptor.getIotaDimension()), + 1, adaptor.getIotaDimension()); + + // Dynamic Iota has an output_shape attribute but the output shape is + // already known by the result type This is to remove the operand that will + // become dead code + for (auto operand : adaptor.getOperands()) { + if (operand.getDefiningOp()) { + rewriter.eraseOp(operand.getDefiningOp()); + } + } + + return success(); + } +}; + +class StableHLOToTTIRScatterOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ScatterOp srcOp, + mlir::stablehlo::ScatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResults()[0].getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + Value operand = srcOp.getInputs()[0]; + Value scatterIndices = srcOp.getScatterIndices(); + Value update = srcOp.getUpdates()[0]; + mlir::ArrayAttr binaryConstraints = rewriter.getArrayAttr( + SmallVector(4, rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + auto updateWindowsDims = + adaptor.getScatterDimensionNumbers().getUpdateWindowDims(); + auto insertedWindowDims = + adaptor.getScatterDimensionNumbers().getInsertedWindowDims(); + auto inputBatchingDims = + adaptor.getScatterDimensionNumbers().getInputBatchingDims(); + auto scatterIndicesBatchingDims = + adaptor.getScatterDimensionNumbers().getScatterIndicesBatchingDims(); + auto scatterDimsToOperandDims = + adaptor.getScatterDimensionNumbers().getScatterDimsToOperandDims(); + auto indexVectorDim = + adaptor.getScatterDimensionNumbers().getIndexVectorDim(); + auto indicesAreSorted = adaptor.getIndicesAreSorted(); + auto uniqueIndices = adaptor.getUniqueIndices(); + + auto newScatterOp = rewriter.create( + srcOp.getLoc(), outputType, operand, scatterIndices, update, + llvm::ArrayRef( + convertArrayRefToInt32vector(updateWindowsDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(insertedWindowDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(inputBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterIndicesBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterDimsToOperandDims)), + indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor, + binaryConstraints); + + // Replaces with different types do not work and will fail silently, so we + // manually set the second operand, since the type changes there from i32 to + // i64. + newScatterOp.setOperand( + 1, adaptor.getScatterIndices().getDefiningOp()->getResult(0)); + + newScatterOp->getRegion(0).takeBody(adaptor.getUpdateComputation()); + changeRegionTypes(newScatterOp->getRegion(0), *getTypeConverter(), + rewriter); + + rewriter.replaceOp(srcOp, newScatterOp); + + return success(); + } + +private: + std::vector + convertArrayRefToInt32vector(const llvm::ArrayRef &source) const { + std::vector converted; + converted.reserve(source.size()); + + for (int64_t value : source) { + converted.push_back(static_cast(value)); + } + + return converted; + } + + void changeRegionTypes(mlir::Region ®ion, + const mlir::TypeConverter &typeConverter, + mlir::PatternRewriter &rewriter) const { + Block &block = *region.getBlocks().begin(); + llvm::SmallVector oldArguments( + block.getArguments().begin(), block.getArguments().end()); + llvm::SmallVector newArguments; + + // Add new arguments with updated types to the block. + for (auto arg : oldArguments) { + if (auto newType = typeConverter.convertType(arg.getType())) { + mlir::BlockArgument newArg = block.addArgument(newType, arg.getLoc()); + newArguments.push_back(newArg); + } else { + newArguments.push_back(arg); // Type didn't change + } + } + + for (auto it : llvm::zip(oldArguments, newArguments)) { + mlir::BlockArgument oldArg = std::get<0>(it); + mlir::Value newArg = std::get<1>(it); + if (oldArg != newArg) { + oldArg.replaceAllUsesWith(newArg); + } + } + + for (auto arg : oldArguments) { + if (!llvm::is_contained(newArguments, arg)) { + block.eraseArgument(arg.getArgNumber()); + } + } + } +}; + +class StableHLOToTTIRReturnOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ReturnOp srcOp, + mlir::stablehlo::ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(srcOp, + srcOp.getResults()); + + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1287,6 +1932,13 @@ void addReshapeOpConversionPattern(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addCCLOpsConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, + ctx); +} + void addLogicalOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1318,6 +1970,27 @@ void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns + .add>( + typeConverter, ctx); +} + +void addScatterOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + +void addReturnOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1339,9 +2012,13 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addConcatOpsConversionPatterns(ctx, patterns, typeConverter); addReshapeOpConversionPattern(ctx, patterns, typeConverter); addLogicalOpConversionPattern(ctx, patterns, typeConverter); + addCCLOpsConversionPattern(ctx, patterns, typeConverter); addSliceOpConversionPattern(ctx, patterns, typeConverter); addClampOpConversionPattern(ctx, patterns, typeConverter); addGatherOpConversionPattern(ctx, patterns, typeConverter); + addIotaOpConversionPattern(ctx, patterns, typeConverter); + addScatterOpConversionPatterns(ctx, patterns, typeConverter); + addReturnOpConversionPatterns(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9b8c634ad..ed7eb0be8 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -16,6 +16,7 @@ #include "mlir/Transforms/DialectConversion.h" #include +#include using namespace mlir; using namespace mlir::tt; @@ -407,6 +408,13 @@ struct GatherToEmbeddingConversionPattern // collapsed slice dims of the gather op auto collapsedSliceDims = op.getCollapsedSliceDims(); + RankedTensorType operandType = + mlir::cast(op->getOperand(0).getType()); + if (!operandType.getElementType().isBF16()) { + return rewriter.notifyMatchFailure( + op, "only supports bfloat16 input tensor."); + } + if (shape.size() > 1) { auto hiddenDim = shape[shape.size() - 1]; // check if sliceSizes has more than one element @@ -775,6 +783,257 @@ class GetDimensionSizeToConstantConversionPattern } }; +// SelectOp is converted to a series of SliceOp and potentially a ConcatOp if +// the sliced dimension is sliced multiple times. For example, if the input +// tensor is +// [[[1, 2, 3], +// [4, 5, 6], +// [7, 8, 9], +// [10, 11, 12], +// [13, 14, 15], +// [16, 17, 18]], +// [[19, 20, 21], +// [22, 23, 24], +// [25, 26, 27], +// [28, 29, 30], +// [31, 32, 33], +// [34, 35, 36]]], +// shape = [2, 6, 3] +// and the SelectOp is dim=1, begin=0, length=2, stride=4, the output tensor +// will be +// [[[1, 2, 3], +// [4, 5, 6], +// [13, 14, 15], +// [16, 17, 18]], +// [[19, 20, 21], +// [22, 23, 24], +// [31, 32, 33], +// [34, 35, 36]]], +// shape = [2, 4, 3] +// In this case 2 slices are created and concatenated to form the output tensor. +// First slice has begins=[0, 0, 0], ends=[2, 2, 3], steps=[1, 1, 1], and the +// second slice has begins=[0, 4, 0], ends=[2, 6, 3], steps=[1, 1, 1]. +struct SelectToSliceConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputType = mlir::cast(adaptor.getInput().getType()); + auto outputType = mlir::cast(op.getType()); + + auto inputShape = inputType.getShape(); + + int32_t dim = + op.getDim() < 0 ? inputType.getRank() + op.getDim() : op.getDim(); + + int32_t begin = op.getBegin(); + int32_t length = op.getLength(); + int32_t stride = op.getStride(); + + int32_t inputDimSize = inputType.getShape()[dim]; + int32_t numSlices = (inputDimSize - begin + stride - 1) / stride; + + llvm::SmallVector begins, ends, steps; + for (int32_t i = 0; i < inputType.getRank(); ++i) { + // Always slicing with step 1. + steps.push_back(1); + if (i == dim) { + // Push placeholder values for now which will be updated later. + begins.push_back(0); + ends.push_back(0); + continue; + } + + // For non-sliced dimensions, begin=0, end=dimSize, step=1. + begins.push_back(0); + ends.push_back(inputType.getDimSize(i)); + } + + // Create a slice for each slice of the input tensor. The slices are then + // concatenated. The slices are created by updating the begin and end values + // for the sliced dimension. + llvm::SmallVector slices; + for (int32_t i = 0; i < numSlices; ++i) { + int32_t newBegin = begin + i * stride; + int32_t newEnd = std::min(newBegin + length, inputDimSize); + + // Make a copy of the input shape and update the dim size. + llvm::SmallVector resultShape(inputShape); + resultShape[dim] = newEnd - newBegin; + auto resultType = + RankedTensorType::get(resultShape, inputType.getElementType()); + + auto sliceDpsResult = rewriter.create( + op.getLoc(), resultShape, inputType.getElementType()); + + begins[dim] = newBegin; + ends[dim] = newEnd; + + auto newOp = rewriter.create( + op.getLoc(), resultType, adaptor.getInput(), sliceDpsResult, + rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), + rewriter.getI32ArrayAttr(steps), adaptor.getOperandConstraints()); + slices.push_back(newOp->getResult(0)); + } + + assert(!slices.empty()); + if (slices.size() > 1) { + auto concatDpsResult = rewriter.create( + op.getLoc(), outputType.getShape(), outputType.getElementType()); + auto concatOp = rewriter.create( + op.getLoc(), outputType, slices, concatDpsResult, + rewriter.getSI32IntegerAttr(dim), adaptor.getOperandConstraints()); + + rewriter.replaceOp(op, concatOp.getResult()); + } else { + rewriter.replaceOp(op, slices[0]); + } + + return success(); + } +}; + +/* + * This pattern rewrites ArangeOp by forcing the arange_dimension to be + * rightmost dimension of the output tensor. This is done by replacing the + * ArangeOp with a new one that has this property, and then transposing out last + * dimension to the dimension specified by the original ArangeOp, and also + * inserting a reshape to match the rank of the intended output and broadcasts + * to repeat the data along the other dimensions. + * + * The ArangeOp that is generated here will be equivalent to how ttnn::ArangeOp + * behaves. The reason this pass is done in TTIR rather than generated when we + * want to lower to TTNN is because in the future we will want to consteval the + * ArangeOp, but have the option to not include repeated data in the constant + * tensor and broadcast at runtime instead. Consteval will be implemented for + * the TTIR dialect only and so this explication of the TMs implicit in ArangeOp + * must be done in TTIR. + */ +struct ArangeForceLastDimensionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + const RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + + int64_t arangeDimension = adaptor.getArangeDimension(); + int64_t arangeDimensionNegative = arangeDimension - outputType.getRank(); + int64_t start = adaptor.getStart(); + int64_t end = adaptor.getEnd(); + int64_t step = adaptor.getStep(); + + int64_t arangeLength = (end - start) / step; + + ArrayRef ttnnShape = {1, 1, 1, arangeLength}; + if (ttnnShape == outputType.getShape()) { + return success(); + } + + RankedTensorType arangeOutputType = RankedTensorType::get( + SmallVector({1, 1, 1, arangeLength}), + outputType.getElementType(), outputType.getEncoding()); + + Value output = + rewriter + .create( // perform arange on the last dimension to + // match how ttnn behaves + op.getLoc(), arangeOutputType, start, end, step, 3) + .getResult(); + + std::vector outputShape = arangeOutputType.getShape().vec(); + // Must transpose the output so that the data changes along the axis defined + // by arangeDimension + if (arangeDimensionNegative != -1) { + std::vector transposeShape = outputShape; + transposeShape[arangeDimensionNegative + transposeShape.size()] = + arangeLength; + transposeShape[arangeOutputType.getRank() - 1] = 1; + RankedTensorType transposeType = RankedTensorType::get( + transposeShape, arangeOutputType.getElementType(), + arangeOutputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), transposeShape, transposeType.getElementType()); + + output = rewriter.create( + op.getLoc(), transposeType, output, dpsOutput, + arangeDimensionNegative + transposeShape.size(), + arangeOutputType.getRank() - 1, + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = transposeShape; + } + + // Must match up the rank of the output with the rank of the intended output + // from the original arange, with the arangeDimension in the correct + // position + if (outputType.getRank() != static_cast(outputShape.size())) { + std::vector reshapeShape; + for (uint32_t i = 0; i < outputType.getRank(); i++) { + i == arangeDimension ? reshapeShape.push_back(end) + : reshapeShape.push_back(1); + } + + RankedTensorType reshapeType = RankedTensorType::get( + SmallVector(reshapeShape.begin(), reshapeShape.end()), + outputType.getElementType(), outputType.getEncoding()); + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), + SmallVector(reshapeShape.begin(), reshapeShape.end()), + reshapeType.getElementType()); + output = rewriter.create( + op.getLoc(), reshapeType, output, dpsOutput, + rewriter.getI32ArrayAttr(reshapeShape), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = + std::vector(reshapeShape.begin(), reshapeShape.end()); + } + + // Must broadcast the rest of the dimensions + SmallVector broadcastDims; + for (uint32_t i = 0; i < outputShape.size(); i++) { + if (i != arangeDimension && outputShape[i] != outputType.getShape()[i]) { + outputShape[i] = outputType.getShape()[i]; + broadcastDims.push_back(rewriter.getI64IntegerAttr(i)); + } + } + if (!broadcastDims.empty()) { + RankedTensorType broadcastType = RankedTensorType::get( + outputShape, outputType.getElementType(), outputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), outputShape, outputType.getElementType()); + + output = rewriter.create( + op.getLoc(), broadcastType, output, dpsOutput, + rewriter.getArrayAttr(broadcastDims), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + assert(mlir::cast(output.getType()).getShape() == + outputType.getShape() && + "Output shape must match the shape of the input tensor"); + } + rewriter.replaceOp(op, output); + return success(); + } +}; + void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -783,6 +1042,8 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 76cbae96e..e244eea8f 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -51,6 +51,15 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + + // These are the ops that must satisfy some conditions after this pass + target.addDynamicallyLegalOp([&](ttir::ArangeOp op) { + auto shape = op.getResult().getType().getShape(); + return (static_cast(op.getArangeDimension()) == 3 && + shape.size() == 4 && shape[0] == 1 && shape[1] == 1 && + shape[2] == 1); + }); TypeConverter typeConverter; // All types map 1:1. diff --git a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp index a3bbddc1d..09727e203 100644 --- a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp +++ b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp @@ -199,8 +199,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult relayout(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); tt::DeviceAttr device = op.getDevice(); assert(device); tt::SystemDescAttr systemDesc = op.getSystemDesc(); @@ -342,8 +342,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult reformat(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); bool shouldTilize = not inputLayout.isTiled() && outputLayout.isTiled(); bool shouldUntilize = inputLayout.isTiled() && not outputLayout.isTiled(); assert(shouldTilize ^ shouldUntilize); @@ -448,10 +448,10 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { return failure(); } assert(inputTy.getShape() == outputTy.getShape()); - assert(mlir::isa(inputTy.getEncoding())); - assert(mlir::isa(outputTy.getEncoding())); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + assert(mlir::isa(inputTy.getEncoding())); + assert(mlir::isa(outputTy.getEncoding())); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); auto components = op.compoundComponents(); bool isCompound = (static_cast(components.isLayoutChange) + @@ -799,6 +799,8 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { inCB1); } else if (mlir::isa(arithOrMathOp)) { builder.create(arithOrMathOp.getLoc()); + } else if (mlir::isa(arithOrMathOp)) { + builder.create(arithOrMathOp.getLoc()); } else { llvm_unreachable("Unhandled binary op init conversion."); } @@ -905,27 +907,13 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { assert(cbOperands.size() == 3 && "Expected two input and one output CB for binary op."); - auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB0 = cbOperands[0]; - auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; - - auto location = arithOrMathOp.getLoc(); - - // Perform computation C = A (*) B on tile A from inCB0 and tile B from - // inCB1 and store the result C in DST register on dstTileIndex. + // Perform computation C = A (*) B on tile A from cbOperands[0] and tile B + // from cbOperands[1] and store the result C in DST register on + // dstTileIndex. if (mlir::isa(arithOrMathOp)) { - Value dstIndex = i32(0, builder); - builder.create(location); - builder.create( - location, inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex); - builder.create(location); - builder.create(location); - builder.create(location, dstIndex, outCB, - outCBTileIndex); - builder.create(location); + convertComputeBinaryFPUOp( + arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, + builder); } else if (mlir::isa(arithOrMathOp)) { commonComputeMulOp(arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, builder); @@ -938,6 +926,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { blockArgIteratorMapping, builder, operandIndicesRecip); + auto inCB0 = cbOperands[0]; + auto inCB1 = cbOperands[1]; + auto location = arithOrMathOp.getLoc(); + Value one = i32(1, builder); builder.create(location, inCB1, one); @@ -947,12 +939,95 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { blockArgIteratorMapping, builder); builder.create(location, inCB1, one); + } else if (mlir::isa(arithOrMathOp)) { + convertComputeBinarySFPUOp( + arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, + builder); } else { llvm_unreachable("Unhandled conversion for operation which is neither " "unary nor binary."); } } + template + void convertComputeBinaryFPUOp( + Operation &arithOrMathOp, ArrayRef cbOperands, + ArrayRef iterators, + const SmallVector &blockArgIteratorMapping, + OpBuilder &builder) const { + auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; + auto inCB0 = cbOperands[0]; + auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; + auto inCB1 = cbOperands[1]; + auto outCB = cbOperands[2]; + auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; + + auto location = arithOrMathOp.getLoc(); + + Value dstIndex = i32(0, builder); + + // acquire DST register lock (MATH) + builder.create(location); + { + builder.create(location, inCB0, inCB1, inCB0TileIndex, + inCB1TileIndex, dstIndex); + } + builder.create(location); + // release DST register lock (MATH) + + // acquire DST register lock (PACK) + builder.create(location); + { + builder.create(location, dstIndex, outCB, + outCBTileIndex); + } + builder.create(location); + // release DST register lock (PACK) + } + + template + void convertComputeBinarySFPUOp( + Operation &arithOrMathOp, ArrayRef cbOperands, + ArrayRef iterators, + const SmallVector &blockArgIteratorMapping, + OpBuilder &builder) const { + auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; + auto inCB0 = cbOperands[0]; + auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; + auto inCB1 = cbOperands[1]; + auto outCB = cbOperands[2]; + auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; + + auto location = arithOrMathOp.getLoc(); + + Value dstLhsTileIndex = i32(0, builder); + Value dstRhsTileIndex = i32(1, builder); // note: rhs is always lhs+1 + + // acquire DST register lock (MATH) + builder.create(location); + { + // copy inCB0[inCB0TileIndex] and inCB1[inCB1TileIndex] to DST: + builder.create(location, inCB0, inCB0TileIndex, + dstLhsTileIndex); + builder.create(location, inCB1, inCB1TileIndex, + dstRhsTileIndex); + // SFPU ooperates on DST tiles: + builder.create(location, dstLhsTileIndex, + dstRhsTileIndex); + } + builder.create(location); + // release DST register lock (MATH) + + // acquire DST register lock (PACK) + builder.create(location); + { + builder.create(location, dstLhsTileIndex, outCB, + outCBTileIndex); + } + builder.create(location); + // release DST register lock (PACK) + } + void commonComputeMulOp(Operation &op, ArrayRef cbOperands, ArrayRef iterators, SmallVector blockArgIteratorMapping, @@ -1308,10 +1383,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { SmallVector> calculateDataMovement(ArrayAttr iteratorTypes, const RankedTensorType &src, const RankedTensorType &dst, DeviceAttr device) const { - auto srcLayout = mlir::cast(src.getEncoding()); + auto srcLayout = mlir::cast(src.getEncoding()); assert(srcLayout.isTiled()); - auto dstLayout = mlir::cast(dst.getEncoding()); + auto dstLayout = mlir::cast(dst.getEncoding()); assert(dstLayout.isTiled()); assert(iteratorTypes.size() >= 2 && "Expected at least 2 iterator types"); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 12e29a960..789485eac 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -65,20 +65,15 @@ class TensorEmptyConversionPattern // Get the shape of the tensor, tensor layout, and data type // - mlir::MemRefType memref = layoutAttr.getMemref(); ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get( rewriter.getContext(), mlir::cast(op->getResult(0).getType()).getShape()); - Type elementType = memref.getElementType(); - DataType dtype = DataType::Float32; + DataType dtype = layoutAttr.getDataType(); ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - if (llvm::isa(elementType)) { + if (layoutAttr.isTiled()) { ttnnLayoutEnum = ttnn::Layout::Tile; - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); } else { ttnnLayoutEnum = ttnn::Layout::RowMajor; - dtype = elementTypeToDataType(elementType); } DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype); ttnn::LayoutAttr tensorLayoutAttr = @@ -101,13 +96,14 @@ class TensorEmptyConversionPattern // Create MemoryConfigAttr // auto device = getOrInsertDevice(rewriter, op); + llvm::SmallVector shardShape = layoutAttr.getShardShape(); ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get( op.getContext(), ttnn::TensorMemoryLayoutAttr::get(op.getContext(), memLayout), ttnn::BufferTypeAttr::get(op.getContext(), bufferType), ttnn::ShardSpecAttr::get( op.getContext(), - ttnn::ShapeAttr::get(op.getContext(), memref.getShape()))); + ttnn::ShapeAttr::get(op.getContext(), shardShape))); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), device, @@ -137,18 +133,15 @@ class ToLayoutOpConversionPattern auto outputLayoutAttr = mlir::cast( op.getResult().getType().getEncoding()); - auto outputMemref = outputLayoutAttr.getMemref(); - // Determine the output data type - DataType dtype = ttnn::utils::getDataTypeFromMemRef(outputMemref); + DataType dtype = outputLayoutAttr.getDataType(); DataTypeAttr outputDataType = DataTypeAttr::get(rewriter.getContext(), dtype); // Determine the output layout (tile or row major) ttnn::BufferType outputBufferType = outputLayoutAttr.getBufferType(); - ttnn::Layout outputLayoutEnum = - ttnn::utils::getLayoutFromMemRef(outputMemref); + ttnn::Layout outputLayoutEnum = outputLayoutAttr.getLayout(); bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory); @@ -176,13 +169,14 @@ class ToLayoutOpConversionPattern op.getResult().setType(result); outputLayoutAttr = mlir::cast(result.getEncoding()); - outputMemref = outputLayoutAttr.getMemref(); outputLayoutEnum = newOutputLayoutEnum; } } ttnn::LayoutAttr outputLayout = ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum); + llvm::SmallVector outputShardShape = + outputLayoutAttr.getShardShape(); // Determine output memory config attr ttnn::TensorMemoryLayout outputTensorMemoryLayout = @@ -193,8 +187,8 @@ class ToLayoutOpConversionPattern outputTensorMemoryLayout), ttnn::BufferTypeAttr::get(rewriter.getContext(), outputBufferType), ttnn::ShardSpecAttr::get( - op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(), - outputMemref.getShape()))); + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), outputShardShape))); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), @@ -222,15 +216,16 @@ class ToLayoutOpConversionPattern ttnn::Layout newOutputLayoutEnum) const { auto oldOutputLayoutAttr = mlir::cast(oldOutput.getEncoding()); - auto oldOutputMemref = oldOutputLayoutAttr.getMemref(); - DataType outputDtype = ttnn::utils::getDataTypeFromMemRef(oldOutputMemref); - llvm::ArrayRef oldShardShape = oldOutputMemref.getShape(); + DataType outputDtype = oldOutputLayoutAttr.getDataType(); + SmallVector oldShardShape = + oldOutputLayoutAttr.getShardShape(); size_t shardShapeSize = oldShardShape.size(); assert(shardShapeSize >= 2 && "expected at least 2D shape"); if (newOutputLayoutEnum == ttnn::Layout::RowMajor) { // Set shard shape to match convention of row major layout - auto tileType = mlir::cast(oldOutputMemref.getElementType()); + auto tileType = + mlir::cast(oldOutputLayoutAttr.getElementType()); llvm::SmallVector newShardShape(oldShardShape.begin(), oldShardShape.end()); newShardShape[shardShapeSize - 2] = @@ -579,7 +574,19 @@ class ConstantOpConversionPattern } }; -} // namespace +class LinearOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::LinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), + adaptor.getB(), adaptor.getBias(), adaptor.getOutput()); + return success(); + } +}; // ANCHOR: adding_an_op_matmul_op_rewriter class MatmulOpConversionPattern : public OpConversionPattern { @@ -792,9 +799,7 @@ class TypecastOpConversionPattern ttnn::TTNNLayoutAttr outputLayoutAttr = mlir::cast(result.getType().getEncoding()); - mlir::MemRefType outputMemref = outputLayoutAttr.getMemref(); - - DataType outputDataType = ttnn::utils::getDataTypeFromMemRef(outputMemref); + DataType outputDataType = outputLayoutAttr.getDataType(); if (op->getUsers().empty()) { return rewriter.notifyMatchFailure( @@ -807,46 +812,6 @@ class TypecastOpConversionPattern } }; -class BroadcastOpConversionPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Fold this operation into all consumer ops. It will only work with TTNN - // ops that support implicit broadcasting. We expect each Op's verify - // function to assert their arguments to verify that they can broadcast. - - if (srcOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - - mlir::Value input = srcOp.getOperand(0); - - mlir::Operation *nextOp = srcOp; - while (isa(*nextOp->getUsers().begin())) { - assert(nextOp->hasOneUse() && - "Broadcast with multiple uses are not supported"); - nextOp = *nextOp->getUsers().begin(); - if (nextOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - } - - rewriter.replaceAllOpUsesWith(nextOp, input); - rewriter.eraseOp(srcOp); - - return success(); - } -}; - class SubtractOpConversionPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -908,6 +873,63 @@ class AllGatherOpConversionPattern } }; +class ArangeOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + assert(static_cast(adaptor.getArangeDimension()) == + outputType.getRank() - 1 && + "Arange dimension must be the final dimension of the output tensor " + "to convert to ttnn.arange"); + + // Get ttnn::TTNNLayoutAttr of the result type + // + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(outputType.getEncoding()); + + DataTypeAttr dtypeAttr = rewriter.getAttr( + elementTypeToDataType(outputType.getElementType())); + Value device = getOrInsertDevice(rewriter, op); + + ttnn::MemoryConfigAttr memConfigAttr = + rewriter.getAttr( + rewriter.getAttr( + layoutAttr.getMemLayout()), + rewriter.getAttr(layoutAttr.getBufferType()), + rewriter.getAttr( + rewriter.getAttr(layoutAttr.getShardShape()))); + + rewriter.replaceOpWithNewOp( + op, outputType, adaptor.getStart(), adaptor.getEnd(), adaptor.getStep(), + dtypeAttr, device, memConfigAttr); + + return success(); + } +}; + +class ScatterOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The ttnn interface has the inverse inputs of the TTIR dialect op (which + // matches torch ops). + rewriter.replaceOpWithNewOp( + op, adaptor.getUpdate(), adaptor.getInput(), adaptor.getOutput()); + + return success(); + } +}; +} // namespace + namespace mlir::tt { void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, @@ -957,7 +979,6 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, - BroadcastOpConversionPattern, EmbeddingOpConversionPattern, SoftmaxOpConversionPattern, TransposeOpConversionPattern, @@ -969,11 +990,14 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SqueezeOpConversionPattern, UnsqueezeOpConversionPattern, ConstantOpConversionPattern, + LinearOpConversionPattern, MatmulOpConversionPattern, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, - AllGatherOpConversionPattern + AllGatherOpConversionPattern, + ArangeOpConversionPattern, + ScatterOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp index 312377eb6..c265e8928 100644 --- a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp +++ b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp @@ -406,8 +406,10 @@ class ConvertTTKernelToEmitCPass TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, @@ -419,6 +421,12 @@ class ConvertTTKernelToEmitCPass TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastOnePacketOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastLoopbackSrcOp>, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, @@ -473,6 +481,8 @@ class ThreadConfigHelper { builder->create(loc, "compute_kernel_api/eltwise_binary.h", /*isStandard=*/false); + builder->create(loc, "compute_kernel_api.h", // max ops + /*isStandard=*/false); builder->create(loc, "compute_kernel_api/tile_move_copy.h", /*isStandard=*/false); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 9b7cf7fe8..f04d5566b 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -618,6 +618,35 @@ class DeallocateOpConversionPattern } }; +// Module Op conversion pattern +// +// This conversion pattern removes attributes from the ModuleOp. Previously, +// ttmlir-translate would complain when translating to C++ if there were any +// attributes from "unregistered" dialects. +// +class ModuleOpConversionPattern + : public TTNNToEmitCBaseOpConversionPattern { + +public: + ModuleOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : TTNNToEmitCBaseOpConversionPattern(typeConverter, + context, benefit) {} + + LogicalResult + matchAndRewrite(mlir::ModuleOp srcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.modifyOpInPlace(srcOp, [&]() { + for (const NamedAttribute &attr : srcOp->getAttrs()) { + srcOp->removeAttr(attr.getName()); + } + }); + + return success(); + } +}; + } // namespace namespace mlir::tt { @@ -639,8 +668,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Tensor ops // patterns - .add>( - typeConverter, ctx); + .add, + DefaultOpConversionPattern>(typeConverter, ctx); // Eltwise unary ops // @@ -684,6 +713,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern>(typeConverter, ctx); @@ -696,7 +726,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Matmul ops // - patterns.add>(typeConverter, ctx); + patterns.add, + DefaultOpConversionPattern>(typeConverter, ctx); // Reduction ops // @@ -720,6 +751,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add>(typeConverter, ctx); + + // Module op + // + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp index 71a7c52b6..bd0c9044f 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp @@ -4,6 +4,11 @@ #include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -12,11 +17,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "ttmlir/Dialect/TTNN/IR/TTNN.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" - using namespace mlir; using namespace mlir::tt; @@ -48,14 +48,20 @@ struct ConvertTTNNToEmitCPass void runOnOperation() override { mlir::ConversionTarget target(getContext()); + // EmitC is legal, TTNN is illegal + // target.addLegalDialect(); target.addIllegalDialect(); - target.addLegalOp(); + + // mlir::ModuleOp is legal only if no attributes are present on it + // + target.addDynamicallyLegalOp( + [&](mlir::ModuleOp op) { return op->getAttrs().empty(); }); // Add header imports to front of module // { - auto module = getOperation(); + mlir::ModuleOp module = getOperation(); OpBuilder builder(module); if (module.getBodyRegion().empty()) { @@ -107,7 +113,7 @@ struct ConvertTTNNToEmitCPass return; } } - }; + } }; } // namespace diff --git a/lib/Conversion/TosaToTTIR/CMakeLists.txt b/lib/Conversion/TosaToTTIR/CMakeLists.txt index 41baf75c6..56000eb65 100644 --- a/lib/Conversion/TosaToTTIR/CMakeLists.txt +++ b/lib/Conversion/TosaToTTIR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TTMLIRTosaToTTIR - TosaToTTIR.cpp + TosaToTTIRPass.cpp + TosaToTTIRPatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TosaToTTIR diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp deleted file mode 100644 index 6c6a7faf5..000000000 --- a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp +++ /dev/null @@ -1,122 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace tt; - -namespace mlir::tt::ttir { - -#define GEN_PASS_DEF_CONVERTTOSATOTTIR -#include "ttmlir/Conversion/Passes.h.inc" - -} // namespace mlir::tt::ttir - -namespace { - -template -class TosaToTTIROpConversionPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(SrcOp srcOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if constexpr (std::is_same::value) { - assert(srcOp.getShift() == 0); - } - - auto outputType = mlir::cast(srcOp.getResult().getType()); - auto outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - rewriter.replaceOpWithNewOp( - srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), - ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); - return success(); - } -}; - -struct ConvertTosaToTTIRPass - : public ttir::impl::ConvertTosaToTTIRBase { - void runOnOperation() override { - mlir::ConversionTarget target(getContext()); - - target.addIllegalDialect(); - - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - // For now keep the same type assuming tosa ops operate on builtin tensor. - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { - assert(isa(type) && - "only ranked tensor type supported"); - return type; - }); - RewritePatternSet patterns(&getContext()); - - // Add conversion patterns. - patterns - .add>( - typeConverter, &getContext()); - patterns - .add>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add>( - typeConverter, &getContext()); - - // Apply conversion. - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - return; - } - } -}; - -} // namespace - -namespace mlir::tt { - -std::unique_ptr> createConvertTosaToTTIRPass() { - return std::make_unique(); -} - -} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp new file mode 100644 index 000000000..183d58cca --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_CONVERTTOSATOTTIR +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +struct ConvertTosaToTTIRPass + : public ttir::impl::ConvertTosaToTTIRBase { + void runOnOperation() override { + mlir::ConversionTarget target(getContext()); + + target.addIllegalDialect(); + + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + // For now keep the same type assuming tosa ops operate on builtin tensor. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { + assert(isa(type) && + "only ranked tensor type supported"); + return type; + }); + RewritePatternSet patterns(&getContext()); + + // Add conversion patterns. + populateTosaToTTIRPatterns(&getContext(), patterns, typeConverter); + + // Apply conversion. + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> createConvertTosaToTTIRPass() { + return std::make_unique(); +} + +} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp new file mode 100644 index 000000000..46eadb789 --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace { + +// TODO(sdjukic): extract this pattern into separate file and use it for both +// TOSA and StableHLO + +template +class TosaToTTIRDefaultDPSOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + LogicalResult legalityResult = + checkConversionLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), + ValueRange(outputTensor), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } + +private: + virtual LogicalResult + checkConversionLegality(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return success(); + } +}; + +class TosaToTTIRMultiplyOpConversionPattern + : public TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, mlir::tt::ttir::MultiplyOp> { + using TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, + mlir::tt::ttir::MultiplyOp>::TosaToTTIRDefaultDPSOpConversionPattern; + +private: + LogicalResult + checkConversionLegality(tosa::MulOp srcOp, tosa::MulOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (srcOp.getShift() != 0) { + return rewriter.notifyMatchFailure( + srcOp, "TTIR MultiplyOp doesn't support shifted multiply."); + } + return success(); + } +}; + +void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); +} + +void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); +} + +void addCompareOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, + ctx); +} + +} // namespace + +namespace mlir::tt { + +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); + addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); + addCompareOpsConversionPatterns(ctx, patterns, typeConverter); +} + +} // namespace mlir::tt diff --git a/lib/Dialect/TT/IR/TTDialect.cpp b/lib/Dialect/TT/IR/TTDialect.cpp index 6f629d697..1ac8a2223 100644 --- a/lib/Dialect/TT/IR/TTDialect.cpp +++ b/lib/Dialect/TT/IR/TTDialect.cpp @@ -13,13 +13,13 @@ using namespace mlir; using namespace mlir::tt; -// This is needed to hoist tt.layout attributes as named attributes declared at -// the module level. +// This is needed to hoist tt.metal_layout attributes as named attributes +// declared at the module level. struct TTOpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (llvm::isa(attr)) { + if (llvm::isa(attr)) { os << "layout"; return AliasResult::OverridableAlias; } diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index bbdd4e259..12166e443 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -466,7 +466,7 @@ calculateLogicalShardShape(mlir::ArrayRef tensorShape, return shardShape; } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -483,7 +483,7 @@ LayoutAttr LayoutAttr::get( return get(context, linear, oobVal, grid, memref, memLayout); } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -493,9 +493,11 @@ LayoutAttr LayoutAttr::get( collapseIntervals, oobVal, memLayout); } -LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, - MemorySpace memorySpace, GridAttr grid, - Type elementType, TensorMemoryLayout memLayout) { +MetalLayoutAttr MetalLayoutAttr::get(::mlir::MLIRContext *context, + RankedTensorType ty, + MemorySpace memorySpace, GridAttr grid, + Type elementType, + TensorMemoryLayout memLayout) { assert(ty); assert(grid); return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}}, @@ -506,7 +508,7 @@ LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, // compute the physical shape of the tensor, i.e the shape of the tensor // after the dimensions have been collapsed onto a grid. llvm::SmallVector -LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { +MetalLayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { llvm::SmallVector physicalShape(getGrid().getShape().size()); SmallVector logicalShapeExprs( llvm::map_range(logicalShape, [context = getContext()](std::int64_t e) { @@ -525,7 +527,7 @@ LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getStride(ArrayRef logicalShape) const { +MetalLayoutAttr::getStride(ArrayRef logicalShape) const { llvm::SmallVector stride(logicalShape.size()); @@ -574,7 +576,7 @@ LayoutAttr::getStride(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getShardShape(bool convertTileToScalar) const { +MetalLayoutAttr::getShardShape(bool convertTileToScalar) const { SmallVector shardShape(getMemref().getShape()); auto elementType = getElementType(); if (mlir::isa(elementType) && convertTileToScalar) { @@ -583,11 +585,11 @@ LayoutAttr::getShardShape(bool convertTileToScalar) const { return shardShape; } -mlir::Type LayoutAttr::getElementType() const { +mlir::Type MetalLayoutAttr::getElementType() const { return getMemref().getElementType(); } -mlir::Type LayoutAttr::getScalarElementType() const { +mlir::Type MetalLayoutAttr::getScalarElementType() const { auto elementType = getElementType(); if (mlir::isa(elementType)) { return mlir::cast(elementType).getElementType(); @@ -595,33 +597,33 @@ mlir::Type LayoutAttr::getScalarElementType() const { return elementType; } -bool LayoutAttr::hasShardedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::hasShardedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::isTiled() const { +bool MetalLayoutAttr::isTiled() const { return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } -uint64_t LayoutAttr::getElementSizeBytes() const { +uint64_t MetalLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); if (mlir::isa(elementType)) { auto tileType = mlir::cast(elementType); @@ -630,7 +632,7 @@ uint64_t LayoutAttr::getElementSizeBytes() const { return elementType.getIntOrFloatBitWidth() / 8; } -uint64_t LayoutAttr::getMemrefSizeBytes() const { +uint64_t MetalLayoutAttr::getMemrefSizeBytes() const { MemRefType ty = getMemref(); auto shape = ty.getShape(); uint64_t size = getElementSizeBytes(); @@ -638,57 +640,60 @@ uint64_t LayoutAttr::getMemrefSizeBytes() const { std::multiplies()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { return get(context, tensorShape, getElementType(), getMemorySpace(), grid, collapseIntervals, getOobVal(), getMemLayout()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { assert(ty); - return LayoutAttr::withGrid(context, ty.getShape(), grid, collapseIntervals); + return MetalLayoutAttr::withGrid(context, ty.getShape(), grid, + collapseIntervals); } -LayoutAttr LayoutAttr::withElementType(::mlir::MLIRContext *context, - Type elementType) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withElementType(::mlir::MLIRContext *context, + Type elementType) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), elementType, getMemorySpace()), getMemLayout()); } -LayoutAttr LayoutAttr::withMemorySpace(::mlir::MLIRContext *context, - MemorySpace memorySpace) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withMemorySpace(::mlir::MLIRContext *context, + MemorySpace memorySpace) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), getElementType(), memorySpace), getMemLayout()); } -LayoutAttr LayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, - TensorMemoryLayout memLayout) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayout memLayout) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, getShardShape(), getElementType(), getMemorySpace()), memLayout); } -LayoutAttr LayoutAttr::withShardShape(::mlir::MLIRContext *context, - llvm::SmallVector shardShape) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withShardShape(::mlir::MLIRContext *context, + llvm::SmallVector shardShape) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, shardShape, getElementType(), getMemorySpace()), getMemLayout()); } -MemorySpace LayoutAttr::getMemorySpace() const { +MemorySpace MetalLayoutAttr::getMemorySpace() const { return mlir::cast(getMemref().getMemorySpace()) .getValue(); } @@ -696,7 +701,7 @@ MemorySpace LayoutAttr::getMemorySpace() const { // Returns shape of the tensor after tilization is applied to the two inner most // dimensions. llvm::SmallVector -LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { +MetalLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); mlir::AffineMap linear = getLinear(); @@ -716,7 +721,7 @@ LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { return ttmlir::utils::evalShape(tiled, tensorShape); } -mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { +mlir::AffineMap MetalLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); return mlir::AffineMap::getMultiDimIdentityMap(getLinear().getNumResults(), @@ -735,7 +740,7 @@ mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { // (d0, d1)[2, 3] -> // (0, d0 floordiv 2, d1 floordiv 3, (d0 mod 2) * 3 + d1 mod 3) // -mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( +mlir::AffineMap MetalLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { mlir::SmallVector shardShape = getShardShape(false /*convertTileToScalar*/); @@ -763,8 +768,8 @@ mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( // grid. Then it composes the logical grid projection with physical memory // mapping. mlir::AffineMap -LayoutAttr::projectOnto(mlir::AffineMap linearMap, - mlir::AffineMap physicalMemoryMap) const { +MetalLayoutAttr::projectOnto(mlir::AffineMap linearMap, + mlir::AffineMap physicalMemoryMap) const { assert(getGrid().getShape().size() == physicalMemoryMap.getNumDims() && "Layout and device grids must have same number of dimensions"); assert(getLinear().getNumResults() == physicalMemoryMap.getNumDims() && @@ -1013,7 +1018,7 @@ DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, // Sample the last index in the tensor to get the last addressable element of // the tensor to determine its footprint in memory. uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, - LayoutAttr layout, + MetalLayoutAttr layout, MemorySpace memorySpace) const { SmallVector shape = layout.isTiled() ? layout.getTiledShape(tensorScalarShape) @@ -1035,9 +1040,9 @@ uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, uint64_t DeviceAttr::getTensorSizeBytes(RankedTensorType tensorType, MemorySpace memorySpace) const { assert(tensorType.getEncoding()); - return getLayoutSizeBytes(tensorType.getShape(), - mlir::cast(tensorType.getEncoding()), - memorySpace); + return getLayoutSizeBytes( + tensorType.getShape(), + mlir::cast(tensorType.getEncoding()), memorySpace); } ::mlir::LogicalResult diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index feedc845c..bc1f02868 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -45,6 +45,37 @@ ::mlir::LogicalResult mlir::tt::ttir::ClampOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::ArangeOp::verify() { + int64_t start = getStart(); + int64_t end = getEnd(); + int64_t step = getStep(); + + if (step == 0) { + return emitOpError("Step value cannot be zero"); + } + + int64_t numValues = (end - start) / step; + + if (numValues <= 0) { + return emitOpError() << "Invalid range: start=" << start << ", end=" << end + << ", step=" << step; + } + + if (numValues != getType().getDimSize(getArangeDimension())) { + return emitOpError() << "Output tensor shape must be " << numValues + << " at dim " << getArangeDimension() + << " (since start=" << start << ", end=" << end + << ", step=" << step << "), but got " + << getType().getDimSize(getArangeDimension()); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -605,6 +636,100 @@ ::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() { } // ANCHOR_END: decomposing_an_op_index_ttir_verify +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// SelectOp verification +::mlir::LogicalResult mlir::tt::ttir::SelectOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + ::mlir::RankedTensorType outputType = getOutput().getType(); + + if (inputType.getRank() != outputType.getRank()) { + return emitOpError("Input and output tensors must have the same rank."); + } + + if (inputType.getElementType() != outputType.getElementType()) { + return emitOpError("Input and output tensors must have the same element " + "type."); + } + + int32_t dim = getDim(); + int32_t origDim = dim; + if (dim < 0) { + dim += inputType.getRank(); + } + + if (dim < 0 || dim >= inputType.getRank()) { + return emitOpError() << "Invalid dimension " << origDim + << " for select op with input tensor rank " + << inputType.getRank(); + } + + int32_t dimSize = inputType.getDimSize(dim); + + int32_t stride = getStride(); + if (stride == 0) { + stride = dimSize; + } + + if (stride < 0) { + return emitOpError() << "Invalid stride " << stride << " for dimension " + << dim << ", stride must be non-negative"; + } + + if (stride > dimSize) { + return emitOpError() << "Invalid stride " << stride << " for dimension " + << dim << " with size " << dimSize + << ". stride must be less than or equal to the " + "dimension size"; + } + + int32_t begin = getBegin(); + int32_t length = getLength(); + if (begin < 0 || begin >= dimSize) { + return emitOpError() << "Invalid begin index " << begin << " for dimension " + << dim << " with size " << dimSize + << ". begin must be " + "in the range [0, dimSize)"; + } + + if (length < 1 || length > stride) { + return emitOpError() << "Invalid length " << length << " for begin index " + << begin << " and stride " << stride + << " for dimension " << dim << " with size " << dimSize + << ". stride must be greater than or equal to length"; + } + + if (begin + length > dimSize) { + return emitOpError() << "Invalid length " << length << " for begin index " + << begin << " and dimension " << dim << " with size " + << dimSize + << ". begin + length must be less than or " + "equal to the dimension size"; + } + + // Get the number of slices as the number of times the stride fits in the + // dimension size starting from the begin index. + int32_t numSlices = (dimSize - begin + stride - 1) / stride; + int32_t totalLength = 0; + for (int32_t i = 0; i < numSlices; i++) { + int32_t newBegin = begin + i * stride; + int32_t newEnd = std::min(newBegin + length, dimSize); + totalLength += newEnd - newBegin; + } + + if (totalLength != outputType.getDimSize(dim)) { + return emitOpError() << "Sum of all slices must be equal to the output " + "dimension size for the given dimension. Expected " + "output dimension size: " + << outputType.getDimSize(dim) << ", but got " + << totalLength; + } + + return success(); +} + //===----------------------------------------------------------------------===// // SqueezeOp //===----------------------------------------------------------------------===// @@ -792,9 +917,9 @@ ::mlir::LogicalResult mlir::tt::ttir::ToLayoutOp::verify() { mlir::tt::ttir::ToLayoutOp::CompoundComponents mlir::tt::ttir::ToLayoutOp::compoundComponents() { auto inputLayout = - mlir::cast(getInput().getType().getEncoding()); + mlir::cast(getInput().getType().getEncoding()); auto outputLayout = - mlir::cast(getOutput().getType().getEncoding()); + mlir::cast(getOutput().getType().getEncoding()); bool isLayoutChange = inputLayout.getLinear() != outputLayout.getLinear(); bool isGridChange = inputLayout.getGrid() != outputLayout.getGrid(); bool isShardChange = @@ -810,6 +935,158 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() { isMemoryLayoutChange}; } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -948,7 +1225,7 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { // AllocOp verification ::mlir::LogicalResult mlir::tt::ttir::AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); @@ -1021,6 +1298,102 @@ ::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +// AllReduceOp verification +::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() { + ::mlir::RankedTensorType inputType = + mlir::cast(getInputs().front().getType()); + int32_t dim = getDim(); + + if (dim >= inputType.getRank()) { + return emitOpError("Invalid dimension for all_reduce op."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// MeshShardOp +//===----------------------------------------------------------------------===// + +// MeshShardOp verification +::mlir::LogicalResult mlir::tt::ttir::MeshShardOp::verify() { + auto shardType = getShardType(); + + // currently we are only supporting replicate or devices from StableHLO + if (shardType != mlir::tt::MeshShardType::Replicate && + shardType != mlir::tt::MeshShardType::Devices) { + return emitOpError("Invalid shard_type for mesh_shard op."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +bool matchSimpleBlock(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return false; + } + mlir::Block &block = region.front(); + if (block.getNumArguments() != 2) { + return false; + } + auto argType1 = + mlir::cast(block.getArgument(0).getType()); + auto argType2 = + mlir::cast(block.getArgument(1).getType()); + if (!argType1 || !argType2) { + return false; + } + if (block.getOperations().size() != 1) { + return false; + } + mlir::tt::ttir::YieldOp returnOp = + mlir::cast(&block.front()); + if (!returnOp) { + return false; + } + if (returnOp.getNumOperands() != 1 || + returnOp.getOperand(0) != block.getArgument(1)) { + return false; + } + return true; +} + +::mlir::LogicalResult mlir::tt::ttir::ScatterOp::verify() { + + ArrayRef inputShape = + mlir::cast(getInput().getType()).getShape(); + + if (getUpdateWindowDims().size() + getInsertedWindowDims().size() != + inputShape.size()) { + return emitOpError("Batching currently not supported"); + } + + for (uint64_t insertedWindowDims : getInsertedWindowDims()) { + if (inputShape[insertedWindowDims] != 1) { + return emitOpError("Dimension size to slice into must be 1"); + } + } + + // We currently do not support custom functions in the scatter function, + // which is a possbility in StableHLO dialect. See issue: + // https://github.com/tenstorrent/tt-mlir/issues/1278 + if (!matchSimpleBlock(getUpdateComputation())) { + return emitOpError( + "Currently not supporting custom scatter function in TTNN " + "dialect and TT-metal."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// @@ -1102,6 +1475,13 @@ void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, block); } +// MaximumOp generic region builder +void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, + ::mlir::Block *block) { + buildGenericEltwiseBinaryRegion(getLoc(), opBuilder, + block); +} + //===----------------------------------------------------------------------===// // KernelOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp index 84409174a..10619f24b 100644 --- a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp +++ b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp @@ -17,37 +17,33 @@ #include "llvm/ADT/SmallVector.h" mlir::LogicalResult -mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) { +mlir::tt::ttir::detail::verifyBroadcastable(mlir::Operation *op) { + const auto getShape = [](const Value val) { + return mlir::cast(val.getType()).getShape(); + }; + + const auto operandSegmentSizes = + op->getAttrOfType("operandSegmentSizes"); + // DPS operands shouldn't affect the result shape. + const auto outputSegmentSize = + operandSegmentSizes[operandSegmentSizes.size() - 1]; + const auto operandShapes = llvm::map_range(op->getOperands(), getShape); llvm::SmallVector broadcastedShape; - mlir::OperandRange operands = op->getOperands(); - mlir::OperandRange::iterator operand_it = operands.begin(); - llvm::SmallVector prevOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - while (++operand_it != operands.end()) { - llvm::SmallVector nextOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - if (!OpTrait::util::getBroadcastedShape(prevOperandShape, nextOperandShape, + for (const auto operandShape : + llvm::drop_end(operandShapes, outputSegmentSize)) { + const auto prevBroadcastedShape = broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(prevBroadcastedShape, operandShape, broadcastedShape)) { return op->emitOpError("Operands are not broadcast compatible"); } - prevOperandShape = broadcastedShape; } - llvm::SmallVector resultShape( - mlir::cast(op->getResult(0).getType()) - .getShape()); + // Check that the result shape matches the broadcasted shape of the operands. + llvm::SmallVector resultShape(getShape(op->getResults().front())); if (broadcastedShape != resultShape) { return op->emitOpError( "Result shape must match operand shapes after broadcasting"); } - TypeID expectedBaseTy = op->getResultTypes().front().getTypeID(); - if (!llvm::all_of(op->getOperandTypes(), - [&](Type t) { return t.getTypeID() == expectedBaseTy; })) { - return op->emitOpError() << "All operands/results must have the same type"; - } - return success(); } diff --git a/lib/Dialect/TTIR/Transforms/Allocate.cpp b/lib/Dialect/TTIR/Transforms/Allocate.cpp index 37e788385..a643f041c 100644 --- a/lib/Dialect/TTIR/Transforms/Allocate.cpp +++ b/lib/Dialect/TTIR/Transforms/Allocate.cpp @@ -22,13 +22,13 @@ inline MemorySpace getMemorySpace(MemRefType memref) { return mlir::cast(memref.getMemorySpace()).getValue(); } -inline MemorySpace getMemorySpace(LayoutAttr layout) { +inline MemorySpace getMemorySpace(MetalLayoutAttr layout) { return getMemorySpace(layout.getMemref()); } inline MemorySpace getMemorySpace(RankedTensorType ty) { assert(ty.getEncoding()); - auto layout = mlir::cast(ty.getEncoding()); + auto layout = mlir::cast(ty.getEncoding()); return getMemorySpace(layout); } diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp new file mode 100644 index 000000000..7823b021e --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRBROADCASTFOLD +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Broadcast Folding pass +// Our backend supports implicit broadcast of operands, so explicit broadcast +// instructions are folded. +// +// For Example: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) -> +// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1, +// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> +// tensor<512xf32> +// +// After folding: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>, +// tensor<512xf32>) -> tensor<512xf32> +//===----------------------------------------------------------------------===// + +class TTIRBroadcastFoldRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const final { + + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } +}; + +class TTIRBroadcastFold + : public impl::TTIRBroadcastFoldBase { +public: + using impl::TTIRBroadcastFoldBase::TTIRBroadcastFoldBase; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index f5fec45a8..597c55e3c 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Allocate.cpp + Broadcast.cpp Constant.cpp Generic.cpp Layout.cpp diff --git a/lib/Dialect/TTIR/Transforms/Generic.cpp b/lib/Dialect/TTIR/Transforms/Generic.cpp index 005e12c07..3bf96f3cd 100644 --- a/lib/Dialect/TTIR/Transforms/Generic.cpp +++ b/lib/Dialect/TTIR/Transforms/Generic.cpp @@ -257,7 +257,7 @@ class TTIRGenericRegionRewriter auto resEncoding = mlir::cast(op->getResult(0).getType()).getEncoding(); if (resEncoding) { - auto resLayout = mlir::cast(resEncoding); + auto resLayout = mlir::cast(resEncoding); gridAttr = resLayout.getGrid(); } @@ -339,7 +339,7 @@ struct TTIRGenericOperandsToMemrefRewriter auto matchingOperand = generic.getMatchingOperand(blockArgNumber); auto operandType = matchingOperand.getType(); - auto bufferLayout = mlir::cast( + auto bufferLayout = mlir::cast( mlir::cast(operandType).getEncoding()); auto bufferType = operandType; @@ -349,7 +349,7 @@ struct TTIRGenericOperandsToMemrefRewriter assert(static_cast(cbIndex) < generic.getCbs().size()); auto cb = generic.getCbs()[cbIndex]; auto cbType = cb.getType(); - auto cbLayout = mlir::cast( + auto cbLayout = mlir::cast( mlir::cast(cbType).getEncoding()); bufferLayout = cbLayout; bufferType = cbType; @@ -387,7 +387,7 @@ class TTIRGenericRegionMemrefTypeConverter : public TypeConverter { if (mlir::isa(encoding)) { return type; } - auto layout = mlir::cast(type.getEncoding()); + auto layout = mlir::cast(type.getEncoding()); auto buffer = BufferAttr::get(ctx, layout.getMemref(), BufferAccess::Alias); return RankedTensorType::get(buffer.getShape(), type.getElementType(), @@ -451,11 +451,11 @@ class TTIRGenericOpCBsRewriter : public OpRewritePattern { // Enforcing tiled layout as in kernel we always want to work with tiles. auto desiredElementType = rewriter.getType(ty.getElementType()); - auto desiredLayout = rewriter.getAttr( + auto desiredLayout = rewriter.getAttr( ty, MemorySpace::DeviceL1, generic.getGrid(), desiredElementType); auto operandTy = operand.getType(); - auto operandLayout = mlir::cast( + auto operandLayout = mlir::cast( mlir::cast(operandTy).getEncoding()); if (desiredLayout.getGrid() == operandLayout.getGrid()) { diff --git a/lib/Dialect/TTIR/Transforms/Layout.cpp b/lib/Dialect/TTIR/Transforms/Layout.cpp index d7eef6732..c3ccbf1a4 100644 --- a/lib/Dialect/TTIR/Transforms/Layout.cpp +++ b/lib/Dialect/TTIR/Transforms/Layout.cpp @@ -38,20 +38,21 @@ class TTIRLayoutTensorTypeConverter : public TypeConverter { TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace, GridAttr deviceGrid) { addConversion([](Type type) { return type; }); - addConversion([ctx, initMemorySpace, - deviceGrid](RankedTensorType type) -> Type { - auto layout = type.getEncoding(); - if (layout) { - return type; - } - std::int64_t deviceGridRank = deviceGrid.getShape().size(); - // Default to single core grid - auto tensorGrid = GridAttr::get(ctx, deviceGridRank); - // Default to initMemorySpace, the optimizer might decide otherwise - auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); - return RankedTensorType::get(type.getShape(), type.getElementType(), - newLayout); - }); + addConversion( + [ctx, initMemorySpace, deviceGrid](RankedTensorType type) -> Type { + auto layout = type.getEncoding(); + if (layout) { + return type; + } + std::int64_t deviceGridRank = deviceGrid.getShape().size(); + // Default to single core grid + auto tensorGrid = GridAttr::get(ctx, deviceGridRank); + // Default to initMemorySpace, the optimizer might decide otherwise + auto newLayout = + MetalLayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); + return RankedTensorType::get(type.getShape(), type.getElementType(), + newLayout); + }); } }; @@ -129,7 +130,7 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, TensorMemoryLayout desiredMemLayout, bool tiled) { auto ty = mlir::cast(input.getType()); - auto currLayout = mlir::cast(ty.getEncoding()); + auto currLayout = mlir::cast(ty.getEncoding()); auto currMemorySpace = currLayout.getMemorySpace(); auto currElementType = currLayout.getElementType(); auto currMemLayout = currLayout.getMemLayout(); @@ -142,9 +143,9 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, return std::nullopt; } - auto desiredLayout = - rewriter.getAttr(ty, desiredMemorySpace, currLayout.getGrid(), - desiredElementType, desiredMemLayout); + auto desiredLayout = rewriter.getAttr( + ty, desiredMemorySpace, currLayout.getGrid(), desiredElementType, + desiredMemLayout); tensor::EmptyOp existingEmpty = input.getDefiningOp(); if (existingEmpty) { @@ -343,7 +344,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - LayoutAttr desiredLayout) const { + MetalLayoutAttr desiredLayout) const { auto ty = mlir::cast(input.getType()); auto output = rewriter.create( loc, ty.getShape(), ty.getElementType(), desiredLayout); @@ -353,7 +354,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { } Value bounce(PatternRewriter &rewriter, ToLayoutOp op, - LayoutAttr bounceLayout) const { + MetalLayoutAttr bounceLayout) const { auto bounced = createToLayoutOp(rewriter, op.getLoc(), op.getInput(), bounceLayout); return rewriter.replaceOpWithNewOp( @@ -375,8 +376,8 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { auto inputType = mlir::cast(op.getInput().getType()); auto outputType = mlir::cast(op.getOutput().getType()); - auto inputLayout = mlir::cast(inputType.getEncoding()); - auto outputLayout = mlir::cast(outputType.getEncoding()); + auto inputLayout = mlir::cast(inputType.getEncoding()); + auto outputLayout = mlir::cast(outputType.getEncoding()); bool inputL1 = inputLayout.getMemorySpace() == MemorySpace::DeviceL1; bool outputL1 = outputLayout.getMemorySpace() == MemorySpace::DeviceL1; diff --git a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp index 49baf51e0..7f78c1afc 100644 --- a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp +++ b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp @@ -17,7 +17,7 @@ namespace mlir::tt::ttmetal { ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -30,7 +30,7 @@ ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::LogicalResult HostReadOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -41,7 +41,7 @@ ::mlir::LogicalResult HostReadOp::verify() { } ::mlir::LogicalResult AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); @@ -76,7 +76,7 @@ ::mlir::LogicalResult AllocOp::verify() { ::mlir::LogicalResult DispatchOp::verify() { // Assert inputs/outputs device memspace for (auto operand : getOperands()) { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( mlir::cast(operand.getType()).getEncoding()); if (not layout) { return emitOpError("Input tensor missing layout attribute"); diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 1620e96b5..4b7804a5f 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -11,10 +11,12 @@ add_mlir_dialect_library(MLIRTTNNDialect DEPENDS MLIRTTNNOpsIncGen MLIRTTOpsIncGen + TTNNOpModelLib LINK_LIBS PUBLIC TTMLIRTTNNUtils MLIRSCFToEmitC MLIRLinalgDialect MLIRMLProgramDialect + TTNNOpModelLib ) diff --git a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp index 9079a6019..344a4a483 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp @@ -5,6 +5,9 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.cpp.inc" +#include "ttmlir/OpModel/TTNN/TTNNOpModel.h" + +#include #include namespace mlir::tt::ttnn { @@ -22,14 +25,16 @@ size_t ReluOp::getOpPerfCycles(const std::vector &input_layouts, std::tuple ReluOp::getOpL1Usage(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return std::make_tuple(1024, 2048, 1024); + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::getOpL1Usage(input_layouts[0], + output_layout); } bool ReluOp::isOpLegal(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return true; + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::isLegal(input_layouts[0], + output_layout); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 8550b8796..8e41368cb 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -42,7 +42,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ClampOp::verify() { const RankedTensorType outputTensorType = mlir::cast(outputs.front().getType()); - if (inputTensorType != outputTensorType) { + if (inputTensorType.getShape() != outputTensorType.getShape()) { return emitOpError("input and output must have same shape."); } @@ -140,6 +140,32 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() { + + if (getStep() == 0) { + return emitOpError("Step cannot be zero."); + } + + int64_t numValues = (getEnd() - getStart()) / getStep(); + + if (numValues <= 0) { + return emitOpError("Invalid range: start=") + << getStart() << ", end=" << getEnd() << ", step=" << getStep(); + } + + std::vector expectedShape = {1, 1, 1, numValues}; + if (getType().getShape().vec() != expectedShape) { + return emitOpError() << "Output tensor shape must be " << expectedShape + << ", but got " << getType().getShape(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // EmptyOp //===----------------------------------------------------------------------===// @@ -164,25 +190,12 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { // DataType and Layout // - mlir::MemRefType memref = layoutAttr.getMemref(); - Type elementType = memref.getElementType(); if (getLayout().has_value()) { - ttnn::Layout ttnnLayoutEnum; - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } + ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout(); assert(ttnnLayoutEnum == getLayoutAttr().getValue()); } if (getDtype().has_value()) { - tt::DataType dtype; - if (llvm::isa(elementType)) { - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); - } + tt::DataType dtype = layoutAttr.getDataType(); assert(dtype == getDtype()); } @@ -592,6 +605,158 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -785,6 +950,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllGatherOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult AllGatherOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); int32_t dim = getDim(); @@ -796,6 +965,10 @@ ::mlir::LogicalResult AllGatherOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ReduceScatterOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult ReduceScatterOp::verify() { // TODO(gfengTT) return success(); diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index d80815f91..8aaae1261 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -34,6 +34,11 @@ bool TTNNLayoutAttr::isTiled() const { return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } +// Get layout of the tensor (RowMajor/Tile) +Layout TTNNLayoutAttr::getLayout() const { + return isTiled() ? Layout::Tile : Layout::RowMajor; +} + // Check if the tensor memory layout is sharded bool TTNNLayoutAttr::hasShardedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::HeightSharded || @@ -119,19 +124,19 @@ mlir::Type TTNNLayoutAttr::getElementType() const { return getMemref().getElementType(); } -// Extract data type from the memref. Example: -// memref<2x2xf32> -> f32 -// memref<2x2x!tt.tile<32x32xf32>> -> f32 -mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { +// Get scalar element type. +// Example: memref<2x2xf32> -> f32 +// Example: memref<2x2x!tt.tile<32x32xf32>> -> f32 +// +// return The scalar element type. +mlir::tt::DataType TTNNLayoutAttr::getDataType() const { Type elementType = getElementType(); - DataType dtype = DataType::Float32; - if (llvm::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); + return tileType.getDataType(); } - return dtype; + + return elementTypeToDataType(elementType); } // Gets the size of shard in bytes @@ -139,10 +144,10 @@ mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { // This function returns the size of the shard in bytes. // Size is calculated by multiplying shard shape with element size. // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); - if (mlir::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); return tileType.getSizeBytes(); } @@ -151,21 +156,31 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const { // Get shard shape // -// This function returns the shape of the shard. If element type is TileType -// and convertTileToScalar is true, then the shape is converted to scalar shape. -// Example: (convertToScalar = true) memref<2x2x!tt.tile<32x32xf32>> -> {64, 64} -// Example: (convertToScalar = false) memref<2x2x!tt.tile<32x32xf32>> -> {2, 2} -// Example: memref<128x128xf32> -> {128, 128} +// Return the shape of the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 } // -// /param convertTileToScalar If true, convert tile shape to scalar shape. -// /return The shape of the shard. -llvm::SmallVector -TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { +// return The shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getShardShape() const { + return SmallVector(getMemref().getShape()); +} + +// Get scalar shard shape +// +// If the element type is TileType, this function returns the scalar shape of +// the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 64, 64 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 64, 96 } +// +// return The scalar shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getScalarShardShape() const { SmallVector shardShape(getMemref().getShape()); - Type elementType = getElementType(); - if (mlir::isa(elementType) && convertTileToScalar) { - return mlir::cast(elementType).getScalarShape(shardShape); + if (isTiled()) { + return mlir::cast(getElementType()).getScalarShape(shardShape); } + return shardShape; } @@ -178,8 +193,8 @@ TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { // d2) and tile shape (32, 32) The result is (90, 10) which is then divided by // tile shape (32, 32) -> (3, 1) // -// /param tensorShape The shape of the tensor -// /return The size of the tensor in tiles. +// param tensorShape The shape of the tensor +// return The size of the tensor in tiles. llvm::SmallVector TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); @@ -214,10 +229,9 @@ TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { // Element size for TileType is tile width * tile height * sizeof(element). // For scalar types, element size is sizeof(element). // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { - MemRefType ty = getMemref(); - ArrayRef shape = ty.getShape(); + SmallVector shape = getShardShape(); uint64_t size = getElementSizeBytes(); return std::accumulate(shape.begin(), shape.end(), size, std::multiplies()); @@ -228,7 +242,7 @@ uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { // This function returns a new identity affine map // with the same number of dimensions as the linear map. // -// /return The new identity affine map. +// return The new identity affine map. mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); @@ -241,12 +255,11 @@ mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { // This function takes a physical memory map and replaces the symbols with the // shard shape // -// /param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] -// /return New memory map with symbols replaced with shard shape. +// param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] +// return New memory map with symbols replaced with shard shape. mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { - mlir::SmallVector shardShape = - getShardShape(false /*convertTileToScalar*/); + mlir::SmallVector shardShape = getShardShape(); assert(physicalMemoryMap.getNumSymbols() == shardShape.size() && "Physical memory map must have same number of symbols as logical " "shard rank"); @@ -289,11 +302,11 @@ int64_t TTNNLayoutAttr::getTensorSizeInBytes(ArrayRef tensorShape, // This function creates a new TTNNLayoutAttr with the given parameters. // The element type, buffer type and memory layout are preserved. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { @@ -307,10 +320,10 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // The shape of the tensor, buffer type, element type and memory layout are // preserved. // -// /param context The MLIR context. -// /param grid The grid where the tensor will be placed. -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param grid The grid where the tensor will be placed. +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { @@ -324,14 +337,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the element type with the given one. // -// /param context The MLIR context. -// /param elementType The new element type. -// /return The new TTNNLayoutAttr with the given element type. +// param context The MLIR context. +// param elementType The new element type. +// return The new TTNNLayoutAttr with the given element type. TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, Type elementType) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), elementType, getBufferType()), getMemLayout()); } @@ -341,14 +354,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory space with the given one. // -// /param context The MLIR context. -// /param memorySpace The new memory space. -// /return The new TTNNLayoutAttr with the given memory space. +// param context The MLIR context. +// param memorySpace The new memory space. +// return The new TTNNLayoutAttr with the given memory space. TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, BufferType memorySpace) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), getElementType(), memorySpace), getMemLayout()); } @@ -358,15 +371,15 @@ TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory layout with the given one. // -// /param context The MLIR context. -// /param memLayout The new memory layout. -// /return The new TTNNLayoutAttr with the given memory layout. +// param context The MLIR context. +// param memLayout The new memory layout. +// return The new TTNNLayoutAttr with the given memory layout. TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), buildMemRef( - context, getShardShape(), getElementType(), getBufferType()), + context, getScalarShardShape(), getElementType(), getBufferType()), memLayout); } @@ -375,9 +388,9 @@ TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces shard shape with the given one. // -// /param context The MLIR context. -// /param shardShape The new shard shape. -// /return The new TTNNLayoutAttr with the given shard shape. +// param context The MLIR context. +// param shardShape The new shard shape. +// return The new TTNNLayoutAttr with the given shard shape. TTNNLayoutAttr TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape) { @@ -392,14 +405,14 @@ TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, // // This function constructs a new TTNNLayoutAttr with the given parameters. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param elementType The type of the element i.e TileType/FloatType/IntegerType -// /param bufferType The type of the buffer -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /param memLayout The memory layout of the tensor -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param elementType The type of the element i.e TileType/FloatType/IntegerType +// param bufferType The type of the buffer +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// param memLayout The memory layout of the tensor +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, BufferType bufferType, GridAttr grid, diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 24980fb7c..3ade96bf8 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -107,9 +107,22 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineTTIRBroadcastFoldPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold()); +} + +void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 05ff417a6..e5d2f86d8 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -276,7 +276,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { EmptyOp emptyOp = mlir::cast(op->getOperands().back().getDefiningOp()); - emptyOp.setDtype(layoutAttr.getDataTypeFromMemRef()); + emptyOp.setDtype(layoutAttr.getDataType()); if (layoutAttr.isTiled()) { emptyOp.setLayout(ttnn::Layout::Tile); } else { @@ -449,16 +449,17 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { BufferType outputBufferType = consumerOpOutputLayout.getBufferType(); TensorMemoryLayout outputTensorMemoryLayout = consumerOpOutputLayout.getMemLayout(); - MemRefType outputMemref = consumerOpOutputLayout.getMemref(); + llvm::SmallVector shardShape = + consumerOpOutputLayout.getShardShape(); MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get( consumerOp->getContext(), TensorMemoryLayoutAttr::get(consumerOp->getContext(), outputTensorMemoryLayout), BufferTypeAttr::get(consumerOp->getContext(), outputBufferType), - ShardSpecAttr::get(consumerOp->getContext(), - ShapeAttr::get(consumerOp->getContext(), - outputMemref.getShape()))); + ShardSpecAttr::get( + consumerOp->getContext(), + ShapeAttr::get(consumerOp->getContext(), shardShape))); // If producerOp is a toLayoutOp, adjust its output layout(update // inplace) to reflect consumerOp's output layout. If producerOp is not a @@ -472,10 +473,9 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { } else { OpBuilder builder(consumerOp); - DataTypeAttr outputDataType = - DataTypeAttr::get(consumerOp->getContext(), - utils::getDataTypeFromMemRef(outputMemref)); - Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref); + DataTypeAttr outputDataType = DataTypeAttr::get( + consumerOp->getContext(), consumerOpOutputLayout.getDataType()); + Layout outputLayoutEnum = consumerOpOutputLayout.getLayout(); LayoutAttr outputLayout = LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum); Operation *memoryReconfigOp = builder.create( diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 79bfeb404..e22540a7d 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -198,24 +198,12 @@ class TTNNDecomposeLayouts } }; - ttnn::Layout getLayoutFromMemRef(mlir::MemRefType memref) const { - ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - Type elementType = memref.getElementType(); - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } - return ttnnLayoutEnum; - } - std::pair getInputOutputLayouts(ttnn::ToLayoutOp op) const { LayoutInfo input, output; auto inputLayoutAttr = mlir::cast(op.getInput().getType().getEncoding()); - auto inputMemref = inputLayoutAttr.getMemref(); assert(op.getMemoryConfig().has_value()); MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value(); @@ -223,10 +211,10 @@ class TTNNDecomposeLayouts input.bufferType = inputLayoutAttr.getBufferType(); output.bufferType = outputMemoryConfig.getBufferType().getValue(); - input.layoutEnum = getLayoutFromMemRef(inputMemref); + input.layoutEnum = inputLayoutAttr.getLayout(); output.layoutEnum = op.getLayout(); - input.dataType = ttnn::utils::getDataTypeFromMemRef(inputMemref); + input.dataType = inputLayoutAttr.getDataType(); assert(op.getDtype().has_value()); output.dataType = op.getDtype().value(); @@ -234,7 +222,7 @@ class TTNNDecomposeLayouts output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout().getValue(); - input.shardShape = inputMemref.getShape(); + input.shardShape = inputLayoutAttr.getShardShape(); output.shardShape = outputMemoryConfig.getShardShapeArray(); return {input, output}; } diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index eebfdc13f..2d4a2ff8f 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -214,6 +214,28 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, .getResult(); } + // If the input tensor is an arange, we want to set the desired layout just + // like the other creation ops. However, a caveat is that in ttnn, arange is + // hardcoded to be ROW_MAJOR. So we must ensure that the layout we assign to + // it is ROW_MAJOR - and to make it tile layout we still must insert + // ToLayoutOp on its output. We can do this by setting the element type to + // ty.getElementType() in case desiredElementType is a TileType. + ttir::ArangeOp existingArange = input.getDefiningOp(); + if (existingArange) { + TTNNLayoutAttr arangeLayout = rewriter.getAttr( + ty.getShape(), ty.getElementType(), desiredBufferType, + tensorConfig.getGrid(), desiredMemLayout, g_defaultCollapseDims); + input = + rewriter + .replaceOpWithNewOp( + existingArange, + mlir::RankedTensorType::get(ty.getShape(), ty.getElementType(), + arangeLayout), + existingArange.getStart(), existingArange.getEnd(), + existingArange.getStep(), existingArange.getArangeDimension()) + .getResult(); + } + // If the input tensor is not a constant or empty tensor, we need to create a // new tensor with the desired layout which will be used as the output of the // ToLayoutOp @@ -281,6 +303,13 @@ class TTNNLayoutDPSOperandsRewriter continue; } + // If the operand is a BroadcastOp or a ToLayout op do not put a + // ToLayoutOp on its output + if (operand.get().getDefiningOp() || + operand.get().getDefiningOp()) { + continue; + } + // Read operand constrait for current operand OperandConstraint operandConstraint = mlir::cast( diff --git a/lib/Dialect/TTNN/Utils/CMakeLists.txt b/lib/Dialect/TTNN/Utils/CMakeLists.txt index f49f829e6..f78f41864 100644 --- a/lib/Dialect/TTNN/Utils/CMakeLists.txt +++ b/lib/Dialect/TTNN/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(TTMLIRTTNNUtils Utils.cpp OptimizerOverrides.cpp + PassOverrides.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index 5ef306cdb..bbc456948 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -6,187 +6,173 @@ namespace mlir::tt::ttnn { -bool OutputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, +void OptimizerOverridesHandler::setEnableOptimizer(bool value) { + enableOptimizer = value; +} + +void OptimizerOverridesHandler::setMemoryReconfig(bool value) { + enableMemoryReconfig = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysis(bool value) { + enableMemoryLayoutAnalysis = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy( + bool value) { + enableMemoryLayoutAnalysisPolicy = value; +} +void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy( + MemoryLayoutAnalysisPolicyType value) { + memoryLayoutAnalysisPolicy = value; +} + +void OptimizerOverridesHandler::setInputLayoutOverrides( + llvm::StringMap &value) { + inputLayoutOverrides = value; +} +void OptimizerOverridesHandler::setOutputLayoutOverrides( llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kMaxGridSize = 2; - constexpr size_t kvPairSize = 2; - constexpr size_t kMaxLayoutOverrideParams = 5; - constexpr size_t iOpName = 0; - constexpr size_t iLayoutOverrideParams = 1; - constexpr size_t iGrid = 0; - constexpr size_t iMemorySpace = 1; - constexpr size_t iTensorMemoryLayout = 2; - constexpr size_t iMemoryLayout = 3; - constexpr size_t iDataType = 4; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char paramSepataor = ':'; - constexpr char gridSeparator = 'x'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for override grid sizes: " + override); - return true; - } + outputLayoutOverrides = value; +} - SmallVector layoutParamParts; - // Split into layout parameters. - opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, - paramSepataor); - if (layoutParamParts.size() != kMaxLayoutOverrideParams) { - opt.error("Invalid number of layout parameters: " + - std::to_string(layoutParamParts.size())); - return true; - } +void OptimizerOverridesHandler::setSystemDescPath(std::string value) { + systemDescPath = value; +} +void OptimizerOverridesHandler::setMaxLegalLayouts(int64_t value) { + maxLegalLayouts = value; +} +void OptimizerOverridesHandler::setMeshShape(std::vector value) { + meshShape = value; +} - // Parse grid. - SmallVector grid; - SmallVector gridParts; - layoutParamParts[iGrid].split(gridParts, gridSeparator); - for (const StringRef gridPart : gridParts) { - int64_t gridValue; - if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { - opt.error("Invalid grid size: " + gridPart); - return true; - } - grid.push_back(gridValue); - } +bool OptimizerOverridesHandler::getEnableOptimizer() const { + return enableOptimizer; +} - // Parse memory space. - std::optional bufferType = - symbolizeBufferType(layoutParamParts[iMemorySpace]); - if (!bufferType.has_value()) { - opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); - return true; - } +bool OptimizerOverridesHandler::getMemoryReconfig() const { + return enableMemoryReconfig; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysis() const { + return enableMemoryLayoutAnalysis; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysisPolicy() const { + return enableMemoryLayoutAnalysisPolicy; +} +MemoryLayoutAnalysisPolicyType +OptimizerOverridesHandler::getMemoryLayoutAnalysisPolicy() const { + return memoryLayoutAnalysisPolicy; +} - // Parse tensor memory layout. - std::optional tensorMemoryLayout = - symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); - if (!tensorMemoryLayout.has_value()) { - opt.error("Invalid tensor memory layout: " + - layoutParamParts[iTensorMemoryLayout]); - return true; - } +std::string OptimizerOverridesHandler::getSystemDescPath() const { + return systemDescPath; +} +int64_t OptimizerOverridesHandler::getMaxLegalLayouts() const { + return maxLegalLayouts; +} +std::vector OptimizerOverridesHandler::getMeshShape() const { + return meshShape; +} - // Parse memory layout. - std::optional memoryLayout = - mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); - if (!memoryLayout.has_value()) { - opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); - return true; - } +llvm::StringMap +OptimizerOverridesHandler::getInputLayoutOverrides() const { + return inputLayoutOverrides; +} +llvm::StringMap +OptimizerOverridesHandler::getOutputLayoutOverrides() const { + return outputLayoutOverrides; +} - // Parse data type. - std::optional dataType = - mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); - if (!dataType.has_value()) { - opt.error("Invalid data type: " + layoutParamParts[iDataType]); - return true; - } +std::string OptimizerOverridesHandler::toString() const { - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ - std::move(grid), bufferType.value(), tensorMemoryLayout.value(), - memoryLayout.value(), dataType.value()}; + std::string options = ""; + + if (enableOptimizer) { + options += std::string(pipelineOptions.optimizerPassEnabled.getArgStr()) + + "=true "; } - return false; -} - -void OutputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "override-output-layout="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const OutputLayoutOverrideParams ¶ms = entry.getValue(); - // Print grid values - for (size_t i = 0; i < params.grid.size(); ++i) { - os << params.grid[i]; - if (i < params.grid.size() - 1) { - os << "x"; - } - } - // Print memory space and memory layout - os << ":" << mlir::tt::ttnn::stringifyBufferType(params.bufferType); - os << ":" - << mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout); - os << ":" << mlir::tt::ttnn::stringifyLayout(params.memoryLayout); - os << ":" << mlir::tt::DataTypeEnumToString(params.dataType); - if (++count < value.size()) { - os << ","; - } + + if (enableMemoryReconfig) { + options += + std::string(pipelineOptions.memReconfigEnabled.getArgStr()) + "=true "; } - os << "\n"; -} -bool InputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kvPairSize = 2; - constexpr size_t iOpName = 0; - constexpr size_t iOperands = 1; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char opParamSeparator = ':'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for input layouts override: " + override); - return true; - } + if (enableMemoryLayoutAnalysis) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisEnabled.getArgStr()) + + "=true "; + } - SmallVector operandIndexes; - SmallVector operandIndexParts; - - // Parse operand indexes. - opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); - for (const StringRef operandIndexPart : operandIndexParts) { - int64_t operandIndexValue; - if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { - opt.error("Invalid operand index: " + operandIndexPart); - return true; - } - operandIndexes.push_back(operandIndexValue); - } + if (enableMemoryLayoutAnalysisPolicy) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisPolicy.getArgStr()) + + MemoryLayoutAnalysisPolicyTypeParser::toString( + memoryLayoutAnalysisPolicy) + + " "; + } - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = - InputLayoutOverrideParams{std::move(operandIndexes)}; + // Create input layout overrides. + // Example: insert-memreconfig=input0=0:1,input1=0,input2=0:1:2 + if (inputLayoutOverrides.size() > 0) { + options += std::string(pipelineOptions.overrideInputLayout.getArgStr()) + + "=" + InputLayoutOverrideParser::toString(inputLayoutOverrides) + + " "; } - return false; -} - -void InputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "insert-memreconfig="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const InputLayoutOverrideParams ¶ms = entry.getValue(); - for (int64_t operandIdx : params.operandIdxes) { - os << operandIdx - << (operandIdx < static_cast(params.operandIdxes.size()) - 1 - ? ':' - : char()); - } - if (++count < value.size()) { - os << ","; + + // Create output layout overrides. + // Example: + // override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16 + // Example: + // override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" + if (outputLayoutOverrides.size() > 0) { + options += + std::string(pipelineOptions.overrideOutputLayout.getArgStr()) + "=" + + OutputLayoutOverrideParser::toString(outputLayoutOverrides) + " "; + } + + if (systemDescPath.size() > 0) { + options += std::string(pipelineOptions.systemDescPath.getArgStr()) + + systemDescPath + " "; + } + + if (maxLegalLayouts > 0) { + options += std::string(pipelineOptions.maxLegalLayouts.getArgStr()) + + std::to_string(maxLegalLayouts) + " "; + } + + if (meshShape.size() > 0) { + options += std::string(pipelineOptions.meshShape.getArgStr()) + "="; + for (int64_t meshShapeValue : meshShape) { + options += std::to_string(meshShapeValue) + ","; } + // Remove the last comma. + options.pop_back(); + } + + if (options[options.size() - 1] == ' ') { + options.pop_back(); } - os << "\n"; + + return options; +} + +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, InputLayoutOverrideParams params) { + inputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, SmallVector &operandIdxes) { + inputLayoutOverrides[opName] = + InputLayoutOverrideParams{std::move(operandIdxes)}; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, OutputLayoutOverrideParams params) { + outputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, SmallVector &grid, BufferType bufferType, + TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, + tt::DataType dataType) { + outputLayoutOverrides[opName] = OutputLayoutOverrideParams{ + std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType}; } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp new file mode 100644 index 000000000..9c8ef2be1 --- /dev/null +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" + +namespace mlir::tt::ttnn { + +bool OutputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kMaxGridSize = 2; + constexpr size_t kvPairSize = 2; + constexpr size_t kMaxLayoutOverrideParams = 5; + constexpr size_t iOpName = 0; + constexpr size_t iLayoutOverrideParams = 1; + constexpr size_t iGrid = 0; + constexpr size_t iMemorySpace = 1; + constexpr size_t iTensorMemoryLayout = 2; + constexpr size_t iMemoryLayout = 3; + constexpr size_t iDataType = 4; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char paramSepataor = ':'; + constexpr char gridSeparator = 'x'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for override grid sizes: " + override); + return true; + } + + SmallVector layoutParamParts; + // Split into layout parameters. + opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, + paramSepataor); + if (layoutParamParts.size() != kMaxLayoutOverrideParams) { + opt.error("Invalid number of layout parameters: " + + std::to_string(layoutParamParts.size())); + return true; + } + + // Parse grid. + SmallVector grid; + SmallVector gridParts; + layoutParamParts[iGrid].split(gridParts, gridSeparator); + for (const StringRef gridPart : gridParts) { + int64_t gridValue; + if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { + opt.error("Invalid grid size: " + gridPart); + return true; + } + grid.push_back(gridValue); + } + + // Parse memory space. + std::optional bufferType = + symbolizeBufferType(layoutParamParts[iMemorySpace]); + if (!bufferType.has_value()) { + opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); + return true; + } + + // Parse tensor memory layout. + std::optional tensorMemoryLayout = + symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); + if (!tensorMemoryLayout.has_value()) { + opt.error("Invalid tensor memory layout: " + + layoutParamParts[iTensorMemoryLayout]); + return true; + } + + // Parse memory layout. + std::optional memoryLayout = + mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); + if (!memoryLayout.has_value()) { + opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); + return true; + } + + // Parse data type. + std::optional dataType = + mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); + if (!dataType.has_value()) { + opt.error("Invalid data type: " + layoutParamParts[iDataType]); + return true; + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ + std::move(grid), bufferType.value(), tensorMemoryLayout.value(), + memoryLayout.value(), dataType.value()}; + } + return false; +} + +std::string OutputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const OutputLayoutOverrideParams ¶ms = entry.getValue(); + // Print grid values + for (size_t i = 0; i < params.grid.size(); ++i) { + res += std::to_string(params.grid[i]); + if (i < params.grid.size() - 1) { + res += "x"; + } + } + // Print memory space and memory layout + res += ":" + + std::string(mlir::tt::ttnn::stringifyBufferType(params.bufferType)); + res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout)); + res += + ":" + std::string(mlir::tt::ttnn::stringifyLayout(params.memoryLayout)); + res += ":" + std::string(mlir::tt::DataTypeEnumToString(params.dataType)); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void OutputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "override-output-layout="; + os << OutputLayoutOverrideParser::toString(value); + os << "\n"; +} + +bool InputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iOperands = 1; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char opParamSeparator = ':'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for input layouts override: " + override); + return true; + } + + SmallVector operandIndexes; + SmallVector operandIndexParts; + + // Parse operand indexes. + opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); + for (const StringRef operandIndexPart : operandIndexParts) { + int64_t operandIndexValue; + if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { + opt.error("Invalid operand index: " + operandIndexPart); + return true; + } + operandIndexes.push_back(operandIndexValue); + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = + InputLayoutOverrideParams{std::move(operandIndexes)}; + } + return false; +} + +std::string InputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const InputLayoutOverrideParams ¶ms = entry.getValue(); + for (int64_t operandIdx : params.operandIdxes) { + res += std::to_string(operandIdx) + ":"; + } + // Remove the last colon. + res.pop_back(); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void InputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "insert-memreconfig="; + os << InputLayoutOverrideParser::toString(value); + os << "\n"; +} + +} // namespace mlir::tt::ttnn diff --git a/lib/OpModel/CMakeLists.txt b/lib/OpModel/CMakeLists.txt new file mode 100644 index 000000000..9c34667d0 --- /dev/null +++ b/lib/OpModel/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TTNN) diff --git a/lib/OpModel/TTNN/CMakeLists.txt b/lib/OpModel/TTNN/CMakeLists.txt new file mode 100644 index 000000000..094b9f1dd --- /dev/null +++ b/lib/OpModel/TTNN/CMakeLists.txt @@ -0,0 +1,40 @@ +set(LIB_NAME TTNNOpModelLib) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(SOURCES + TTNNOpModelLib.cpp +) +add_library(${LIB_NAME} STATIC ${SOURCES}) + +message(STATUS "TTMLIR_ENABLE_OP_MODEL[${TTMLIR_ENABLE_OP_MODEL}]") +if (TTMLIR_ENABLE_OPMODEL) + # Link to tt-metal libs and include directories + target_include_directories(${LIB_NAME} PUBLIC "$") + target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY) + target_compile_definitions(${LIB_NAME} PUBLIC TTMLIR_ENABLE_OPMODEL) +else() + # link stubs implementation when op model library is disabled + message(WARNING "TTNNOpModelLib is disabled. The optimizer will not achieve optimal performance.") +endif() + +# Specify the include directories for the library +target_include_directories(${LIB_NAME} + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ + ${PROJECT_SOURCE_DIR}/include/ttmlir/OpModel/TTNN/) + + +# Add TTNNOpModelLib to the export set +install(TARGETS ${LIB_NAME} + EXPORT TTNNOpModelLibTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# Export the targets +export(EXPORT TTNNOpModelLibTargets + FILE "${CMAKE_CURRENT_BINARY_DIR}/TTNNOpModelLibTargets.cmake" + NAMESPACE TTNN::) diff --git a/lib/OpModel/TTNN/TTNNOpModelLib.cpp b/lib/OpModel/TTNN/TTNNOpModelLib.cpp new file mode 100644 index 000000000..87bfc0415 --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib.cpp @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "TTNNOpModel.h" + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "TTNNOpModelLib_Impl.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" + +#include +#include + +#include +#include +#endif // TTMLIR_ENABLE_OPMODEL + +namespace mlir::tt::op_model::ttnn { + +#ifdef TTMLIR_ENABLE_OPMODEL +// alias to a common tt_metal types +using DataType = ::tt::tt_metal::DataType; +using Layout = ::tt::tt_metal::Layout; +using CoreRange = ::tt::tt_metal::CoreRange; +using CoreRangeSet = ::tt::tt_metal::CoreRangeSet; +using CoreCoord = ::tt::tt_metal::CoreCoord; +using ShardSpec = ::tt::tt_metal::ShardSpec; +using ShardOrientation = ::tt::tt_metal::ShardOrientation; +using TensorMemoryLayout = ::tt::tt_metal::TensorMemoryLayout; +using MemoryConfig = ::tt::tt_metal::MemoryConfig; + +namespace detail { + +DataType getDataType(const mlir::MemRefType &memref) { + + auto dataType = elementTypeToDataType(memref.getElementType()); + + switch (dataType) { + case tt::DataType::Float32: + return DataType::FLOAT32; + case tt::DataType::BFloat16: + return DataType::BFLOAT16; + case tt::DataType::BFP_BFloat8: + return DataType::BFLOAT8_B; + case tt::DataType::BFP_BFloat4: + return DataType::BFLOAT4_B; + case tt::DataType::UInt32: + return DataType::UINT32; + case tt::DataType::UInt16: + return DataType::UINT16; + case tt::DataType::UInt8: + return DataType::UINT8; + default: + throw std::runtime_error("Invalid element type"); + } +} + +::ttnn::SimpleShape getTensorShape(const mlir::MemRefType &memref) { + ::tt::tt_metal::SmallVector small_vector_shape( + memref.getShape().begin(), memref.getShape().end()); + return ::ttnn::SimpleShape(small_vector_shape); +} + +const std::array +getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + const auto layoutShardTile = layout.getShardShape(); + + if (layoutShardTile.size() != 2) { + llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; + return {0, 0}; + } + + std::array shardShape; + shardShape[0] = layoutShardTile[0]; + shardShape[1] = layoutShardTile[1]; + return shardShape; +} + +Layout getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return layout.isTiled() ? Layout::TILE : Layout::ROW_MAJOR; +} + +CoreRangeSet getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // TODO(mbezulj): handle more complex grid shapes + // assuming grid shape is one rect starting at (0,0) + + const auto layoutGrid = layout.getGrid(); + + const auto layoutGridShape = layoutGrid.getShape(); + if (layoutGridShape.size() != 2) { + llvm::errs() << "ERROR: layout_grid.getShape().size() == 2\n"; + return {}; + } + + return CoreRangeSet(CoreRange(CoreCoord(0, layoutGridShape[0]), + CoreCoord(0, layoutGridShape[1]))); +} + +std::optional +layout_get_shard_spec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; + // defaulting to ROW_MAJOR. TODO: figure out if we need to expose this + return isShardedMemoryLayout(layout.getMemLayout()) + ? std::make_optional(ShardSpec(getCoreRangeSet(layout), + getShardShape(layout), + ShardOrientation::ROW_MAJOR, false)) + : std::nullopt; +} + +::tt::tt_metal::BufferType getBufferType(const mlir::MemRefType &memref) { + auto memorySpace = + mlir::cast(memref.getMemorySpace()).getValue(); + + switch (memorySpace) { + case tt::MemorySpace::DeviceDRAM: + return ::tt::tt_metal::BufferType::DRAM; + case tt::MemorySpace::DeviceL1: + return ::tt::tt_metal::BufferType::L1; + default: // TODO(mbezulj): handle other memory spaces + throw std::runtime_error("Unsupported memory space"); + } +} + +::tt::tt_metal::TensorMemoryLayout +getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + auto tensorMemoryLayout = layout.getMemLayout(); + + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: + return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: + return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: + return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: + return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: + return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; + default: + throw std::runtime_error("Unsupported tensor memory layout"); + } +} + +::tt::tt_metal::MemoryConfig +getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + + auto tensorMemoryLayout = getTensorMemoryLayout(layout); + auto bufferType = getBufferType(layout.getMemref()); + + auto shardSpec = layout_get_shard_spec(layout); + return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, + shardSpec); +} + +} // namespace detail +#endif // TTMLIR_ENABLE_OPMODEL + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +bool ReluOpInterface::isLegal( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + +#ifdef TTMLIR_ENABLE_OPMODEL + return true; // to wire into tt-metal with the next uplift +#else + return true; +#endif // TTMLIR_ENABLE_OPMODEL +} + +std::tuple ReluOpInterface::getOpL1Usage( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + return std::make_tuple(0, 0, 0); // to wire into tt-metal with the next uplift +#else + return std::make_tuple(0, 0, 0); +#endif // TTMLIR_ENABLE_OPMODEL +} + +} // namespace mlir::tt::op_model::ttnn diff --git a/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h new file mode 100644 index 000000000..ed39d881a --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H + +// This header resolves tt-metal warnings that would otherwise be treated as +// errors in the MLIR build. Ensure that this is the only place where tt-metal +// headers are included. + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wctad-maybe-unsupported" +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wignored-qualifiers" +#pragma clang diagnostic ignored "-Wvla-extension" +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#pragma clang diagnostic ignored "-Wsign-compare" +#pragma clang diagnostic ignored "-Wc++20-extensions" +#pragma clang diagnostic ignored "-Wc++20-designator" +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#pragma clang diagnostic ignored "-Wsuggest-override" +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wnested-anon-types" +#pragma clang diagnostic ignored "-Wreorder-ctor" +#pragma clang diagnostic ignored "-Wmismatched-tags" +#pragma clang diagnostic ignored "-Wunused-lambda-capture" +#pragma clang diagnostic ignored "-Wmissing-field-initializers" +#pragma clang diagnostic ignored "-Wunused-private-field" +#pragma clang diagnostic ignored "-Wimplicit-fallthrough" +#pragma clang diagnostic ignored "-Wstring-conversion" +#pragma clang diagnostic ignored "-Wunneeded-internal-declaration" +#pragma clang diagnostic ignored "-Wunused-local-typedef" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wpessimizing-move" +#pragma clang diagnostic ignored "-Wparentheses" +#pragma clang diagnostic ignored "-Wdeprecated-volatile" +#pragma clang diagnostic ignored "-Wdeprecated-this-capture" +#pragma clang diagnostic ignored "-Wc++23-extensions" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" +#pragma clang diagnostic ignored "-Wlogical-op-parentheses" +#pragma clang diagnostic ignored "-Wundefined-inline" +#pragma clang diagnostic ignored "-Wc99-extensions" +#pragma clang diagnostic ignored "-Wc++11-narrowing" +#pragma clang diagnostic ignored "-Wzero-length-array" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +#define FMT_HEADER_ONLY + +#include "tt_metal/common/core_coord.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" + +#pragma clang diagnostic pop + +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index 47e15accf..e82deaf63 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -62,18 +62,18 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref, toFlatbuffer(cache, memLayout), size); } -flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr layoutAttr, - ArrayRef logicalShape, DeviceAttr deviceAttr) { - auto strideInt64 = layoutAttr.getStride(logicalShape); +flatbuffers::Offset<::tt::target::LayoutDesc> metalLayoutAttrToFlatbuffer( + FlatbufferObjectCache &cache, MetalLayoutAttr metalLayoutAttr, + ArrayRef logicalShape, DeviceAttr deviceAttr) { + auto strideInt64 = metalLayoutAttr.getStride(logicalShape); std::vector stride(strideInt64.begin(), strideInt64.end()); - auto coreRangeSet = - toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getWorkerGrid()); + auto coreRangeSet = toFlatbuffer(cache, metalLayoutAttr.getGrid(), + deviceAttr.getWorkerGrid()); return ::tt::target::CreateLayoutDescDirect( - *cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()), + *cache.fbb, &stride, toFlatbuffer(cache, metalLayoutAttr.getOobVal()), &coreRangeSet, - cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer, - layoutAttr.getMemLayout())); + cache.getOrCreate(metalLayoutAttr.getMemref(), memrefAttrToFlatbuffer, + metalLayoutAttr.getMemLayout())); } } // namespace mlir::tt @@ -277,7 +277,7 @@ static std::shared_ptr translateModuleToFlatbuffer( argumentAllocations[input.getArgNumber()]); assert( argAlloc.getMemorySpace() == - mlir::cast( + mlir::cast( mlir::cast(input.getType()).getEncoding()) .getMemorySpace() && "argument allocation memory space does not match tensor type " diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 30b83014d..9706880e3 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Support/LogicalResult.h" +#include "types_generated.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -162,10 +163,10 @@ createDeviceRef(FlatbufferObjectCache &cache, Value device) { template ::flatbuffers::Offset<::tt::target::ttnn::Operation> createOperation(FlatbufferObjectCache &cache, ::flatbuffers::Offset op, - std::string const &debugString) { + std::string const &debugString, std::string const &locInfo) { return CreateOperationDirect( *cache.fbb, ::tt::target::ttnn::OpTypeTraits::enum_value, op.Union(), - debugString.c_str()); + debugString.c_str(), locInfo.c_str()); } ::flatbuffers::Offset<::tt::target::ttnn::GetDeviceOp> @@ -333,6 +334,46 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { kHostAllocatedSize)); } +::flatbuffers::Offset<::tt::target::ttnn::ArangeOp> +createOp(FlatbufferObjectCache &cache, ArangeOp op) { + + std::optional<::tt::target::DataType> dtype = + op.getDtype().has_value() + ? std::make_optional(toFlatbuffer(cache, op.getDtype().value())) + : std::nullopt; + auto device = + op.getDevice() ? cache.at<::tt::target::DeviceRef>(op.getDevice()) : 0; + + auto memoryConfigDesc = op.getMemoryConfig().has_value() + ? cache.getOrCreate(op.getMemoryConfig().value(), + memoryConfigToFlatbuffer) + : 0; + + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + + return ::tt::target::ttnn::CreateArangeOp( + *cache.fbb, static_cast(op.getStart()), + static_cast(op.getEnd()), static_cast(op.getStep()), + dtype /* optional */, device /* optional */, + memoryConfigDesc /* optional */, output); +} + +::flatbuffers::Offset<::tt::target::ttnn::LinearOp> +createOp(FlatbufferObjectCache &cache, LinearOp op) { + auto in0 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); + auto in1 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); + auto bias = op.getODSOperands(2).empty() + ? flatbuffers::Offset<::tt::target::TensorRef>() + : cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getBias())); + auto output = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateLinearOp(*cache.fbb, in0, in1, bias, output); +} + // ANCHOR: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::MatmulOp> createOp(FlatbufferObjectCache &cache, MatmulOp op) { @@ -485,6 +526,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Div; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Sigmoid; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Scatter; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Log1p; } else if constexpr (std::is_same_v) { @@ -554,7 +597,6 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { dim_arg, op.getKeepDim()); } -template ::flatbuffers::Offset<::tt::target::ttnn::TransposeOp> createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { auto in = @@ -567,7 +609,6 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1); } -template ::flatbuffers::Offset<::tt::target::ttnn::ConcatOp> createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins; @@ -582,7 +623,6 @@ createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { return ::tt::target::ttnn::CreateConcatOpDirect(*cache.fbb, &ins, out, dim); } -template ::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp> createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { auto in0 = @@ -594,7 +634,6 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); } -template ::flatbuffers::Offset<::tt::target::ttnn::ReshapeOp> createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { auto in = @@ -607,7 +646,6 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape); } -template ::flatbuffers::Offset<::tt::target::ttnn::SliceOp> createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { auto in = @@ -625,7 +663,6 @@ createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { step); } -template ::flatbuffers::Offset<::tt::target::ttnn::MaxPool2dOp> createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { auto in = @@ -643,7 +680,6 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { op.getPaddingWidth()); } -template ::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp> createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { auto in = @@ -655,7 +691,6 @@ createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { return ::tt::target::ttnn::CreateSoftmaxOp(*cache.fbb, in, out, dimension); } -template ::flatbuffers::Offset<::tt::target::ttnn::DeallocateOp> createDeallocateOp(FlatbufferObjectCache &cache, DeallocateOp op) { auto in = @@ -666,208 +701,263 @@ createDeallocateOp(FlatbufferObjectCache &cache, DeallocateOp op) { ::flatbuffers::Offset<::tt::target::ttnn::Operation> emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, - std::string const &debugString) { + std::string const &debugString, std::string const &locInfo) { if (auto getDeviceOp = dyn_cast(op); getDeviceOp) { - return createOperation(cache, createOp(cache, getDeviceOp), debugString); + return createOperation(cache, createOp(cache, getDeviceOp), debugString, + locInfo); } if (auto toMemoryConfigOp = dyn_cast(op); toMemoryConfigOp) { return createOperation(cache, createOp(cache, toMemoryConfigOp), - debugString); + debugString, locInfo); } if (auto toLayoutOp = dyn_cast(op); toLayoutOp) { - return createOperation(cache, createOp(cache, toLayoutOp), debugString); + return createOperation(cache, createOp(cache, toLayoutOp), debugString, + locInfo); } if (auto typecastOp = dyn_cast(op); typecastOp) { - return createOperation(cache, createOp(cache, typecastOp), debugString); + return createOperation(cache, createOp(cache, typecastOp), debugString, + locInfo); } if (auto toDeviceOp = dyn_cast(op); toDeviceOp) { - return createOperation(cache, createOp(cache, toDeviceOp), debugString); + return createOperation(cache, createOp(cache, toDeviceOp), debugString, + locInfo); } if (auto fromDeviceOp = dyn_cast(op); fromDeviceOp) { - return createOperation(cache, createOp(cache, fromDeviceOp), debugString); + return createOperation(cache, createOp(cache, fromDeviceOp), debugString, + locInfo); } if (auto emptyOp = dyn_cast(op); emptyOp) { - return createOperation(cache, createOp(cache, emptyOp), debugString); + return createOperation(cache, createOp(cache, emptyOp), debugString, + locInfo); } if (auto fullOp = dyn_cast(op); fullOp) { - return createOperation(cache, createOp(cache, fullOp), debugString); + return createOperation(cache, createOp(cache, fullOp), debugString, + locInfo); } if (auto absOp = dyn_cast(op); absOp) { - return createOperation(cache, createEltwiseOp(cache, absOp), debugString); + return createOperation(cache, createEltwiseOp(cache, absOp), debugString, + locInfo); } if (auto addOp = dyn_cast(op); addOp) { - return createOperation(cache, createEltwiseOp(cache, addOp), debugString); + return createOperation(cache, createEltwiseOp(cache, addOp), debugString, + locInfo); } if (auto floorOp = dyn_cast(op); floorOp) { - return createOperation(cache, createEltwiseOp(cache, floorOp), debugString); + return createOperation(cache, createEltwiseOp(cache, floorOp), debugString, + locInfo); } if (auto isFiniteOp = dyn_cast(op); isFiniteOp) { return createOperation(cache, createEltwiseOp(cache, isFiniteOp), - debugString); + debugString, locInfo); } if (auto andOp = dyn_cast(op); andOp) { - return createOperation(cache, createEltwiseOp(cache, andOp), debugString); + return createOperation(cache, createEltwiseOp(cache, andOp), debugString, + locInfo); } if (auto cbrtOp = dyn_cast(op); cbrtOp) { - return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString, + locInfo); } if (auto notOp = dyn_cast(op); notOp) { - return createOperation(cache, createEltwiseOp(cache, notOp), debugString); + return createOperation(cache, createEltwiseOp(cache, notOp), debugString, + locInfo); } if (auto orOp = dyn_cast(op); orOp) { - return createOperation(cache, createEltwiseOp(cache, orOp), debugString); + return createOperation(cache, createEltwiseOp(cache, orOp), debugString, + locInfo); } if (auto xorOp = dyn_cast(op); xorOp) { - return createOperation(cache, createEltwiseOp(cache, xorOp), debugString); + return createOperation(cache, createEltwiseOp(cache, xorOp), debugString, + locInfo); } if (auto multiplyOp = dyn_cast(op); multiplyOp) { return createOperation(cache, createEltwiseOp(cache, multiplyOp), - debugString); + debugString, locInfo); } if (auto negOp = dyn_cast(op); negOp) { - return createOperation(cache, createEltwiseOp(cache, negOp), debugString); + return createOperation(cache, createEltwiseOp(cache, negOp), debugString, + locInfo); } if (auto subtractOp = dyn_cast(op); subtractOp) { return createOperation(cache, createEltwiseOp(cache, subtractOp), - debugString); + debugString, locInfo); } if (auto eqOp = dyn_cast(op); eqOp) { - return createOperation(cache, createEltwiseOp(cache, eqOp), debugString); + return createOperation(cache, createEltwiseOp(cache, eqOp), debugString, + locInfo); } if (auto neOp = dyn_cast(op); neOp) { - return createOperation(cache, createEltwiseOp(cache, neOp), debugString); + return createOperation(cache, createEltwiseOp(cache, neOp), debugString, + locInfo); } if (auto geOp = dyn_cast(op); geOp) { - return createOperation(cache, createEltwiseOp(cache, geOp), debugString); + return createOperation(cache, createEltwiseOp(cache, geOp), debugString, + locInfo); } if (auto gtOp = dyn_cast(op); gtOp) { - return createOperation(cache, createEltwiseOp(cache, gtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, gtOp), debugString, + locInfo); } if (auto leOp = dyn_cast(op); leOp) { - return createOperation(cache, createEltwiseOp(cache, leOp), debugString); + return createOperation(cache, createEltwiseOp(cache, leOp), debugString, + locInfo); } if (auto ltOp = dyn_cast(op); ltOp) { - return createOperation(cache, createEltwiseOp(cache, ltOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ltOp), debugString, + locInfo); } if (auto maximumOp = dyn_cast(op); maximumOp) { return createOperation(cache, createEltwiseOp(cache, maximumOp), - debugString); + debugString, locInfo); } if (auto minimumOp = dyn_cast(op); minimumOp) { return createOperation(cache, createEltwiseOp(cache, minimumOp), - debugString); + debugString, locInfo); } if (auto reluOp = dyn_cast(op); reluOp) { - return createOperation(cache, createEltwiseOp(cache, reluOp), debugString); + return createOperation(cache, createEltwiseOp(cache, reluOp), debugString, + locInfo); } if (auto sqrtOp = dyn_cast(op); sqrtOp) { - return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString, + locInfo); } if (auto rsqrtOp = dyn_cast(op); rsqrtOp) { - return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString, + locInfo); } if (auto signOp = dyn_cast(op); signOp) { - return createOperation(cache, createEltwiseOp(cache, signOp), debugString); + return createOperation(cache, createEltwiseOp(cache, signOp), debugString, + locInfo); } if (auto expOp = dyn_cast(op); expOp) { - return createOperation(cache, createEltwiseOp(cache, expOp), debugString); + return createOperation(cache, createEltwiseOp(cache, expOp), debugString, + locInfo); } if (auto logOp = dyn_cast(op); logOp) { - return createOperation(cache, createEltwiseOp(cache, logOp), debugString); + return createOperation(cache, createEltwiseOp(cache, logOp), debugString, + locInfo); } if (auto expm1Op = dyn_cast(op); expm1Op) { - return createOperation(cache, createEltwiseOp(cache, expm1Op), debugString); + return createOperation(cache, createEltwiseOp(cache, expm1Op), debugString, + locInfo); } if (auto sigmoidOp = dyn_cast(op); sigmoidOp) { return createOperation(cache, createEltwiseOp(cache, sigmoidOp), - debugString); + debugString, locInfo); } if (auto log1pOp = dyn_cast(op); log1pOp) { - return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString); + return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString, + locInfo); + } + if (auto scatterOp = dyn_cast(op); scatterOp) { + return createOperation(cache, createEltwiseOp(cache, scatterOp), + debugString, locInfo); } if (auto reciprocalOp = dyn_cast(op); reciprocalOp) { return createOperation(cache, createEltwiseOp(cache, reciprocalOp), - debugString); + debugString, locInfo); } if (auto divOp = dyn_cast(op); divOp) { - return createOperation(cache, createEltwiseOp(cache, divOp), debugString); + return createOperation(cache, createEltwiseOp(cache, divOp), debugString, + locInfo); } if (auto remainderOp = dyn_cast(op); remainderOp) { return createOperation(cache, createEltwiseOp(cache, remainderOp), - debugString); + debugString, locInfo); } if (auto leakyReluOp = dyn_cast(op); leakyReluOp) { return createOperation(cache, createEltwiseOp(cache, leakyReluOp), - debugString); + debugString, locInfo); + } + if (auto linearOp = dyn_cast(op); linearOp) { + return createOperation(cache, createOp(cache, linearOp), debugString, + locInfo); } if (auto matmulOp = dyn_cast(op); matmulOp) { - return createOperation(cache, createOp(cache, matmulOp), debugString); + return createOperation(cache, createOp(cache, matmulOp), debugString, + locInfo); } if (auto sumOp = dyn_cast(op); sumOp) { - return createOperation(cache, createReductionOp(cache, sumOp), debugString); + return createOperation(cache, createReductionOp(cache, sumOp), debugString, + locInfo); } if (auto meanOp = dyn_cast(op); meanOp) { - return createOperation(cache, createReductionOp(cache, meanOp), - debugString); + return createOperation(cache, createReductionOp(cache, meanOp), debugString, + locInfo); } if (auto maxOp = dyn_cast(op); maxOp) { - return createOperation(cache, createReductionOp(cache, maxOp), debugString); + return createOperation(cache, createReductionOp(cache, maxOp), debugString, + locInfo); } if (auto embeddingOp = dyn_cast(op); embeddingOp) { return createOperation(cache, createEmbeddingOp(cache, embeddingOp), - debugString); + debugString, locInfo); } if (auto softmaxOp = dyn_cast(op); softmaxOp) { return createOperation(cache, createSoftmaxOp(cache, softmaxOp), - debugString); + debugString, locInfo); } if (auto transposeOp = dyn_cast(op); transposeOp) { return createOperation(cache, createTransposeOp(cache, transposeOp), - debugString); + debugString, locInfo); } if (auto clampOp = dyn_cast(op); clampOp) { return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp), - debugString); + debugString, locInfo); } if (auto conv2dOp = dyn_cast(op); conv2dOp) { - return createOperation(cache, createOp(cache, conv2dOp), debugString); + return createOperation(cache, createOp(cache, conv2dOp), debugString, + locInfo); } if (auto allGatherOp = dyn_cast(op); allGatherOp) { - return createOperation(cache, createOp(cache, allGatherOp), debugString); + return createOperation(cache, createOp(cache, allGatherOp), debugString, + locInfo); } if (auto concatOp = dyn_cast(op); concatOp) { - return createOperation(cache, createConcatOp(cache, concatOp), debugString); + return createOperation(cache, createConcatOp(cache, concatOp), debugString, + locInfo); } if (auto reshapeOp = dyn_cast(op); reshapeOp) { return createOperation(cache, createReshapeOp(cache, reshapeOp), - debugString); + debugString, locInfo); } if (auto sliceOp = dyn_cast(op); sliceOp) { - return createOperation(cache, createSliceOp(cache, sliceOp), debugString); + return createOperation(cache, createSliceOp(cache, sliceOp), debugString, + locInfo); } if (auto max_pool2dOp = dyn_cast(op); max_pool2dOp) { return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp), - debugString); + debugString, locInfo); } if (auto deallocateOp = dyn_cast(op); deallocateOp) { return createOperation(cache, createDeallocateOp(cache, deallocateOp), - debugString); + debugString, locInfo); } if (auto ceilOp = dyn_cast(op); ceilOp) { - return createOperation(cache, createEltwiseOp(cache, ceilOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ceilOp), debugString, + locInfo); } if (auto cosOp = dyn_cast(op); cosOp) { - return createOperation(cache, createEltwiseOp(cache, cosOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cosOp), debugString, + locInfo); } if (auto sinOp = dyn_cast(op); sinOp) { - return createOperation(cache, createEltwiseOp(cache, sinOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sinOp), debugString, + locInfo); } if (auto whereOp = dyn_cast(op); whereOp) { - return createOperation(cache, createEltwiseOp(cache, whereOp), debugString); + return createOperation(cache, createEltwiseOp(cache, whereOp), debugString, + locInfo); } if (auto geluOp = dyn_cast(op); geluOp) { - return createOperation(cache, createEltwiseOp(cache, geluOp), debugString); + return createOperation(cache, createEltwiseOp(cache, geluOp), debugString, + locInfo); + } + if (auto arangeOp = dyn_cast(op); arangeOp) { + return createOperation(cache, createOp(cache, arangeOp), debugString, + locInfo); } llvm_unreachable("unhandled op in emitTTNNOperation"); diff --git a/python/TTModule.cpp b/python/TTModule.cpp index f631b0116..b8d543410 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -16,14 +16,14 @@ namespace mlir::ttmlir::python { void populateTTModule(py::module &m) { - tt_attribute_class(m, "LayoutAttr") + tt_attribute_class(m, "MetalLayoutAttr") .def_static("get", [](MlirContext ctx, MlirType rankedTensorType, uint32_t memorySpaceValue, MlirAttribute grid, std::vector> collapseIntervals, uint32_t oobValValue, uint32_t memLayoutValue) { - return wrap(tt::LayoutAttr::get( + return wrap(tt::MetalLayoutAttr::get( unwrap(ctx), mlir::cast(unwrap(rankedTensorType)), static_cast(memorySpaceValue), @@ -37,7 +37,7 @@ void populateTTModule(py::module &m) { std::vector> collapseIntervals) { return wrap( - mlir::cast(unwrap(self)) + mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals)); @@ -47,7 +47,7 @@ void populateTTModule(py::module &m) { std::vector tensorShape, MlirAttribute grid, std::vector> collapseIntervals) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals); @@ -55,13 +55,13 @@ void populateTTModule(py::module &m) { .def_static( "with_element_type", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return wrap(mlir::cast(unwrap(self)) + return wrap(mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType))); }) .def_static( "with_element_type_", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType)); }) .def("getLayout", @@ -73,35 +73,45 @@ void populateTTModule(py::module &m) { mlir::cast(unwrap(type)); assert(tensor.getEncoding()); // Make sure that this Tensor has an // encoding value - tt::LayoutAttr layout = - mlir::cast(tensor.getEncoding()); + tt::MetalLayoutAttr layout = + mlir::cast(tensor.getEncoding()); return layout; }) - .def("wrapped", [](tt::LayoutAttr const &self) { return wrap(self); }) - .def_property_readonly( - "stride", - [](tt::LayoutAttr const &self, std::vector logicalShape) { - auto stride = self.getStride(logicalShape); - return std::vector(stride.begin(), stride.end()); - }) - .def_property_readonly("oobval", &tt::LayoutAttr::getOobVal) + .def("wrapped", + [](tt::MetalLayoutAttr const &self) { return wrap(self); }) + .def_property_readonly("stride", + [](tt::MetalLayoutAttr const &self, + std::vector logicalShape) { + auto stride = self.getStride(logicalShape); + return std::vector(stride.begin(), + stride.end()); + }) + .def_property_readonly("oobval", &tt::MetalLayoutAttr::getOobVal) .def_property_readonly("oobval_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast(la.getOobVal()); }) - .def_property_readonly("grid_attr", &tt::LayoutAttr::getGrid) - .def_property_readonly("memref", &tt::LayoutAttr::getMemref) - .def_property_readonly("memory_space", &tt::LayoutAttr::getMemorySpace) + .def_property_readonly("grid_attr", &tt::MetalLayoutAttr::getGrid) + .def_property_readonly( + "memref", + [](tt::MetalLayoutAttr self) { return wrap(self.getMemref()); }) + .def_property_readonly("memory_space", + &tt::MetalLayoutAttr::getMemorySpace) .def_property_readonly("memory_space_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast( la.getMemorySpace()); }) - .def_property_readonly("shard_shape", &tt::LayoutAttr::getShardShape) - .def_property_readonly("memory_layout", &tt::LayoutAttr::getMemLayout) - .def_property_readonly("memory_layout_as_int", [](tt::LayoutAttr la) { - return static_cast(la.getMemLayout()); - }); + .def_property_readonly("shard_shape", &tt::MetalLayoutAttr::getShardShape) + .def_property_readonly("memory_layout", + &tt::MetalLayoutAttr::getMemLayout) + .def_property_readonly( + "linear", + [](tt::MetalLayoutAttr self) { return wrap(self.getLinear()); }) + .def_property_readonly("memory_layout_as_int", + [](tt::MetalLayoutAttr la) { + return static_cast(la.getMemLayout()); + }); tt_attribute_class(m, "GridAttr") .def_static("get", @@ -236,6 +246,14 @@ void populateTTModule(py::module &m) { return self.getEthInactive().vec(); }); + tt_attribute_class(m, "CoreCoordAttr") + .def_static("get", + [](MlirContext ctx, int64_t y, int64_t x) { + return wrap(tt::CoreCoordAttr::get(unwrap(ctx), y, x)); + }) + .def_property_readonly("y", &tt::CoreCoordAttr::getY) + .def_property_readonly("x", &tt::CoreCoordAttr::getX); + tt_attribute_class(m, "ChipCoordAttr") .def_static("get", [](MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, @@ -276,29 +294,29 @@ void populateTTModule(py::module &m) { }) .def_static( "get", - [](MlirContext ctx, std::vector cpuDescs, - std::vector chipDescs, - std::vector chipDescIndices, - std::vector chipCapabilities, - std::vector chipCoords, - std::vector chipChannels) { + [](MlirContext ctx, const std::vector &cpuDescs, + const std::vector &chipDescs, + const std::vector &chipDescIndices, + const std::vector &chipCapabilities, + const std::vector &chipCoords, + const std::vector &chipChannels) { std::vector chipDescsUnwrapped; - for (auto chipDesc : chipDescs) { + for (const auto &chipDesc : chipDescs) { chipDescsUnwrapped.push_back( mlir::cast(unwrap(chipDesc))); } std::vector chipCapabilitiesUnwrapped; - for (auto chipCapability : chipCapabilities) { + for (const auto &chipCapability : chipCapabilities) { chipCapabilitiesUnwrapped.push_back( mlir::cast(unwrap(chipCapability))); } std::vector chipCoordsUnwrapped; - for (auto chipCoord : chipCoords) { + for (const auto &chipCoord : chipCoords) { chipCoordsUnwrapped.push_back( mlir::cast(unwrap(chipCoord))); } std::vector chipChannelsUnwrapped; - for (auto chipChannel : chipChannels) { + for (const auto &chipChannel : chipChannels) { chipChannelsUnwrapped.push_back( mlir::cast(unwrap(chipChannel))); } @@ -430,8 +448,11 @@ void populateTTModule(py::module &m) { return mlir::cast(unwrap(self)); }) .def_property_readonly("grid_attr", &tt::DeviceAttr::getWorkerGrid) - .def_property_readonly("l1_map", &tt::DeviceAttr::getL1Map) - .def_property_readonly("dram_map", &tt::DeviceAttr::getDramMap) + .def_property_readonly( + "l1_map", [](tt::DeviceAttr self) { return wrap(self.getL1Map()); }) + .def_property_readonly( + "dram_map", + [](tt::DeviceAttr self) { return wrap(self.getDramMap()); }) .def_property_readonly( "mesh_shape", [](tt::DeviceAttr const &self) { return self.getMeshShape().vec(); }) @@ -447,7 +468,10 @@ void populateTTModule(py::module &m) { unwrap(ctx), SmallVector{height, width}, static_cast(dataType))); }) - .def_property_readonly("data_type", &tt::TileType::getDataType) + .def_property_readonly("data_type_as_int", + [](tt::TileType self) { + return static_cast(self.getDataType()); + }) .def_property_readonly("shape", [](tt::TileType const &tile) { return std::vector({tile.getHeight(), tile.getWidth()}); }); diff --git a/python/TTNNModule.cpp b/python/TTNNModule.cpp index 24bd05c8f..11e47982d 100644 --- a/python/TTNNModule.cpp +++ b/python/TTNNModule.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "mlir/CAPI/AffineMap.h" #include "ttmlir/Bindings/Python/TTMLIRModule.h" namespace mlir::ttmlir::python { @@ -127,5 +128,27 @@ void populateTTNNModule(py::module &m) { }) .def_property_readonly("y", &tt::ttnn::MeshShapeAttr::getY) .def_property_readonly("x", &tt::ttnn::MeshShapeAttr::getX); + + tt_attribute_class(m, "TTNNLayoutAttr") + .def_static("get", + [](MlirContext ctx, MlirAffineMap linear, MlirAttribute grid, + MlirType memref, unsigned memLayout) { + return wrap(tt::ttnn::TTNNLayoutAttr::get( + unwrap(ctx), mlir::cast(unwrap(linear)), + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), + static_cast(memLayout))); + }) + .def_property_readonly( + "linear", + [](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getLinear()); }) + .def_property_readonly("grid_attr", &tt::ttnn::TTNNLayoutAttr::getGrid) + .def_property_readonly( + "memref", + [](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); }) + .def_property_readonly( + "memory_layout_as_int", [](tt::ttnn::TTNNLayoutAttr self) { + return static_cast(self.getMemLayout()); + }); } } // namespace mlir::ttmlir::python diff --git a/python/ttmlir/dialects/ttnn.py b/python/ttmlir/dialects/ttnn.py index d81f58111..659938cf6 100644 --- a/python/ttmlir/dialects/ttnn.py +++ b/python/ttmlir/dialects/ttnn.py @@ -3,4 +3,5 @@ # SPDX-License-Identifier: Apache-2.0 from ._ttnn_ops_gen import * +from ._ttnn_enum_gen import * from .._mlir_libs._ttmlir import register_dialect, ttnn_ir as ir diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 7a68a7e94..5544e1d70 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -47,6 +47,8 @@ void wait(Event event); std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 6c55ac1de..67aa91a71 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -83,6 +83,8 @@ void wait(Event event); std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 1dc721f66..e4348da60 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -71,6 +71,8 @@ void wait(Event event); std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 586b8394e..a57ac3fcd 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -261,6 +261,21 @@ std::string getOpDebugString(OpContext opContextHandle) { throw std::runtime_error("runtime is not enabled"); } +std::string getOpLocInfo(OpContext opContextHandle) { +#ifdef TT_RUNTIME_ENABLE_TTNN + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getOpLocInfo(opContextHandle); + } +#endif + +#ifdef TT_RUNTIME_ENABLE_TTMETAL + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getOpLocInfo(opContextHandle); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle) { #if defined(TT_RUNTIME_ENABLE_TTNN) diff --git a/runtime/lib/ttmetal/CMakeLists.txt b/runtime/lib/ttmetal/CMakeLists.txt index 3706d7433..f31fad253 100644 --- a/runtime/lib/ttmetal/CMakeLists.txt +++ b/runtime/lib/ttmetal/CMakeLists.txt @@ -10,7 +10,7 @@ target_include_directories(TTRuntimeTTMetal PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) target_include_directories(TTRuntimeTTMetal SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY) -add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY tt-metal FBS_GENERATION) +target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY DEVICE_LIBRARY) +add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY DEVICE_LIBRARY tt-metal FBS_GENERATION) # Optionally compile profiling code and link tracy client for perf profiling. diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index ab343554e..22d43ba36 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -262,6 +262,12 @@ std::string getOpDebugString(OpContext opContextHandle) { return ""; } +std::string getOpLocInfo(OpContext opContextHandle) { + // Not implemented + LOG_WARNING("obtaining op location info for metal runtime not implemented"); + return ""; +} + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle) { // Not implemented diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h index ca50ad58b..75b22d114 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h @@ -6,6 +6,7 @@ #define TT_RUNTIME_TTNN_UTILS_H #include "flatbuffers/vector.h" +#include "tt_metal/impl/buffers/buffer.hpp" #include "ttmlir/Target/Common/types_generated.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttnn/types.hpp" diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 4edc4780b..38115803f 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -5,6 +5,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp diff --git a/runtime/lib/ttnn/operations/creation/arange.cpp b/runtime/lib/ttnn/operations/creation/arange.cpp new file mode 100644 index 000000000..446cdf72a --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "arange.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include +#include +#include + +namespace tt::runtime::ttnn::operations::creation { +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + ::ttnn::DataType dtype = + ::ttnn::DataType::BFLOAT16; // Default in arange implementation + std::optional> device = std::nullopt; + ::ttnn::MemoryConfig memoryConfig = + ::ttnn::DRAM_MEMORY_CONFIG; // Default in arange implementation + + if (op->dtype()) { + dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(*(op->dtype())); + } + + if (op->memcfg()) { + memoryConfig = utils::createMemoryConfig(op->memcfg(), op->out()); + } + + if (op->device()) { + // ttnn::arange supports no device (host) and single device + DeviceVariant targetDevice = + context.getTargetDevice(op->device()->global_id()); + + LOG_ASSERT(std::holds_alternative>( + targetDevice), + "ttnn::arange does not support MeshDevice."); + device = std::make_optional( + std::get>(targetDevice)); + } + ::ttnn::Tensor out = ::ttnn::arange(op->start(), op->end(), op->step(), dtype, + device, memoryConfig); + + utils::updateTensorPool(tensorPool, out, op->out()->global_id()); +} +} // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/arange.h b/runtime/lib/ttnn/operations/creation/arange.h new file mode 100644 index 000000000..157ee2dc6 --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::creation { + +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::creation + +#endif diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index 2a05d6246..5c1d056f9 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -41,6 +41,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::remainder); break; } + case ::tt::target::ttnn::EltwiseOpType::Scatter: { + runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::scatter); + break; + } default: LOG_FATAL("Unsupported Eltwise Binary Composite operation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h index 9be8bc6b7..bd497fe98 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h @@ -15,6 +15,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { case ::tt::target::ttnn::EltwiseOpType::Maximum: case ::tt::target::ttnn::EltwiseOpType::Minimum: case ::tt::target::ttnn::EltwiseOpType::Remainder: + case ::tt::target::ttnn::EltwiseOpType::Scatter: return true; default: return false; diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp index a54777ab2..f97f71e40 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp @@ -7,6 +7,15 @@ namespace tt::runtime::ttnn::operations::binary { +bool shouldSwapBinaryOperands(const ::tt::target::ttnn::EltwiseOp *op, + ::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) { + // For scatter, we expect the left-hand side operator to be lesser or equal in + // volume to the right hand side, so we omit the swap. + return (op->type() != ::tt::target::ttnn::EltwiseOpType::Scatter && + workaround::Env::get().swapBinaryOperands && + (*lhs)->volume() < (*rhs)->volume()); +} + void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, ::ttnn::Tensor **lhs, @@ -21,8 +30,7 @@ void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, // TODO(bug #1124): We're currently swapping the operands for binary ops // in runtime if the lhs operand is smaller (and requires broadcast onto the // rhs operand). We should add this check in the compiler. - if (workaround::Env::get().swapBinaryOperands && - (*lhs)->volume() < (*rhs)->volume()) { + if (shouldSwapBinaryOperands(op, lhs, rhs)) { std::swap(*lhs, *rhs); } } diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index 435607b87..c595fe26b 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -125,7 +125,10 @@ createMemoryConfig(const ::tt::target::TensorRef *tensorRef) { ::tt::tt_metal::BufferType ttnnBufferType = ::tt::runtime::ttnn::utils::toTTNNBufferType(targetMemorySpace); - return {ttnnMemLayout, ttnnBufferType, shardSpec}; + return {ttnnMemLayout, ttnnBufferType, + ttnnMemLayout == tt_metal::TensorMemoryLayout::INTERLEAVED + ? std::nullopt + : std::make_optional(shardSpec)}; } // Prefer to use this method over the one above @@ -169,8 +172,11 @@ createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, ttnnCoreRangeSet, ttnnShardShape, ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); - ::ttnn::MemoryConfig memoryConfig = {tensorMemoryLayout, bufferType, - shardSpec}; + ::ttnn::MemoryConfig memoryConfig = { + tensorMemoryLayout, bufferType, + tensorMemoryLayout == tt_metal::TensorMemoryLayout::INTERLEAVED + ? std::nullopt + : std::make_optional(shardSpec)}; return memoryConfig; } diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index abe71f970..a25102d9a 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -8,8 +8,8 @@ #include "tt/runtime/ttnn/operations/utils.h" #include -// ANCHOR: adding_an_op_matmul_runtime_operations namespace tt::runtime::ttnn::operations::matmul { +// ANCHOR: adding_an_op_matmul_runtime_operations void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); @@ -20,10 +20,6 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ::tt::tt_metal::MemoryConfig outputMemoryConfig = utils::createMemoryConfig(op->out()); - std::optional< - ::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig> - programConfig = std::nullopt; - const std::optional memoryConfig = std::make_optional(outputMemoryConfig); @@ -37,5 +33,35 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { tensorPool.insert_or_assign(op->out()->global_id(), out); } -} // namespace tt::runtime::ttnn::operations::matmul // ANCHOR_END: adding_an_op_matmul_runtime_operations + +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + std::optional<::ttnn::Tensor> bias = + op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) + : std::nullopt; + + DEBUG_ASSERT(lhs.is_allocated()); + DEBUG_ASSERT(rhs.is_allocated()); + DEBUG_ASSERT(!bias || bias->is_allocated()); + + ::ttnn::DataType outputDataType = utils::getDataType(op->out()); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + + const std::optional memoryConfig = + std::make_optional(outputMemoryConfig); + + const std::optional dtype = + std::make_optional(outputDataType); + + ::ttnn::Tensor out = ::ttnn::linear( + lhs, rhs, bias, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, + dtype, /*programConfig*/ std::nullopt, /*activation*/ std::nullopt, + /*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::matmul diff --git a/runtime/lib/ttnn/operations/matmul/matmul.h b/runtime/lib/ttnn/operations/matmul/matmul.h index 5957a54a3..7b0583786 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.h +++ b/runtime/lib/ttnn/operations/matmul/matmul.h @@ -10,6 +10,7 @@ namespace tt::runtime::ttnn::operations::matmul { void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context); +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context); } // namespace tt::runtime::ttnn::operations::matmul #endif diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index dfd8b9375..4fc6fca87 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -31,11 +31,14 @@ preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, op->dilation_width() * (op->kernel_width() - 1) - 1) / op->stride_width(); + constexpr bool en_ch_padding = false; + auto parallel_config = ::ttnn::operations::conv::conv2d::determine_parallel_config( ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(), op->channels(), output_height, output_width, op->channels(), - device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR); + device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, + en_ch_padding); auto sharded_memory_config = ::ttnn::operations::conv::conv2d:: create_sharded_memory_config_from_parallel_config(inputShape, parallel_config, 1); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 8cfa01389..3aab3a94c 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -4,6 +4,7 @@ #include "operations/ccl/all_gather.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" +#include "operations/creation/arange.h" #include "operations/creation/empty.h" #include "operations/creation/full.h" #include "operations/data_movement/concat.h" @@ -32,9 +33,19 @@ #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" +#ifdef TT_RUNTIME_ENABLE_PERF_TRACE +#include "tracy/Tracy.hpp" +#endif + namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; +void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) { +#ifdef TT_RUNTIME_ENABLE_PERF_TRACE + TracyMessage(op->loc_info()->c_str(), op->loc_info()->size()); +#endif +} + static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); @@ -73,6 +84,7 @@ class ProgramExecutor { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); + tracyLogOpLocation(op); runOperation(op); runCallback(executableHandle, op, &context); } @@ -148,6 +160,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::EltwiseOp: { return runEltwiseOperation(op->type_as_EltwiseOp()); } + case ::tt::target::ttnn::OpType::LinearOp: { + return operations::matmul::run(op->type_as_LinearOp(), context); + } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { return operations::matmul::run(op->type_as_MatmulOp(), context); @@ -186,6 +201,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::AllGatherOp: { return operations::ccl::run(op->type_as_AllGatherOp(), context); } + case ::tt::target::ttnn::OpType::ArangeOp: { + return operations::creation::run(op->type_as_ArangeOp(), context); + } default: { LOG_FATAL("Unsupported operation type"); } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 86fd2d25c..2dfc07788 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -202,6 +202,12 @@ std::string getOpDebugString(OpContext opContextHandle) { return std::string(opContext.debug_info()->c_str()); } +std::string getOpLocInfo(OpContext opContextHandle) { + auto const &opContext = + opContextHandle.as<::tt::target::ttnn::Operation>(DeviceRuntime::TTNN); + return std::string(opContext.loc_info()->c_str()); +} + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle) { auto const &programContext = diff --git a/runtime/test/CMakeLists.txt b/runtime/test/CMakeLists.txt index 8a0d12ee3..e4a7adc40 100644 --- a/runtime/test/CMakeLists.txt +++ b/runtime/test/CMakeLists.txt @@ -37,6 +37,7 @@ target_include_directories(TTRuntimeTEST INTERFACE target_link_libraries(TTRuntimeTEST INTERFACE TTMETAL_LIBRARY + DEVICE_LIBRARY TTBinary TTRuntime TTRuntimeTTNN diff --git a/runtime/tools/python/CMakeLists.txt b/runtime/tools/python/CMakeLists.txt index a4c7a5191..353ebbe7d 100644 --- a/runtime/tools/python/CMakeLists.txt +++ b/runtime/tools/python/CMakeLists.txt @@ -12,6 +12,7 @@ add_custom_target(ttrt TT_RUNTIME_ENABLE_PERF_TRACE=${TT_RUNTIME_ENABLE_PERF_TRACE} TT_RUNTIME_DEBUG=${TT_RUNTIME_DEBUG} TT_RUNTIME_WORKAROUNDS=${TT_RUNTIME_WORKAROUNDS} + TTMLIR_BINARY_DIR=${TTMLIR_BINARY_DIR} TTMLIR_VERSION_MAJOR=${TTMLIR_VERSION_MAJOR} TTMLIR_VERSION_MINOR=${TTMLIR_VERSION_MINOR} TTMLIR_VERSION_PATCH=${TTMLIR_VERSION_PATCH} diff --git a/runtime/tools/python/setup.py b/runtime/tools/python/setup.py index ddbe3da9f..f5d148578 100644 --- a/runtime/tools/python/setup.py +++ b/runtime/tools/python/setup.py @@ -18,6 +18,11 @@ "SOURCE_ROOT", os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", ".."), ) +# Use 'src_dir/build' as default location if TTMLIR_BINARY_DIR env variable is not available. +ttmlir_build_dir = os.environ.get( + "TTMLIR_BINARY_DIR", + os.path.join(src_dir, "build"), +) toolchain = os.environ.get("TTMLIR_TOOLCHAIN_DIR", "/opt/ttmlir-toolchain") metaldir = f"{src_dir}/third_party/tt-metal/src/tt-metal-build" ttmetalhome = os.environ.get("TT_METAL_HOME", "") @@ -37,12 +42,12 @@ include_dirs=[ f"{toolchain}/include", f"{src_dir}/runtime/include", - f"{src_dir}/build/include", - f"{src_dir}/build/include/ttmlir/Target/Common", + f"{ttmlir_build_dir}/include", + f"{ttmlir_build_dir}/include/ttmlir/Target/Common", ], libraries=["TTBinary", "flatbuffers"], library_dirs=[ - f"{src_dir}/build/runtime/lib", + f"{ttmlir_build_dir}/runtime/lib", f"{toolchain}/lib", ], define_macros=[("VERSION_INFO", __version__)], @@ -80,13 +85,13 @@ for dylib in runlibs: shutil.copy( f"{metaldir}/lib/{dylib}", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", ) command = [ "patchelf", "--set-rpath", "$ORIGIN", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/{dylib}", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/{dylib}", ] try: @@ -103,7 +108,7 @@ for dylib in perflibs: shutil.copy( f"{metaldir}/tools/profiler/bin/{dylib}", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", ) shutil.copy( f"{metaldir}/tools/profiler/bin/{dylib}", @@ -169,7 +174,7 @@ def tt_metal_ignore_folders(folder, contents): # copy metal dir folder shutil.copytree( f"{ttmetalhome}/tt_metal", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tt_metal", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tt_metal", dirs_exist_ok=True, ignore=tt_metal_ignore_folders, ) @@ -177,14 +182,14 @@ def tt_metal_ignore_folders(folder, contents): # copy runtime dir folder shutil.copytree( f"{ttmetalhome}/runtime", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/runtime", dirs_exist_ok=True, ) # copy kernels shutil.copytree( f"{ttmetalhome}/ttnn", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/ttnn", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/ttnn", dirs_exist_ok=True, ) @@ -198,16 +203,16 @@ def package_files(directory): return paths extra_files_tt_metal = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tt_metal/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tt_metal/" ) extra_files_runtime = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/runtime/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/runtime/" ) extra_files_ttnn = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/ttnn/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/ttnn/" ) extra_files_tests = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tests/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tests/" ) metallibs += extra_files_tt_metal @@ -222,18 +227,18 @@ def package_files(directory): include_dirs=[ f"{toolchain}/include", f"{src_dir}/runtime/include", - f"{src_dir}/build/include", - f"{src_dir}/build/include/ttmlir/Target/Common", + f"{ttmlir_build_dir}/include", + f"{ttmlir_build_dir}/include/ttmlir/Target/Common", ], libraries=["TTRuntime"] + linklibs + ["flatbuffers"], library_dirs=[ - f"{src_dir}/build/runtime/lib", - f"{src_dir}/build/runtime/lib/common", - f"{src_dir}/build/runtime/lib/ttnn", - f"{src_dir}/build/runtime/lib/ttnn/operations", - f"{src_dir}/build/runtime/lib/ttmetal", + f"{ttmlir_build_dir}/runtime/lib", + f"{ttmlir_build_dir}/runtime/lib/common", + f"{ttmlir_build_dir}/runtime/lib/ttnn", + f"{ttmlir_build_dir}/runtime/lib/ttnn/operations", + f"{ttmlir_build_dir}/runtime/lib/ttmetal", f"{toolchain}/lib", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", f"{metaldir}/lib", ], define_macros=[ diff --git a/runtime/tools/python/ttrt/common/perf.py b/runtime/tools/python/ttrt/common/perf.py index a341c2b4f..f70defa31 100644 --- a/runtime/tools/python/ttrt/common/perf.py +++ b/runtime/tools/python/ttrt/common/perf.py @@ -17,11 +17,16 @@ import atexit import traceback from pathlib import Path +import csv from ttrt.common.util import * from ttrt.common.query import Query +def get_loc_data_hook(binary, programContext, opContext): + op_debug_str = ttrt.runtime.get_op_debug_str(opContext) + + class Perf: registered_args = {} @@ -456,6 +461,38 @@ def signal_handler(sig, frame): ) process_ops(None, None, False) + + # Add post-processing steps to insert location data into the ops_perf data file + with open(profiler_csv_file_path, "r") as perf_file: + perf_reader = csv.DictReader(perf_file) + headers = list(perf_reader.fieldnames) + ["LOC"] + perf_data = list(perf_reader) + + with open(profiler_csv_file_path, "w+") as perf_file, open( + tracy_ops_data_file_path, "r" + ) as message_file: + message_reader = csv.reader(message_file, delimiter=";") + ops_index = 0 + prev = None + for message in message_reader: + message = message[0] # Don't need timestamp information + if message.startswith("`"): + # This is a TTNN Message + # The location data is now in the previous message + # The order of data is maintained in perf_data so as the messages are received, they update the id last encountered. + # Now that we have a new message, we can update the location data from the previous message + if prev: + # Get the location data from the previous message and add it as new data for the perf_data (as a new col) + if len(perf_data) > ops_index: + perf_data[ops_index]["LOC"] = prev + ops_index += 1 + else: + prev = message + perf_writer = csv.DictWriter(perf_file, fieldnames=headers) + perf_writer.writeheader() + for row in perf_data: + perf_writer.writerow(row) + self.file_manager.copy_file( perf_folder_path, profiler_csv_file_path, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index dfc4a6820..c0378727c 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -100,6 +100,8 @@ PYBIND11_MODULE(_C, m) { "Get the input tensor of the op"); m.def("get_op_debug_str", &tt::runtime::getOpDebugString, "Get the debug string of the op"); + m.def("get_op_loc_info", &tt::runtime::getOpLocInfo, + "Get the location info of the op"); py::class_(m, "DebugEnv") .def_static("get", &tt::runtime::debug::Env::get) diff --git a/test/python/tensor_layout.py b/test/python/tensor_layout.py index 39a9a728b..2dbf249e9 100644 --- a/test/python/tensor_layout.py +++ b/test/python/tensor_layout.py @@ -34,7 +34,7 @@ def createTensorLayout( shape, F32Type.get(ctx), None, Location.unknown(ctx) ) memoryLayout = getTensorMemoryLayout(memorySpace) - layout = tt.ir.LayoutAttr.get( + layout = tt.ir.MetalLayoutAttr.get( ctx, tensorTy, memorySpace, grid, collapseIntervals, oobVal, memoryLayout ) return RankedTensorType.get(shape, F32Type.get(ctx), layout, Location.unknown(ctx)) @@ -42,7 +42,7 @@ def createTensorLayout( def tilize(tensor, dataType, tileShape=[32, 32]): assert len(tileShape) == 2 - return tt.ir.LayoutAttr.with_element_type_( + return tt.ir.MetalLayoutAttr.with_element_type_( ctx, tensor.encoding, tt.ir.TileType.get(ctx, tileShape[0], tileShape[1], dataType), @@ -52,15 +52,15 @@ def tilize(tensor, dataType, tileShape=[32, 32]): def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): if isinstance(grid, list) or isinstance(grid, tuple): grid = tt.ir.GridAttr.get(ctx, list(grid)) - return tt.ir.LayoutAttr.with_grid_( + return tt.ir.MetalLayoutAttr.with_grid_( ctx, tensor.encoding, tensor.shape, grid, collapseIntervals ) t0 = createTensorLayout([2, 3, 64, 128], [2, 4]) -# CHECK: tensor<2x3x64x128xf32, #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<2x3x64x128xf32, #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> print(t0) -# CHECK: #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t0, tt.DataType.BFP_BFloat8).wrapped()) print(parallelize(t0, [3, 2]).wrapped()) @@ -69,24 +69,24 @@ def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): print(parallelize(t1, [3, 2]).wrapped()) t2 = createTensorLayout([128], [4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> print(t2) -# CHECK: #tt.layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [2]).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [1, 2]).wrapped()) t3 = createTensorLayout([128], [1, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> print(t3) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t3, tt.DataType.BFP_BFloat8).wrapped()) t4 = createTensorLayout([128], [1, 2, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> print(t4) -# CHECK: #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t4, tt.DataType.BFP_BFloat8).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t4, [1, 2]).wrapped()) diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir index fa6cbb423..42a26ad15 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir @@ -8,3 +8,54 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic return %1 : tensor<512x512xf32> } } + +module { + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<1x23x40x1xf32>) -> tensor<1x23x40x128xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<128xf32>) -> tensor<1x23x40x128xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %2 = stablehlo.divide %0, %1 : tensor<1x23x40x128xf32> + return %2 : tensor<1x23x40x128xf32> + } +} + +module { + func.func @main(%arg0: tensor<32xi64>, %arg1: tensor<32x1xi64>) -> tensor<32x32xi1> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<32xi64>) -> tensor<32x32xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<32x1xi64>) -> tensor<32x32xi64> + %2 = stablehlo.compare GT, %0, %1, SIGNED : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<32x32xi1> + } +} + +module { + func.func @main(%arg0: tensor<16x1xf32>, %arg1: tensor<1x1x32xi64>) -> tensor<1x16x32xf32> { + %0 = stablehlo.convert %arg1 : (tensor<1x1x32xi64>) -> tensor<1x1x32xf32> + %1 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<16x1xf32>) -> tensor<1x16x32xf32> + %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1x32xf32>) -> tensor<1x16x32xf32> + %3 = stablehlo.multiply %1, %2 : tensor<1x16x32xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %3 : tensor<1x16x32xf32> + } +} + +module { + func.func @main(%arg0: tensor<1x10xi64>, %arg1: tensor<10x1xi64>) -> tensor<10x10xi64> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x10xi64>) -> tensor<10x10xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<10x1xi64>) -> tensor<10x10xi64> + %2 = stablehlo.subtract %0, %1 : tensor<10x10xi64> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<10x10xi64> + } +} + +module { + func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<8xf32>) -> tensor<8xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<8xf32> + %2 = stablehlo.add %0, %1 : tensor<8xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<8xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir new file mode 100644 index 000000000..5fbab794c --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir @@ -0,0 +1,83 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s + +// jax/pjrt sharding target 1x2 for n300 +module @jit_matmul_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,2]<=[2]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x392xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<392x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<8192x392xf32>, tensor<392x16384xf32>) -> tensor<8192x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<8192x392xf32>, %arg1: tensor<392x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x392xf32>, tensor<392x16384xf32>) -> tensor<8192x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<8192x16384xf32> + } +} + +// jax/pjrt sharding target 2x4 for t3k +module @jit_matmul_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<4096x196xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<196x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<4096x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<4096x196xf32>, %arg1: tensor<196x16384xf32>) -> (tensor<4096x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<4096x16384xf32> + } +} + +// jax/pjrt sharding target 1x8 for t3k +module @jit_matmul_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x98xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<98x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<8192x98xf32>, tensor<98x16384xf32>) -> tensor<8192x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<8192x98xf32>, %arg1: tensor<98x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x98xf32>, tensor<98x16384xf32>) -> tensor<8192x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<8192x16384xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir similarity index 100% rename from test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir new file mode 100644 index 000000000..52e2d8001 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module { + func.func @main(%arg0: tensor<8x1x920xbf16>, %arg1: tensor<8x100x32xbf16>, %arg2: tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<8x32x920xbf16>) -> tensor<8x32x920xbf16> + // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] + %1 = stablehlo.dot_general %arg1, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x100x32xbf16>, tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> + return %1 : tensor<8x100x920xbf16> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir new file mode 100644 index 000000000..43241ac6f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir @@ -0,0 +1,11 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_dnamic_iota attributes {} { + func.func public @test_dynamic_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %output_shape = stablehlo.constant dense<[1, 32, 128, 128]> : tensor<4xi64> + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 1: i64} : (tensor<4xi64>) -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir index ba29d123e..e80bb7588 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir @@ -8,6 +8,7 @@ module @jit_gather attributes {} { // CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] return %0 : tensor<1x32x1024xf32> } + func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] @@ -22,4 +23,20 @@ module @jit_gather attributes {} { return %0 : tensor<1x2x384xf32> } + func.func public @test_gather_3(%arg0: tensor<32128x512xbf16>, %arg1: tensor<1x15xi64>) -> tensor<1x15x512xbf16> { + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<1x15x512xbf16> + // CHECK: %[[VAL:[0-9]+]] = "ttir.gather"(%arg0, %arg1, %[[EMPTY]]) + // CHECK-SAME: collapsed_slice_dims = array, + // CHECK-SAME: index_vector_dim = 2 : si64, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: offset_dims = array, + // CHECK-SAME: operand_batching_dims = array, + // CHECK-SAME: slice_sizes = array, + // CHECK-SAME: start_index_map = array, + // CHECK-SAME: start_indices_batching_dims = array + // CHECK-SAME: (tensor<32128x512xbf16>, tensor<1x15xi32>, tensor<1x15x512xbf16>) -> tensor<1x15x512xbf16> + %0 = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16> + // CEHCK: return %[[VAL]] : tensor<1x15x512xbf16> + return %0 : tensor<1x15x512xbf16> + } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir new file mode 100644 index 000000000..857a621bb --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_iota attributes {} { + func.func public @test_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %0 = "stablehlo.iota"() {iota_dimension = 1: i64} : () -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir new file mode 100644 index 000000000..92cd8895f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_scatter attributes {} { + func.func public @test_scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x1xi64>, %arg2: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE1:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + %result = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi64>, tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> + // CHECK: [[VAL1:%[0-9]+]] = "ttir.scatter"(%arg0, %arg1, %arg2, [[VAL0]]) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array} + // CHECK: ([[TENSOR_SIZE1]], tensor<1x1xi32>, tensor<1x3x32x32xf32>, [[TENSOR_SIZE1]]) -> tensor<1x3x320x320xf32> + return %result : tensor<1x3x320x320xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE1]] + } +} diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir new file mode 100644 index 000000000..6f72e56f1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + return %1 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir new file mode 100644 index 000000000..8365bbddd --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir @@ -0,0 +1,26 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: %{{[0-9]+}} = "ttir.slice" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } + + func.func @select_multi_slice(%arg0: tensor<4x2x64x128xf32>) -> tensor<4x2x64x32xf32> { + %0 = tensor.empty() : tensor<4x2x64x32xf32> + + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.concat" + %1 = "ttir.select"(%arg0, %0) <{dim = -1: si32, begin = 0: si32, length = 4: si32, stride = 16: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x2x64x128xf32>, tensor<4x2x64x32xf32>) -> tensor<4x2x64x32xf32> + + return %1 : tensor<4x2x64x32xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir new file mode 100644 index 000000000..522628160 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -0,0 +1,194 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for linear operation + +// Verify that the parsing fails if either of operands is a scalar +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_a(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_b(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input B must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Bias must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// Verifty that the parsing fails if the output is a scalar +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { + // CHECK: error: 'ttir.linear' op Scalar output is not supported, output must be at least a 1D tensor + %0 = tensor.empty() : tensor + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor + return %1 : tensor + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_output_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { + // CHECK: error: 'ttir.linear' op Scalar output must be a 1D tensor of size 1 + %0 = tensor.empty() : tensor<2xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + return %1 : tensor<2xbf16> + } +} + +// Inner dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { +func.func @linear_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](64) and B[-2](128) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// Batch dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(7) and B(2) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } +} + +// Bias shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_matmul_bias_broadcast_incompatible(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64) is not broadcast compatible with the matmul output shape(64,64) + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<2x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_matmul_bias_broadcast_incompatible(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64x64xbf16>) -> tensor<3x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64,64) is not broadcast compatible with the matmul output shape(3,64,64) + %0 = tensor.empty() : tensor<3x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + return %1 : tensor<3x64x64xbf16> + } +} + +// Output shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Output shape rank(1) must match the expected output shape rank(2) + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { + // CHECK: error: 'ttir.linear' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) + %0 = tensor.empty() : tensor<64x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir new file mode 100644 index 000000000..f505bfcb7 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir @@ -0,0 +1,116 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_dim(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid dimension}} + %1 = "ttir.select"(%arg0, %0) <{dim = -3: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_stride(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid stride.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 7: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_stride_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid stride.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = -1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_begin(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid begin index.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = -3: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_begin_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid begin index.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 4: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 5: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 0: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length_3(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_total_size(%arg0: tensor<4x2x64x48xf32>) -> tensor<4x2x4x48xf32> { + %0 = tensor.empty() : tensor<4x2x4x48xf32> + // CHECK: {{.*error.*Sum of all slices.*}} + %1 = "ttir.select"( %arg0, %0) <{dim = 2: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x2x64x48xf32>, tensor<4x2x4x48xf32>) -> tensor<4x2x4x48xf32> + return %1 : tensor<4x2x4x48xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir new file mode 100644 index 000000000..b613c85bf --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir @@ -0,0 +1,44 @@ +// RUN: ttmlir-opt %s | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } + + func.func @select_half(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { + %0 = tensor.empty() : tensor<4x2xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> + return %1 : tensor<4x2xf32> + } + + func.func @select_single(%arg0: tensor<4x4xf32>) -> tensor<4x1xf32> { + %0 = tensor.empty() : tensor<4x1xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 3: si32, length = 1: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x1xf32>) -> tensor<4x1xf32> + return %1 : tensor<4x1xf32> + } + + func.func @select_half_2_no_stride(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { + %0 = tensor.empty() : tensor<4x2xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 2: si32, length = 2: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> + return %1 : tensor<4x2xf32> + } + + func.func @select_neg_dim(%arg0: tensor<10x3x128x64xf32>) -> tensor<10x3x8x64xf32> { + %0 = tensor.empty() : tensor<10x3x8x64xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = -2: si32, begin = 0: si32, length = 2: si32, stride = 32: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<10x3x128x64xf32>, tensor<10x3x8x64xf32>) -> tensor<10x3x8x64xf32> + return %1 : tensor<10x3x8x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir index 2335fb0df..42cab3d1f 100644 --- a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir @@ -3,21 +3,21 @@ #dram = #tt.memory_space #l1_ = #tt.memory_space -// CHECK-DAG: #[[row_major1x1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major2x2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major2x2:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> -#row_major1x1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major1x1_T = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -#tile1x1_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -#tile1x1_bf16 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -#tile1x1_f32_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -#tile2x2_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +#row_major1x1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major1x1_T = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +#tile1x1_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +#tile1x1_bf16 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +#tile1x1_f32_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +#tile2x2_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> func.func @noncompound_linear(%in: tensor<64x128xf32, #row_major1x1>) -> tensor<64x128xf32, #row_major1x1_T> { %out = tensor.empty() : tensor<64x128xf32, #row_major1x1_T> diff --git a/test/ttmlir/Dialect/TTIR/test_allocate.mlir b/test/ttmlir/Dialect/TTIR/test_allocate.mlir index a80a8c1c9..5888cf3f6 100644 --- a/test/ttmlir/Dialect/TTIR/test_allocate.mlir +++ b/test/ttmlir/Dialect/TTIR/test_allocate.mlir @@ -1,7 +1,7 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> module attributes {} { func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { // CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]] diff --git a/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir new file mode 100644 index 000000000..e1454ad0a --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir @@ -0,0 +1,28 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for Broadcastable interface + +// CHECK: 'ttir.abs' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<1x64xbf16>) -> tensor<2x64xbf16> { + %0 = tensor.empty() : tensor<2x64xbf16> + %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16> + return %1 : tensor<2x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<2x3x64xf32>, %arg1: tensor<64xf32>) -> tensor<4x2x3x64xf32> { + %0 = tensor.empty() : tensor<4x2x3x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32> + return %1 : tensor<4x2x3x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<3x64xf32>, %arg1: tensor<1x3x64xf32>, %arg2: tensor<2x1x64xf32>) -> tensor<1x2x3x64xf32> { + %0 = tensor.empty() : tensor<1x2x3x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32> + return %1 : tensor<1x2x3x64xf32> +} diff --git a/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir new file mode 100644 index 000000000..a22dc2837 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir @@ -0,0 +1,37 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for NOperands trait + +// CHECK: error: 'ttir.abs' op expected 2 operands, but found 3 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 4 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 2 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op expected 4 operands, but found 5 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir new file mode 100644 index 000000000..dc3f09fba --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir @@ -0,0 +1,12 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for matmul operation +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: error: 'ttir.arange' op Output tensor shape must be 16 at dim 1 (since start=0, end=32, step=2), but got 32 + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 2: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir new file mode 100644 index 000000000..16c396c00 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +// XFAIL: * +// https://github.com/tenstorrent/tt-mlir/issues/1448 +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir index dfbf99008..6404ee6e9 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir @@ -1,9 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + func.func @gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x32x1024xf32> + %0 = tensor.empty() : tensor<1x32x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) { offset_dims = array, @@ -15,13 +15,13 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> - return %1 : tensor<1x32x1024xf32> + } : (tensor<32000x1024xbf16>, tensor<1x32xi32>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> + return %1 : tensor<1x32x1024xbf16> } - func.func @gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { + func.func @gather_1(%operand: tensor<448x384xbf16>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x2x384xf32> + %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) <{ offset_dims = array, @@ -33,13 +33,13 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - }> : (tensor<448x384xf32>, tensor<1x2x1xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32> - return %1 : tensor<1x2x384xf32> + }> : (tensor<448x384xbf16>, tensor<1x2x1xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> + return %1 : tensor<1x2x384xbf16> } - func.func @gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> { + func.func @gather_2(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x2x384xf32> + %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) <{ offset_dims = array, @@ -51,7 +51,7 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - }> : (tensor<51864x384xf32>, tensor<1x2xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32> - return %1 : tensor<1x2x384xf32> + }> : (tensor<51864x384xbf16>, tensor<1x2xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> + return %1 : tensor<1x2x384xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir index 2a06bf92b..44ffea73e 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir @@ -110,3 +110,25 @@ module attributes {} { return %1 : tensor<1x2x384xf32> } } + +// Verify that the parsing fails for data type other than bfloat16. +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + %0 = tensor.empty() : tensor<1x32x1024xf32> + // CHECK: error: failed to legalize operation 'ttir.gather' that was explicitly marked illegal + %1 = "ttir.gather"(%operand, %start_indices, %0) { + offset_dims = array, + collapsed_slice_dims = array, + operand_batching_dims = array, + start_indices_batching_dims = array, + start_index_map = array, + index_vector_dim = 1 : si64, + slice_sizes = array, + indices_are_sorted = false, + operand_constraints = [#any_device, #any_device, #any_device] + } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> + return %1 : tensor<1x32x1024xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir new file mode 100644 index 000000000..0e248623d --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -0,0 +1,216 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias_broadcast(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<128xbf16>) -> tensor<128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128xbf16 + %0 = tensor.empty() : tensor<128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> + return %1 : tensor<128xbf16> + } + + func.func @linear_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64xbf16 + %0 = tensor.empty() : tensor<64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } + + func.func @linear_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64xbf16 + %0 = tensor.empty() : tensor<12x7x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + return %1 : tensor<12x7x64xbf16> + } + + func.func @linear_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128xbf16 + %0 = tensor.empty() : tensor<12x7x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + return %1 : tensor<12x7x128xbf16> + } + + func.func @linear_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + return %1 : tensor<12x7x64x64xbf16> + } + + func.func @linear_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128x128xbf16 + %0 = tensor.empty() : tensor<12x7x128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + return %1 : tensor<12x7x128x128xbf16> + } + + // linear nd - nd tests + func.func @linear_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<7x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<1x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x7x64x64xbf16 + %0 = tensor.empty() : tensor<7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } + + func.func @linear_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_bias(%arg0: tensor<14x7x32x32xbf16>, %arg1:tensor<14x1x32x64xbf16>, %bias: tensor<64xbf16>) -> tensor<14x7x32x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x7x32x64xbf16 + %0 = tensor.empty() : tensor<14x7x32x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<14x7x32x32xbf16 + // CHECK-SAME: tensor<14x1x32x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> + return %1 : tensor<14x7x32x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_matmul(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<4x3x128x32xbf16>, %bias: tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %0 = tensor.empty() : tensor<14x4x3x64x32xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<3x64x128xbf16 + // CHECK-SAME: tensor<4x3x128x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> + return %1 : tensor<14x4x3x64x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir new file mode 100644 index 000000000..56728eb52 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint + +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir new file mode 100644 index 000000000..5991efeab --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, shape = #ttnn.shape<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+x[0-9]+]]>}> : (!tt.device<#device>) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+]], {{.*}}> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + // CHECK: {{[0-9]+}} = "ttnn.scatter"(%4, %2, %5) <{operandSegmentSizes = array}> : (tensor<1x3x32x32xf32, {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>) -> tensor<[[TENSOR_SHAPE1]], {{.*}}> + return %2 : tensor<1x3x320x320xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE1]], {{.*}}> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir new file mode 100644 index 000000000..d911ec6fe --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir new file mode 100644 index 000000000..01aa0e91b --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + %output_shape = stablehlo.constant dense<[1, 1, 32, 128]> : tensor<4xi64> + // CHECK: ttnn.arange + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 3: i64} : (tensor<4xi64>) -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir new file mode 100644 index 000000000..d911ec6fe --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir new file mode 100644 index 000000000..a231432ab --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 3: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir similarity index 82% rename from test/ttmlir/Silicon/StableHLO/dot_general_op.mlir rename to test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir index 57a0bdcd8..179f112b4 100644 --- a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir @@ -6,8 +6,8 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s -module @jit_dot_general attributes {} { - func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { +module @jit_dot_general_2d attributes {} { + func.func public @test_dot_general_2d(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { // CHECK-LABEL: func.func public @test_dot_general // CHECK: ttnn.empty // CHECK: ttnn.matmul diff --git a/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir new file mode 100644 index 000000000..f23ece73f --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir @@ -0,0 +1,21 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_dot_general_4d attributes {} { + func.func public @test_dot_general_4d(%arg0 : tensor<1x128x16x32xf32>, %arg1 : tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> { + // CHECK-LABEL: func.func public @test_dot_general + // CHECK: ttnn.empty + // CHECK: ttnn.matmul + // CHECK-SAME: tensor<1x128x16x32xf32, + // CHECK-SAME: tensor<1x128x32x8xf32, + // CHECK-SAME: tensor<1x128x16x8xf32, + // CHECK-SAME: -> tensor<1x128x16x8xf32 + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<1x128x16x32xf32>, tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> + return %0 : tensor<1x128x16x8xf32> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/gather_op.mlir b/test/ttmlir/Silicon/StableHLO/gather_op.mlir new file mode 100644 index 000000000..9a4a90b1b --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/gather_op.mlir @@ -0,0 +1,45 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RU1N: FileCheck --input-file=%t.mlir %s + +module @jit_gather attributes {} { + func.func public @test_gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { + // CHECK-LABEL: func.func public @test_gather_0 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x32xi32, + // CHECK-SAME: tensor<1x32x1024xbf16 + // CHECK-SAME: tensor<32000x1024xbf16, + // CHECK-SAME: -> tensor<1x32x1024xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32000x1024xbf16>, tensor<1x32xi32>) -> tensor<1x32x1024xbf16> + return %0 : tensor<1x32x1024xbf16> + } + + func.func public @test_gather_1(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> { + // CHECK-LABEL: func.func public @test_gather_1 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x2xi32, + // CHECK-SAME: tensor<1x2x384xbf16 + // CHECK-SAME: tensor<51864x384xbf16, + // CHECK-SAME: -> tensor<1x2x384xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<51864x384xbf16>, tensor<1x2xi32>) -> tensor<1x2x384xbf16> + return %0 : tensor<1x2x384xbf16> + } + + func.func public @test_gather_2(%operand: tensor<32128x512xbf16>, %start_indices: tensor<1x15xi64>) -> tensor<1x15x512xbf16> { + // CHECK-LABEL: func.func public @test_gather_2 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x16xi32, + // CHECK-SAME: tensor<1x15x512xbf16 + // CHECK-SAME: tensor<32128x512xbf16, + // CHECK-SAME: -> tensor<1x15x512xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16> + return %0 : tensor<1x15x512xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTMetal/simple_max.mlir b/test/ttmlir/Silicon/TTMetal/simple_max.mlir new file mode 100644 index 000000000..92bdbe72c --- /dev/null +++ b/test/ttmlir/Silicon/TTMetal/simple_max.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm + +#any_device = #tt.operand_constraint + +func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir index 1674ae0d3..cdde621c2 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir @@ -1,8 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> -#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> +#layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, #layout2> { %0 = tensor.empty() : tensor<256x32xf32, #layout2> @@ -15,7 +15,7 @@ func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, # return %1 : tensor<256x32xf32, #layout2> } -#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> +#layout3 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, #layout3> { %0 = tensor.empty() : tensor<32x384xf32, #layout3> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -27,7 +27,7 @@ func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, # return %1 : tensor<32x384xf32, #layout3> } -#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> +#layout4 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> func.func @reduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32, #layout4> { %0 = tensor.empty() : tensor<32x32xf32, #layout4> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir index 64cf5f57a..d7d3cea1d 100644 --- a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir +++ b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir @@ -4,10 +4,10 @@ #l1_ = #tt.memory_space -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32, #tilized> @@ -25,10 +25,10 @@ func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64 } -#untilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> -#tilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> +#untilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> +#tilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> tensor<2x3x64x128xf32, #untilized4D_2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<2x3x64x128xf32, #tilized4D> @@ -48,10 +48,10 @@ func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> ten return %5 : tensor<2x3x64x128xf32, #untilized4D_2x2> } -#untilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> -#tilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> +#untilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> +#tilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x6 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize_reblock_big(%arg0: tensor<96x192xf32, #untilized_big>) -> tensor<96x192xf32, #untilized_big> { // move to tilized 1x1 // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/to_layout.mlir b/test/ttmlir/Silicon/TTMetal/to_layout.mlir index 015e65175..e5318c6c1 100644 --- a/test/ttmlir/Silicon/TTMetal/to_layout.mlir +++ b/test/ttmlir/Silicon/TTMetal/to_layout.mlir @@ -5,8 +5,8 @@ #l1_ = #tt.memory_space #dram = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1> { %0 = tensor.empty() : tensor<4x16xf32, #layout1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -14,8 +14,8 @@ func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1 return %1 : tensor<4x16xf32, #layout1> } -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized> { %0 = tensor.empty() : tensor<64x128xf32, #tilized> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -26,11 +26,11 @@ func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, # return %3 : tensor<64x128xf32, #untilized> } -#untilized_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> -#untilized_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> -#untilized2x2_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> -#untilized2x2_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> -#untilized1x4_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> +#untilized_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> +#untilized_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> +#untilized2x2_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> +#untilized2x2_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> +#untilized1x4_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> func.func @dram_to_l1(%arg0: tensor<16x64xf32, #untilized_dram>) -> tensor<16x64xf32, #untilized_l1> { %0 = tensor.empty() : tensor<16x64xf32, #untilized_l1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir new file mode 100644 index 000000000..f3affc69d --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// UNSUPPORTED: true +// https://github.com/tenstorrent/tt-mlir/issues/1448 +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 64: si64, step = 2: si64, arange_dimension = 2: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir new file mode 100644 index 000000000..196e75709 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 128: si64, step = 1: si64, arange_dimension = 3: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir new file mode 100644 index 000000000..3f304969c --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint + +func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} + +func.func @subtract(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir index ba995925d..0193ec36b 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir @@ -4,8 +4,8 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> - // CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_10:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_11:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir new file mode 100644 index 000000000..6da5d3910 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir b/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir deleted file mode 100644 index 1d88725d1..000000000 --- a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - -func.func public @broadcast() -> (tensor<32xf32>) { - %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> - %1 = tensor.empty() : tensor<32xf32> - %2 = "ttir.broadcast"(%0, %1) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1xf32>, tensor<32xf32>) -> tensor<32xf32> - %3 = tensor.empty() : tensor<32xf32> - %4 = "ttir.broadcast"(%2, %3) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - // CHECK-NOT: %[[C:.*]] = "ttir.broadcast"[[C:.*]] - return %4 : tensor<32xf32> -} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 976f2867d..b7912d4c1 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -306,3 +306,13 @@ func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> ten %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> return %1 : tensor<64x128xi32> } + +func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + return %2 : tensor<1x3x320x320xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir new file mode 100644 index 000000000..f53de38cf --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -0,0 +1,33 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/unittests/Optimizer/CMakeLists.txt b/test/unittests/Optimizer/CMakeLists.txt index 681d78ff0..4e6ee799a 100644 --- a/test/unittests/Optimizer/CMakeLists.txt +++ b/test/unittests/Optimizer/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(OptimizerTests TestShardSolver.cpp + TestOptimizerOverrides.cpp ) target_link_libraries(OptimizerTests diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp new file mode 100644 index 000000000..c75fde21f --- /dev/null +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -0,0 +1,433 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" + +using namespace mlir::tt::ttnn; + +class TestOptimizerOverrides : public ::testing::Test { + +public: + OptimizerOverridesHandler optimizerOverridesHandler; + + void SetUp() override {} + + llvm::StringMap createInputLayoutOverrides() { + + // struct InputLayoutOverrideParams { + // SmallVector operandIdxes; + // }; + + llvm::StringMap inputLayoutOverrides; + + // Create input layout overrides for 3 input overrides. + inputLayoutOverrides["input0"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input1"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input2"] = createInputLayoutOverrideParams(); + + return inputLayoutOverrides; + } + + InputLayoutOverrideParams createInputLayoutOverrideParams() { + + InputLayoutOverrideParams inputLayoutOverrideParams; + + // Create input layout override params for 2 operands. + // Their operand indexes are 0 and 1, respectively. + inputLayoutOverrideParams.operandIdxes.push_back(0); + inputLayoutOverrideParams.operandIdxes.push_back(1); + + return inputLayoutOverrideParams; + } + + llvm::StringMap createOutputLayoutOverrides() { + + llvm::StringMap outputLayoutOverrides; + + // Create output layout overrides for 3 output overrides. + outputLayoutOverrides["output0"] = createOutputLayoutOverrideParams_0(); + outputLayoutOverrides["output1"] = createOutputLayoutOverrideParams_1(); + outputLayoutOverrides["output2"] = createOutputLayoutOverrideParams_2(); + + return outputLayoutOverrides; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_0() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 0 has + // - grid size 2x2, + // - buffer type dram + // - tensor memory layout interleaved + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.bufferType = BufferType::DRAM; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::Interleaved; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_1() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 1 has + // - grid size 8x4, + // - buffer type l1 + // - tensor memory layout block_sharded + // - memory layout row_major + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(8); + outputLayoutOverrideParams.grid.push_back(4); + outputLayoutOverrideParams.bufferType = BufferType::L1; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::BlockSharded; + outputLayoutOverrideParams.memoryLayout = Layout::RowMajor; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_2() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 2 has + // - grid size 3x6, + // - buffer type system + // - tensor memory layout height_sharded + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(3); + outputLayoutOverrideParams.grid.push_back(6); + outputLayoutOverrideParams.bufferType = BufferType::SystemMemory; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::HeightSharded; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + bool + compareInputLayoutOverrides(llvm::StringMap in1, + llvm::StringMap in2) { + // Check if the sizes of the two input layout overrides are the same. + if (in1.size() != in2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = in1.begin(); it1 != in1.end(); it1++) { + // Check if the two input layout overrides have the same keys. + llvm::StringMap::iterator it2 = + in2.find(it1->getKey()); + if (it2 == in2.end()) { + return false; + } + // Check if the two input layout overrides have the same values. + // The structure InputLayoutOverrideParams has overloaded operators for == + // and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + bool compareOutputLayoutOverrides( + llvm::StringMap out1, + llvm::StringMap out2) { + // Check if the sizes of the two output layout overrides are the same. + if (out1.size() != out2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = out1.begin(); it1 != out1.end(); it1++) { + // Check if the two output layout overrides have the same keys. + llvm::StringMap::iterator it2 = + out2.find(it1->getKey()); + if (it2 == out2.end()) { + return false; + } + // Check if the two output layout overrides have the same values. + // The structure OutputLayoutOverrideParams has overloaded operators for + // == and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + void TearDown() override {} +}; + +// Test the setEnableOptimizer method +TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { + + optimizerOverridesHandler.setEnableOptimizer(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); + + optimizerOverridesHandler.setEnableOptimizer(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer()); +} + +// Test the setMemoryConfig method +TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { + + optimizerOverridesHandler.setMemoryReconfig(true); + ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); + + optimizerOverridesHandler.setMemoryReconfig(false); + ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig()); +} + +// Test the setMemoryLayoutAnalysis method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); +} + +// Test the setEnableMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); +} + +// Test the setMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); +} + +// Test the setInputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) { + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.setInputLayoutOverrides(inputLayoutOverrides); + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the setOutputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) { + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.setOutputLayoutOverrides(outputLayoutOverrides); + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.addInputLayoutOverride( + "input0", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input1", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input2", createInputLayoutOverrideParams()); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + llvm::SmallVector operandIdxes1 = {0, 1}; + llvm::SmallVector operandIdxes2 = {0, 1}; + llvm::SmallVector operandIdxes3 = {0, 1}; + + optimizerOverridesHandler.addInputLayoutOverride("input0", operandIdxes1); + optimizerOverridesHandler.addInputLayoutOverride("input1", operandIdxes2); + optimizerOverridesHandler.addInputLayoutOverride("input2", operandIdxes3); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", createOutputLayoutOverrideParams_0()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", createOutputLayoutOverrideParams_1()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", createOutputLayoutOverrideParams_2()); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + llvm::SmallVector grid1 = {2, 2}; + llvm::SmallVector grid2 = {8, 4}; + llvm::SmallVector grid3 = {3, 6}; + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", grid1, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::Tile, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", grid2, BufferType::L1, TensorMemoryLayout::BlockSharded, + Layout::RowMajor, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", grid3, BufferType::SystemMemory, + TensorMemoryLayout::HeightSharded, Layout::Tile, + mlir::tt::DataType::Float16); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the setSystemDescPath method +TEST_F(TestOptimizerOverrides, TestSetSystemDescPath) { + + optimizerOverridesHandler.setSystemDescPath("system_desc_path"); + ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path"); +} + +// Test the setMaxLegalLayouts method +TEST_F(TestOptimizerOverrides, TestSetMaxLegalLayouts) { + + optimizerOverridesHandler.setMaxLegalLayouts(10); + ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10); +} + +// Test the setMeshShape method +TEST_F(TestOptimizerOverrides, TestSetMeshShape) { + + std::vector meshShape; + meshShape.push_back(1); + meshShape.push_back(2); + + optimizerOverridesHandler.setMeshShape(meshShape); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[0], meshShape[0]); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[1], meshShape[1]); +} + +// Test the toString method +TEST_F(TestOptimizerOverrides, TestToString) { + + std::string options; + options += + "enable-optimizer=true "; // The optimizer pass is enabled by default. + options += "memreconfig-enabled=true "; + options += "memory-layout-analysis-enabled=true "; + options += "insert-memreconfig=add_0_1_2=0 "; + options += + "override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"; + + llvm::SmallVector operandIdxes = {0}; + llvm::SmallVector grid = {1, 1}; + + optimizerOverridesHandler.setEnableOptimizer(true); + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + optimizerOverridesHandler.setMemoryReconfig(true); + optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", operandIdxes); + optimizerOverridesHandler.addOutputLayoutOverride( + "add_1_2", grid, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::RowMajor, mlir::tt::DataType::Float32); + + ASSERT_EQ(optimizerOverridesHandler.toString(), options); +} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index b1c679a50..bf28aebc9 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -1,6 +1,6 @@ include(ExternalProject) -set(TT_METAL_VERSION "f16cadfabebd7654baef73e4ac2c3240b12b0d1d") +set(TT_METAL_VERSION "ab3dc0c4f5c3ce9722261c878970bfa92a212fc9") if ("$ENV{ARCH_NAME}" STREQUAL "grayskull") set(ARCH_NAME "grayskull") @@ -22,12 +22,13 @@ set(TTMETAL_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd - ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/device + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/device/api ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_NAME} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_EXTRA_DIR} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/src/firmware/riscv/${ARCH_NAME} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_eager + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/nanomsg/28cc32d5bdb6a858fe53b3ccf7e923957e53eada/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/fmt/73b5ec45edbd92babfd91c3777a9e1ab9cac8238/include @@ -39,6 +40,7 @@ set(TTMETAL_INCLUDE_DIRS set(TTMETAL_LIBRARY_DIR ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/lib) set(TTNN_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/_ttnn.so) set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libtt_metal.so) +set(DEVICE_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libdevice.so) if (TT_RUNTIME_ENABLE_PERF_TRACE) set(TRACY_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libtracy.so) else() @@ -48,6 +50,7 @@ endif() set(TTMETAL_LIBRARY_DIR ${TTMETAL_LIBRARY_DIR} PARENT_SCOPE) set(TTNN_LIBRARY_PATH ${TTNN_LIBRARY_PATH} PARENT_SCOPE) set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_PATH} PARENT_SCOPE) +set(DEVICE_LIBRARY_PATH ${DEVICE_LIBRARY_PATH} PARENT_SCOPE) set(TRACY_LIBRARY_PATH ${TRACY_LIBRARY_PATH} PARENT_SCOPE) ExternalProject_Add( @@ -65,13 +68,13 @@ ExternalProject_Add( GIT_REPOSITORY https://github.com/tenstorrent/tt-metal.git GIT_TAG ${TT_METAL_VERSION} GIT_PROGRESS ON - BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${TRACY_LIBRARY_PATH} + BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${TRACY_LIBRARY_PATH} ${DEVICE_LIBRARY_PATH} ) set_target_properties(tt-metal PROPERTIES EXCLUDE_FROM_ALL TRUE) -list(APPEND library_names TTNN_LIBRARY TTMETAL_LIBRARY) -list(APPEND library_paths ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH}) +list(APPEND library_names TTNN_LIBRARY TTMETAL_LIBRARY DEVICE_LIBRARY) +list(APPEND library_paths ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${DEVICE_LIBRARY_PATH}) if (TT_RUNTIME_ENABLE_PERF_TRACE) list(APPEND library_names TRACY_LIBRARY) diff --git a/tools/explorer/CMakeLists.txt b/tools/explorer/CMakeLists.txt index 7ad0791b8..44613b267 100644 --- a/tools/explorer/CMakeLists.txt +++ b/tools/explorer/CMakeLists.txt @@ -17,7 +17,7 @@ ExternalProject_Add( add_custom_target(explorer COMMENT "Building tt-explorer... ${TTMLIR_BIN_DIR}" - COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter + COMMAND pip install $<$:-e> ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer/src/model-explorer/src/server/package DEPENDS TTMLIRPythonModules model-explorer ttrt diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index 2bb3ece81..d0c49b7af 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Dict import model_explorer -from . import ttir, runner, utils +from . import runner, utils, mlir import dataclasses import enum @@ -46,7 +46,7 @@ def convert( module = utils.parse_mlir_file(model_path) # Convert TTIR to Model Explorer Graphs and Display/Return - graph = ttir.ttir_to_graph(module) + graph = mlir.build_graph(module) return {"graphs": [graph]} def execute( diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py new file mode 100644 index 000000000..b9ae471ca --- /dev/null +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -0,0 +1,571 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# Utility library for parsing MLIR + +from collections import defaultdict +from model_explorer import graph_builder + +from ttmlir.dialects import tt, ttnn, ttir +from ttmlir import ir + + +def get_loc_str(loc): + try: + res = str(loc).split('"')[1] + except: + res = "unknown" + return res + + +class AttrHandler: + """ + A class that handles parsing and registering handlers for MLIR attribute types. + """ + + ATTR_HANDLERS = {} + + @staticmethod + def default_parser(attr): + return [graph_builder.KeyValue(key=attr.name, value=str(attr.attr))] + + @staticmethod + def parse_attr(attr): + if attr.name in AttrHandler.ATTR_HANDLERS: + return AttrHandler.ATTR_HANDLERS[attr.name](attr.attr) + else: + # Unknown Attr Type, return default parser + return AttrHandler.default_parser(attr) + + @staticmethod + def register_handler(attr_name): + """ + Decorator function to register a handler for a specific attribute name. + + Usage: + + @AttrHandler.register_handler("attr_name") + def parse_attr_name(attr: ir.Attribute) -> List[graph_builder.KeyValue]: + pass + + registers a handler for any NamedAttribute present in the MLIR module with the name "attr_name". + + The handler itself is the function that is decorated with this decorator. It must follow the function signature of + `parse_attr_name` as shown above. + """ + + def decorator(handler): + AttrHandler.ATTR_HANDLERS[attr_name] = handler + return handler + + return decorator + + +@AttrHandler.register_handler("tt.device") +def parse_tt_device(attr): + device = tt.ir.DeviceAttr.maybe_downcast(attr) + result = [] + result.append( + graph_builder.KeyValue( + key="device_chip_ids", value=", ".join(map(str, device.chip_ids)) + ) + ) + result.append( + graph_builder.KeyValue( + key="device_grid_shape", value=str(device.grid_attr.shape) + ) + ) + if device.mesh_shape: + result.append( + graph_builder.KeyValue( + key="device_mesh_shape", value=str(device.mesh_shape) + ) + ) + result.append(graph_builder.KeyValue(key="device_l1_map", value=str(device.l1_map))) + result.append( + graph_builder.KeyValue(key="device_dram_map", value=str(device.dram_map)) + ) + return result + + +@AttrHandler.register_handler("tt.system_desc") +def parse_tt_system_desc(attr): + system_desc = tt.ir.SystemDescAttr.maybe_downcast(attr) + result = [] + for i, chip_desc, chip_coord, chip_capability in zip( + system_desc.chip_desc_indices, + system_desc.chip_descs, + system_desc.chip_coords, + system_desc.chip_capabilities, + ): + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-arch", value=str(tt.Arch(chip_desc.arch.arch_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-capability", + value=str(tt.ChipCapability(chip_capability.capability_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-coord", + value="x".join( + map( + str, + (chip_coord.rack, chip_coord.shelf, chip_coord.y, chip_coord.x), + ) + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-channel-size", + value=str(chip_desc.dram_channel_size), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-unreserved-base", + value=str(chip_desc.dram_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-unreserved-end", + value=str(chip_desc.dram_unreserved_end), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-erisc-l1-unreserved-size", + value=str(chip_desc.erisc_l1_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-grid", value="x".join(map(str, chip_desc.grid)) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-l1-size", value=str(chip_desc.l1_size) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-l1-unreserved-base", + value=str(chip_desc.l1_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-noc-dram-address-align-bytes", + value=str(chip_desc.noc_dram_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-noc-l1-address-align-bytes", + value=str(chip_desc.noc_l1_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-num-cbs", value=str(chip_desc.num_cbs) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-num-dram-channels", + value=str(chip_desc.num_dram_channels), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-pcie-address-align-bytes", + value=str(chip_desc.pcie_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-usable-dram-channel-size", + value=str(chip_desc.usable_dram_channel_size), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-usable-l1-size", value=str(chip_desc.usable_l1_size) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-supported-data-types", + value=", ".join( + [ + str(tt.DataType(dt.data_type_as_int)) + for dt in chip_desc.supported_data_types + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-supported-tile-sizes", + value=", ".join( + [ + "x".join(map(str, (tsize.y, tsize.x))) + for tsize in chip_desc.supported_tile_sizes + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.dram + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-eth-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.eth + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-eth-inactive-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.eth_inactive + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-worker-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.worker + ] + ), + ) + ) + return result + + +@AttrHandler.register_handler("mesh_shape") +def parse_mesh_shape(attr): + mesh_shape = ttnn.ir.MeshShapeAttr.maybe_downcast(attr) + return [ + graph_builder.KeyValue( + key="mesh_shape", value="x".join(map(str, (mesh_shape.y, mesh_shape.x))) + ) + ] + + +@AttrHandler.register_handler("layout") +def parse_layout(attr): + # This is for parsing TTNN Layouts (Enum) + layout = ttnn.ir.LayoutAttr.maybe_downcast(attr) + return [graph_builder.KeyValue(key="layout", value=str(ttnn.Layout(layout.value)))] + + +@AttrHandler.register_handler("memory_config") +def parse_memory_config(attr): + memory_config = ttnn.ir.MemoryConfigAttr.maybe_downcast(attr) + result = [] + result.append( + graph_builder.KeyValue( + key="buffer-type", + value=str(ttnn.BufferType(memory_config.buffer_type.value)), + ) + ) + result.append( + graph_builder.KeyValue( + key="shard-shape", + value="x".join(map(str, memory_config.shard_spec.shard_shape.shape)), + ) + ) + result.append( + graph_builder.KeyValue( + key="tensor-memory-layout", + value=str( + ttnn.TensorMemoryLayout(memory_config.tensor_memory_layout.value) + ), + ) + ) + return result + + +@AttrHandler.register_handler("force") +def parse_force(attr): + return [graph_builder.KeyValue(key="force", value=str(attr.value))] + + +@AttrHandler.register_handler("dtype") +def parse_dtype(attr): + dtype = tt.ir.DataTypeAttr.maybe_downcast(attr) + return [ + graph_builder.KeyValue( + key="dtype", value=str(tt.DataType(dtype.data_type_as_int)) + ) + ] + + +@AttrHandler.register_handler("shape") +def parse_shape(attr): + shape = ttnn.ir.ShapeAttr.maybe_downcast(attr) + if not shape: + return [graph_builder.KeyValue(key="shape", value=str(attr))] + return [graph_builder.KeyValue(key="shape", value="x".join(map(str, shape.shape)))] + + +@AttrHandler.register_handler("operandSegmentSizes") +def parse_operandSegmentSizes(attr): + return [graph_builder.KeyValue(key="operandSegmentSizes", value=str(list(attr)))] + + +@AttrHandler.register_handler("dimension") +def parse_dimension(attr): + return [graph_builder.KeyValue(key="dimension", value=str(attr.value))] + + +@AttrHandler.register_handler("tt.layout") +def parse_tt_layout(attr): + layout = tt.ir.MetalLayoutAttr.maybe_downcast(attr) + result = [] + result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) + result.append( + graph_builder.KeyValue( + key="memory_space", value=str(tt.MemorySpace(layout.memory_space_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key="memory_layout", + value=str(tt.TensorMemoryLayout(layout.memory_layout_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ) + ) + result.append( + graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) + ) + result.append( + graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + ) + tile_type = tt.ir.TileType.maybe_downcast(layout.memref.element_type) + if tile_type is not None: + result.append( + graph_builder.KeyValue( + key="tile_datatype", value=str(tt.DataType(tile_type.data_type_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key="tile_shape", value="x".join(map(str, tile_type.shape)) + ) + ) + return result + + +@AttrHandler.register_handler("ttnn_layout") +def parse_ttnn_ttnn_layout(attr): + layout = ttnn.ir.TTNNLayoutAttr.maybe_downcast(attr) + result = [] + result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) + result.append( + graph_builder.KeyValue( + key="memory_layout", + value=str(ttnn.TensorMemoryLayout(layout.memory_layout_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ) + ) + result.append( + graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) + ) + result.append( + graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + ) + buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) + result.append( + graph_builder.KeyValue( + key="memref_memory_space", value=str(ttnn.BufferType(buffer_attr.value)) + ) + ) + return result + + +class OpHandler: + def __init__(self, op): + self.op = op + + def get_id(self, names: defaultdict): + name = get_loc_str(self.op.location) + name_num = names[name] + id = name + "__" + str(name_num) + names[name] += 1 + return id + + def get_namespace(self, parent_op=None): + op = self.op if not parent_op else parent_op + name = get_loc_str(op.location) + if op.parent and op.parent.name != "builtin.module": + return self.get_namespace(op.parent) + "/" + name + return name + + def get_attributes(self): + # Parse Op Attributes themselves + result = [] + for attr in self.op.attributes: + result.extend(AttrHandler.parse_attr(attr)) + return result + + def make_graph_node(self, name_dict): + return graph_builder.GraphNode( + id=self.get_id(name_dict), + label=self.op.name, + namespace=self.get_namespace(), + attrs=self.get_attributes(), + ) + + def make_constant_node(self, name_dict, constant_name): + return graph_builder.GraphNode( + id=self.get_id(name_dict), + label=constant_name, + namespace=self.get_namespace(), + ) + + +EMPTY_OPS = [ + "ttnn.empty", + "tensor.empty", +] + +FILTERED_OPS = [ + "ttnn.deallocate", + "ttnn.get_device", +] + + +def build_graph(module): + name_dict = defaultdict(int) + output_connections = defaultdict(int) + graph = graph_builder.Graph(id="tt-graph") + + op_to_graph_node = {} + + module_op = OpHandler(module.operation) + graph.nodes.append(module_op.make_graph_node(name_dict)) + + for op in module.body.operations: + append_later = [] + for region in op.regions: + for block in region.blocks: + for op in block.operations: + # Create all the nodes and constants in the first pass. + operation = OpHandler(op) + graph_node = operation.make_graph_node(name_dict) + + if op.name in EMPTY_OPS: + append_later.append(graph_node) + elif op.name not in FILTERED_OPS: + graph.nodes.append(graph_node) + + op_to_graph_node[op] = graph_node + + for operand in op.operands: + if isinstance(operand, ir.Value): + # This is a constant and we need to create a node for it. + operand_node = operation.make_constant_node( + name_dict, operand.get_name() + ) + graph.nodes.append(operand_node) + op_to_graph_node[operand] = operand_node + + # This puts the node at the far right when viewing which is a bit more consistant with it being the last operand. + for node in append_later: + graph.nodes.append(node) + + for op in block.operations: + # Create all the edges in the second pass. + for operand_index, operand in enumerate(op.operands): + if operand.owner == block: + source_node = op_to_graph_node[operand] + else: + source_node = op_to_graph_node[operand.owner] + + target_node = op_to_graph_node[op] + + target_node.incomingEdges.append( + graph_builder.IncomingEdge( + sourceNodeId=source_node.id, + sourceNodeOutputId=output_connections[source_node.id], + targetNodeInputId=operand_index, + ) + ) + + output_attrs = [] + if isinstance(operand.type, ir.RankedTensorType): + output_attrs = [ + graph_builder.KeyValue( + key="shape", value=str(operand.type.shape) + ), + graph_builder.KeyValue( + key="dtype", value=str(operand.type.element_type) + ), + graph_builder.KeyValue( + key="rank", value=str(operand.type.rank) + ), + ] + if hasattr(operand.type, "encoding") and operand.type.encoding: + if "ttnn_layout" in str(operand.type.encoding): + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("ttnn_layout") + ) + ) + else: + # Parse as a standard layout + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("tt.layout") + ) + ) + source_node.outputsMetadata.append( + graph_builder.MetadataItem( + id=str(output_connections[source_node.id]), + attrs=[ + graph_builder.KeyValue( + key="__tensor_tag", value=target_node.label + ), + ] + + output_attrs, + ) + ) + output_connections[source_node.id] += 1 + + return graph diff --git a/tools/explorer/tt_adapter/src/tt_adapter/ttir.py b/tools/explorer/tt_adapter/src/tt_adapter/ttir.py deleted file mode 100644 index 76cd470b0..000000000 --- a/tools/explorer/tt_adapter/src/tt_adapter/ttir.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 -# Library to manipulate TTIR Modules - -from model_explorer import graph_builder -from ttmlir.dialects import tt, ttir, ttkernel -from collections import defaultdict - - -def get_loc_str(loc): - # TODO(odjuricic) Need to expose this in python bindings, if possible. - try: - res = str(loc).split('"')[1] - except: - res = "unknown" - return res - - -def create_id(op, name_dict): - name = get_loc_str(op.location) - name_num = name_dict[name] - id = name + "__" + str(name_num) - name_dict[name] += 1 - return id - - -def get_attrs(op): - result = [] - for attr in op.attributes: - result.append(graph_builder.KeyValue(key=attr.name, value=str(attr.attr))) - return result - - -def create_namespace(op): - name = get_loc_str(op.location) - if op.parent and op.parent.name != "builtin.module": - return create_namespace(op.parent) + "/" + name - return name - - -def get_layout_attrs(tensor): - attrs = [ - graph_builder.KeyValue(key="shape", value=str(tensor.type.shape)), - graph_builder.KeyValue( - key="element_type", - value=str(tensor.type.element_type), - ), - graph_builder.KeyValue(key="rank", value=str(tensor.type.rank)), - ] - - if hasattr(tensor.type, "encoding") and tensor.type.encoding: - layout = tt.ir.LayoutAttr.getLayout(tensor.type) - attrs.extend( - [ - graph_builder.KeyValue( - key="Memory Space", - value=str(tt.MemorySpace(layout.memory_space_as_int)), - ), - graph_builder.KeyValue( - key="Memory Layout", - value=str(tt.TensorMemoryLayout(layout.memory_layout_as_int)), - ), - graph_builder.KeyValue( - key="Grid Shape", - value=str(list(layout.grid_attr.shape)), - ), - ] - ) - - return attrs - - -def ttir_to_graph(module): - # Can assume that to-layout pass has already been run on the module. - name_dict = defaultdict(int) - output_connections = defaultdict(int) - graph = graph_builder.Graph(id="ttir-graph") - - op_to_graph_node = dict() - - for op in module.body.operations: - append_later = [] - for region in op.regions: - for block in region.blocks: - for op in block.operations: - # Create all the nodes and constants in the first pass. - graph_node = graph_builder.GraphNode( - id=create_id(op, name_dict), - label=op.name, - namespace=create_namespace(op), - attrs=get_attrs(op), - ) - - if op.name == "tensor.empty": - append_later.append(graph_node) - else: - graph.nodes.append(graph_node) - - op_to_graph_node[op] = graph_node - - for operand in op.operands: - if operand.owner == block and operand not in op_to_graph_node: - # This is a constant and we need to create a node for it. - operand_node = graph_builder.GraphNode( - id=create_id(op, name_dict), - label=operand.get_name(), - namespace=create_namespace(op), - ) - graph.nodes.append(operand_node) - op_to_graph_node[operand] = operand_node - - # This puts the node at the far right when viewing which is a bit more consistant with it being the last operand. - for node in append_later: - graph.nodes.append(node) - - for op in block.operations: - # Create all the edges in the second pass. - for operand_index, operand in enumerate(op.operands): - if operand.owner == block: - source_node = op_to_graph_node[operand] - else: - source_node = op_to_graph_node[operand.owner] - - target_node = op_to_graph_node[op] - - target_node.incomingEdges.append( - graph_builder.IncomingEdge( - sourceNodeId=source_node.id, - sourceNodeOutputId=output_connections[source_node.id], - targetNodeInputId=operand_index, - ) - ) - - output_attrs = get_layout_attrs(operand) - source_node.outputsMetadata.append( - graph_builder.MetadataItem( - id=str(output_connections[source_node.id]), - attrs=[ - graph_builder.KeyValue( - key="__tensor_tag", value=target_node.label - ), - ] - + output_attrs, - ) - ) - output_connections[source_node.id] += 1 - - return graph diff --git a/tools/explorer/tt_adapter/src/tt_adapter/utils.py b/tools/explorer/tt_adapter/src/tt_adapter/utils.py index fe68d89ac..bca7e640b 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/utils.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/utils.py @@ -6,8 +6,8 @@ def parse_mlir_file(model_path): with ttmlir.ir.Context() as ctx, open(model_path, "r") as model_file: - ttmlir.dialects.ttkernel.register_dialect(ctx) ttmlir.dialects.ttir.register_dialect(ctx) ttmlir.dialects.tt.register_dialect(ctx) - module = ttmlir.ir.Module.parse("".join(model_file.readlines()), ctx) + ttmlir.dialects.ttnn.register_dialect(ctx) + module = ttmlir.ir.Module.parse(model_file.read(), ctx) return module diff --git a/tools/ttnn-standalone/README.md b/tools/ttnn-standalone/README.md index 816cfe1cf..619e52d1c 100644 --- a/tools/ttnn-standalone/README.md +++ b/tools/ttnn-standalone/README.md @@ -14,7 +14,7 @@ Third party ML models (PyTorch, Jax, ONNX, ...) can be compiled to a set of TTNN ```bash # Compile a model to C++ code -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --convert-ttir-to-ttnn --ttnn-decompose-layouts --ttnn-deallocate --convert-ttnn-to-emitc test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate --mlir-to-cpp -allow-unregistered-dialect +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate --mlir-to-cpp # Copy paste the generated function into `ttnn-standalone.cpp`. diff --git a/tools/ttnn-standalone/ttnn-standalone.cpp b/tools/ttnn-standalone/ttnn-standalone.cpp index dff9afff4..0dee60f13 100644 --- a/tools/ttnn-standalone/ttnn-standalone.cpp +++ b/tools/ttnn-standalone/ttnn-standalone.cpp @@ -5,11 +5,9 @@ #include "ttnn-precompiled.hpp" // To generate forward function, run: -// ./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-implicit-device -// --ttir-layout --convert-ttir-to-ttnn --ttnn-decompose-layouts -// --ttnn-deallocate --convert-ttnn-to-emitc +// ./build/bin/ttmlir-opt --ttir-to-emitc-pipeline // test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate -// --mlir-to-cpp -allow-unregistered-dialect +// --mlir-to-cpp ttnn::Tensor forward(ttnn::Tensor v1, ttnn::Tensor v2) { ttnn::Device *v3 = ttnn::DeviceGetter::getInstance();