diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index b0e7f3c6c8af..76108aec29ca 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -11,26 +11,29 @@ inputs: runs: using: "composite" + steps: - name: Set up Python uses: actions/setup-python@v2 with: python-version: 3.9 + - name: Install MLIR Python depends run: | python -m pip install -r $GITHUB_WORKSPACE/externals/llvm-project/mlir/python/requirements.txt shell: bash + - name: Install PyTorch nightly depends run: | python -m pip install -r requirements.txt shell: bash + - name: Install Ninja uses: llvm/actions/install-ninja@55d844821959226fab4911f96f37071c1d4c3268 - - name: Get Submodule Hash - id: get-submodule-hash - run: echo "::set-output name=hash::$(md5sum $(git submodule status))" - shell: bash + - name: Ccache for C++ compilation - uses: hendrikmuhs/ccache-action@4687d037e4d7cf725512d9b819137a3af34d39b3 + uses: hendrikmuhs/ccache-action@v1.2 with: - key: ${{ runner.os }}-clangreleaseasserts-${{ steps.get-submodule-hash.outputs.hash }}${{ inputs.cache-suffix }} + key: ${{ runner.os }}-torch_mlir_build_assets-${{ inputs.cache-suffix }} + max-size: 2G + verbose: 2 diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 9b578b06bfc7..dde3d22d3424 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -2,26 +2,68 @@ name: Bazel Build and Test on: push: - branches: - - main + branches: [ main ] + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: - build: - name: Build and Test (Release Asserts) - runs-on: ubuntu-20.04 + ubuntu-build: + name: ubuntu-x86_64 + runs-on: ubuntu-22.04 + steps: - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Get torch-mlir - uses: actions/checkout@v2 + - name: Checkout torch-mlir + uses: actions/checkout@v3 with: submodules: 'true' - - name: Build with bazel + + - name: Setup cache for bazel + uses: actions/cache@v3 + with: + path: ~/.cache/bazel + key: ubuntu_x86_64_torch_mlir_bazel_build_cache + + # Change bazel cache directory to root ownership + # to allow writing to it from within the docker container. + # If no cache hits, this directory is not present + # so don't run chown (will error otherwise). + - name: Set bazel cache permissions run: | - cd $GITHUB_WORKSPACE/utils/bazel - bazel build @torch-mlir//... + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R root:root "${HOME}/.cache/bazel" + fi + + - name: Build docker image + run: | + docker build -f utils/bazel/docker/Dockerfile \ + -t torch-mlir:ci \ + . + + - name: Bazel build torch-mlir + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + ./utils/bazel/docker/run_bazel_build.sh + + # Switch back bazel cache directory to user ownership + # to allow GHA post-cache step to save cache without + # permissions issue. + - name: Switch bazel cache permissions + run: | + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R "$USER":"$USER" "${HOME}/.cache/bazel" + fi + - name: Send mail if: failure() uses: dawidd6/action-send-mail@v3 diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 6440d370c3a3..a5e97b65ba89 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -1,102 +1,201 @@ name: Build and Test on: - push: - branches: - - main pull_request: + branches: [ main ] + push: + branches: [ main ] workflow_dispatch: +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + + +# Provisioned Jobs: +# ubuntu - x86_64 - llvm in-tree - pytorch binary - build+test # most used dev flow and fastest signal +# ubuntu - x86_64 - llvm out-of-tree - pytorch source - build+test # most elaborate build +# macos - arm64 - llvm in-tree - pytorch binary - build only # cross compile, can't test arm64 jobs: - build: - name: Build and Test (Release Asserts) - # Changes to the name of this job needs to be synced with releaseSnapshotPackage.yml. - runs-on: ubuntu-20.04 + build-test: + strategy: + fail-fast: true + matrix: + os-arch: [ubuntu-x86_64, macos-arm64] + llvm-build: [in-tree, out-of-tree] + torch-binary: [ON, OFF] + exclude: + # Exclude llvm in-tree and pytorch source + - llvm-build: in-tree + torch-binary: OFF + # Exclude llvm out-of-tree and pytorch binary + - llvm-build: out-of-tree + torch-binary: ON + # Exclude macos-arm64 and llvm out-of-tree altogether + - os-arch: macos-arm64 + llvm-build: out-of-tree + include: + # Specify OS versions + - os-arch: ubuntu-x86_64 + os: ubuntu-22.04 + - os-arch: macos-arm64 + os: macos-12 + runs-on: ${{ matrix.os }} + steps: - - name: Get torch-mlir + - name: Checkout torch-mlir uses: actions/checkout@v2 with: submodules: 'true' - - uses: ./.github/actions/setup-build + + - name: Setup ccache + uses: ./.github/actions/setup-build with: - cache-suffix: '' - - name: Build and Test torch-mlir (Assert) + cache-suffix: ${{ matrix.os-arch }}-${{ matrix.llvm-build }}-${{ matrix.torch-binary }} + + - name: Configure os-arch='ubuntu-x86_64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' + # Fastest build, most used dev flow + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} run: | - cd $GITHUB_WORKSPACE - mkdir build - cd build - cmake $GITHUB_WORKSPACE/externals/llvm-project/llvm -GNinja \ + cmake -GNinja -Bbuild \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_LINKER=lld \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ - -DPython3_EXECUTABLE=$(which python) \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ - -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \ + -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/externals/llvm-external-projects/torch-mlir-dialects" \ + -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_TARGETS_TO_BUILD=host - ninja check-torch-mlir-all - - name: RefBackend - TorchScript end-to-end tests - run: | - cd $GITHUB_WORKSPACE - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.torchscript.main --config=refbackend -v - - name: EagerMode - TorchScript end-to-end tests - run: | - cd $GITHUB_WORKSPACE - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.torchscript.main --config=eager_mode -v - - name: TOSA backend - TorchScript end-to-end tests - run: | - cd $GITHUB_WORKSPACE - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.torchscript.main --config=tosa -v + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE/externals/llvm-project/llvm - build-out-of-tree: - name: Build out-of-tree (Release Asserts) - runs-on: ubuntu-20.04 - steps: - - name: Get torch-mlir - uses: actions/checkout@v2 - with: - submodules: 'true' - - uses: ./.github/actions/setup-build - with: - cache-suffix: '-out-of-tree' - - name: Build LLVM (standalone) - # This build takes a while but is expected to almost always be cached. - # A cache invalidation occurs when the committed LLVM version is changed. + - name: Configure os-arch='ubuntu-x86_64' llvm-build='out-of-tree' torch-binary='${{ matrix.torch-binary }}' + # Most elaborate build, but cached + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'out-of-tree' }} run: | - cd $GITHUB_WORKSPACE - cmake -Bllvm-build -GNinja \ + cmake -GNinja -Bllvm-build \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_LINKER=lld \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ - -DPython3_EXECUTABLE=$(which python) \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_PROJECTS=mlir \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_TARGETS_TO_BUILD=host \ - externals/llvm-project/llvm - ninja -Cllvm-build + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE/externals/llvm-project/llvm + cmake --build llvm-build - - name: Build and test torch-mlir (out-of-tree) - run: | - cd $GITHUB_WORKSPACE cmake -GNinja -Bbuild \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_LINKER=lld \ + -DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm/" \ + -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir/" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE + + - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' + # cross compile, can't test arm64 + if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} + run: | + # TODO: Reenable LTC after build on macOS-arm64 is fixed (https://github.com/llvm/torch-mlir/issues/1253) + cmake -GNinja -Bbuild_arm64 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_LINKER=lld \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ - -DMLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir/" \ - -DLLVM_DIR="$(pwd)/llvm-build/lib/cmake/llvm/" \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ + -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/externals/llvm-external-projects/torch-mlir-dialects" \ + -DLLVM_TARGETS_TO_BUILD=AArch64 \ + -DLLVM_USE_HOST_TOOLS=ON \ + -DLLVM_ENABLE_ZSTD=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DPython3_EXECUTABLE=$(which python) \ - . - ninja -Cbuild check-torch-mlir-all + -DTORCH_MLIR_ENABLE_MHLO=OFF \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ + -DMACOSX_DEPLOYMENT_TARGET=12.0 \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE/externals/llvm-project/llvm + + - name: Build torch-mlir + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + cmake --build build + + - name: Build torch-mlir (cross-compile) + if: ${{ matrix.os-arch == 'macos-arm64' }} + run: | + cmake --build build_arm64 + + - name: Run torch-mlir unit tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + cmake --build build --target check-torch-mlir-all + + - name: Ensure generated files are up to date + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + ./build_tools/update_torch_ods.sh + ./build_tools/update_shape_lib.sh + if ! git diff --quiet; then + echo "#######################################################" + echo "Generated files are not up to date (see diff below)" + echo ">>> Please run ./build_tools/update_torch_ods.sh and ./build_tools/update_shape_lib.sh <<<" + echo "#######################################################" + git diff --color=always + exit 1 + fi + + - name: Run refbackend e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.main --config=refbackend -v - # Don't run python tests, as check-torch-mlir-all already checks - # what we want. + - name: Run eager_mode e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.main --config=eager_mode -v + + - name: Run mhlo e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.main --config=mhlo -v + + - name: Run tosa e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.main --config=tosa -v + + - name: Run lazy_tensor_core e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + echo "LTC tests disabled temporarily. https://github.com/llvm/torch-mlir/pull/1292" + # python -m e2e_testing.main --config=lazy_tensor_core -v diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 7eb2c9301ec7..eb493699d8b2 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -54,7 +54,7 @@ jobs: build_macos: name: MacOS Build - runs-on: macos-latest + runs-on: macos-12 steps: - name: Get torch-mlir uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index dc506413e504..330a871b0efb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.swp .cache/ .vscode +.ccache .env *.code-workspace .ipynb_checkpoints @@ -11,6 +12,7 @@ libtorch* /build/ __pycache__ +*.pyc .pytype @@ -22,3 +24,11 @@ __pycache__ # Bazel bazel-* + +# Autogenerated files +/python/torch_mlir/csrc/base_lazy_backend/generated + +#Docker builds +build_oot/ +docker_venv/ +llvm-build/ diff --git a/.gitmodules b/.gitmodules index 62a290ea6000..81c66a441907 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ -[submodule "external/llvm-project"] +[submodule "externals/llvm-project"] path = externals/llvm-project url = https://github.com/llvm/llvm-project.git +[submodule "externals/mlir-hlo"] + path = externals/mlir-hlo + url = https://github.com/tensorflow/mlir-hlo.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b0f48c87805..8f6d4d932d19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ endif() project(torch-mlir LANGUAGES CXX C) set(CMAKE_C_STANDARD 11) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) macro(torch_mlir_add_llvm_external_project name identifier location) message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}") @@ -36,12 +36,26 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() +option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) +if(TORCH_MLIR_ENABLE_MHLO) + add_definitions(-DTORCH_MLIR_ENABLE_MHLO) +endif() + +option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) + +if(TORCH_MLIR_ENABLE_LTC) + set(ENV{TORCH_MLIR_ENABLE_LTC} 1) +else() + set(ENV{TORCH_MLIR_ENABLE_LTC} 0) +endif() + torch_mlir_add_llvm_external_project( torch-mlir-dialects TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + message(STATUS "Torch-MLIR out-of-tree build.") # Out-of-tree build #------------------------------------------------------------------------------- @@ -82,10 +96,8 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects) else() + message(STATUS "Torch-MLIR in-tree build.") # In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir - # FIXME: This should really be inherited from the LLVM tree. In particular, - # it's going to change when cross-compiling. - set(MLIR_TABLEGEN_EXE mlir-tblgen) option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF) option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) @@ -97,6 +109,15 @@ else() set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() +if (TORCH_MLIR_ENABLE_MHLO) + set(MHLO_BUILD_EMBEDDED ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo + ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo + EXCLUDE_FROM_ALL) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) + include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include) +endif() + set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})") diff --git a/README.md b/README.md index 260636e0e4e4..cd6f43b5c4c2 100644 --- a/README.md +++ b/README.md @@ -26,15 +26,15 @@ We have few paths to lower down to the Torch MLIR Dialect. - TorchScript This is the most tested path down to Torch MLIR Dialect, and the PyTorch ecosystem is converging on using TorchScript IR as a lingua franca. - - LazyTensorCore (Based on the PyTorch [`lazy_tensor_staging` branch](https://github.com/pytorch/pytorch/tree/lazy_tensor_staging/lazy_tensor_core)) - This path provides the upcoming LTC path of capture. It is based of an unstable devel branch but is the closest way for you to adapt any existing `torch/xla` derivatives. - + - LazyTensorCore + Read more details [here](docs/ltc_backend.md). ## Project Communication - `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel - Github issues [here](https://github.com/llvm/torch-mlir/issues) - [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse - Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information. +- Weekly op office hours on Thursdays 8:30-9:30AM PST. See [here](https://discourse.llvm.org/t/announcing-torch-mlir-office-hours/63973/2) for more information. ## Install torch-mlir snapshot @@ -71,10 +71,9 @@ torch-mlir prediction [('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)] ``` -### LazyTensorCore +### Lazy Tensor Core -The LazyTensorCore integration is still in progress, and is being built on the -[`torch_mlir_ltc_backend` branch](https://github.com/llvm/torch-mlir/tree/torch_mlir_ltc_backend). +View examples [here](docs/ltc_examples.md). ### Eager Mode @@ -94,4 +93,4 @@ The project follows the conventions of typical MLIR-based projects: * `python` top level directory for Python code ## Developers -If you would like to develop and build torch-mlir from source please look at [Development Notes](development.md) +If you would like to develop and build torch-mlir from source please look at [Development Notes](docs/development.md) diff --git a/Torch-MLIR.png b/Torch-MLIR.png index 1312806c205e..4a800e4e6167 100644 Binary files a/Torch-MLIR.png and b/Torch-MLIR.png differ diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py new file mode 100644 index 000000000000..55ec80f6d8a7 --- /dev/null +++ b/build_tools/autogen_ltc_backend.py @@ -0,0 +1,528 @@ +import argparse +import hashlib +import importlib.util +import logging +import os +import re +import subprocess +import warnings +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from shutil import which +from textwrap import dedent, indent + +# PyTorch's LTC backend autogen script +import torchgen +import torchgen.dest.lazy_ir +import torchgen.gen_lazy_tensor +import yaml +from torchgen.api.lazy import LazyIrSchema, setValueT +from torchgen.api.types import BaseCppType +from torchgen.dest import GenLazyShapeInferenceDefinition +from torchgen.gen import get_grouped_native_functions, parse_native_yaml +from torchgen.gen_backend_stubs import parse_backend_yaml + +TORCH_DIR = Path(importlib.util.find_spec("torch").origin).resolve().parent.parent +TORCH_INCLUDE_DIR = TORCH_DIR.joinpath("torch", "include") +if not TORCH_INCLUDE_DIR.is_dir(): + TORCH_INCLUDE_DIR = TORCH_DIR +TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() +TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent + + +def reindent(text, prefix=""): + return indent(dedent(text), prefix) + + +@dataclass(frozen=True) +class GenMlirLazyIr(torchgen.dest.GenLazyIR): + def isOptionalCType(self, arg): + return str(type(arg)) == "" + + def lowering_function(self, schema: LazyIrSchema): + signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override" + + if schema.properties.LowerDeclOnly: + return f"{signature};" + elif not schema.properties.Lower: + return "" + + emplace_arguments = [] + for arg in schema.positional_args: + if arg.is_lazy_value: + if self.isOptionalCType(arg.lazy_type): + emplace_arguments.append( + f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" + ) + else: + emplace_arguments.append("loctx->GetOutputOp(operand(i++))") + else: + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + f"arguments.emplace_back({a});" for a in emplace_arguments + ) + emplace_kwarg_values = [ + f'"{t.name}", loctx->GetOutputOp(operand(i++))' + for t in schema.keyword_values + ] + emplace_kwarg_scalars = [ + f'"{t.name}", {t.name}' for t in schema.keyword_scalars + ] + emplace_kwarguments = "\n ".join( + f"kwarguments.emplace_back({a});" + for a in emplace_kwarg_values + emplace_kwarg_scalars + ) + + # Only create this variable if it's used to avoid Wunused-variable + operand_idx_counter = "size_t i = 0;" if "i++" in (emplace_arguments_str + emplace_kwarguments) else "" + + return reindent( + f""" + {signature} {{ + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + {operand_idx_counter} + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + + return {schema.aten_name}_out; + }} + """, + " ", + ) + + +class GenTorchMlirLTC: + def __init__(self, binary_dir): + self.script_path = Path(__file__).resolve() + self.config_path = ( + Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml") + ) + self.torch_ops_file = TORCH_MLIR_DIR.joinpath( + # fmt: off + "include", "torch-mlir", "Dialect", "Torch", "IR", "GeneratedTorchOps.td", + # fmt: on + ) + assert self.torch_ops_file.exists() + self.binary_dir = Path(binary_dir) + assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}" + self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml") + self.backend_path = TORCH_MLIR_DIR.joinpath( + "python", "torch_mlir", "csrc", "base_lazy_backend" + ) + assert self.backend_path.is_dir() + self.generated_path = self.binary_dir.joinpath( + "python", "torch_mlir", "csrc", "base_lazy_backend", "generated" + ) + self.generated_path.mkdir(parents=True, exist_ok=True) + + # Create symlink to match doc structure + generated_path = self.backend_path.joinpath("generated").resolve() + if not generated_path.exists(): + generated_path.symlink_to( + os.path.relpath(self.generated_path, generated_path.parent), + target_is_directory=True, + ) + + self.tensor_class = "torch::lazy::LazyTensor" + + # Set the lazy value class + setValueT(BaseCppType("torch::lazy", "Value")) + + def calculate_hash(self): + m = hashlib.sha256() + + # Add file contents to hash + for path in ( + self.script_path, + self.config_path, + self.torch_ops_file, + self.source_yaml, + self.backend_path.joinpath("shape_inference.cpp"), + TORCHGEN_DIR.joinpath("dest", "lazy_ir.py"), + TORCHGEN_DIR.joinpath("api", "lazy.py"), + TORCHGEN_DIR.joinpath("model.py"), + ): + if path.exists(): + m.update(path.read_bytes()) + + return m.hexdigest().strip() + + def generate_native_functions(self): + logging.info("Generating Native Functions Yaml") + + native_path = TORCHGEN_DIR.joinpath("packaged", "ATen", "native") + native_yaml_path = native_path.joinpath("native_functions.yaml") + tags_yaml_path = native_path.joinpath("tags.yaml") + + ts_native_yaml_path = TORCH_DIR.joinpath( + "aten", "src", "ATen", "native", "ts_native_functions.yaml" + ) + ts_native_yaml = None + if ts_native_yaml_path.exists(): + ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) + else: + logging.warning(f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}") + + + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) + self.native_functions = parsed_yaml.native_functions + self.backend_indices = parsed_yaml.backend_indices + self.grouped_native_functions = get_grouped_native_functions( + self.native_functions + ) + + def get_native_function_name(f): + func = f if hasattr(f, "func") else f.functional + return str(func.func.name) + + self.native_functions = { + get_native_function_name(f): f for f in self.native_functions + } + + def get_opnames(ops): + opnames = defaultdict(set) + for op in ops: + opname = op.split(".")[0] + opnames[opname].add(op) + return opnames + + aten_funcs = get_opnames( + map(get_native_function_name, self.grouped_native_functions) + ) + + with self.config_path.open() as f: + config = yaml.load(f, yaml.CLoader) + + # List of unsupported ops in LTC autogen because of some error + blacklist = set(config.get("blacklist", [])) + + # List of supported ops that we don't want to do the full codegen for + # primarily view ops + supported = set(config.get("supported", [])) + + # List of non-native ops to do IR codegen for + non_native = config.get("non_native", []) + + # use ripgrep if available as its much faster + if which("rg") is not None: + cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"] + else: + cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"] + + torch_ops = set( + op[6:] + for op in subprocess.check_output( + cmd + [str(self.torch_ops_file)], + encoding="utf-8", + ) + .strip() + .split(os.linesep) + ) + torch_opnames = get_opnames(torch_ops) + + # process ops list + ops = set() + composite_implicit = set() + + for op in torch_ops: + if op not in self.native_functions: + continue + + func = self.native_functions[op] + base = func.func.name.name.base + + if base in blacklist or op in blacklist: + continue + if base in supported or op in supported: + continue + # Blacklist new_/_like ops since they are non-differentiable. + if any(o.startswith("new_") or o.endswith("_like") for o in (base, op)): + continue + + if func.has_composite_implicit_autograd_kernel: + composite_implicit.add(op) + elif func.func.name.name.inplace: + for autogen in func.autogen: + if "functional" in autogen.overload_name: + ops.add(str(autogen)) + else: + ops.add(op) + + skipped = set(torch_ops) - ops - supported - composite_implicit + + # List of ops autogen even if not explicitly supported by Torch-MLIR explicitly + ops |= set(config.get("whitelist", [])) + + # Additional ops to support that are not supported by Torch-MLIR explicitly + supported |= set(config.get("additional_ops", [])) + + self.ops = sorted(ops) + + with self.source_yaml.open("w") as f: + source_yaml = { + "backend": "Lazy", + "cpp_namespace": "torch::lazy", + "full_codegen": self.ops, + "supported": sorted(supported), + "non_native": non_native, + } + yaml.dump(source_yaml, f, default_flow_style=False) + f.write( + dedent( + """ + + # Composite implicit ops (supported by Torch-MLIR but not differentiable) + {composite_implicit} + # Skipped ops (supported by Torch-MLIR but no equivalent native function) + {skipped} + """ + ).format( + composite_implicit=os.linesep.join( + f"# - {op}" for op in sorted(composite_implicit) + ), + skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)), + ) + ) + + if ts_native_yaml: + ts_full_codegen = set(ts_native_yaml["full_codegen"]) + ts_supported = set(ts_native_yaml["supported"]) + mlir_full_codegen = set(self.ops) + + if ts_full_codegen - mlir_full_codegen: + logging.debug( + "Full Codegen ops supported by the TorchScript backend " + "but not by the Torch-MLIR backend:\n {}".format( + "\n ".join(sorted(ts_full_codegen - mlir_full_codegen)) + ) + ) + + if mlir_full_codegen - ts_full_codegen: + logging.debug( + "Full Codegen ops supported by the Torch-MLIR backend " + "but not by the TorchScript backend:\n {}".format( + "\n ".join(sorted(mlir_full_codegen - ts_full_codegen)) + ) + ) + + if ts_supported - supported: + logging.debug( + "Ops supported by the TorchScript backend " + "but not by the Torch-MLIR backend:\n {}".format( + "\n ".join(sorted(ts_supported - supported)) + ) + ) + + if supported - ts_supported: + logging.debug( + "Ops supported by the Torch-MLIR backend " + "but not by the TorchScript backend:\n {}".format( + "\n ".join(sorted(supported - ts_supported)) + ) + ) + + def generate_shape_inference(self): + parsed_backend_yaml = parse_backend_yaml( + self.source_yaml, + self.grouped_native_functions, + self.backend_indices, + ) + backend_index = self.backend_indices[parsed_backend_yaml.backend_key] + + shape_gen = GenLazyShapeInferenceDefinition(backend_index, self.tensor_class) + + sig_re = re.compile( + r"std::vector\s+(?P\w+)\((?P[^\)]+)\)" + ) + global_signatures = {} + + def extract_signatures(text): + signatures = set() + for name, args in sig_re.findall(text): + signature = re.sub(r"\s+", "", f"{name}({args})") + global_signatures[signature] = (name, args) + signatures.add(signature) + return signatures + + shape_inference_decls = [] + for op in self.ops: + f = self.native_functions[op] + shape_sig = shape_gen(f) + shape_inference_decls.extend(shape_sig) + + self.generated_path.joinpath("shape_inference.h").write_text( + dedent( + """ + // This file contains autogenerated Lazy Shape Inference declarations + // for ops that dont have a corresponding structured kernel or shape definition + + #include + #include + #include + #include + #include + #include + #include + + namespace torch {{ + namespace lazy {{ + + {} + + }} // namespace lazy + }} // namespace torch + """ + ).format(os.linesep.join(sorted(shape_inference_decls))) + ) + + shape_inference_decls = extract_signatures( + self.generated_path.joinpath("shape_inference.h").read_text() + ) + assert len(shape_inference_decls) > 0 + upstream_shape_inference_decls = extract_signatures( + TORCH_INCLUDE_DIR.joinpath( + "torch", "csrc", "lazy", "core", "shape_inference.h" + ).read_text() + ) + assert len(upstream_shape_inference_decls) > 0 + shape_inference_defs = extract_signatures( + self.backend_path.joinpath("shape_inference.cpp").read_text() + ) + assert len(shape_inference_decls) > len(shape_inference_defs) + + missing_defs = ( + shape_inference_decls + - upstream_shape_inference_decls + - shape_inference_defs + ) + if missing_defs: + self.generated_path.joinpath("shape_inference.cpp").write_text( + dedent( + """ + // This file contains autogenerated Lazy Shape Inference placeholders + // for ops that dont have a corresponding structured kernel or shape definition + + #include "shape_inference.h" + #include "torch_mlir/csrc/base_lazy_backend/utils/exception.h" + namespace torch {{ + namespace lazy {{ + {} + }} // namespace lazy + }} // namespace torch + """ + ).format( + "".join( + dedent( + f""" + std::vector {name}({args}) {{ + UNIMPLEMENTED_FUNCTION_ERROR(); + }} + """ + ) + for name, args in map( + global_signatures.get, sorted(missing_defs) + ) + ) + ) + ) + + unnecessary_defs = shape_inference_defs - shape_inference_decls + if unnecessary_defs: + unnecessary_defs = "\n\t".join( + f"{name}({args})" + for name, args in map(global_signatures.get, unnecessary_defs) + ) + warnings.warn( + f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}" + ) + + def generate_backend(self): + logging.info("Running Lazy Tensor Autogen") + + # No fallback code allowed + def gen_fallback_code(*args, **kwargs): + return "" + + torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code + + torchgen.gen_lazy_tensor.run_gen_lazy_tensor( + backend_name="TorchMlir", + aten_path=str(TORCHGEN_DIR.joinpath("packaged", "ATen")), + source_yaml=str(self.source_yaml), + output_dir=str(self.generated_path), + dry_run=False, + impl_path=str(self.backend_path.joinpath("mlir_native_functions.cpp")), + node_base="torch::lazy::TorchMlirNode", + node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), + tensor_class=self.tensor_class, + tensor_class_hdr="torch/csrc/lazy/core/tensor.h", + shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), + lazy_ir_generator=GenMlirLazyIr, + ) + + def __call__(self): + self.generate_native_functions() + self.generate_shape_inference() + self.generate_backend() + + +def main(args): + generator = GenTorchMlirLTC(args.binary_dir) + + hash_file = generator.binary_dir.joinpath("generated_backend.hash") + + prev_hash = None + if hash_file.exists(): + prev_hash = hash_file.read_text().strip() + + new_hash = generator.calculate_hash() + + if args.force or new_hash != prev_hash: + generator() + hash_file.write_text(new_hash) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-b", + "--binary_dir", + type=str, + default=os.getenv( + "TORCH_MLIR_BINARY_DIR", + TORCH_MLIR_DIR.joinpath("build"), + ), + ) + parser.add_argument( + "-f", + "--force", + action="store_true", + ) + parser.add_argument( + "-d", + "--debug", + help="Print lots of debugging statements", + action="store_const", + dest="loglevel", + const=logging.DEBUG, + default=logging.WARNING, + ) + parser.add_argument( + "-v", + "--verbose", + help="Be verbose", + action="store_const", + dest="loglevel", + const=logging.INFO, + ) + args = parser.parse_args() + logging.basicConfig(level=args.loglevel) + main(args) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml new file mode 100644 index 000000000000..dde1f0f014c2 --- /dev/null +++ b/build_tools/autogen_ltc_backend.yaml @@ -0,0 +1,88 @@ +blacklist: +# List of unsupported ops in LTC autogen because of some error +- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here +- empty_like # Error: TODO add support for type BaseType(name=) +- index.Tensor # Error: TODO not sure if there are other valid types to handle here +- index_put # Error: TODO not sure if there are other valid types to handle here +- index_put_ # Error: TODO not sure if there are other valid types to handle here +- stack # Error: TODO not sure if there are other valid types to handle here + +# Additional ops which autogen is supported for but don't compile yet +- _convolution +- detach +- item +- size +- where +- copy_ + +# Disabled for consistency with TS backend +- new_empty +- rsub +- slice.Tensor # Disabled in favour of slice_copy.Tensor +- zeros + +# Disabled in favour of functionalized alternatives +- _reshape_alias +- expand +- permute +- select.int +- squeeze +- squeeze.dim +- t +- transpose.int +- unsqueeze +- view + +# whitelist: +# List of ops to autogen even if not supported by Torch-MLIR explicitly +#- split_copy.Tensor +#- split_with_sizes_copy +#- unbind_copy.int + +# List of supported ops that we don't want to do the full codegen for +supported: +# - bernoulli +# - bernoulli_ +- _to_copy +- clone +- empty.memory_format +- empty_strided +- fill_.Scalar +- _unsafe_view + +# ops required for functionalization +- lift +- lift_fresh +# Below are all operators that are "composite" in core, +# but require us to explicitly re-enable functionalization in order to use them. +# Why? These operators are all CompositeExplicitAutograd, which mean that they run +# after functionalization, +# but their implementations call view operators (which we need to functionalize away). +- block_diag +- new_empty_strided +- pixel_shuffle +- pixel_unshuffle +- select_backward +- slice_backward +- diagonal_backward +- _trilinear +- linalg_pinv.atol_rtol_tensor +- logsumexp.out + + +additional_ops: +# Additional ops to support that are not supported by Torch-MLIR explicitly +- _copy_from +- _copy_from_and_resize + +# List of non native ops that we only want to do IR node class generation for +non_native: + - func: scalar(Scalar value, ScalarType type) -> Tensor + opkind: at::prim::Constant + properties: + - ShapeCompute + - TreatScalarsAsConstants + - func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor + opkind: ltc_cast + properties: + - ShapeCompute diff --git a/build_tools/build_libtorch.sh b/build_tools/build_libtorch.sh index aa859eeaf795..bc3ec12bb4b1 100755 --- a/build_tools/build_libtorch.sh +++ b/build_tools/build_libtorch.sh @@ -5,11 +5,12 @@ set -xeu -o pipefail SRC_ROOT="$( cd "$(dirname "$0")" ; pwd -P)/.." PYTORCH_ROOT=${PYTORCH_ROOT:-$SRC_ROOT/externals/pytorch} PYTORCH_INSTALL_PATH=${PYTORCH_INSTALL_PATH:-$SRC_ROOT/libtorch} -PYTORCH_REPO="${PYTORCH_REPO:-pytorch/pytorch}" -PYTORCH_BRANCH="${PYTORCH_BRANCH:-master}" +TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}" +TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}" PT_C_COMPILER="${PT_C_COMPILER:-clang}" PT_CXX_COMPILER="${PT_CXX_COMPILER:-clang++}" -CMAKE_OSX_ARCHITECTURES="${CMAKE_OSX_ARCHITECTURES:-arm64;x86_64}" +CMAKE_OSX_ARCHITECTURES="${CMAKE_OSX_ARCHITECTURES:-x86_64}" +MACOSX_DEPLOYMENT_TARGET="${MACOSX_DEPLOYMENT_TARGET:-12.0}" WHEELHOUSE="${WHEELHOUSE:-$SRC_ROOT/build_tools/python_deploy/wheelhouse}" PYTHON_BIN="${TORCH_MLIR_PYTHON_VERSION:-python3}" PIP_BIN="${TORCH_MLIR_PIP_VERSION:-pip3}" @@ -23,10 +24,13 @@ NC='\033[0m' echo "SRC_ROOT=${SRC_ROOT}" echo "PYTORCH_ROOT=${PYTORCH_ROOT}" -echo "PYTORCH_REPO=${PYTORCH_REPO}" -echo "PYTORCH_BRANCH=${PYTORCH_BRANCH}" +echo "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" +echo "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" +echo "MACOSX_DEPLOYMENT_TARGET=${MACOSX_DEPLOYMENT_TARGET}" echo "CMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}" + export CMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES} +export MACOSX_DEPLOYMENT_TARGET=${MACOSX_DEPLOYMENT_TARGET} export CMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} export CMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} @@ -41,16 +45,12 @@ install_requirements() { checkout_pytorch() { if [[ ! -d "$PYTORCH_ROOT" ]]; then - git clone --depth 1 --single-branch --branch "${PYTORCH_BRANCH}" https://github.com/"$PYTORCH_REPO" "$PYTORCH_ROOT" + git clone --depth 1 --single-branch --branch "${TORCH_MLIR_SRC_PYTORCH_BRANCH}" https://github.com/"$TORCH_MLIR_SRC_PYTORCH_REPO" "$PYTORCH_ROOT" fi cd "$PYTORCH_ROOT" git reset --hard HEAD git clean -df - for dep in protobuf pocketfft cpuinfo FP16 psimd fmt sleef pybind11 onnx flatbuffers foxi; do - git submodule update --init --depth 1 -- third_party/$dep - done - # setup.py will try to re-fetch - sed -i.bak -E 's/^[[:space:]]+check_submodules()/#check_submodules()/g' setup.py + git submodule update --init --depth 1 --recursive } build_pytorch() { @@ -68,9 +68,14 @@ build_pytorch() { fi BUILD_SHARED_LIBS=ON \ + BUILD_CAFFE2_OPS=OFF \ + INTERN_BUILD_ATEN_OPS=OFF \ + ATEN_NO_TEST=OFF \ + USE_LITE_INTERPRETER_PROFILER=OFF \ BUILD_TEST=OFF \ GLIBCXX_USE_CXX11_ABI=1 \ CMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES} \ + MACOSX_DEPLOYMENT_TARGET=${MACOSX_DEPLOYMENT_TARGET} \ INTERN_BUILD_ATEN_OPS=OFF \ INTERN_DISABLE_ONNX=ON \ INTERN_USE_EIGEN_BLAS=ON \ @@ -78,11 +83,12 @@ build_pytorch() { ONNX_ML=OFF \ USE_BREAKPAD=OFF \ USE_CUDA=OFF \ + USE_ITT=OFF \ USE_DISTRIBUTED=OFF \ USE_EIGEN_FOR_BLAS=OFF \ - USE_FBGEMM=OFF \ + USE_FBGEMM=ON \ USE_GLOO=OFF \ - USE_KINETO=OFF \ + USE_KINETO=ON \ USE_MKL=OFF \ USE_MKLDNN=OFF \ USE_MPS=OFF \ @@ -90,9 +96,10 @@ build_pytorch() { USE_NNPACK=OFF \ USE_OBSERVERS=OFF \ USE_OPENMP=OFF \ - USE_PYTORCH_QNNPACK=OFF \ + USE_PYTORCH_QNNPACK=ON \ USE_QNNPACK=OFF \ USE_XNNPACK=OFF \ + USE_PRECOMPILED_HEADERS=1 \ ${PYTHON_BIN} setup.py bdist_wheel -d "$WHEELHOUSE" } @@ -123,10 +130,23 @@ install_pytorch() { ${PIP_BIN} install --force-reinstall $WHEELHOUSE/* } +unpack_pytorch() { + PYTHON_SITE=`${PYTHON_BIN} -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'` + echo "wheel unpacking Pytorch..into ${PYTHON_SITE}" + wheel unpack -d "$WHEELHOUSE"/unpack_tmp "$WHEELHOUSE"/*.whl + mv "$WHEELHOUSE"/unpack_tmp/* "$PYTHON_SITE"/ +} + #main echo "Building libtorch from source" checkout_pytorch install_requirements build_pytorch package_pytorch -install_pytorch +if [[ $CMAKE_OSX_ARCHITECTURES = "arm64" ]]; then + echo "${Yellow} Cross compiling for arm64 so unpacking PyTorch wheel for libs${NC}" + unpack_pytorch +else + echo "${Green} Installing the built PyTorch wheel ${NC}" + install_pytorch +fi diff --git a/build_tools/docker/Dockerfile b/build_tools/docker/Dockerfile new file mode 100644 index 000000000000..1027a11416f2 --- /dev/null +++ b/build_tools/docker/Dockerfile @@ -0,0 +1,54 @@ +ARG BASE_IMG=ubuntu:22.04 +FROM ${BASE_IMG} as dev-base + +# Disable apt-key parse waring. If someone knows how to do whatever the "proper" +# thing is then feel free. The warning complains about parsing apt-key output, +# which we're not even doing. +ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1 + +ARG ARCH="x86_64" +ARG REPO_NAME="deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy main" +RUN apt-get update && \ + apt-get install -y \ + ca-certificates \ + software-properties-common \ + wget \ + apt-transport-https \ + ccache \ + curl \ + cmake \ + ninja-build \ + git \ + gnupg \ + lsb-release \ + python3-pip \ + python3.10 \ + python3.10-dev \ + python3.10-venv \ + unzip && \ + echo $REPO_NAME >> /etc/apt/sources.list.d/llvm.list && \ + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key| apt-key add - && \ + apt-get update && \ + apt-get install -y \ + clang \ + lld + +######## Bazel ######## +WORKDIR /install-bazel +ARG BAZEL_VERSION=5.2.0 + +# https://bazel.build/install/ubuntu +RUN curl -fsSL https://bazel.build/bazel-release.pub.gpg \ + | gpg --dearmor >bazel-archive-keyring.gpg \ + && mv bazel-archive-keyring.gpg /usr/share/keyrings \ + && echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" \ + | tee /etc/apt/sources.list.d/bazel.list \ + && apt-get update \ + && apt-get install -y "bazel=${BAZEL_VERSION?}" \ + && rm -rf /install-bazel + +### Clean up +RUN apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /main_checkout/torch-mlir diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 958b63dff165..f3b9898a9f14 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -16,22 +16,21 @@ # ./build_tools/python_deploy/build_linux_packages.sh # # Build specific Python versions and packages to custom directory: -# python_versions="cp38-cp38 cp39-cp39" \ -# packages="torch-mlir" \ -# output_dir="/tmp/wheelhouse" \ +# TM_PYTHON_VERSIONS="cp38-cp38 cp39-cp39" \ +# TM_PACKAGES="torch-mlir" \ +# TM_OUTPUT_DIR="/tmp/wheelhouse" \ # ./build_tools/python_deploy/build_linux_packages.sh # # Valid Python versions match a subdirectory under /opt/python in the docker # image. Typically: -# cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310 +# cp38-cp38 cp39-cp39 cp310-cp310 # # Valid packages: -# torch-mlir +# torch-mlir, in-tree, out-of-tree # # Note that this script is meant to be run on CI and it will pollute both the -# output directory and in-tree build/ directories (under runtime/ and -# iree/compiler/) with docker created, root owned builds. Sorry - there is -# no good way around it. +# output directory and in-tree build/ directories with docker created, root owned builds. +# Sorry - there is no good way around it but TODO: move to using user UID/GID. # # It can be run on a workstation but recommend using a git worktree dedicated # to packaging to avoid stomping on development artifacts. @@ -39,54 +38,117 @@ set -eu -o errtrace this_dir="$(cd "$(dirname "$0")" && pwd)" repo_root="$(cd "$this_dir"/../../ && pwd)" -manylinux_docker_image="${manylinux_docker_image:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}" -python_versions="${TM_PYTHON_VERSIONS:-cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310}" -output_dir="${output_dir:-${this_dir}/wheelhouse}" -packages="${packages:-torch-mlir}" +# This needs to be a manylinux image so we can ship pip packages +TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}" +# This assumes an Ubuntu LTS like image. You can build your own with +# ./build_tools/docker/Dockerfile +TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}" +# Version of Python to use in Release builds. Ignored in CIs. +TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp39-cp39 cp310-cp310}" +# Location to store Release wheels +TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" +# What "packages to build" +TM_PACKAGES="${TM_PACKAGES:-torch-mlir out-of-tree in-tree}" +# Use pre-built Pytorch +TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" +# Skip running tests if you want quick iteration +TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}" PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE" export TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}" echo "Setting torch-mlir Python Package version to: ${TORCH_MLIR_PYTHON_PACKAGE_VERSION}" function run_on_host() { - echo "Running on host" - echo "Launching docker image ${manylinux_docker_image}" - echo "Outputting to ${output_dir}" - rm -rf "${output_dir}" - mkdir -p "${output_dir}" + echo "Running on host for $1:$@" + echo "Outputting to ${TM_OUTPUT_DIR}" + rm -rf "${TM_OUTPUT_DIR}" + mkdir -p "${TM_OUTPUT_DIR}" + case "$package" in + torch-mlir) + TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} + export USERID=0 + export GROUPID=0 + ;; + out-of-tree) + TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE} + # CI uses only Python3.10 + TM_PYTHON_VERSIONS="cp310-cp310" + export USERID=$(id -u) + export GROUPID=$(id -g) + ;; + in-tree) + TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE} + # CI uses only Python3.10 + TM_PYTHON_VERSIONS="cp310-cp310" + export USERID=$(id -u) + export GROUPID=$(id -g) + ;; + *) + echo "Unrecognized package '$package'" + exit 1 + ;; + esac + echo "Launching docker image ${TM_CURRENT_DOCKER_IMAGE} with UID:${USERID} GID:${GROUPID}" docker run --rm \ -v "${repo_root}:/main_checkout/torch-mlir" \ - -v "${output_dir}:/wheelhouse" \ + -v "${TM_OUTPUT_DIR}:/wheelhouse" \ + -v "${HOME}:/home/${USER}" \ + --user ${USERID}:${GROUPID} \ + --workdir="/home/$USER" \ + --volume="/etc/group:/etc/group:ro" \ + --volume="/etc/passwd:/etc/passwd:ro" \ + --volume="/etc/shadow:/etc/shadow:ro" \ + --ipc=host \ + --ulimit nofile=32768:32768 \ -e __MANYLINUX_BUILD_WHEELS_IN_DOCKER=1 \ -e "TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION}" \ - -e "python_versions=${python_versions}" \ - -e "packages=${packages}" \ - "${manylinux_docker_image}" \ - -- bash /main_checkout/torch-mlir/build_tools/python_deploy/build_linux_packages.sh + -e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \ + -e "TM_PACKAGES=${package}" \ + -e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \ + -e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \ + -e "CCACHE_DIR=/main_checkout/torch-mlir/.ccache" \ + "${TM_CURRENT_DOCKER_IMAGE}" \ + /bin/bash /main_checkout/torch-mlir/build_tools/python_deploy/build_linux_packages.sh } function run_in_docker() { echo "Running in docker" - echo "Using python versions: ${python_versions}" + echo "Using python versions: ${TM_PYTHON_VERSIONS}" local orig_path="$PATH" # Build phase. - for package in $packages; do - echo "******************** BUILDING PACKAGE ${package} ********************" - for python_version in $python_versions; do + for package in $TM_PACKAGES; do + echo "******************** BUILDING PACKAGE ${package} (docker) ************" + for python_version in $TM_PYTHON_VERSIONS; do python_dir="/opt/python/$python_version" if ! [ -x "$python_dir/bin/python" ]; then - echo "ERROR: Could not find python: $python_dir (skipping)" - continue + echo "Could not find python: $python_dir (using system default Python3)" + python_dir=`which python3` + echo "Defaulting to $python_dir (expected for CI builds)" fi export PATH=$python_dir/bin:$orig_path - echo ":::: Python version $(python --version)" + echo ":::: Python version $(python3 --version)" case "$package" in torch-mlir) clean_wheels torch_mlir "$python_version" build_torch_mlir #run_audit_wheel torch_mlir "$python_version" + clean_build torch_mlir "$python_version" + ;; + out-of-tree) + setup_venv "$python_version" + build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" + if [ "${TM_SKIP_TESTS}" == "OFF" ]; then + test_out_of_tree + fi + ;; + in-tree) + setup_venv "$python_version" + build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version" + if [ "${TM_SKIP_TESTS}" == "OFF" ]; then + test_in_tree; + fi ;; *) echo "Unrecognized package '$package'" @@ -97,6 +159,160 @@ function run_in_docker() { done } + +function build_in_tree() { + local torch_from_src="$1" + local python_version="$2" + echo ":::: Build in-tree Torch from source: $torch_from_src with Python: $python_version" + cmake -GNinja -B/main_checkout/torch-mlir/build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_LINKER=lld \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="/main_checkout/torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="/main_checkout/torch-mlir/externals/llvm-external-projects/torch-mlir-dialects" \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_src" \ + -DPython3_EXECUTABLE="$(which python3)" \ + /main_checkout/torch-mlir/externals/llvm-project/llvm + cmake --build /main_checkout/torch-mlir/build + ccache -s +} + +function _check_file_not_changed_by() { + # _check_file_not_changed_by + cmd="$1" + file="$2" + file_backup="$PWD/$(basename $file)" + file_new="$PWD/$(basename $file).new" + # Save the original file. + cp "$file" "$file_backup" + # Run the command to regenerate it. + "$1" || return 1 + # Save the new generated file. + cp "$file" "$file_new" + # Restore the original file. We want this function to not change the user's + # working tree state. + mv "$file_backup" "$file" + # We use git-diff as "just a diff program" (no SCM stuff) because it has + # nicer output than regular `diff`. + if ! git diff --quiet "$file" "$file_new"; then + echo "#######################################################" + echo "Generated file '${file}' is not up to date (see diff below)" + echo ">>> Please run '${cmd}' to update it <<<" + echo "#######################################################" + git diff --color=always "$file" "$file_new" + # TODO: Is there a better cleanup strategy that doesn't require duplicating + # this inside and outside the `if`? + rm "$file_new" + rm "$file_backup" + return 1 + fi + rm "$file_new" + rm "$file_backup" +} + +function test_in_tree() { + echo ":::: Test in-tree" + cmake --build /main_checkout/torch-mlir/build --target check-torch-mlir-all + + cd /main_checkout/torch-mlir/ + export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir" + + echo ":::: Check that update_shape_lib.sh has been run" + _check_file_not_changed_by ./build_tools/update_shape_lib.sh lib/Dialect/Torch/Transforms/ShapeLibrary.cpp + + echo ":::: Check that update_torch_ods.sh has been run" + _check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td + + echo ":::: Run refbackend e2e integration tests" + python -m e2e_testing.main --config=refbackend -v + + echo ":::: Run eager_mode e2e integration tests" + python -m e2e_testing.main --config=eager_mode -v + + echo ":::: Run TOSA e2e integration tests" + python -m e2e_testing.main --config=tosa -v + + # Temporarily disabled in top of main (https://github.com/llvm/torch-mlir/pull/1292) + #echo ":::: Run Lazy Tensor Core e2e integration tests" + #python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v +} + +function setup_venv() { + local python_version="$1" + echo ":::: Setting up VENV with Python: $python_version" + python3 -m venv /main_checkout/torch-mlir/docker_venv + source /main_checkout/torch-mlir/docker_venv/bin/activate + + echo ":::: pip installing dependencies" + python3 -m pip install -r /main_checkout/torch-mlir/externals/llvm-project/mlir/python/requirements.txt + python3 -m pip install -r /main_checkout/torch-mlir/requirements.txt + +} + +function build_out_of_tree() { + local torch_from_src="$1" + local python_version="$2" + echo ":::: Build out-of-tree Torch from source: $torch_from_src with Python: $python_version" + + if [ ! -d "/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" ] + then + echo ":::: LLVM / MLIR is not built so building it first.." + cmake -GNinja -B/main_checkout/torch-mlir/llvm-build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_LINKER=lld \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="$(which python3)" \ + /main_checkout/torch-mlir/externals/llvm-project/llvm + cmake --build /main_checkout/torch-mlir/llvm-build + fi + + # Incremental builds come here directly and can run cmake if required. + cmake -GNinja -B/main_checkout/torch-mlir/build_oot \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_LINKER=lld \ + -DLLVM_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/llvm/" \ + -DMLIR_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_src" \ + -DPython3_EXECUTABLE="$(which python3)" \ + /main_checkout/torch-mlir + cmake --build /main_checkout/torch-mlir/build_oot + ccache -s +} + +function test_out_of_tree() { + echo ":::: Test out-of-tree" + cmake --build /main_checkout/torch-mlir/build_oot --target check-torch-mlir-all +} + +function clean_build() { + # clean up for recursive runs + local package="$1" + local python_version="$2" + echo ":::: Clean build dir $package $python_version" + rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch +} + function build_torch_mlir() { python -m pip install -r /main_checkout/torch-mlir/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu CMAKE_GENERATOR=Ninja \ @@ -123,7 +339,10 @@ function clean_wheels() { # Trampoline to the docker container if running on the host. if [ -z "${__MANYLINUX_BUILD_WHEELS_IN_DOCKER-}" ]; then - run_on_host "$@" + for package in $TM_PACKAGES; do + echo "******************** BUILDING PACKAGE ${package} (host) *************" + run_on_host "${package} $@" + done else run_in_docker "$@" fi diff --git a/build_tools/python_deploy/build_macos_packages.sh b/build_tools/python_deploy/build_macos_packages.sh index b97c791a4a79..18606a0c2263 100755 --- a/build_tools/python_deploy/build_macos_packages.sh +++ b/build_tools/python_deploy/build_macos_packages.sh @@ -35,6 +35,10 @@ export CMAKE_OSX_ARCHITECTURES="${TORCH_MLIR_OSX_ARCH:-arm64;x86_64}" echo "CMAKE_OSX_ARCHITECTURES: $CMAKE_OSX_ARCHITECTURES" echo "MACOSX_DEPLOYMENT_TARGET $MACOSX_DEPLOYMENT_TARGET" +# Disable LTC build on MacOS to avoid linkage issues +# https://github.com/llvm/torch-mlir/issues/1253 +export TORCH_MLIR_ENABLE_LTC=0 + function run() { echo "Using python versions: ${python_versions}" diff --git a/build_tools/torchscript_e2e_heavydep_tests/README.md b/build_tools/torchscript_e2e_heavydep_tests/README.md index e469ef1229be..5c7a34bcec0e 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/README.md +++ b/build_tools/torchscript_e2e_heavydep_tests/README.md @@ -13,19 +13,19 @@ self-contained virtual environment. It can be used like so: # serialized test artifacts in the other specified directory. # This command is safe to re-run if you have already built the virtual # environment and just changed the tests. -build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh \ +build_tools/e2e_heavydep_tests/generate_serialized_tests.sh \ path/to/heavydep_venv \ path/to/heavydep_serialized_tests # Add the --serialized-test-dir flag to point at the directory containing the # serialized tests. All other functionality is the same as the normal invocation -# of torchscript_e2e_test.sh, but the serialized tests will be available. -tools/torchscript_e2e_test.sh --serialized-test-dir=path/to/heavydep_serialized_tests +# of e2e_test.sh, but the serialized tests will be available. +tools/e2e_test.sh --serialized-test-dir=path/to/heavydep_serialized_tests ``` The tests use the same (pure-Python) test framework as the normal -torchscript_e2e_test.sh, but the tests are added in -`build_tools/torchscript_e2e_heavydep_tests` instead of +e2e_test.sh, but the tests are added in +`build_tools/e2e_heavydep_tests` instead of `frontends/pytorch/e2e_testing/torchscript`. We rely critically on serialized TorchScript compatibility across PyTorch diff --git a/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh b/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh index aa0d42953d47..ddda4d9a6778 100755 --- a/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh +++ b/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh @@ -38,4 +38,4 @@ cd "$torch_mlir_src_root" export PYTHONPATH=${PYTHONPATH-} source "$torch_mlir_src_root/.env" -python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir="$serialized_test_dir" +python3 -m build_tools.e2e_heavydep_tests.main --output_dir="$serialized_test_dir" diff --git a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py b/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py index 8f1f834bed71..f387dafca374 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py +++ b/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py @@ -6,9 +6,9 @@ import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export torch.manual_seed(0) diff --git a/build_tools/torchscript_e2e_heavydep_tests/main.py b/build_tools/torchscript_e2e_heavydep_tests/main.py index 48a409495738..28e374628568 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/main.py +++ b/build_tools/torchscript_e2e_heavydep_tests/main.py @@ -5,7 +5,7 @@ import argparse -from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to +from torch_mlir_e2e_test.serialization import serialize_all_tests_to from . import hf_sequence_classification from . import vision_models diff --git a/build_tools/torchscript_e2e_heavydep_tests/train_models.py b/build_tools/torchscript_e2e_heavydep_tests/train_models.py index 29f5c4a8fd2f..d4d9e39994b9 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/train_models.py +++ b/build_tools/torchscript_e2e_heavydep_tests/train_models.py @@ -9,9 +9,9 @@ from torch.nn.utils import _stateless from transformers import AutoTokenizer, AutoModelForSequenceClassification -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export from torch import fx import copy diff --git a/build_tools/torchscript_e2e_heavydep_tests/vision_models.py b/build_tools/torchscript_e2e_heavydep_tests/vision_models.py index 59ca23e127a9..71043eb985af 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/vision_models.py +++ b/build_tools/torchscript_e2e_heavydep_tests/vision_models.py @@ -6,9 +6,9 @@ import torch import torchvision.models as models -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export import timm torch.manual_seed(0) diff --git a/build_tools/update_shape_lib.sh b/build_tools/update_shape_lib.sh index e0de9f6cb0cb..b2d619d3484b 100755 --- a/build_tools/update_shape_lib.sh +++ b/build_tools/update_shape_lib.sh @@ -16,19 +16,17 @@ build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")" torch_transforms_cpp_dir="${src_dir}/lib/Dialect/Torch/Transforms" python_packages_dir="${build_dir}/tools/torch-mlir/python_packages" +TORCH_MLIR_EXT_PYTHONPATH="${TORCH_MLIR_EXT_PYTHONPATH:-""}" pypath="${python_packages_dir}/torch_mlir" -# TODO: Re-enable once custom op support is back. -#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then -# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}" -#fi -#ext_module="torch_mlir._torch_mlir_custom_op_example" -#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then -# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} " -#fi +if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then + pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}" +fi +TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}" +if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then + ext_module="${TORCH_MLIR_EXT_MODULES} " +fi PYTHONPATH="${pypath}" python \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \ + --pytorch_op_extensions=${ext_module:-""} \ --torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" - -# TODO: Add back to shape_lib_gen invocation once custom op support is back. -# --pytorch_op_extensions=${ext_module} \ diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index 2e22ab12e08f..2b30ffe663ae 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -16,20 +16,19 @@ build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")" torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR" python_packages_dir="${build_dir}/tools/torch-mlir/python_packages" +TORCH_MLIR_EXT_PYTHONPATH="${TORCH_MLIR_EXT_PYTHONPATH:-""}" pypath="${python_packages_dir}/torch_mlir" -# TODO: Re-enable once custom op support is back. -#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then -# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}" -#fi -#ext_module="torch_mlir._torch_mlir_custom_op_example" -#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then -# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}" -#fi +if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then + pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}" +fi +TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}" +ext_module="${ext_module:-""}" +if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then + ext_module="${TORCH_MLIR_EXT_MODULES}" +fi PYTHONPATH="${pypath}" python \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \ --torch_ir_include_dir="${torch_ir_include_dir}" \ + --pytorch_op_extensions="${ext_module}" \ --debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt" - -# TODO: Add back to torch_ods_gen invocation once custom op support is back. -# --pytorch_op_extensions="${ext_module}" \ diff --git a/development.md b/development.md deleted file mode 100644 index 8bffe251bd0e..000000000000 --- a/development.md +++ /dev/null @@ -1,269 +0,0 @@ -# Checkout and build from source - -## Check out the code - -```shell -git clone https://github.com/llvm/torch-mlir -cd torch-mlir -git submodule update --init -``` - -## Setup your Python VirtualEnvironment and Dependencies - -Also, ensure that you have the appropriate `python-dev` package installed -to access the Python development libraries / headers. - -```shell -python -m venv mlir_venv -source mlir_venv/bin/activate -# Some older pip installs may not be able to handle the recent PyTorch deps -python -m pip install --upgrade pip -# Install latest PyTorch nightlies and build requirements. -python -m pip install -r requirements.txt -``` - -## Build Python Packages - -We have preliminary support for building Python packages. This can be done -with the following commands: - -``` -python -m pip install --upgrade pip -python -m pip install -r requirements.txt -CMAKE_GENERATOR=Ninja python setup.py bdist_wheel -``` - -## CMake Build - -Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is the most straightforward, as it will build LLVM dependencies as well. - -### Building torch-mlir in-tree - -The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. - -```shell -cmake -GNinja -Bbuild \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DPython3_FIND_VIRTUALENV=ONLY \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ - -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \ - -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_TARGETS_TO_BUILD=host \ - externals/llvm-project/llvm -``` -The following additional quality of life flags can be used to reduce build time: -* Enabling ccache: -```shell - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -``` -* Enabling LLD (links in seconds compared to minutes) -```shell - -DCMAKE_EXE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_MODULE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_SHARED_LINKER_FLAGS_INIT="-fuse-ld=lld" -# Use --ld-path= instead of -fuse-ld=lld for clang > 13 -``` -* Enabling libtorch binary cache -By default we download the latest version of libtorch everytime you build so we are always on the latest version. Set `-DLIBTORCH_CACHE=ON` to -not download the latest version everytime. If libtorch gets out of date and you test against a newer PyTorch you may notice failures. -```shell - -DLIBTORCH_CACHE=ON -``` -* Enabling building libtorch as part of your build -By default we download the latest version of libtorch. We have an experimental path to build libtorch (and PyTorch wheels) from source. -```shell - -DLIBTORCH_SRC_BUILD=ON # Build Libtorch from source - -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) -``` - -### Building against a pre-built LLVM - -If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, you can also build the project *out-of-tree* using the following command as template: -```shell -cmake -GNinja -Bbuild \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DPython3_FIND_VIRTUALENV=ONLY \ - -DMLIR_DIR="$LLVM_INSTALL_DIR/lib/cmake/mlir/" \ - -DLLVM_DIR="$LLVM_INSTALL_DIR/lib/cmake/llvm/" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_TARGETS_TO_BUILD=host \ - . -``` -The same QoL CMake flags can be used to enable ccache and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`. - -Be aware that the installed version of LLVM needs in general to match the committed version in `externals/llvm-project`. Using a different version may or may not work. - - -### Build commands - -After either cmake run (in-tree/out-of-tree), use one of the following commands to build the project: -```shell -# Build just torch-mlir (not all of LLVM) -cmake --build build --target tools/torch-mlir/all - -# Run unit tests. -cmake --build build --target check-torch-mlir - -# Run Python regression tests. -cmake --build build --target check-torch-mlir-python - -# Build everything (including LLVM if in-tree) -cmake --build build -``` - -## Setup Python Environment to export the built Python packages -```shell -export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples -``` - -## Jupyter - -Jupyter notebook: -```shell -python -m ipykernel install --user --name=torch-mlir --env PYTHONPATH "$PYTHONPATH" -# Open in jupyter, and then navigate to -# `examples/resnet_inference.ipynb` and use the `torch-mlir` kernel to run. -jupyter notebook -``` - -[Example IR](https://gist.github.com/silvasean/e74780f8a8a449339aac05c51e8b0caa) for a simple 1 layer MLP to show the compilation steps from TorchScript. - - -## Interactive Use - -The `build_tools/write_env_file.sh` script will output a `.env` -file in the workspace folder with the correct PYTHONPATH set. This allows -tools like VSCode to work by default for debugging. This file can also be -manually `source`'d in a shell. - - -## Bazel Build - -> **NOTE** Our Bazel build follows LLVM's Bazel build policy: only the -> subcommunity interested in Bazel is responsible for fixing it. Average -> Torch-MLIR developers should not be notified of any Bazel build issues and are -> not responsible for fixing any breakages (though any help is, of course, -> welcome). For more info, see LLVM's -> [Peripheral Support Tier](https://llvm.org/docs/SupportPolicy.html#peripheral-tier) -> definition. - -Torch-MLIR can also be built using Bazel (apart from the official CMake build) for users that depend on Bazel in their workflows. To build `torch-mlir-opt` using Bazel, follow these steps: - -1. Install [Bazel](https://docs.bazel.build/versions/main/install.html) if you don't already have it -2. Install a relatively new release of [Clang](https://releases.llvm.org/download.html) -3. Build: -```shell -cd utils/bazel -bazel build @torch-mlir//... -``` -4. Find the built binary at `bazel-bin/external/torch-mlir/torch-mlir-opt`. - - -# Testing - -Torch-MLIR has two types of tests: - -1. End-to-end execution tests. These compile and run a program and check the - result against the expected output from execution on native Torch. These use - a homegrown testing framework (see - `python/torch_mlir_e2e_test/torchscript/framework.py`) and the test suite - lives at `python/torch_mlir_e2e_test/test_suite/__init__.py`. - -2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. - For example, these might involve using `torch-mlir-opt` to run a pass and - check the output with `FileCheck`. - - -## Running execution (end-to-end) tests: - -```shell -# Run all tests on the reference backend -./tools/torchscript_e2e_test.sh -# Run tests that match the regex `Conv2d`, with verbose errors. -./tools/torchscript_e2e_test.sh --filter Conv2d --verbose -# Run tests on the TOSA backend. -./tools/torchscript_e2e_test.sh --config tosa -``` - -## Running unit tests. - -To run all of the unit tests, run: - -``` -ninja check-torch-mlir-all -``` - -This can be broken down into - -``` -ninja check-torch-mlir check-torch-mlir-dialects check-torch-mlir-python -``` - -To run more fine-grained tests, you can do, for `check-torch-mlir`: - -``` -cd $TORCH_MLIR_BUILD_DIR/tools/torch-mlir/test -$TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonicalize -``` - -See [the `lit` documentation](https://llvm.org/docs/CommandGuide/lit.html) for details on the available lit args. - -For example, if you wanted to test just `test/Dialect/Torch/canonicalize.mlir`, -then you might do - -``` -cd $TORCH_MLIR_BUILD_DIR/tools/torch-mlir/test -$TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonicalize.mlir -``` - -Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs. - -# Unexpected test failures with PyTorch / Libtorch skew - -Torch-MLIR currently by default links to libtorch binaries and tests are run with the PyTorch nightlies. This can cause version / api -skew in your tests like (https://github.com/llvm/torch-mlir/issues/1007). If you notice any unexpected test failures please follow the steps below: - -``` -rm -rf libtorch* # note the asterisk after libtorch, since there is also a .zip file that needs to be removed -rm -rf build/ -python -m pip install -r requirements.txt --upgrade # to get the latest pytorch -# Then rebuild and test torch-mlir -``` -We expect this to be fixed once we take on a dependency on PyTorch and build it from source. That work is being tracked in [this](https://github.com/llvm/torch-mlir/tree/release-src-build) branch. - -# Updating the LLVM submodule - -Torch-MLIR maintains `llvm-project` (which contains, among other things, -upstream MLIR) as a submodule in `externals/llvm-project`. We aim to update this -at least weekly to new LLVM revisions to bring in the latest features and spread -out over time the effort of updating our code for MLIR API breakages. - -Updating the LLVM submodule is done by: - -1. In the `externals/llvm-project` directory, run `git pull` to update to the - upstream revision of interest (such as a particular upstream change that is - needed for your Torch-MLIR PR). -2. Rebuild and test Torch-MLIR (see above), fixing any issues that arise. This - might involve fixing various API breakages introduced upstream (they are - likely unrelated to what you are working on). If these fixes are too complex, - please file a work-in-progress PR explaining the issues you are running into - asking for help so that someone from the community can help. -3. Run `build_tools/update_shape_lib.sh` to update the shape library -- this is - sometimes needed because upstream changes can affect canonicalization and - other minor details of the IR in the shape library. See [docs/shape_lib.md](docs/shape_lib.md) for more details on the shape library. - - -Here are some examples of PR's updating the LLVM submodule: - -- https://github.com/llvm/torch-mlir/pull/958 -- https://github.com/llvm/torch-mlir/pull/856 - - -# Other docs - -- GitHub wiki: https://github.com/llvm/torch-mlir/wiki -- Of particular interest in the [How to add end-to-end support for new Torch ops](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation) doc. diff --git a/docs/Torch-MLIR.excalidraw b/docs/Torch-MLIR.excalidraw new file mode 100644 index 000000000000..4b7d8d29b413 --- /dev/null +++ b/docs/Torch-MLIR.excalidraw @@ -0,0 +1,2639 @@ +{ + "type": "excalidraw", + "version": 2, + "source": "https://app.excalidraw.com", + "elements": [ + { + "type": "rectangle", + "version": 926, + "versionNonce": 1878316780, + "isDeleted": false, + "id": "VoA-P762E-kYEfEXMXMan", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 354.493408203125, + "y": 142, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 562.9459228515625, + "height": 205.00000000000003, + "seed": 66180426, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 386, + "versionNonce": 1469668820, + "isDeleted": false, + "id": "HqFbe3ioHlzPjh47of6Xm", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 548, + "y": 183.5, + "strokeColor": "#1864ab", + "backgroundColor": "transparent", + "width": 201, + "height": 64, + "seed": 146478038, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 51.319148936170194, + "fontFamily": 1, + "text": "PyTorch", + "baseline": 45, + "textAlign": "center", + "verticalAlign": "top", + "containerId": null, + "originalText": "PyTorch" + }, + { + "type": "rectangle", + "version": 175, + "versionNonce": 261298540, + "isDeleted": false, + "id": "gIQVefxMbT2pRGToulSyw", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 400, + "y": 280, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 156.00000000000003, + "height": 47.99999999999998, + "seed": 190072534, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 466, + "versionNonce": 970359636, + "isDeleted": false, + "id": "LRs912__zHToeBmjAzSQ7", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 573.9208984375, + "y": 281.69195556640625, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 157.39685058593753, + "height": 47.50274658203128, + "seed": 344731990, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "m5nLHFx0hX6Cd6zMuikcu" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 472, + "versionNonce": 1464142828, + "isDeleted": false, + "id": "cubDvRltmWCH__B9Y9m-8", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 742.048583984375, + "y": 281.6216735839844, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 163.87255859375006, + "height": 45.258087158203125, + "seed": 1854584586, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 385, + "versionNonce": 1951594708, + "isDeleted": false, + "id": "3oURCWeTRMOEqHOJb9pDi", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 753, + "y": 410, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999997, + "height": 54.99999999999999, + "seed": 500165974, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 618, + "versionNonce": 1101930348, + "isDeleted": false, + "id": "HXNEs54Djw-u5oqv0I0RN", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 538.3072509765625, + "y": 511.7095947265625, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 238.47961425781253, + "height": 54.99999999999999, + "seed": 260191382, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "id": "mEM1iJl3apSkidvjPFb07", + "type": "arrow" + }, + { + "id": "CPvTKrc3_ABgC6tI8JY9-", + "type": "arrow" + } + ], + "updated": 1660940744726, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 717, + "versionNonce": 1654963796, + "isDeleted": false, + "id": "xPee8sq_dDf5TxhIIE9xf", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 368.935546875, + "y": 719.6466064453125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999997, + "height": 54.99999999999999, + "seed": 1201547606, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 885, + "versionNonce": 365691116, + "isDeleted": false, + "id": "qSfNJH2ZSv_X2ar-lnJiq", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 551.935546875, + "y": 720.6466064453125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999997, + "height": 54.99999999999999, + "seed": 1864788298, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 505, + "versionNonce": 2107757524, + "isDeleted": false, + "id": "zDzlPtMgDR9JllW5HcHQk", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 588, + "y": 412, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999999, + "height": 57, + "seed": 791673302, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "arrow", + "id": "m5nLHFx0hX6Cd6zMuikcu" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 697, + "versionNonce": 83166060, + "isDeleted": false, + "id": "DCHk8Ww01wbH6p1ggjcHw", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 401, + "y": 412, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 140.99999999999997, + "height": 58.99999999999998, + "seed": 676876554, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "arrow", + "id": "m5nLHFx0hX6Cd6zMuikcu" + }, + { + "type": "arrow", + "id": "jlcRseYBmZOfpeR6fsXlH" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 71, + "versionNonce": 1041727340, + "isDeleted": false, + "id": "4mua5Z9wyXZyyXjAS3jhg", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 431, + "y": 291, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 113, + "height": 25, + "seed": 880334614, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "jlcRseYBmZOfpeR6fsXlH" + } + ], + "updated": 1660940707533, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "TorchScript", + "baseline": 18, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "TorchScript" + }, + { + "type": "text", + "version": 115, + "versionNonce": 1557771756, + "isDeleted": false, + "id": "yW-2pzejqWogK6u9YysVa", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 579.5, + "y": 293, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 149, + "height": 25, + "seed": 1259559178, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "m5nLHFx0hX6Cd6zMuikcu" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "torch_dispatch", + "baseline": 18, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "torch_dispatch" + }, + { + "type": "text", + "version": 213, + "versionNonce": 1060499156, + "isDeleted": false, + "id": "KGkYT_1D9auJhECUHsXn6", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 744.37158203125, + "y": 292, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 158, + "height": 25, + "seed": 1351715018, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "LazyTensorCore", + "baseline": 18, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "LazyTensorCore" + }, + { + "type": "text", + "version": 232, + "versionNonce": 1787041900, + "isDeleted": false, + "id": "lzgAzH8DMNlzk1SEvenUB", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 760.5, + "y": 409.5, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 101.53703703703698, + "height": 48.52212389380531, + "seed": 1500560138, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 17.97115699770567, + "fontFamily": 1, + "text": "LTC MLIR \nPlug-in", + "baseline": 40.52212389380531, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "LTC MLIR \nPlug-in" + }, + { + "type": "text", + "version": 485, + "versionNonce": 28898388, + "isDeleted": false, + "id": "lC49obx_HZvLUDhBFwN3d", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 600, + "y": 417.5, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 105, + "height": 46, + "seed": 390444682, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 18.14814814814815, + "fontFamily": 1, + "text": "Build per-op\nJIT graph", + "baseline": 39, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "Build per-op\nJIT graph" + }, + { + "type": "text", + "version": 692, + "versionNonce": 1237199596, + "isDeleted": false, + "id": "rxna1NdNTTOVVZAztGxaH", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 409.33333333333337, + "y": 421.5, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 129.16666666666663, + "height": 39.99999999999997, + "seed": 1691917462, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 15.123456790123457, + "fontFamily": 1, + "text": "Torchscript/MLIR\nConverter ", + "baseline": 32.99999999999997, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "Torchscript/MLIR\nConverter " + }, + { + "type": "arrow", + "version": 332, + "versionNonce": 1145876972, + "isDeleted": false, + "id": "jlcRseYBmZOfpeR6fsXlH", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 472.1378278340716, + "y": 329.1111111111111, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 3.1903837298033295, + "height": 69.77777777777783, + "seed": 715847702, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940707534, + "link": null, + "locked": false, + "startBinding": { + "elementId": "4mua5Z9wyXZyyXjAS3jhg", + "gap": 13.111111111111107, + "focus": 0.2511447692849595 + }, + "endBinding": { + "elementId": "DCHk8Ww01wbH6p1ggjcHw", + "gap": 13.111111111111107, + "focus": -0.06264298943240007 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -3.1903837298033295, + 69.77777777777783 + ] + ] + }, + { + "type": "arrow", + "version": 247, + "versionNonce": 219278700, + "isDeleted": false, + "id": "m5nLHFx0hX6Cd6zMuikcu", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 650.9914408647296, + "y": 330.66666666666663, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 5.669498788367719, + "height": 73.73333333333335, + "seed": 938405002, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "yW-2pzejqWogK6u9YysVa", + "gap": 12.666666666666666, + "focus": 0.014297452752670503 + }, + "endBinding": { + "elementId": "zDzlPtMgDR9JllW5HcHQk", + "gap": 7.599999999999999, + "focus": -0.14282601590587216 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -5.669498788367719, + 73.73333333333335 + ] + ] + }, + { + "type": "arrow", + "version": 591, + "versionNonce": 2002710356, + "isDeleted": false, + "id": "mhl9dSP2l8IK7eFYvftAg", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 820.4646398319892, + "y": 330.44444444444446, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 12.349181782953792, + "height": 74.66666666666663, + "seed": 288672586, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "KGkYT_1D9auJhECUHsXn6", + "gap": 13.444444444444441, + "focus": -0.016956973122620834 + }, + "endBinding": { + "elementId": "3oURCWeTRMOEqHOJb9pDi", + "gap": 4.8888888888888875, + "focus": -0.2077568362566729 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -12.349181782953792, + 74.66666666666663 + ] + ] + }, + { + "type": "line", + "version": 399, + "versionNonce": 1707696108, + "isDeleted": false, + "id": "Ua_J40SfhXyeAd-Z97hiN", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 273.00000000000006, + "y": 372, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 869, + "height": 0, + "seed": 1815219414, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + 869, + 0 + ] + ] + }, + { + "type": "line", + "version": 621, + "versionNonce": 1039000788, + "isDeleted": false, + "id": "RQYWMlM3DABMQJxtGYJXR", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 287.7365722656251, + "y": 696.1629028320312, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 869, + "height": 0, + "seed": 332524694, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": null, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": null, + "points": [ + [ + 0, + 0 + ], + [ + 869, + 0 + ] + ] + }, + { + "type": "text", + "version": 276, + "versionNonce": 646149228, + "isDeleted": false, + "id": "EohfXBF_ChzOXb26jwh5C", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 926, + "y": 341.6290283203125, + "strokeColor": "#e67700", + "backgroundColor": "transparent", + "width": 206, + "height": 20, + "seed": 507549130, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660946733297, + "link": null, + "locked": false, + "fontSize": 16, + "fontFamily": 1, + "text": "github.com/pytorch/pytorch", + "baseline": 14, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "github.com/pytorch/pytorch" + }, + { + "type": "text", + "version": 356, + "versionNonce": 1177231956, + "isDeleted": false, + "id": "HyVVUIQKFQuD09qT1bRwZ", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 922, + "y": 379.5, + "strokeColor": "#e67700", + "backgroundColor": "transparent", + "width": 193, + "height": 20, + "seed": 1700073354, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660946723727, + "link": null, + "locked": false, + "fontSize": 16, + "fontFamily": 1, + "text": "github.com/llvm/torch-mlir", + "baseline": 14, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "github.com/llvm/torch-mlir" + }, + { + "type": "text", + "version": 74, + "versionNonce": 642471148, + "isDeleted": false, + "id": "0khJ3P1VsWyHIbWAkrpHE", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 605.5, + "y": 518.5, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 93, + "height": 40, + "seed": 346821974, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + }, + { + "type": "arrow", + "id": "DTYYOEYxneSWWLSsMj-QA" + }, + { + "id": "x63UEL7zv_DhnLWWouZUy", + "type": "arrow" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 16, + "fontFamily": 1, + "text": "Torch-MLIR\nDialect", + "baseline": 34, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "Torch-MLIR\nDialect" + }, + { + "type": "text", + "version": 402, + "versionNonce": 913424340, + "isDeleted": false, + "id": "aonYH0YvXUyexHmAKJd9H", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 373.9582741477272, + "y": 729.1466064453125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 116.47727272727283, + "height": 41.00000000000004, + "seed": 950588874, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + }, + { + "type": "arrow", + "id": "DTYYOEYxneSWWLSsMj-QA" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 14.909090909090922, + "fontFamily": 1, + "text": "Reference MLIR\n CPU runner", + "baseline": 34.00000000000004, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "Reference MLIR\n CPU runner" + }, + { + "type": "text", + "version": 604, + "versionNonce": 1919539052, + "isDeleted": false, + "id": "BKdt39so0UxlQhr92ms3M", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 581.9582741477273, + "y": 731.1466064453125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 67, + "height": 38, + "seed": 1271343190, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + }, + { + "type": "arrow", + "id": "DTYYOEYxneSWWLSsMj-QA" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 14.909090909090922, + "fontFamily": 1, + "text": "CPU\nBackends", + "baseline": 32, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "CPU\nBackends" + }, + { + "type": "arrow", + "version": 227, + "versionNonce": 592618836, + "isDeleted": false, + "id": "3jV4ltqqNRgUJ_hiQTprf", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 475.15209996893157, + "y": 472, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 177.46230293328904, + "height": 39, + "seed": 91937098, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "rxna1NdNTTOVVZAztGxaH", + "focus": 0.8787723093564095, + "gap": 10.5 + }, + "endBinding": { + "elementId": "0khJ3P1VsWyHIbWAkrpHE", + "focus": 0.9144895152973532, + "gap": 7.5 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + 177.46230293328904, + 39 + ] + ] + }, + { + "type": "arrow", + "version": 157, + "versionNonce": 1821641196, + "isDeleted": false, + "id": "TE1j6kxZKej3YsuMPcxgT", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 649.6665160954531, + "y": 475, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 0.35695639494633724, + "height": 35, + "seed": 2060134986, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "lC49obx_HZvLUDhBFwN3d", + "focus": 0.047058823529411764, + "gap": 11.5 + }, + "endBinding": { + "elementId": "0khJ3P1VsWyHIbWAkrpHE", + "focus": -0.06382978723404255, + "gap": 8.5 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -0.35695639494633724, + 35 + ] + ] + }, + { + "type": "arrow", + "version": 177, + "versionNonce": 1576721108, + "isDeleted": false, + "id": "pNcSwccuMNO6_J-0ec8fZ", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 806.5768512080433, + "y": 469, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 155.95625943644677, + "height": 41, + "seed": 1813501834, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "lzgAzH8DMNlzk1SEvenUB", + "focus": -0.8987651623136664, + "gap": 10.977876106194685 + }, + "endBinding": { + "elementId": "0khJ3P1VsWyHIbWAkrpHE", + "focus": -0.8956714761376248, + "gap": 8.5 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -155.95625943644677, + 41 + ] + ] + }, + { + "type": "arrow", + "version": 673, + "versionNonce": 936635500, + "isDeleted": false, + "id": "DTYYOEYxneSWWLSsMj-QA", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 626.2705408977356, + "y": 567, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 27.138599512812107, + "height": 36.913818359375, + "seed": 1140554966, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "0khJ3P1VsWyHIbWAkrpHE", + "gap": 8.5, + "focus": 0.07962382445141065 + }, + "endBinding": { + "elementId": "Rjt45nyi1UlloVmswsnId", + "gap": 11.75262451171875, + "focus": -0.39929698566253613 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -27.138599512812107, + 36.913818359375 + ] + ] + }, + { + "type": "text", + "version": 313, + "versionNonce": 1358979156, + "isDeleted": false, + "id": "HWSOxxZBB7Y41X2hNmfVY", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 1002.736572265625, + "y": 706.6629028320312, + "strokeColor": "#e67700", + "backgroundColor": "transparent", + "width": 137, + "height": 22, + "seed": 893440982, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 16, + "fontFamily": 1, + "text": "MLIR Ecoysystem", + "baseline": 15, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "MLIR Ecoysystem" + }, + { + "type": "rectangle", + "version": 961, + "versionNonce": 99843820, + "isDeleted": false, + "id": "i5ZeSwnhCon-_MtxYNUZP", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 696.6268310546875, + "y": 720.1858520507812, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999997, + "height": 54.99999999999999, + "seed": 1479891827, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 678, + "versionNonce": 607729108, + "isDeleted": false, + "id": "6RXNGjHK6CGiBPsckoqoq", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 726.6495583274148, + "y": 730.6858520507812, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 67, + "height": 38, + "seed": 1962314045, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + }, + { + "type": "arrow", + "id": "DTYYOEYxneSWWLSsMj-QA" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 14.909090909090922, + "fontFamily": 1, + "text": "GPU\nBackends", + "baseline": 32, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "GPU\nBackends" + }, + { + "type": "rectangle", + "version": 941, + "versionNonce": 1065140588, + "isDeleted": false, + "id": "7Eot8G67eEcBglL1uSBIw", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 844.628173828125, + "y": 720.63720703125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 127.99999999999997, + "height": 54.99999999999999, + "seed": 451677725, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 672, + "versionNonce": 347822932, + "isDeleted": false, + "id": "31Juzr7aHXLZC8omDLfZp", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 866.1509011008523, + "y": 731.13720703125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 84, + "height": 38, + "seed": 1681468979, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "type": "arrow", + "id": "TE1j6kxZKej3YsuMPcxgT" + }, + { + "type": "arrow", + "id": "pNcSwccuMNO6_J-0ec8fZ" + }, + { + "type": "arrow", + "id": "DTYYOEYxneSWWLSsMj-QA" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 14.909090909090922, + "fontFamily": 1, + "text": "Accelerator\nBackends", + "baseline": 32, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": null, + "originalText": "Accelerator\nBackends" + }, + { + "type": "text", + "version": 93, + "versionNonce": 733590508, + "isDeleted": false, + "id": "92OV5fS7X4ZKvuJQDlQtC", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 532.7581787109375, + "y": 75.37094116210938, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 247, + "height": 25, + "seed": 1293852047, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Torch-MLIR Architecture", + "baseline": 18, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "Torch-MLIR Architecture" + }, + { + "type": "freedraw", + "version": 122, + "versionNonce": 1289539796, + "isDeleted": false, + "id": "q5GjA5dVEf3xpkGXxgY_x", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 995.828857421875, + "y": 753.1173095703125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 1.5487060546875, + "height": 3.09722900390625, + "seed": 1071022415, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0, + 0.258056640625 + ], + [ + 0, + 0.7742919921875 + ], + [ + 0, + 1.0323486328125 + ], + [ + 0, + 1.548583984375 + ], + [ + 0, + 1.806640625 + ], + [ + 0.5162353515625, + 2.0648193359375 + ], + [ + 0.7742919921875, + 2.0648193359375 + ], + [ + 1.032470703125, + 2.0648193359375 + ], + [ + 1.032470703125, + 1.548583984375 + ], + [ + 1.032470703125, + 1.29052734375 + ], + [ + 1.032470703125, + 1.0323486328125 + ], + [ + 1.032470703125, + 0.7742919921875 + ], + [ + 1.032470703125, + 0.5162353515625 + ], + [ + 1.032470703125, + 0.258056640625 + ], + [ + 0.7742919921875, + 0 + ], + [ + 0.5162353515625, + -0.25811767578125 + ], + [ + 0.258056640625, + -0.25811767578125 + ], + [ + 0, + -0.5162353515625 + ], + [ + -0.258056640625, + -0.5162353515625 + ], + [ + -0.5162353515625, + -0.5162353515625 + ], + [ + -0.5162353515625, + -0.7742919921875 + ], + [ + -0.5162353515625, + -1.03240966796875 + ], + [ + 0, + 0 + ] + ], + "lastCommittedPoint": null, + "simulatePressure": true, + "pressures": [] + }, + { + "type": "freedraw", + "version": 132, + "versionNonce": 414471788, + "isDeleted": false, + "id": "xuNFYEjvBADrQPPJPBPdW", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 1015.4517822265625, + "y": 753.2559814453125, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 3.410888671875, + "height": 3.35540771484375, + "seed": 348041455, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0, + 0.258056640625 + ], + [ + 0, + 0.51611328125 + ], + [ + 0, + 1.0323486328125 + ], + [ + 0, + 1.29052734375 + ], + [ + 0, + 1.806640625 + ], + [ + 0, + 2.0648193359375 + ], + [ + 0.2581787109375, + 2.3228759765625 + ], + [ + 0.5162353515625, + 2.5810546875 + ], + [ + 1.3460693359375, + 2.5810546875 + ], + [ + 1.6041259765625, + 2.5810546875 + ], + [ + 2.120361328125, + 2.5810546875 + ], + [ + 2.37841796875, + 2.5810546875 + ], + [ + 2.8946533203125, + 2.5810546875 + ], + [ + 3.1527099609375, + 2.5810546875 + ], + [ + 3.410888671875, + 2.5810546875 + ], + [ + 3.410888671875, + 2.3228759765625 + ], + [ + 3.410888671875, + 2.0648193359375 + ], + [ + 3.410888671875, + 1.806640625 + ], + [ + 3.410888671875, + 1.548583984375 + ], + [ + 3.410888671875, + 1.29052734375 + ], + [ + 3.410888671875, + 0.7742919921875 + ], + [ + 3.410888671875, + 0.51611328125 + ], + [ + 3.410888671875, + 0 + ], + [ + 3.1527099609375, + -0.25811767578125 + ], + [ + 2.6365966796875, + -0.5162353515625 + ], + [ + 2.37841796875, + -0.77435302734375 + ], + [ + 1.8621826171875, + -0.77435302734375 + ], + [ + 1.6041259765625, + -0.77435302734375 + ], + [ + 0.7742919921875, + -0.77435302734375 + ], + [ + 0.5162353515625, + -0.77435302734375 + ], + [ + 0.2581787109375, + -0.77435302734375 + ], + [ + 0, + -0.77435302734375 + ], + [ + 0, + 0 + ] + ], + "lastCommittedPoint": null, + "simulatePressure": true, + "pressures": [] + }, + { + "type": "freedraw", + "version": 143, + "versionNonce": 1010013780, + "isDeleted": false, + "id": "XK6Yvtn8CG0Xfr9Bjbrkp", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 1044.102783203125, + "y": 755.1343383789062, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 1.548583984375, + "height": 3.1527099609375, + "seed": 921486127, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0, + 0.25811767578125 + ], + [ + 0, + 0.77423095703125 + ], + [ + 0, + 1.03240966796875 + ], + [ + 0, + 1.54864501953125 + ], + [ + 0, + 1.80670166015625 + ], + [ + 0.258056640625, + 1.80670166015625 + ], + [ + 0.51611328125, + 1.80670166015625 + ], + [ + 1.0323486328125, + 1.80670166015625 + ], + [ + 1.548583984375, + 1.80670166015625 + ], + [ + 1.548583984375, + 0.97686767578125 + ], + [ + 1.548583984375, + 0.71881103515625 + ], + [ + 1.548583984375, + 0.20257568359375 + ], + [ + 1.548583984375, + -0.05548095703125 + ], + [ + 1.548583984375, + -0.57171630859375 + ], + [ + 1.548583984375, + -0.82977294921875 + ], + [ + 1.29052734375, + -1.08782958984375 + ], + [ + 1.0323486328125, + -1.34600830078125 + ], + [ + 0.7742919921875, + -1.34600830078125 + ], + [ + 0.51611328125, + -1.34600830078125 + ], + [ + 0.258056640625, + -1.34600830078125 + ], + [ + 0, + -1.34600830078125 + ], + [ + 0, + -1.08782958984375 + ], + [ + 0, + -0.82977294921875 + ], + [ + 0, + -0.31353759765625 + ], + [ + 0, + -0.05548095703125 + ], + [ + 0, + 0.20257568359375 + ], + [ + 0, + 0.46075439453125 + ], + [ + 0, + 0.97686767578125 + ], + [ + 0, + 1.23504638671875 + ], + [ + 0, + 1.49310302734375 + ], + [ + 0.258056640625, + 1.49310302734375 + ], + [ + 0.51611328125, + 1.49310302734375 + ], + [ + 0.7742919921875, + 1.49310302734375 + ], + [ + 1.29052734375, + 1.49310302734375 + ], + [ + 1.548583984375, + 1.23504638671875 + ], + [ + 1.548583984375, + 0.97686767578125 + ], + [ + 1.548583984375, + 0.46075439453125 + ], + [ + 1.548583984375, + 0.20257568359375 + ], + [ + 1.548583984375, + -0.31353759765625 + ], + [ + 1.548583984375, + -0.57171630859375 + ], + [ + 1.29052734375, + -0.57171630859375 + ], + [ + 1.0323486328125, + -0.82977294921875 + ], + [ + 0.7742919921875, + -0.82977294921875 + ], + [ + 0, + 0 + ] + ], + "lastCommittedPoint": null, + "simulatePressure": true, + "pressures": [] + }, + { + "type": "rectangle", + "version": 866, + "versionNonce": 1865029356, + "isDeleted": false, + "id": "tuK_yULMKLM4aneU8P4e0", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 363.45111083984375, + "y": 613.5057983398438, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 151, + "height": 60, + "seed": 523010006, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "type": "text", + "id": "955ApbANtg1nM0T4L_lPF" + }, + { + "id": "mEM1iJl3apSkidvjPFb07", + "type": "arrow" + }, + { + "id": "fg1evGb2SXtz9bEkOlLIc", + "type": "arrow" + } + ], + "updated": 1660940747834, + "link": null, + "locked": false + }, + { + "type": "rectangle", + "version": 973, + "versionNonce": 1944283756, + "isDeleted": false, + "id": "Rjt45nyi1UlloVmswsnId", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 550.8880004882812, + "y": 615.6664428710938, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 91.42199707031247, + "height": 54.99999999999999, + "seed": 1873669962, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "id": "DTYYOEYxneSWWLSsMj-QA", + "type": "arrow" + }, + { + "id": "mEM1iJl3apSkidvjPFb07", + "type": "arrow" + }, + { + "id": "pIXLNNqmJXXOzYe5WlYqa", + "type": "arrow" + } + ], + "updated": 1660940753156, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 136, + "versionNonce": 1053469548, + "isDeleted": false, + "id": "955ApbANtg1nM0T4L_lPF", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 368.45111083984375, + "y": 618.5057983398438, + "strokeColor": "#000000", + "backgroundColor": "#ffffff", + "width": 141, + "height": 50, + "seed": 239771926, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "LinAlg, Arith, \nTensor, SCF", + "baseline": 43, + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "tuK_yULMKLM4aneU8P4e0", + "originalText": "LinAlg, Arith, \nTensor, SCF" + }, + { + "type": "text", + "version": 80, + "versionNonce": 1763250516, + "isDeleted": false, + "id": "4aYMFG5z_f5D-5B73ME6f", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 568.6182861328125, + "y": 629.3970947265625, + "strokeColor": "#000000", + "backgroundColor": "#ffffff", + "width": 57, + "height": 25, + "seed": 492865802, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "TOSA", + "baseline": 18, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "TOSA" + }, + { + "type": "rectangle", + "version": 1119, + "versionNonce": 498753004, + "isDeleted": false, + "id": "JMk3kAkopUnnVgQxEftpW", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 671.0012817382812, + "y": 614.4166870117188, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 91.42199707031247, + "height": 54.99999999999999, + "seed": 2000464714, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "id": "DTYYOEYxneSWWLSsMj-QA", + "type": "arrow" + }, + { + "id": "x63UEL7zv_DhnLWWouZUy", + "type": "arrow" + } + ], + "updated": 1660940133882, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 208, + "versionNonce": 1763197932, + "isDeleted": false, + "id": "ZDNsEoxAd0IRaNlRtJrxj", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 688.7315673828125, + "y": 632.1473388671875, + "strokeColor": "#000000", + "backgroundColor": "#ffffff", + "width": 54, + "height": 25, + "seed": 682493526, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "id": "1COZPP792gFA4J2p8SUL6", + "type": "arrow" + } + ], + "updated": 1660940755597, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "MHLO", + "baseline": 18, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "MHLO" + }, + { + "type": "rectangle", + "version": 1282, + "versionNonce": 108840276, + "isDeleted": false, + "id": "GC28VKCyldd4DkqpJ6x5L", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 795.2947387695312, + "y": 614.5426025390625, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 108.83752441406244, + "height": 54.24505615234373, + "seed": 372294154, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [ + { + "type": "arrow", + "id": "mhl9dSP2l8IK7eFYvftAg" + }, + { + "id": "DTYYOEYxneSWWLSsMj-QA", + "type": "arrow" + }, + { + "id": "x63UEL7zv_DhnLWWouZUy", + "type": "arrow" + }, + { + "id": "CPvTKrc3_ABgC6tI8JY9-", + "type": "arrow" + }, + { + "id": "H9ED9PW7ahwdjILa_abi_", + "type": "arrow" + } + ], + "updated": 1660940758987, + "link": null, + "locked": false + }, + { + "type": "text", + "version": 249, + "versionNonce": 287034452, + "isDeleted": false, + "id": "2WJXGAKYpqP6z56YGe8wd", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 810.1859130859375, + "y": 628.8974609375, + "strokeColor": "#000000", + "backgroundColor": "#ffffff", + "width": 71, + "height": 25, + "seed": 867941270, + "groupIds": [], + "strokeSharpness": "sharp", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "fontSize": 20, + "fontFamily": 1, + "text": "Custom", + "baseline": 18, + "textAlign": "left", + "verticalAlign": "top", + "containerId": null, + "originalText": "Custom" + }, + { + "type": "arrow", + "version": 953, + "versionNonce": 2023019244, + "isDeleted": false, + "id": "mEM1iJl3apSkidvjPFb07", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 586.0586122160754, + "y": 567.7095947265625, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 169.99376240103805, + "height": 44.79620361328125, + "seed": 1848065302, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "gap": 1, + "focus": -0.16397353424924313 + }, + "endBinding": { + "elementId": "tuK_yULMKLM4aneU8P4e0", + "gap": 1, + "focus": -0.7421691796813538 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -169.99376240103805, + 44.79620361328125 + ] + ] + }, + { + "type": "arrow", + "version": 782, + "versionNonce": 638001620, + "isDeleted": false, + "id": "x63UEL7zv_DhnLWWouZUy", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 697.741418897031, + "y": 566.218505859375, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 16.700023534062893, + "height": 36.4176025390625, + "seed": 1875803402, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "0khJ3P1VsWyHIbWAkrpHE", + "focus": -0.5933125414432678, + "gap": 7.718505859375 + }, + "endBinding": { + "elementId": "JMk3kAkopUnnVgQxEftpW", + "focus": 0.26991784070863634, + "gap": 11.78057861328125 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + 16.700023534062893, + 36.4176025390625 + ] + ] + }, + { + "type": "arrow", + "version": 958, + "versionNonce": 2060686700, + "isDeleted": false, + "id": "CPvTKrc3_ABgC6tI8JY9-", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "dotted", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 753.9814145256831, + "y": 569.2363794300367, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 85.3327636527423, + "height": 35.66609550965097, + "seed": 374709386, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940133882, + "link": null, + "locked": false, + "startBinding": { + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "gap": 2.526784703474154, + "focus": -0.13253084174140292 + }, + "endBinding": { + "elementId": "GC28VKCyldd4DkqpJ6x5L", + "gap": 9.64012759937481, + "focus": 0.6491500699876752 + }, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + 85.3327636527423, + 35.66609550965097 + ] + ] + }, + { + "type": "arrow", + "version": 1135, + "versionNonce": 1210220628, + "isDeleted": false, + "id": "fg1evGb2SXtz9bEkOlLIc", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 437.0812317864563, + "y": 674.9987945556641, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 0.3775514635380546, + "height": 21.31964111328125, + "seed": 639756628, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940763344, + "link": null, + "locked": false, + "startBinding": { + "elementId": "tuK_yULMKLM4aneU8P4e0", + "focus": 0.017258250085285455, + "gap": 1.4929962158203125 + }, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -0.3775514635380546, + 21.31964111328125 + ] + ] + }, + { + "type": "arrow", + "version": 1178, + "versionNonce": 99515604, + "isDeleted": false, + "id": "pIXLNNqmJXXOzYe5WlYqa", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 596.2096497552064, + "y": 672.1397247314453, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 0.3775514635380546, + "height": 21.31964111328125, + "seed": 735805012, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940753156, + "link": null, + "locked": false, + "startBinding": { + "elementId": "Rjt45nyi1UlloVmswsnId", + "focus": -0.0026784973230609553, + "gap": 1.4732818603515625 + }, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -0.3775514635380546, + 21.31964111328125 + ] + ] + }, + { + "type": "arrow", + "version": 1176, + "versionNonce": 735824724, + "isDeleted": false, + "id": "1COZPP792gFA4J2p8SUL6", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 714.3037659661442, + "y": 671.2773590087891, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 0.3775514635380546, + "height": 21.31964111328125, + "seed": 1043317332, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940755596, + "link": null, + "locked": false, + "startBinding": { + "elementId": "ZDNsEoxAd0IRaNlRtJrxj", + "focus": 0.03512711488973025, + "gap": 14.130020141601562 + }, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -0.3775514635380546, + 21.31964111328125 + ] + ] + }, + { + "type": "arrow", + "version": 1185, + "versionNonce": 569227116, + "isDeleted": false, + "id": "H9ED9PW7ahwdjILa_abi_", + "fillStyle": "hachure", + "strokeWidth": 1, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "angle": 0, + "x": 846.5879456536442, + "y": 674.2441558837891, + "strokeColor": "#000000", + "backgroundColor": "transparent", + "width": 0.3775514635380546, + "height": 21.31964111328125, + "seed": 107954388, + "groupIds": [], + "strokeSharpness": "round", + "boundElements": [], + "updated": 1660940758987, + "link": null, + "locked": false, + "startBinding": { + "elementId": "GC28VKCyldd4DkqpJ6x5L", + "focus": 0.04642355602818736, + "gap": 5.4564971923828125 + }, + "endBinding": null, + "lastCommittedPoint": null, + "startArrowhead": null, + "endArrowhead": "arrow", + "points": [ + [ + 0, + 0 + ], + [ + -0.3775514635380546, + 21.31964111328125 + ] + ] + } + ], + "appState": { + "gridSize": null, + "viewBackgroundColor": "#ffffff" + }, + "files": {} +} \ No newline at end of file diff --git a/docs/Torch-MLIR_Architecture.png b/docs/Torch-MLIR_Architecture.png new file mode 100644 index 000000000000..1ad041c323aa Binary files /dev/null and b/docs/Torch-MLIR_Architecture.png differ diff --git a/docs/adding_an_e2e_test.md b/docs/adding_an_e2e_test.md new file mode 100644 index 000000000000..27cfe42a3075 --- /dev/null +++ b/docs/adding_an_e2e_test.md @@ -0,0 +1,169 @@ +# Adding an E2E Test + +## Overview + +Adding support for a Torch operator in Torch-MLIR should always be accompanied +by at least one end-to-end test to make sure the implementation of the op +matches the behavior of PyTorch. The tests live in the +`torch-mlir/python/torch_mlir_e2e_test/test_suite/` directory. When adding a new +test, choose a file that best matches the op you're testing, and if there is no +file that best matches add a new file for your op. + +## An E2E Test Deconstructed + +In order to understand how to create an end-to-end test for your op, let's break +down an existing test to see what the different parts mean: + +```python +class IndexTensorModule3dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, (index,)) + + +@register_test_case(module_factory=lambda: IndexTensorModule3dInput()) +def IndexTensorModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) +``` + +### Class Name + + +```python +class IndexTensorModule3dInput(torch.nn.Module): +``` + +The class name should always contain the name of the op that is being +tested. This makes it easy to search for tests for a particular op. Often times +an op will require multiple tests to make sure different paths in the +compilation work as expected. In such cases, it is customary to add extra +information to the class name about what is being tested. In this example, the +op is being tested with a rank-3 tensor as an input. + +### `__init__` Method + +```python + def __init__(self): + super().__init__() +``` + +In most tests, the `__init__` method simply calls the `__init__` method of the +`torch.nn.Module` class. However, sometimes this method can be used to +initialize parameters needed in the `forward` method. An example of such a case +is in the [E2E test for Resnet18](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/test_suite/vision_models.py#L17-L22). + + +### `@export` and `@annotate_args` Decorators + +```python + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) +``` + +The [`@export` decorator](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/torchscript/annotations.py#L30) +lets the importer know which methods in the class will be public after the +`torch.nn.Module` gets imported into the `torch` dialect. All E2E tests should +have this decorator on the `forward` method. + +The [`@annotate_args` decorator](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/torchscript/annotations.py#L53) +is used to give the importer information about the arguments of the method being +decorated, which can then be propagated further into the IR of the body of the +method. The list of annotations **must** have one annotation for each argument +including the `self` argument. The `self` argument always gets the annotation of +`None`, while the other inputs get an annotation with three fields in the +following order: + +1. Shape of input tensor. Use `-1` for dynamic dimensions +2. Dtype of the input tensor +3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h#L54-L67). This + will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the + IR to eventually have value semantics. + +From the structure of the annotations for the arguments other than the `self` +argument it is clear that only tensor arguments are supported. This means that +if an op requires an input other than a tensor, you need to do one of the +following: + +- Create the value in the method body +- Create the value as a class parameter in the `__init__` method +- In the case of certain values such as `int`s and `float`s, you can pass a + zero-rank tensor as an input and use `int(input)` or `float(input)`in the + method body to turn the tensor into a scalar `int` or `float`, respectively. + +### `forward` Method + +```python + def forward(self, x, index): + return torch.ops.aten.index(x, (index,)) +``` + +The forward method should be a simple test of your op. In other words, it will +almost always take the form of simply returning the result of calling your +op. The call to your op should **always** be made using +`torch.ops.aten.{op_name}` to make it very clear which ATen op is being +tested. Some ATen ops have different variants under the same base name, such as +`aten.mean`, which has also a variant `aten.mean.dim`. At the Python level, such +ops are accessed by just their base name, and the right variant is chosen based +on the inputs given. For example, to test `aten.mean.dim` the test should use +`torch.ops.aten.mean(..., dim=...)`. + +### `@register_test_case` Decorator + +```python +@register_test_case(module_factory=lambda: IndexTensorModule3dInput()) +``` + +The `@register_test_case` decorator is used to register the test case +function. The `module_factory` argument should be a function that when called +produces an instance of the test class. This function will be used to create the +first argument passed to the test case function. + +### Test Case Function + +```python +def IndexTensorModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) +``` + +The convention adopted for the name of the test case function is to have the +same name as the test class postfixed by `_basic`. The test function always +takes an instance of the test class as the first argument and a +[`TestUtils`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L167) +object as the second argument. The `TestUtils` has some methods, such as +[`tu.rand`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L182) +and +[`tu.randint`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L185), +that allow the creation of random tensors in a way that makes sure the compiled +module and the golden trace recieve the same tensors as input. Therefore, all +random inputs should be generated through the `TestUtils` object. + + +## Things to Consider When Creating New Tests + +- Do you need negative numbers? If so, + [`tu.rand`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L182) + and + [`tu.randint`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L185) + both allow you to specify a lower and upper bound for random number generation +- Make sure the annotation of the forward method matches the input types and + shapes +- If an op takes optional flag arguments, there should be a test for each flag + that is supported +- If there are tricky edge cases that your op needs to handle, have a test for + each edge case +- Always follow the style and conventions of the file you're adding a test + in. An attempt has been made to keep all E2E test files with consistent style, + but file specific variations do exist + diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 000000000000..ebfb9029d4f0 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,448 @@ +# Torch-MLIR Architecture + +## Introduction + +The Torch-MLIR project provides core infrastructure for bridging the PyTorch +ecosystem and the MLIR ecosystem. For example, Torch-MLIR enables PyTorch models +to be lowered to a few different MLIR dialects. Torch-MLIR does not attempt to +provide a production end-to-end flow for PyTorch programs by itself, but is a +useful component for constructing one. + +## Overview + +Torch-MLIR has two parts, which we call the "frontend" and "backend". These two +halves interface at an abstraction layer that we call the "backend contract", +which is a subset of the `torch` dialect with certain properties appealing for +backends to lower from. + +![Torch-MLIR Architecture](Torch-MLIR_Architecture.png) + +The frontend of Torch-MLIR is concerned with interfacing to PyTorch itself, and +then normalizing the program to the "backend contract". This part involves build +system complexity and exposure to PyTorch APIs to get the program into the MLIR +`torch` dialect. When we interface with TorchScript, we additionally have a +large amount of lowering and simplification to do within MLIR on the `torch` +dialect. + +The "backend" of Torch-MLIR takes IR in the "backend contract" form and lowers +it to various target dialects of interest to the MLIR ecosystem (various +"backends"). In particular, right now we support lowering to: + +- Linalg-on-Tensors (+ `arith`, `tensor`, etc.) +- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) +- [MHLO](https://github.com/tensorflow/mlir-hlo) + +The terms "frontend" and "backend" are highly overloaded in any compiler +project, but frequently in Torch-MLIR this is the meaning that they have. +Sometimes "frontend" can mean something even further up the stack, such as +something in PyTorch itself. When there is ambiguity we will refer to this as +"at the PyTorch level". Similarly, "backend" can sometimes refer to something +sitting below Linalg-on-Tensors, TOSA, or MHLO. + +## The `torch` dialect + +See [include/torch-mlir/Dialect/Torch/IR](https://github.com/llvm/torch-mlir/tree/main/include/torch-mlir/Dialect/Torch/IR) + +The central MLIR abstraction in the Torch-MLIR project is the `torch` dialect. +This dialect supports progressive lowering from the raw imported PyTorch +programs that various PyTorch integration points provide, all the way down to +the backend contract. + +The `torch` dialect must be versatile enough to support being imported by any +program capture mechanism in PyTorch -- this could be TorchDynamo, `torch.fx`, +LazyTensorCore, TorchScript, `torch.jit.trace`, etc. Thankfully, PyTorch is +factored such that we can handle this with one core import path, which is +through the PyTorch +"[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)", +and lives in +[torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir). +The JIT IR is a highly principled IR that faithfully models a Python subset (+ +tensors, the PyTorch op registry, and a few other things). All the other PyTorch +program representations can eventually bottom-out on the JIT IR via some path +provided by PyTorch. The `torch` dialect is almost entirely in 1:1 +correspondence with the JIT IR -- this allows the importer to be extremely small +(the core is +[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp#L1)). + +### Ops + +See [TorchOps.td](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1) + +The ops in the `torch` dialect are almost entirely generated based on the +PyTorch JIT IR operator registry via the script +[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)). +This script queries the registry and generates MLIR +[ODS](https://mlir.llvm.org/docs/OpDefinitions/) in +[GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation). + +There are also some manually implemented ops in the following categories (see +[TorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1)): + +- Ops used for modeling PyTorch IValue object graphs (e.g. `torch.nn_module`, + `torch.class_type`). +- `torch.global_slot` and related ops which are used to model an incremental + lowering of the IValue object graphs. +- Ops that are supported in the JIT interpreter directly, and so don't have a + corresponding op in the registry (e.g. `torch.prim.If`, + `torch.prim.ListConstruct`, `torch.constant.*`) +- `torch.operator` which is used to represent ops from the registry which + haven't been generated by `torch_ods_gen.py`. + +### Types + +See [TorchTypes.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td#L1) + +The `torch` dialect has a complete set of types modeling the PyTorch type +system, which itself is a strongly typed subset of the Python type system (+ +tensors). These types are almost all 1:1 with the corresponding +[PyTorch types](https://github.com/pytorch/pytorch/blob/c54d18dbc7bb2f9fdd83c5de529702e5a02295c3/aten/src/ATen/core/jit_type.h#L1). + +The one exception where a significant amount of design work has been done in +Torch-MLIR is the handling of tensors. Torch-MLIR's tensor types allow +progressive lowering from raw imported IR which maybe be missing shapes, dtypes, +and value semantics, into the backend contract which provides those. Torch-MLIR +has two tensor types `ValueTensorType` (`!torch.vtensor`) and +`NonValueTensorType` (`!torch.tensor`) sharing most of their definition in +[TorchTypes.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td#L58). +The `NonValueTensorType` models a `torch.Tensor` including mutation, aliasing, +etc. while the `ValueTensorType` has value semantics. That is, `ValueTensorType` +is immutable and non-aliased. These types have a common C++ base class +[`BaseTensorType`](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h#L40) +which permits abstracting across them. Both `ValueTensorType` and +`NonValueTensorType` have an optional list of optional sizes and an optional +dtype. + +## The "backend contract" + +See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp#L151) + +The backend contract is a normalized form of the `torch` dialect with a set of +properties that make it easy to lower into various forms such as +Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the +box. The primary guarantees that we provide Torch-MLIR's backends are: + +- All tensors have been converted to value semantics. +- All tensors have at least a known number of dimensions (i.e. rank), and + ideally also have a precise size for each dimension. +- All tensors have a known dtype. +- Certain ops have been decomposed to make them easier to handle (this is + configurable). + +See the extensive comments in the function `satisfiesBackendContract` (and its +callees) in the `LowerToBackendContract` pass for an extended rationale for +these decisions, and a precise definition of the backend contract. + +## The Frontends + +Torch-MLIR provides 2 main frontends: + +- LazyTensorCore - a frontend that is based around intercepting PyTorch + dispatcher calls and creating a graph that is lazily evaluated on demand. +- TorchScript - a frontend based around importing TorchScript functions or + modules. Such modules or functions can be obtained via `torch.jit.script`, + `torch.jit.trace`, or a few other methods in the PyTorch ecosystem. + +Internally these share a lot of the core import code. + +### LazyTensorCore + +Docs: https://github.com/llvm/torch-mlir/blob/main/docs/ltc_backend.md + +LazyTensorCore (LTC) is a program capture method provided by PyTorch that does +device-level tracing. This low-level interception point sits below gradient +calculations, and is thus a good choice for training flows. The downside of LTC +is that it depends on having the whole PyTorch runtime available, so cannot be +used for ahead-of-time compilation or capturing standalone program artifacts. + +From an implementation perspective, the JIT IR that is produced by +LazyTensorCore has already had a number of transformations performed on it, in +particular, after importing from JIT IR to MLIR, the backend contract is +trivially satisfied. So the Torch-MLIR implementation complexity for +LazyTensorCore is restricted to build system and PyTorch integration, rather +than actual MLIR compiler passes. + +### TorchScript (`torch.jit.script`) + +[TorchScript](https://pytorch.org/docs/stable/jit.html) is a strict Python +subset which is modeled faithfully in the JIT IR. Additionally, TorchScript can +represent a full `torch.nn.Module` object graph (hierarchy). This results in a +significant amount of work needing to be done by the frontend to lower it to the +backend contract: + +- The `torch.nn.Module` hierarchy must be lowered to the backend contract, which + does not allow any program state. +- The program must be converted to value semantics (functionalized). +- Shapes and dtypes must be inferred. +- Many "Python-isms" must be simplified away, such as list appends, string + operations, etc. + +Because TorchScript does not naturally give shapes or dtypes, we usually require +the user to annotate a set of expected shapes and dtypes of any arguments. We then propagate those throughout the program. + +`torch.jit.trace` produces JIT IR with shapes and dtypes already, but no value +semantics. And often users want to erase the shapes in the trace to allow +dynamic shapes for the trace. Additionally, the Python-level data structures and +APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we +consider both of those as the same from the perspective of the responsibilities +of the compiler. Both are accessed via the `torch_mlir.compile` Python API. + +### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript + +PyTorch consistently models a subset of Python objects with its concept of +[`IValue`](https://github.com/pytorch/pytorch/blob/1ee9eb52b612f5fb4b63bbda832e44c8902edb64/aten/src/ATen/core/ivalue.h#L171) +(interpreter value). These are used throughout PyTorch to represent Python +values. When one `torch.jit.script`'s a `torch.nn.Module`, the result is +actually an `IValue` that represents the module, with a hierarchy of children +`IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to +represent the bodies of methods on the modules. So in addition to importing the +JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp#L1). + +Most of the IValue modeling can reuse `torch` dialect ops that already exist +otherwise, such as `torch.constant.int` to represent an int in the object graph. +However, special IR constructs are needed for modeling the `torch.nn.Module`'s +themselves. + +An example is: + +```mlir +torch.class_type @c { + torch.attr "b" : !torch.bool + torch.attr "i" : !torch.int + torch.attr "f" : !torch.float + torch.attr "t" : !torch.tensor + torch.method "get_tensor", @get_tensor +} +func.func private @get_tensor(%arg0: !torch.nn.Module<"c">) -> !torch.tensor { + %2 = torch.prim.GetAttr %arg0["t"] : !torch.nn.Module<"c"> -> !torch.tensor + return %2 : !torch.tensor +} + +%true = torch.constant.bool true +%int3 = torch.constant.int 3 +%float4.250000e01 = torch.constant.float 4.250000e+01 +%0 = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +%1 = torch.nn_module { + torch.slot "b", %true : !torch.bool + torch.slot "i", %int3 : !torch.int + torch.slot "f", %float4.250000e01 : !torch.float + torch.slot "t", %0 : !torch.tensor +} : !torch.nn.Module<"c"> +``` + +See the documentation for the ops for more information on the semantics of this +form. + + +### Lowering TorchScript to the backend contract + +The `torchscript-module-to-torch-backend-pipeline` contains the set of simplifications used convert TorchScript to the backend contract. At a high level, it consists of the following transformations: + +1. GlobalizeObjectGraph: This takes the `IValue` object graph and converts it + into a flat list of globals (see `torch.global_slot` and related ops). +1. LowerToBackendContract: This pass iteratively applies a simplification + pipeline until the backend contract is reached. The simplification pipeline consists of: + - Standard canonicalization. + - Shape refinement. See [shape_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md) for detail + - DType refinement. See `RefineTypes`. + - Decomposing ops into more primitive ops. See `DecomposeComplexOps`. + +### Layering of the PyTorch Dependency + +One of the core principles of our Torch-MLIR <-> PyTorch interop is that +anything that links against PyTorch must interact with MLIR through +[the Torch-MLIR C API](https://github.com/llvm/torch-mlir/tree/main/include/torch-mlir-c). +This bypasses a number of very complex dependency and shared library issues. + +Additionally, we maintain the invariant that the core MLIR compiler code (in +`lib/` and `include/`) never has a build dependency on PyTorch itself. This +strict isolation avoids a number of complex dependency issues and ensures that +`torch-mlir-opt` and similar debugging tools always provide the excellent +development and debugging experience that MLIR developers expect. Sometimes, +certain highly stable enums and related logic must be shared with upstream +PyTorch, and for those we copy code from PyTorch into +[TorchUpstream.h](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L13). + +## The Backends + +Torch-MLIR provides 3 built-in backends, which take the backend contract IR and +lower it to the requirements of each backend. The 3 backends are: + +- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, + `tensor`, etc.) +- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) +- [MHLO](https://github.com/tensorflow/mlir-hlo) + +### The Linalg Backend (Linalg-on-Tensors) + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToLinalg + +The Linalg-on-Tensors backend was the first backend that we added, and it is +still the most complete. It fully supports dynamic shapes (known number of +dimensions but arbitrary dynamic dimension sizes). Since linalg was originally +designed as a dialect for transformations, it can be too low-level for certain +consumers. + +### The TOSA Backend + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToTosa + +The TOSA backend was the second backend that we added. It remains preferred by +many users (especially "hardware" or "hardware-adjacent" folks). Some of its characteristics are: +- It is tied to a [spec](https://www.mlplatform.org/tosa/tosa_spec.html) with a + really clear "ISA-like" expository style that resonates with a lot of folks +- The coarse-grained named-op approach is a good match for the many compilers + that are designed that way. +- It has really good support for quantization / integer data types. +- It has clear versioning/stability guarantees on the op semantics. +- It is extremely solid with static shapes (and many of its users only care + about static shapes, so that's fine). + +### The MHLO Backend + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo + +The MHLO backend was the third backend that we added, and it offers a reasonable +blend of the benefits of the other two. +- It is a coarse-grained named-op approach. +- It has a pretty clear spec for most of the ops (with a bit of mental + translation and hoping that MHLO is the same as HLO): + https://www.tensorflow.org/xla/operation_semantics +- It functionally supports dynamic shapes (though not as coherent and consistent + as Linalg-on-Tensors, and the dynamic shape support falls outside the + wonderful HLO docs above). +- It appears to be pretty tied to HLO (which is highly mature) so most of the op + surface area doesn't change too much. +- It has a different set of principles than TOSA which tend to make it more + expressive at the cost of having a larger abstraction gap from hardware. For + example, TOSA limits (for highly considered reasons) the number of dimensions + that certain operators can handle to 1D-4D, when from a purely algebraic + perspective there isn't a good reason to not be more general. Similarly, more + general forms of reduction and scatter also fall into MHLO nicely while + TOSA's principles tend to bias it away from that. + +### Backend Implementation + +All the backends are implemented using the MLIR [Dialect Conversion +infrastructure](https://mlir.llvm.org/docs/DialectConversion/). This involves +converting the `torch` dialect types to other types, so we closely follow the +principes from the "Type Conversions the Not-So-Hard Way" talk +([slides](https://drive.google.com/file/d/1FVbzCXxZzS9LBLuvpPNLWJD-XDkt54ky/view?usp=sharing), +[recording](https://drive.google.com/file/d/1VfVajitgf8ZPnd-HRkJvaJiFLhBsluXN/view?usp=sharing)). +We follow the standard `{include,lib}/Conversion/TorchTo*` convention used in +MLIR for conversion passes. + +For type conversion, we provide +[BackendTypeConversion.cpp](https://github.com/llvm/torch-mlir/blob/57681f794764a34c34e2be7f07f7dfbcafa683c1/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp#L1) +and +[BackendTypeConversionPasses.cpp](https://github.com/llvm/torch-mlir/blob/57681f794764a34c34e2be7f07f7dfbcafa683c1/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp#L1) +which provide a default conversion from `torch` dialect types to the builtin +`tensor` type and scalar integer/float types. These are not the right choice for +all backends, but can be copied and adapted by backends. These files closely +follow the "Type Conversions the Not-So-Hard Way" talk. + + +## Testing + +See +[development.md](https://github.com/llvm/torch-mlir/blob/9c8b96272057f4f8210de5842b6952228434cfa2/development.md#testing) +for more details on running tests. + +Torch-MLIR has two types of tests: + +1. End-to-end execution tests. These compile and run a program and check the + result against the expected output from execution on native Torch. These use + a homegrown testing framework (see + [framework.py](https://github.com/llvm/torch-mlir/blob/7d4a0d0e2b65c7ce8de19993f3b10ad5344fe32b/python/torch_mlir_e2e_test/torchscript/framework.py#L6)) + and the test suite lives at `python/torch_mlir_e2e_test/test_suite`. + +2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. + For example, these might involve using `torch-mlir-opt` to run a pass and + check the output with `FileCheck`. `lit` is flexible enough to unit test + various Python pieces, importers, and LTC this way as well. + +### Why so much end-to-end testing? + +Torch-MLIR places a heavy emphasis on end-to-end testing for the following reasons: + +Reason 1: Even if a compiler pass produces the output IR that the author +expected, that output IR may not correctly implement the semantics of the op. +This is especially true for complex, often-poorly-specified deep learning +operators that Torch-MLIR is mainly concerned with. It is critical to run these +against the source of truth to ensure correct implementation. + +Reason 2: There are many patterns in Torch-MLIR's backends that really just +expand one op into other ops without any real logic. When we started Torch-MLIR, +we were very religious about always having `.mlir` unit tests even for these +"macro expansion" patterns, but we found that these tests 1) Never caught a bug +2) Interfered with refactoring / caused spurious extra work (changing op syntax, +etc.). There is not much point to having a bunch of tests like this, which are +basically just rewriting the builder calls in a different syntax: + +``` +// MyPass.cpp +b.create(...) +b.create(...) + +// test.mlir +// CHECK: foo +// CHECK: bar +``` + +Such a test is simply checking that the implementation of an op is the way it +is. There is no way to change the implementation while having the test pass. So +the test is fully redundant with the implementation. + +Because of this, many Torch-MLIR patches adding support for new ops have no +`.mlir` unit tests, and only include end-to-end test(s). We generally make sure +that our end-to-end tests are as targeted as possible. As a result, when +debugging end-to-end test failures, the resulting reproducers (which our test +framework automaticaly produces for failures) are usually already fully reduced +test cases. + +### Do's and Don'ts for unit vs end-to-end testing. + +DO use an end-to-end test if you are implementing a new op or extending the +support for an existing op. + +DO use a unit test if your lowering for an op has multiple cases / logic. This +also helps future maintainers of the lowering to see in one place all the +different edge cases of the op that you had to handle. (these can be easily +reduced out of all the end-to-end tests you added). + +DON'T use a unit test if your lowering pattern could be described as a trivial +"macro expansion" of one op into another op or set of ops. That is, if you feel +like your unit test is just rewriting `b.create<...>(...)` into `CHECK: ...` +then it is probably not a useful unit test. + +DON'T add a unit test for trivial changes to RefineTypes. + +With the exceptions above, all changes should include appropriate unit tests, as +is standard in the LLVM and MLIR community. This includes full coverage of all +canonicalizations, pretty printing, passes, errors, and diagnostics. + +### The RefBackend (Reference Backend) + +In order to run end-to-end tests, Torch-MLIR needs an end-to-end flow. +Thankfully, upstream MLIR has just enough pieces to precariously put one +together that is enough for testing. + +The RefBackend consists of a few minor +[C++ passes](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/include/torch-mlir/RefBackend/Passes.td#L1) +filling in some corners missing upstream and +[Python glue logic](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/python/torch_mlir_e2e_test/linalg_on_tensors_bakends/refbackend.py#L1) +to pull together upstream functionality into a working system. + +The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the +ops and lowers them to loops. Note that TOSA and MHLO support lowering to +Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend. + +The RefBackend is absolutely not suitable for any production use case. It leaks +memory, doesn't support any error handling, performs no optimizations, and +probably a bunch of other horrible things. We are patiently awaiting for the +upstream MLIR community to produce a viable end-to-end flow with better +characteristics. + +### Presentations and Talks + +* 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) +* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U&t=2374s) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 000000000000..46bdb9c5609b --- /dev/null +++ b/docs/development.md @@ -0,0 +1,441 @@ +# Checkout and build from source + +## Check out the code + +```shell +git clone https://github.com/llvm/torch-mlir +cd torch-mlir +git submodule update --init +``` + +## Setup your Python VirtualEnvironment and Dependencies + +Also, ensure that you have the appropriate `python-dev` package installed +to access the Python development libraries / headers. + +```shell +python -m venv mlir_venv +source mlir_venv/bin/activate +# Some older pip installs may not be able to handle the recent PyTorch deps +python -m pip install --upgrade pip +# Install latest PyTorch nightlies and build requirements. +python -m pip install -r requirements.txt +``` + +## Docker Builds + +We have preliminary support for building with Docker images. This is a new +flow and we would like your feedback on how it works for you and please +feel free to file any feedback or issues. + +Install [Docker Engine](https://docs.docker.com/engine/install/ubuntu/). You don't need Docker Desktop. + +You have three types of builds selectable with the Environment Variable `TM_PACKAGES`:`torch-mlir` the +Release build, `out-of-tree` where torch-mlir is build with a pre-built MLIR and `in-tree` where torch-mlir +is built as part of the LLVM project along with MLIR. + +We mount a ccache and pip cache inside the docker container to speed up iterative builds. Iterative +builds should be as fast as running without docker. + +### In-Tree builds + +Build MLIR and Torch-MLIR together as part of the LLVM repo. + +```shell +TM_PACKAGES="in-tree" ./build_tools/python_deploy/build_linux_packages.sh +``` + +### Out-of-Tree builds + +Build LLVM/MLIR first and then build Torch-MLIR referencing that build +```shell +TM_PACKAGES="out-of-tree" ./build_tools/python_deploy/build_linux_packages.sh +``` + +### Release builds + +Build in a manylinux Docker image so we can upload artifacts to PyPI. + +```shell +TM_PACKAGES="torch-mlir" ./build_tools/python_deploy/build_linux_packages.sh +``` + +### Mimicing CI+Release builds + +If you wanted to build all the CIs locally + +```shell +TM_PACKAGES="out-of-tree in-tree" ./build_tools/python_deploy/build_linux_packages.sh +``` + +If you wanted to build all the CIs and the Release builds (just with Python 3.10 since most other Python builds are redundant) + +```shell +TM_PACKAGES="torch-mlir out-of-tree in-tree" TM_PYTHON_VERSIONS="cp310-cp310" ./build_tools/python_deploy/build_linux_packages.sh +``` + +Note: The Release docker still runs as root so it may generate some files owned by root:root. We hope to move it to run as a user in the future. + +### Cleaning up + +Docker builds tend to leave a wide variety of files around. Luckily most are owned by the user but there are still some that need to be removed +as superuser. + +```shell +rm -rf build build_oot llvm-build docker_venv externals/pytorch/build .ccache +``` + +## Building your own Docker image + +If you would like to build your own docker image (usually not necessary). You can run: + +```shell +cd ./build_tools/docker +docker build -t your-name/torch-mlir-ci --no-cache . +``` + +### Other configurable environmental variables + +The following additional environmental variables can be used to customie your docker build: + +* Custom Release Docker image: + Defaults to `stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest` +```shell + TM_RELEASE_DOCKER_IMAGE="stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest" +``` +* Custom CI Docker image: + Defaults to `powderluv/torch-mlir-ci:latest`. This assumes an Ubuntu LTS like image. You can build your own with `./build_tools/docker/Dockerfile` +```shell + TM_CI_DOCKER_IMAGE="powderluv/torch-mlir-ci:latest" +``` + +* Custom Python Versions for Release builds: + Version of Python to use in Release builds. Ignored in CIs. Defaults to `cp38-cp38 cp39-cp39 cp310-cp310` +```shell + TM_PYTHON_VERSIONS="cp38-cp38 cp39-cp39 cp310-cp310" +``` + +* Location to store Release build wheels +```shell + TM_OUTPUT_DIR="./build_tools/python_deploy/wheelhouse" +``` + +* What "packages" to build: + Defaults to torch-mlir. Options are `torch-mlir out-of-tree in-tree` +```shell + TM_PACKAGES="torch-mlir out-of-tree in-tree" +``` +* Use pre-built Pytorch: + Defaults to using pre-built Pytorch. Setting it to `OFF` builds from source +```shell + TM_USE_PYTORCH_BINARY="OFF" +``` +* Skip running tests + Skip running tests if you want quick build only iteration. Default set to `OFF` +```shell + TM_SKIP_TESTS="OFF" +``` + + +## Build Python Packages + +We have preliminary support for building Python packages. This can be done +with the following commands: + +``` +python -m pip install --upgrade pip +python -m pip install -r requirements.txt +CMAKE_GENERATOR=Ninja python setup.py bdist_wheel +``` + +## CMake Build + +Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is the most straightforward, as it will build LLVM dependencies as well. + +### Building torch-mlir in-tree + +The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. + +```shell +cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \ + -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + externals/llvm-project/llvm +``` +The following additional quality of life flags can be used to reduce build time: +* Enabling ccache: +```shell + -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache +``` +* Enabling LLD (links in seconds compared to minutes) +```shell + -DCMAKE_EXE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_MODULE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_SHARED_LINKER_FLAGS_INIT="-fuse-ld=lld" +# Use --ld-path= instead of -fuse-ld=lld for clang > 13 +``` +* Enabling libtorch binary cache +By default we download the latest version of libtorch everytime you build so we are always on the latest version. Set `-DLIBTORCH_CACHE=ON` to +not download the latest version everytime. If libtorch gets out of date and you test against a newer PyTorch you may notice failures. +```shell + -DLIBTORCH_CACHE=ON +``` +* Enabling building libtorch as part of your build +By default we download the latest version of libtorch. We have an experimental path to build libtorch (and PyTorch wheels) from source. +```shell + -DLIBTORCH_SRC_BUILD=ON # Build Libtorch from source + -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) +``` + +### Building against a pre-built LLVM + +If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, you can also build the project *out-of-tree* using the following command as template: +```shell +cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DMLIR_DIR="$LLVM_INSTALL_DIR/lib/cmake/mlir/" \ + -DLLVM_DIR="$LLVM_INSTALL_DIR/lib/cmake/llvm/" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + . +``` +The same QoL CMake flags can be used to enable ccache and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`. + +Be aware that the installed version of LLVM needs in general to match the committed version in `externals/llvm-project`. Using a different version may or may not work. + + +### Build commands + +After either cmake run (in-tree/out-of-tree), use one of the following commands to build the project: +```shell +# Build just torch-mlir (not all of LLVM) +cmake --build build --target tools/torch-mlir/all + +# Run unit tests. +cmake --build build --target check-torch-mlir + +# Run Python regression tests. +cmake --build build --target check-torch-mlir-python + +# Build everything (including LLVM if in-tree) +cmake --build build +``` + +## Setup Python Environment to export the built Python packages +```shell +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples +``` + +## Testing MLIR output in various dialects + +To test the compiler's output to the different MLIR dialects, you can use the example `examples/torchscript_resnet18_all_output_types.py`. + +Make sure you have activated the virtualenv and set the `PYTHONPATH` above: +```shell +source mlir_venv/bin/activate +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples +python examples/torchscript_resnet18_all_output_types.py +``` + +This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. + +The main functionality is on `torch_mlir.compile()`'s `output_type`. + +Ex: +```python +module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +``` + +Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`. + +## Jupyter + +Jupyter notebook: +```shell +python -m ipykernel install --user --name=torch-mlir --env PYTHONPATH "$PYTHONPATH" +# Open in jupyter, and then navigate to +# `examples/resnet_inference.ipynb` and use the `torch-mlir` kernel to run. +jupyter notebook +``` + +[Example IR](https://gist.github.com/silvasean/e74780f8a8a449339aac05c51e8b0caa) for a simple 1 layer MLP to show the compilation steps from TorchScript. + + +## Interactive Use + +The `build_tools/write_env_file.sh` script will output a `.env` +file in the workspace folder with the correct PYTHONPATH set. This allows +tools like VSCode to work by default for debugging. This file can also be +manually `source`'d in a shell. + + +## Bazel Build + +> **NOTE** Our Bazel build follows LLVM's Bazel build policy: only the +> subcommunity interested in Bazel is responsible for fixing it. Average +> Torch-MLIR developers should not be notified of any Bazel build issues and are +> not responsible for fixing any breakages (though any help is, of course, +> welcome). For more info, see LLVM's +> [Peripheral Support Tier](https://llvm.org/docs/SupportPolicy.html#peripheral-tier) +> definition. + +Torch-MLIR can also be built using Bazel (apart from the official CMake build) for users that depend on Bazel in their workflows. To build `torch-mlir-opt` using Bazel, follow these steps: + +1. Launch an interactive docker container with the required deps installed: +```shell +./utils/bazel/docker/run_docker.sh +``` +2. Build torch-mlir using bazel (from container): +```shell +./utils/bazel/docker/run_bazel_build.sh +``` +3. Find the built binary at `utils/bazel/bazel-bin/external/torch-mlir/torch-mlir-opt`. + + +# Testing + +Torch-MLIR has two types of tests: + +1. End-to-end execution tests. These compile and run a program and check the + result against the expected output from execution on native Torch. These use + a homegrown testing framework (see + `python/torch_mlir_e2e_test/torchscript/framework.py`) and the test suite + lives at `python/torch_mlir_e2e_test/test_suite/__init__.py`. + +2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. + For example, these might involve using `torch-mlir-opt` to run a pass and + check the output with `FileCheck`. + + +## Running execution (end-to-end) tests: + +```shell +# Run all tests on the reference backend +./tools/e2e_test.sh +# Run tests that match the regex `Conv2d`, with verbose errors. +./tools/e2e_test.sh --filter Conv2d --verbose +# Run tests on the TOSA backend. +./tools/e2e_test.sh --config tosa +``` + +## Running unit tests. + +To run all of the unit tests, run: + +``` +ninja check-torch-mlir-all +``` + +This can be broken down into + +``` +ninja check-torch-mlir check-torch-mlir-dialects check-torch-mlir-python +``` + +To run more fine-grained tests, you can do, for `check-torch-mlir`: + +``` +cd $TORCH_MLIR_BUILD_DIR/tools/torch-mlir/test +$TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonicalize +``` + +See [the `lit` documentation](https://llvm.org/docs/CommandGuide/lit.html) for details on the available lit args. + +For example, if you wanted to test just `test/Dialect/Torch/canonicalize.mlir`, +then you might do + +``` +cd $TORCH_MLIR_BUILD_DIR/tools/torch-mlir/test +$TORCH_MLIR_BUILD_DIR/bin/llvm-lit $TORCH_MLIR_SRC_ROOT/test -v --filter=canonicalize.mlir +``` + +Most of the unit tests use the [`FileCheck` tool](https://llvm.org/docs/CommandGuide/FileCheck.html) to verify expected outputs. + +# PyTorch source builds and custom PyTorch versions + +Torch-MLIR by default builds with the latest nightly PyTorch version. This can be toggled to build from latest PyTorch source with +``` +-DTORCH_MLIR_USE_INSTALLED_PYTORCH=OFF +-DTORCH_MLIR_SRC_PYTORCH_REPO=vivekkhandelwal1/pytorch # Optional. Github path. Defaults to pytorch/pytorch +-DTORCH_MLIR_SRC_PYTORCH_BRANCH=master # Optional. Defaults to PyTorch's main branch +``` + +# Updating the LLVM and MLIR-HLO submodules + +Torch-MLIR depends on `llvm-project` (which contains, among other things, +upstream MLIR) and `mlir-hlo`, both of which are submodules in the `externals/` +directory. We aim to update these at least weekly to bring in the latest +features and spread out over time the effort of updating our code for MLIR API +breakages. + +## Which LLVM commit should I pick? + +Since downstream projects may want to build Torch-MLIR (and thus LLVM and +MLIR-HLO) in various configurations (Release versus Debug builds; on Linux, +Windows, or macOS; possibly with Clang, LLD, and LLDB enabled), it is crucial to +pick LLVM commits that pass tests for all combinations of these configurations. + +So every week, we track the so-called _green_ commit (i.e. the LLVM commit which +works with all of the above configurations) in Issue +https://github.com/llvm/torch-mlir/issues/1178. In addition to increasing our +confidence that the resulting update will not break downstream projects, basing +our submodule updates on these green commits also helps us stay in sync with +LLVM updates in other projects like ONNX-MLIR and MLIR-HLO. + +## What is the update process? + +1. **Lookup green commit hashes**: From the Github issue + https://github.com/llvm/torch-mlir/issues/1178, find the LLVM and MLIR-HLO + green commits for the week when Torch-MLIR is being updated. +2. **Update the `llvm-project` submodule**: In the `externals/llvm-project` + directory, run `git fetch` followed by `git checkout ` + (where `` is the green commit hash for the LLVM project + from Step 1). +3. **Update the `mlir-hlo` submodule**: In the `externals/mlir-hlo` directory, + run `git fetch` followed by `git checkout ` (where + `` is the green commit hash for the MLIR-HLO project + from Step 1). +4. **Rebuild and test Torch-MLIR**: See the section "CMake Build" above for + instructions, fixing any issues that arise. This might involve fixing various + API breakages introduced upstream (they are likely unrelated to what you are + working on). If these fixes are too complex, please file a work-in-progress + PR explaining the issues you are running into asking for help so that someone + from the community can help. +5. **Update Shape Library**: Run `build_tools/update_shape_lib.sh`. This is + sometimes needed because upstream changes can affect canonicalization and + other minor details of the IR in the shape library. See + [docs/shape_lib.md](docs/shape_lib.md) for more details on the shape library. + + +Here are some examples of PRs updating the LLVM and MLIR-HLO submodules: + +- https://github.com/llvm/torch-mlir/pull/1180 +- https://github.com/llvm/torch-mlir/pull/1229 + +# Enabling Address Sanitizer (ASan) + +To enable ASAN, pass `-DLLVM_USE_SANITIZER=Address` to CMake. This should "just +work" with all C++ tools like `torch-mlir-opt`. When running a Python script +such as through `./tools/e2e_test.sh`, you will need to do: + +``` +LD_PRELOAD="$(clang -print-file-name=libclang_rt.asan-x86_64.so)" ./tools/e2e_test.sh -s +# See instructions here for how to get the libasan path for GCC: +# https://stackoverflow.com/questions/48833176/get-location-of-libasan-from-gcc-clang +``` + +TODO: Add ASan docs for LTC. + +# Other docs + +- GitHub wiki: https://github.com/llvm/torch-mlir/wiki +- Of particular interest in the [How to add end-to-end support for new Torch ops](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation) doc. diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md new file mode 100644 index 000000000000..58c0e8de2d83 --- /dev/null +++ b/docs/ltc_backend.md @@ -0,0 +1,137 @@ +# Torch-MLIR Lazy Tensor Core Backend + +## Table of Contents +- [Introduction](#introduction) +- [Examples](#examples) +- [Code Structure](#code-structure) +- [Architecture](#architecture) +- [Implementing a custom backend](#implementing-a-custom-backend) +- [Future Expansion](#future-expansion) + +## Introduction +[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR. +After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation. + +LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. +Implementations based on this abstract class will be able to specify their own compile and execution workflows. +Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend). + +## Examples +View examples [here](ltc_examples.md). + +## Code Structure + +### Autogen Build Tools ([`build_tools`](../build_tools)) + +- `autogen_ltc_backend.{py,yaml}` + - The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td), + excluding those explicitly blacklisted in the YAML file + +### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated)) +Generated files are created in this directory, which is ignored by version control. + +- `LazyIr.h` + - Definitions of `torch::lazy:TorchMlirNode` subclasses for each supported autogen op +- `LazyNativeFunctions.{cpp,h}` + - Native function definitions for each supported op (handles `at::Tensor -> at::Tensor` data flow and creation of `torch::lazy:TorchMlirNode`) +- `LazyNonNativeIr.h` + - Non-native `torch::lazy:TorchMlirNode` subclasses +- `RegisterLazy.cpp` + - Registers PyTorch kernels under the `lazy` dispatch key for all supported ops, which map to our native functions +- `shape_inference.{cpp,h}` + - Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions + +### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend)) + +- `backend_impl.{cpp,h}` + - Base LTC backend to setup Torch-MLIR lowering context +- `dynamic_ir.{cpp,h}` + - Manually implemented "dynamic" nodes +- `ir_builder.h` + - Torch-MLIR implementation of `torch::lazy::IrBuilder` +- `mlir_lowering_context.h` + - Handles conversion from `torch::lazy::Node` to MLIR via JIT and Torch-MLIR infrastructure +- `mlir_native_functions.cpp` + - Manually implemented native functions +- `mlir_node.{cpp,h}` + - Torch-MLIR implementation of `torch::lazy::Node` +- `mlir_node_lowering.{cpp,h}` + - Lower a `torch::lazy::Node` to JIT graph in preparation for MLIR generation +- `shape_inference.cpp` + - Implementation of select shape inference functions (most functions are [implemented upstream](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/shape_inference.cpp)) + +### Reference Backend ([`python/torch_mlir/csrc/reference_lazy_backend`](../python/torch_mlir/csrc/reference_lazy_backend)) + +- `backend_impl.{cpp,h}` + - Reference Torch-MLIR LTC backend implementation, which simply stores the MLIR as a string and executes computation on CPU +- `reference_lazy_backend_pybind.cpp` + - pybind for reference Torch-MLIR LTC backend + +### Examples ([`examples`](../examples)) + +- `ltc_backend_bert.py` + - Example HuggingFace BERT model traced by LTC to MLIR +- `ltc_backend_mnist.py` + - Example MNIST model traced by LTC to MLIR + +## Architecture + +![LTC Diagram](ltc_images/ltc_architecture.png) + +### Tracing LTC graph + +The journey begins with a tensor in PyTorch on the `lazy` device, which may undergo a number of operations during its lifetime. +```python +>>> lazy_backend._initialize() +>>> x = torch.tensor(..., device='lazy') +>>> y = torch.tanh(x) +... +``` +The call to `torch.tanh` triggers a chain of events. PyTorch checks the dispatch table under the `lazy` key and finds the kernel for `tanh` +previously registered in `RegisterLazy.cpp`. + +Next, `LazyNativeFunctions::tanh` from `LazyNativeFunctions.cpp` is called, which triggers the creation of a `Tanh` node, which is a subclass of `TorchMlirNode` and `torch::lazy::Node`, defined in `LazyIr.h`. +These nodes are then tracked internally by LTC as the computation graph is traced out. + +![Tracing Tensors](ltc_images/tracing_tensors.png) + +### Syncing Tensors + +At some point, the tensors will be synced in order to execute the computation -- either explicitly via `mark_step`, or implicitly through some operation that requires the contents of the tensors (e.g. printing to console). + +```python +>>> torch._lazy.mark_step() +``` + +This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and +creates an instance of `TorchMlirLoweringContext`. Here, the `TorchMlirNode`s are lowered to JIT via `mlir_node_lowering.cpp` and inserted into a `jit::Graph`. + +Next, `TorchMlirLoweringContext::Build` is executed and the final `jit::Graph` is sent to `torch_mlir::importJitFunctionAsFuncOp` to generate MLIR using the existing infrastructure from Torch-MLIR. +At this point, a `TorchMlirComputation` is created containing the final `mlir::FuncOp`. + +![Syncing Tensors](ltc_images/syncing_tensors.png) + +### Final Compilation and Execution + +The `TorchMlirComputation` is sent to the vendor specific implementation of `TorchMlirBackendImpl::Compile` to be handed off to the vendor's compilation stack (if applicable). + +Finally, the compiled computation is sent to `TorchMlirBackendImpl::ExecuteComputation` to be executed on the vendor device, which produces some results to be send back to PyTorch. + +![Vendor Execution](ltc_images/vendor_execution.png) + +## Implementing a custom backend + +A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/). +All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device. + +A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself. +Most of the code in the reference implementation should be reusable, excluding some debug related function (e.g. `get_latest_computation`). + +## Future Expansion + +There are a number of areas for future improvement: +- Generate source information in `jit::Graph` so it can be embedded in the MLIR +- Currently the reference backend implementation executes via the `jit::Graph` instead of the MLIR since we currently lack lowerings for many ops, which would make it difficult to run models such as HF BERT + - In the future, we should change the implementation to lower the MLIR to linalg and execute on a reference backend +- As new models get tested, we will inevitably run into errors related to unimplemented shape inference functions. +This problem is simply solved by implementing the missing function, or adding a structured kernel to PyTorch. diff --git a/docs/ltc_examples.md b/docs/ltc_examples.md new file mode 100644 index 000000000000..b9306edce492 --- /dev/null +++ b/docs/ltc_examples.md @@ -0,0 +1,54 @@ +# Torch-MLIR Lazy Tensor Core Backend Examples + +Refer to the main documentation [here](ltc_backend.md). + +## Example Usage +```python +import torch +import torch._lazy +import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend + +# Register the example LTC backend. +lazy_backend._initialize() + +device = 'lazy' + +# Create some tensors and perform operations. +inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) +outputs = torch.tanh(inputs) + +# Mark end of training/evaluation iteration and lower traced graph. +torch._lazy.mark_step() +print('Results:', outputs) + +# Optionally dump MLIR graph generated from LTC trace. +computation = lazy_backend.get_latest_computation() +if computation: + print(computation.debug_string()) +``` + +``` +Received 1 computation instances at Compile! +Received 1 arguments, and returned 2 results during ExecuteCompile! + +Results: tensor([[0.7616, 0.9640, 0.9951, 0.9993, 0.9999]], device='lazy:0') + +JIT Graph: +graph(%p0 : Float(1, 5)): + %1 : Float(1, 5) = aten::tanh(%p0) + return (%p0, %1) + +MLIR: +func.func @graph(%arg0: !torch.vtensor<[1,5],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32>) { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[1,5],f32> -> !torch.vtensor<[1,5],f32> + return %arg0, %0 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32> +} + +Input/Output Alias Mapping: +Output: 0 -> Input param: 0 + +In Mark Step: true +``` + +## Example Models +There are also examples of a [HuggingFace BERT](../examples/ltc_backend_bert.py) and [MNIST](../examples/ltc_backend_mnist.py) model running on the example LTC backend. diff --git a/docs/ltc_images/ltc_architecture.png b/docs/ltc_images/ltc_architecture.png new file mode 100644 index 000000000000..a00c85c766bf Binary files /dev/null and b/docs/ltc_images/ltc_architecture.png differ diff --git a/docs/ltc_images/syncing_tensors.png b/docs/ltc_images/syncing_tensors.png new file mode 100644 index 000000000000..1905ec2d3aa4 Binary files /dev/null and b/docs/ltc_images/syncing_tensors.png differ diff --git a/docs/ltc_images/tracing_tensors.png b/docs/ltc_images/tracing_tensors.png new file mode 100644 index 000000000000..152a82d909d8 Binary files /dev/null and b/docs/ltc_images/tracing_tensors.png differ diff --git a/docs/ltc_images/vendor_execution.png b/docs/ltc_images/vendor_execution.png new file mode 100644 index 000000000000..509cb74aaaf6 Binary files /dev/null and b/docs/ltc_images/vendor_execution.png differ diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/main.py similarity index 76% rename from e2e_testing/torchscript/main.py rename to e2e_testing/main.py index 4e23afb05f2d..623a7739d0a9 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/main.py @@ -7,28 +7,35 @@ import re import sys -from torch_mlir_e2e_test.torchscript.framework import run_tests -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_from +from torch_mlir_e2e_test.framework import run_tests +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.serialization import deserialize_all_tests_from # Available test configs. -from torch_mlir_e2e_test.torchscript.configs import ( - LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig +from torch_mlir_e2e_test.configs import ( + LazyTensorCoreTestConfig, + LinalgOnTensorsBackendTestConfig, + MhloBackendTestConfig, + NativeTorchTestConfig, + TorchScriptTestConfig, + TosaBackendTestConfig, + EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET +from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -36,10 +43,12 @@ def _get_argparse(): help=f''' Meaning of options: "refbackend": run through torch-mlir's RefBackend. +"mhlo": run through torch-mlir's default MHLO backend. "tosa": run through torch-mlir's default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "eager_mode": run through torch-mlir's eager mode frontend, using RefBackend for execution. +"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. ''') parser.add_argument('-f', '--filter', default='.*', help=''' Regular expression specifying which tests to include in this run. @@ -53,7 +62,7 @@ def _get_argparse(): Right now, these are additional tests which require heavy Python dependencies to generate (or cannot even be generated with the version of PyTorch used by torch-mlir). -See `build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh` +See `build_tools/e2e_heavydep_tests/generate_serialized_tests.sh` for more information on building these artifacts. ''') parser.add_argument('-s', '--sequential', @@ -77,6 +86,9 @@ def main(): if args.config == 'tosa': config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET + if args.config == 'mhlo': + config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) + xfail_set = all_test_unique_names - MHLO_PASS_SET elif args.config == 'native_torch': config = NativeTorchTestConfig() xfail_set = {} @@ -86,6 +98,9 @@ def main(): elif args.config == 'eager_mode': config = EagerModeTestConfig() xfail_set = EAGER_MODE_XFAIL_SET + elif args.config == 'lazy_tensor_core': + config = LazyTensorCoreTestConfig() + xfail_set = LTC_XFAIL_SET # Find the selected tests, and emit a diagnostic if none are found. tests = [ @@ -102,7 +117,7 @@ def main(): sys.exit(1) # Run the tests. - results = run_tests(tests, config, args.sequential) + results = run_tests(tests, config, args.sequential, args.verbose) # Report the test results. failed = report_results(results, xfail_set, args.verbose) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py deleted file mode 100644 index c7cd67f521fd..000000000000 --- a/e2e_testing/torchscript/xfail_sets.py +++ /dev/null @@ -1,179 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -# This file describes the sets of tests expected to fail for each config. -# This information is deliberately kept in a side table, rather than -# in-situ on the test, as a deliberate layering decision: tests should -# have unique keys to identify them and enable side tables of various kinds -# (this includes down into lower parts of the stack, where a side table -# might be used to keep more elaborate sets of testing configurations). - -from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS - -REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS - -EAGER_MODE_XFAIL_SET = { - # RefBackend fails - "TableBatchEmbeddingModule_basic", - "QuantizedMLP_basic" -} - -# Write the TOSA set as a "passing" set as it is very early in development -# and very few tests work yet. -TOSA_PASS_SET = { - "ElementwiseUnaryModule_basic", - "ElementwiseBinaryModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseReluModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ElementwiseMinimumModule_basic", - "ElementwiseMinimumIntModule_basic", - "ElementwiseMaximumModule_basic", - "ElementwiseMaximumIntModule_basic", - "TanhBackward_basic", - "ElementwiseAddModule_basic", - "ReturnThreeTensorFloat32_basic", - "AddCMulModule_basic", - "AddCDivModule_basic", - "SqueezeModule_broadcast", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "BoolTensorHandleSignless_basic", - "ElementwiseRsqrtModule_basic", - "SqueezeModule_static", - "SqueezeModule_noUnitDim", - "SqueezeModule_allUnitDim", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeDimModule_unitDim", - "ReturnTwoTensorF32I64_basic", - "ElementwisePowModule_basic", - "BmmModule_basic", - "MmDagModule_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_3d", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGtFloatTensorModule_basic", - "ElementwiseGtIntTensorModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatTensorModule_basic", - "ElementwiseLtIntTensorModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatTensorModule_basic", - "ElementwiseEqIntTensorModule_basic", - "ElementwiseMulScalarModule_int", - "ElementwiseMulScalarModule_float", - "ElementwiseMulTensorIntModule_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseCeilModule_basic", - "ElementwiseReciprocalModule_basic", - "TypePromotionAlphaWiderModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "FlattenStaticModule_basic", - "FlattenRank0Module_basic", - "ElementwiseFlattenBroadcastModule_basic", - "SquareModule_basic", - "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", - "NativeLayerNormModule4D_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ElementwiseLog2Module_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dFloatModule_basic", - "Threshold2dFloatModule_basic", - "Threshold3dFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseMulScalarModule_basic", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "SiluModule_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNoChangeStaticModule_basic", - "UnsafeViewExpandModule_basic", - "ReshapeCollapseModule_basic", - "ElementwiseGeluModule_basic", - "GeluBackwardModule_basic", - "ElementwiseNeIntScalarModule_basic", - "ElementwiseNeFloatTensorModule_basic", - "Convolution2DStaticModule_basic", - "ElementwiseNegModule_basic", - "TestMultipleTensorReturn_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "BaddbmmDynamicModule_basic", - "BaddbmmStaticModule_basic", - "BaddbmmWithAlphaBetaModule_basic", - "BaddbmmWithAlphaModule_basic", - "BaddbmmWithBetaModule_basic", - "BaddbmmBroadcast1DInputModule_basic", - "BaddbmmBroadcast2DInputModule_basic", - "NumpyTRank1Module_basic", - "NumpyTRank2Module_basic", - "NumpyTRankNStaticModule_basic", - "NumpyTRankNDynamicModule_basic", - "EmbeddingModuleI32Static_basic", - "TModuleRank2_basic", - "TransposeIntModule_basic", - "TransposeIntNegDimsModule_basic", - "ArgmaxModule_keepDim", - "ArgmaxModule_with_dim", - "_LogSoftmaxModuleStable_basic", -} diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py new file mode 100644 index 000000000000..1cfa3652e4cd --- /dev/null +++ b/e2e_testing/xfail_sets.py @@ -0,0 +1,483 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# This file describes the sets of tests expected to fail for each config. +# This information is deliberately kept in a side table, rather than +# in-situ on the test, as a deliberate layering decision: tests should +# have unique keys to identify them and enable side tables of various kinds +# (this includes down into lower parts of the stack, where a side table +# might be used to keep more elaborate sets of testing configurations). + +from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS + +REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS + +EAGER_MODE_XFAIL_SET = { + # RefBackend fails + "TableBatchEmbeddingModule_basic", + "QuantizedMLP_basic", + "Matmul_vecmat" +} + +MHLO_PASS_SET = { + "FlattenStaticModule_basic", + "FlattenRank0Module_basic", + "TensorsConcatNegativeDimModule_basic", + "NumelModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "SqueezeModule_allUnitDim", + "SqueezeDimModule_unitDim", + "MeanModule_basic", + "MeanDynamicSizesModule_basic", + "MeanDimEmptyDimModule_basic", + "NumToTensorFloatModule_basic", + "AtenToDeviceModule_basic", + "AvgPool2dStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Convolution2DStaticModule_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ReturnThreeTensorFloat32_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "SqueezeModule_static", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceStartEqEndModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceWholeTensorModule_basic", + "ReturnTwoTensorF32I64_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_2d", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dWithIndicesStaticModule_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MaxPool2dStaticModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "UnsafeViewExpandModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceSumDimIntListFloatModule_basic", + "ReduceSumDimIntListIntModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "RepeatModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeExpandModule_basic", + "RollModule_basic", + "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "BaddbmmStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NumToTensorIntModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNStaticModule_basic", + "NumpyTRankNDynamicModule_basic", + "TModuleRank2_basic", + "TensorLiteralModule_basic", + "TensorsConcatModule_basic", + "TensorOpaqueLiteralModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "OnesModuleCPUDevice_basic", + "Permute0RankModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", +} + +# Write the TOSA set as a "passing" set as it is very early in development +# and very few tests work yet. +TOSA_PASS_SET = { + "ElementwiseUnaryModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaximumIntModule_basic", + "TanhBackward_basic", + "ElementwiseAddModule_basic", + "ReturnThreeTensorFloat32_basic", + "AddCMulModule_basic", + "AddCDivModule_basic", + "SqueezeModule_broadcast", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", + "ElementwiseRsqrtModule_basic", + "SqueezeModule_static", + "SqueezeModule_noUnitDim", + "SqueezeModule_allUnitDim", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SqueezeDimModule_unitDim", + "ReturnTwoTensorF32I64_basic", + "ElementwisePowModule_basic", + "BmmModule_basic", + "MmDagModule_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_3d", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "TypePromotionAlphaWiderModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "FlattenStaticModule_basic", + "FlattenRank0Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", + "MaxPool2dStaticModule_basic", + "ResNet18StaticModule_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ElementwiseLog2Module_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dFloatModule_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseMulScalarModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleCPUDevice_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "SiluModule_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseNegModule_basic", + "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "BaddbmmDynamicModule_basic", + "BaddbmmStaticModule_basic", + "BaddbmmWithAlphaBetaModule_basic", + "BaddbmmWithAlphaModule_basic", + "BaddbmmWithBetaModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNStaticModule_basic", + "NumpyTRankNDynamicModule_basic", + "EmbeddingModuleI32Static_basic", + "TModuleRank2_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "ArgmaxModule_keepDim", + "ArgmaxModule_with_dim", + "_LogSoftmaxModuleStable_basic", +} + +LTC_XFAIL_SET = { + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AddIntModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliTensorModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "DivFloatModule_basic", + "DropoutTrainModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideModule_basic", + "ElementwiseWhereScalarModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereSelfModule_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EqIntModule_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2DStatic_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardTanhIntModule_basic", + "HardTanhModule_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "IndexTensorSelectDimModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "Matmul_dot", + "Matmul_matvec", + "MulIntModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "QuantizedMLP_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RollModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "SliceEndSleStartModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceStartEqEndModule_basic", + "SqrtIntModule_basic", + "StdBiasedModule_basic", + "StdDimBiasedModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimNoneDimModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TableBatchEmbeddingModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatModule_basic", + "UniformModule_basic", + "UniformStaticModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "AtenEmbeddingBagSumExample_basic", + "Aten_EmbeddingBagExample_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderScalarModule_Float_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRemainderScalarModule_Bool_basic", +} diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py new file mode 100644 index 000000000000..048c74233c4c --- /dev/null +++ b/examples/ltc_backend_bert.py @@ -0,0 +1,160 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +Runs a training of the Bert model using the Lazy Tensor Core with the +example Torch MLIR backend. + +Most of the code in this example was copied from the wonderful tutorial + https://huggingface.co/transformers/training.html#fine-tuning-in-native-pytorch + +Based on LTC code samples by ramiro050 + https://github.com/ramiro050/lazy-tensor-samples +""" + +import argparse +import sys +from typing import List + +import torch +import torch._C +import torch._lazy +from datasets import load_dataset +from datasets.dataset_dict import DatasetDict +from torch.utils.data import DataLoader +from transformers import BertForSequenceClassification, \ + BertConfig, BertTokenizer, AdamW, get_scheduler + + +def tokenize_dataset(dataset: DatasetDict) -> DatasetDict: + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + def tokenize_function(examples): + return tokenizer(examples["text"], padding="max_length", + truncation=True) + + tokenized_datasets = dataset.map(tokenize_function, batched=True) + tokenized_datasets = tokenized_datasets.remove_columns(['text']) + tokenized_datasets = tokenized_datasets.rename_column('label', 'labels') + tokenized_datasets.set_format('torch') + + return tokenized_datasets + + +def train(model: BertForSequenceClassification, + num_epochs: int, + num_training_steps: int, + train_dataloader: DataLoader, + device: torch.device) -> List[torch.Tensor]: + optimizer = AdamW(model.parameters(), lr=5e-5) + lr_scheduler = get_scheduler('linear', optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training_steps) + + model.train() + losses = [] + for _ in range(num_epochs): + for batch in train_dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + loss.backward() + losses.append(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if 'lazy' in str(model.device): + print("Calling Mark Step") + torch._lazy.mark_step() + + return losses + + +def main(device='lazy', full_size=False): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. + + :param device: name of device to load tensors to + :param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant + """ + torch.manual_seed(0) + + tokenized_datasets = tokenize_dataset(load_dataset('imdb')) + small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ + .select(range(2)) + + train_dataloader = DataLoader(small_train_dataset, shuffle=True, + batch_size=8) + if full_size: + model = BertForSequenceClassification.from_pretrained('bert-base-cased', + num_labels=2) + else: + configuration = BertConfig( + vocab_size=28996, + hidden_size=32, + num_hidden_layers=1, + num_attention_heads=2, + intermediate_size=32, + hidden_act='gelu', + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + layer_norm_eps=1.0e-05, + ) + model = BertForSequenceClassification(configuration) + + model.to(device) + + num_epochs = 3 + num_training_steps = num_epochs * len(train_dataloader) + losses = train(model, num_epochs, num_training_steps, train_dataloader, device) + + # Get debug information from LTC + if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules: + computation = lazy_backend.get_latest_computation() + if computation: + print(computation.debug_string()) + + print('Loss: ', losses) + + return model, losses + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--device", + type=str.upper, + choices=["CPU", "TS", "MLIR_EXAMPLE"], + default="MLIR_EXAMPLE", + help="The device type", + ) + parser.add_argument( + "-f", + "--full_size", + action='store_true', + default=False, + help="Use full sized BERT model instead of one with smaller parameterization", + ) + args = parser.parse_args() + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + import torch._lazy.ts_backend + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend + + lazy_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device, args.full_size) diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py new file mode 100644 index 000000000000..bdc9edd096f6 --- /dev/null +++ b/examples/ltc_backend_mnist.py @@ -0,0 +1,105 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +Example use of the example Torch MLIR LTC backend. +""" +import argparse +import sys + +import torch +import torch._lazy +import torch.nn.functional as F + + +def main(device='lazy'): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. + + :param device: name of device to load tensors to + """ + torch.manual_seed(0) + + inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) + assert inputs.device.type == device + + targets = torch.tensor([3], dtype=torch.int64, device=device) + assert targets.device.type == device + + print("Initialized data") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(5, 10) + + def forward(self, x): + out = self.fc1(x) + out = F.relu(out) + return out + + model = Model().to(device) + model.train() + assert all(p.device.type == device for p in model.parameters()) + + print("Initialized model") + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + num_epochs = 3 + losses = [] + for _ in range(num_epochs): + optimizer.zero_grad() + + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + losses.append(loss) + + optimizer.step() + + if device == "lazy": + print("Calling Mark Step") + torch._lazy.mark_step() + + # Get debug information from LTC + if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules: + computation = lazy_backend.get_latest_computation() + if computation: + print(computation.debug_string()) + + print(losses) + + return model, losses + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--device", + type=str.upper, + choices=["CPU", "TS", "MLIR_EXAMPLE"], + default="MLIR_EXAMPLE", + help="The device type", + ) + args = parser.parse_args() + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + import torch._lazy.ts_backend + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend + + lazy_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device) diff --git a/examples/torchscript_mhlo_backend_resnet.py b/examples/torchscript_mhlo_backend_resnet.py new file mode 100644 index 000000000000..bb481f6c3366 --- /dev/null +++ b/examples/torchscript_mhlo_backend_resnet.py @@ -0,0 +1,14 @@ +import torch +import torchvision.models as models +import torch_mlir + +model = models.resnet18(pretrained=True) +model.eval() +data = torch.randn(2,3,200,200) +out_mhlo_mlir_path = "./resnet18_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_mhlo_backend_tinybert.py b/examples/torchscript_mhlo_backend_tinybert.py new file mode 100644 index 000000000000..62827361e84f --- /dev/null +++ b/examples/torchscript_mhlo_backend_tinybert.py @@ -0,0 +1,24 @@ +import torch +import torch_mlir + +from transformers import BertForMaskedLM + +# Wrap the bert model to avoid multiple returns problem +class BertTinyWrapper(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False) + + def forward(self, data): + return self.bert(data)[0] + +model = BertTinyWrapper() +model.eval() +data = torch.randint(30522, (2, 128)) +out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}") diff --git a/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt b/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt index 42243e198520..2de2d4eba67e 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt +++ b/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt @@ -30,7 +30,6 @@ else() # Configure CMake and tablegen. list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules) list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake) - set(MLIR_TABLEGEN_EXE mlir-tblgen) include(TableGen) include(AddLLVM) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 0fa16f1df49b..f10c59e5bdb5 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -516,7 +516,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { for (OpOperand *opOperand : op.getInputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) - ? tensorCastOp.source() + ? tensorCastOp.getSource() : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index b03ea88dcacc..3d68625b768d 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -40,9 +41,14 @@ static LogicalResult lowerToLoopsImpl(OpBuilder &builder, return scalarLoopOp.generateScalarImplementation(builder, loc, ivs); } LogicalResult status = success(); + Value offset = getValueOrCreateConstantIndexOp(builder, loc, + loopRanges[loopDepth].offset); + Value size = + getValueOrCreateConstantIndexOp(builder, loc, loopRanges[loopDepth].size); + Value stride = getValueOrCreateConstantIndexOp(builder, loc, + loopRanges[loopDepth].stride); builder.create( - loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size, - loopRanges[loopDepth].stride, ValueRange{}, + loc, offset, size, stride, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { ivs.push_back(iv); status = diff --git a/externals/llvm-project b/externals/llvm-project index 889c6f399676..00d648bdb5a8 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 889c6f3996769a991a24da957f597e7500d158e7 +Subproject commit 00d648bdb5a8b71785269b4851b651c883de2cd9 diff --git a/externals/mlir-hlo b/externals/mlir-hlo new file mode 160000 index 000000000000..305a2f252296 --- /dev/null +++ b/externals/mlir-hlo @@ -0,0 +1 @@ +Subproject commit 305a2f25229660ea789bf70ed8e7336227f6228a diff --git a/include/torch-mlir-c/Dialects.h b/include/torch-mlir-c/Dialects.h index 20ded3520910..99156c17009c 100644 --- a/include/torch-mlir-c/Dialects.h +++ b/include/torch-mlir-c/Dialects.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_C_DIALECTS_H #define TORCHMLIR_C_DIALECTS_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 8cff8da860a9..f459960ee542 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t); /// Gets a !torch.tensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). - /// `optionalDtype` is allowed to be null, meaning that no dtype @@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t); /// Gets a !torch.vtensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). /// - `optionalDtype` is allowed to be null, meaning that no dtype diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index 607ed2f968b0..9ee80b304b66 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,5 +1,9 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +if(TORCH_MLIR_ENABLE_MHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) +else() + mlir_tablegen(Passes.h.inc -gen-pass-decls) +endif() add_public_tablegen_target(TorchMLIRConversionPassIncGen) add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 02a376d192d4..28138edcb478 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -16,9 +16,9 @@ include "mlir/Pass/PassBase.td" // Torch conversions //===----------------------------------------------------------------------===// -def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> { +def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { let summary = "Convert recognized Torch ops to Std ops"; - let constructor = "mlir::torch::createConvertTorchToStdPass()"; + let constructor = "mlir::torch::createConvertTorchToArithPass()"; } def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { @@ -125,4 +125,25 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; } +#ifdef TORCH_MLIR_ENABLE_MHLO +def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { + let summary = "Convert Torch ops to MHLO ops"; + let description = [{ + Convert Torch ops to mhlo ops. + }]; + let constructor = "mlir::torch::createConvertTorchToMhloPass()"; + + // Specify any options. + let options = [ + Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", + "Enable static shape conversion">, + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", + "Enable truncate index from i64 to i32(unsafely)">, + ]; +} +#endif + #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h b/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h similarity index 98% rename from include/torch-mlir/Conversion/TorchToStd/TorchToStd.h rename to include/torch-mlir/Conversion/TorchToArith/TorchToArith.h index 3285bd5f0ae8..ab708557b21c 100644 --- a/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h +++ b/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h @@ -16,7 +16,7 @@ namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToStdPass(); +std::unique_ptr> createConvertTorchToArithPass(); } } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h new file mode 100644 index 000000000000..8e2f5fc8630d --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h @@ -0,0 +1,25 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToMhloPass(); +std::unique_ptr> +createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index); +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4a1f98768933..e36b40e0f55c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -158,6 +158,54 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ }]; } + +def Torch_AtenRelu6Op : Torch_Op<"aten.relu6", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::relu6 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenRelu6_Op : Torch_Op<"aten.relu6_", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::relu6_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + + def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [ AllowsTypeRefinement, HasValueSemantics, @@ -565,12 +613,12 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ }]; } -def Torch_AtenCosOp : Torch_Op<"aten.cos", [ +def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cos : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::expm1 : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -579,20 +627,20 @@ def Torch_AtenCosOp : Torch_Op<"aten.cos", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCosOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExpm1Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCosOp::print(OpAsmPrinter &printer) { + void AtenExpm1Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ +def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::cos_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::expm1_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -601,21 +649,21 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCos_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenExpm1_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCos_Op::print(OpAsmPrinter &printer) { + void AtenExpm1_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNegOp : Torch_Op<"aten.neg", [ +def Torch_AtenCosOp : Torch_Op<"aten.cos", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::neg : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cos : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -624,20 +672,20 @@ def Torch_AtenNegOp : Torch_Op<"aten.neg", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNegOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenCosOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNegOp::print(OpAsmPrinter &printer) { + void AtenCosOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ +def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::neg_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::cos_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -646,66 +694,68 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNeg_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenCos_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNeg_Op::print(OpAsmPrinter &printer) { + void AtenCos_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ +def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFloorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenAtan2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ +def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFloor_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenAtan2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ +def Torch_AtenNegOp : Torch_Op<"aten.neg", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::neg : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -714,20 +764,20 @@ def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenNegOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCeilOp::print(OpAsmPrinter &printer) { + void AtenNegOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ +def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::neg_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -736,21 +786,21 @@ def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenNeg_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCeil_Op::print(OpAsmPrinter &printer) { + void AtenNeg_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ +def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_not : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -759,20 +809,20 @@ def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseNotOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseNotOp::print(OpAsmPrinter &printer) { + void AtenFloorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ +def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_not_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -781,107 +831,101 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseNot_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseNot_Op::print(OpAsmPrinter &printer) { + void AtenFloor_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ +def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSubTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenCeilOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ +def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSub_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenCeil_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ +def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_not : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenBitwiseNotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMulTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenBitwiseNotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ +def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_not_ : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenBitwiseNot_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMul_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenBitwiseNot_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } @@ -980,55 +1024,6 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [ }]; } -def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -1266,191 +1261,46 @@ def Torch_AtenNe_TensorOp : Torch_Op<"aten.ne_.Tensor", [ }]; } -def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ +def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenDivScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenDivScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ +def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenSubScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMulScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::div.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDivScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenDivScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::div_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDiv_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenDiv_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } void AtenDiv_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); @@ -1836,6 +1686,55 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ }]; } +def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, @@ -2024,12 +1923,12 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ }]; } -def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ +def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2038,20 +1937,20 @@ def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenRsqrtOp::print(OpAsmPrinter &printer) { + void AtenSqrtOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ +def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2060,21 +1959,21 @@ def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenSqrt_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenRsqrt_Op::print(OpAsmPrinter &printer) { + void AtenSqrt_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ +def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2083,20 +1982,20 @@ def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAbsOp::print(OpAsmPrinter &printer) { + void AtenLog1pOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ +def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2105,21 +2004,21 @@ def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenAbs_Op::print(OpAsmPrinter &printer) { + void AtenLog1p_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ +def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2128,20 +2027,20 @@ def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenReciprocalOp::print(OpAsmPrinter &printer) { + void AtenRsqrtOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ +def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2150,79 +2049,169 @@ def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenReciprocal_Op::print(OpAsmPrinter &printer) { + void AtenRsqrt_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ +def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAbsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ +def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAbs_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ +def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$threshold, - AnyTorchScalarType:$value + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenThresholdOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenReciprocalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenReciprocal_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$threshold, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenThresholdOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } void AtenThresholdOp::print(OpAsmPrinter &printer) { @@ -2391,165 +2380,163 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [ }]; } -def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ +def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchOptionalStringType:$rounding_mode ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAddTensorOp::print(OpAsmPrinter &printer) { + void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; let hasCanonicalizer = 1; } -def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [ +def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchOptionalStringType:$rounding_mode ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAdd_TensorOp::print(OpAsmPrinter &printer) { + void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ +def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$tensor1, - AnyTorchTensorType:$tensor2, - AnyTorchScalarType:$value + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddcmulOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddcmulOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenMulTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$tensor1, - AnyTorchTensorType:$tensor2, - AnyTorchScalarType:$value + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddcdivOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddcdivOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenMul_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ +def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaximumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenMaximumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAddTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMinimumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenMinimumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAdd_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ +def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, + AnyTorchTensorType:$other, AnyTorchScalarType:$alpha ); let results = (outs @@ -2557,446 +2544,769 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsubScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenRsubScalarOp::print(OpAsmPrinter &printer) { + void AtenSubTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`"; + let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_StringType:$approximate + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGeluOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGeluOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSub_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ +def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$exponent + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPowTensorScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenPowTensorScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAddScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, - AnyTorchScalarType:$threshold + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenThresholdBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenThresholdBackwardOp::print(OpAsmPrinter &printer) { + void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [ +def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFloorDivideOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSubScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ +def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$value + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFill_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ - AllowsTypeRefinement +def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_FloatType:$from, - Torch_FloatType:$to, - AnyTorchOptionalGeneratorType:$generator + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUniform_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenUniform_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenMulScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRandLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenRandLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ +def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bernoulli : (Tensor, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalGeneratorType:$generator + AnyTorchTensorType:$tensor1, + AnyTorchTensorType:$tensor2, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBernoulliOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAddcmulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenBernoulliOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAddcmulOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ - AllowsTypeRefinement +def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_FloatType:$p, - AnyTorchOptionalGeneratorType:$generator + AnyTorchTensorType:$tensor1, + AnyTorchTensorType:$tensor2, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBernoulli_FloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenAddcdivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenBernoulli_FloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenAddcdivOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [ - AllowsTypeRefinement +def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`"; + let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$p, - AnyTorchOptionalGeneratorType:$generator + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBernoulli_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMaximumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBernoulli_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMaximumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ +def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::triu : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$diagonal + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTriuOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMinimumOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenTriuOp::print(OpAsmPrinter &printer) { + void AtenMinimumOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::triu_ : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$diagonal + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTriu_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenRsubScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenTriu_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenRsubScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ +def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfOptionalTensorType:$indices, - AnyTorchTensorType:$values, - Torch_BoolType:$accumulate + Torch_StringType:$approximate ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexPutOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenGeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenIndexPutOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenGeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfOptionalTensorType:$indices, - AnyTorchTensorType:$values, - Torch_BoolType:$accumulate + AnyTorchScalarType:$exponent ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexPut_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenPowTensorScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenIndexPut_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenPowTensorScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenIndexPutHackedTwinOp : Torch_Op<"aten.index_put.hacked_twin", [ +def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins + AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, - AnyTorchListOfTensorType:$indices, - AnyTorchTensorType:$values, - Torch_BoolType:$accumulate + AnyTorchScalarType:$threshold ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenThresholdBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenIndexPutHackedTwinOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenThresholdBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::index_put_.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTensorType:$indices, - AnyTorchTensorType:$values, - Torch_BoolType:$accumulate + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexPut_HackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenIndexPut_HackedTwinOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenFloorDivideOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ +def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)`"; + let summary = "Generated op for `aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias + AnyTorchTensorType:$self, + AnyTorchScalarType:$beta, + AnyTorchScalarType:$threshold ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLinearOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenSoftplusOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenLinearOp::print(OpAsmPrinter &printer) { + void AtenSoftplusOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenMmOp : Torch_Op<"aten.mm", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mm : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$mat2 + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMmOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFill_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$from, + Torch_FloatType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniform_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUniform_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandLikeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRandLikeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bernoulli : (Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBernoulliOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBernoulliOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$p, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBernoulli_FloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenBernoulli_FloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bernoulli_.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$p, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBernoulli_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenBernoulli_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::triu : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$diagonal + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTriuOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenTriuOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenTriu_Op : Torch_Op<"aten.triu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::triu_ : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$diagonal + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTriu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenTriu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfOptionalTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexPutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenIndexPutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfOptionalTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexPut_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenIndexPut_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenIndexPutHackedTwinOp : Torch_Op<"aten.index_put.hacked_twin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenIndexPutHackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::index_put_.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexPut_HackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenIndexPut_HackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinearOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLinearOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMmOp : Torch_Op<"aten.mm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mm : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMmOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } void AtenMmOp::print(OpAsmPrinter &printer) { @@ -3005,452 +3315,1182 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [ }]; } -def Torch_AtenAddmmOp : Torch_Op<"aten.addmm", [ +def Torch_AtenAddmmOp : Torch_Op<"aten.addmm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat1, + AnyTorchTensorType:$mat2, + AnyTorchScalarType:$beta, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAddmmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenAddmmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMatmulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMatmulOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose2dInputOp : Torch_Op<"aten.conv_transpose2d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose2dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose2dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + AnyTorchListOfTorchIntType:$dilation + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTranspose3dInputOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenConvTranspose3dInputOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvolutionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenConvolutionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + Torch_BoolType:$benchmark, + Torch_BoolType:$deterministic, + Torch_BoolType:$cudnn_enabled, + Torch_BoolType:$allow_tf32 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ConvolutionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 13, 1); + } + void Aten_ConvolutionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 13, 1); + } + }]; +} + +def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + Torch_BoolType:$benchmark, + Torch_BoolType:$deterministic, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 12, 1); + } + void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 12, 1); + } + }]; +} + +def Torch_AtenRollOp : Torch_Op<"aten.roll", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::roll : (Tensor, int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$shifts, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRollOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenRollOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::flip : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFlipOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFlipOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$training, + Torch_FloatType:$momentum, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeBatchNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 3); + } + void AtenNativeBatchNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 3); + } + }]; +} + +def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$training, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBatchNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenBatchNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchListOfTorchIntType:$normalized_shape, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enable + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLayerNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenLayerNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchListOfTorchIntType:$normalized_shape, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeLayerNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 3); + } + void AtenNativeLayerNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 3); + } + }]; +} + +def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenMaxPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + +def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode, + AnyTorchTensorType:$indices + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool2dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenMaxPool2dWithIndicesBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSoftmaxIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLogSoftmaxIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_log_softmax : (Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$half_to_float + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_LogSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_LogSoftmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$mat1, - AnyTorchTensorType:$mat2, - AnyTorchScalarType:$beta, - AnyTorchScalarType:$alpha + Torch_IntType:$k, + Torch_IntType:$dim, + Torch_BoolType:$largest, + Torch_BoolType:$sorted + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTopkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenTopkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::transpose.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim0, + Torch_IntType:$dim1 ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddmmOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenTransposeIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAddmmOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenTransposeIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ +def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPermuteOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bmm : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBmmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBmmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cumsum : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCumsumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenCumsumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMatmulOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFloorDivideScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMatmulOp::print(OpAsmPrinter &printer) { + void AtenFloorDivideScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ +def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_IntType:$groups + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenConv2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenLogsumexpOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenConv2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenLogsumexpOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ +def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$transposed, - AnyTorchListOfTorchIntType:$output_padding, - Torch_IntType:$groups + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenConvolutionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 9, 1); + ParseResult AtenMeanDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenConvolutionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 9, 1); + void AtenMeanDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [ +def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$transposed, - AnyTorchListOfTorchIntType:$output_padding, - Torch_IntType:$groups + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 9, 1); + ParseResult Aten__And__TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 9, 1); + void Aten__And__TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ +def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$transposed, - AnyTorchListOfTorchIntType:$output_padding, - Torch_IntType:$groups, - Torch_BoolType:$benchmark, - Torch_BoolType:$deterministic, - Torch_BoolType:$cudnn_enabled, - Torch_BoolType:$allow_tf32 + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$half_to_float ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_ConvolutionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 13, 1); + ParseResult Aten_SoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void Aten_ConvolutionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 13, 1); + void Aten_SoftmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ +def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::flip : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::mean : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dims + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFlipOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMeanOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFlipOp::print(OpAsmPrinter &printer) { + void AtenMeanOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [ +def Torch_AtenStdOp : Torch_Op<"aten.std", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`"; + let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchOptionalTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchOptionalTensorType:$running_mean, - AnyTorchOptionalTensorType:$running_var, - Torch_BoolType:$training, - Torch_FloatType:$momentum, - Torch_FloatType:$eps + AnyTorchTensorType:$self, + Torch_BoolType:$unbiased ); let results = (outs - AnyTorchTensorType:$result0, - AnyTorchTensorType:$result1, - AnyTorchTensorType:$result2 + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNativeBatchNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 8, 3); + ParseResult AtenStdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNativeBatchNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 8, 3); + void AtenStdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ +def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchOptionalTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchOptionalTensorType:$running_mean, - AnyTorchOptionalTensorType:$running_var, - Torch_BoolType:$training, - Torch_FloatType:$momentum, - Torch_FloatType:$eps, - Torch_BoolType:$cudnn_enabled + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$unbiased, + Torch_BoolType:$keepdim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBatchNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 9, 1); + ParseResult AtenStdDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenBatchNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 9, 1); + void AtenStdDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ +def Torch_AtenVarOp : Torch_Op<"aten.var", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::var : (Tensor, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchListOfTorchIntType:$normalized_shape, - AnyTorchOptionalTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - Torch_FloatType:$eps, - Torch_BoolType:$cudnn_enable + AnyTorchTensorType:$self, + Torch_BoolType:$unbiased ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLayerNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenVarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLayerNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenVarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ +def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`"; + let summary = "Generated op for `aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchListOfTorchIntType:$normalized_shape, - AnyTorchOptionalTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - Torch_FloatType:$eps + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$unbiased, + Torch_BoolType:$keepdim ); let results = (outs - AnyTorchTensorType:$result0, - AnyTorchTensorType:$result1, - AnyTorchTensorType:$result2 + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNativeLayerNormOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 3); + ParseResult AtenVarDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenNativeLayerNormOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 3); + void AtenVarDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ +def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$kernel_size, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$ceil_mode + AnyTorchOptionalListOfTorchIntType:$dim, + AnyTorchOptionalIntType:$correction, + Torch_BoolType:$keepdim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenMaxPool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenVarCorrectionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ +def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$kernel_size, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$ceil_mode + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index ); let results = (outs - AnyTorchTensorType:$result0, - AnyTorchTensorType:$result1 + AnyTorchTensorType:$output, + AnyTorchTensorType:$total_weight ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 2); + ParseResult AtenNllLossForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); } - void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 2); + void AtenNllLossForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); } }]; } -def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [ +def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$kernel_size, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$ceil_mode, - AnyTorchTensorType:$indices + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + AnyTorchTensorType:$total_weight ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxPool2dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 8, 1); + ParseResult AtenNllLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); } - void AtenMaxPool2dWithIndicesBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 8, 1); + void AtenNllLossBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); } }]; } -def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ +def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$kernel_size, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - Torch_BoolType:$ceil_mode, - Torch_BoolType:$count_include_pad, - AnyTorchOptionalIntType:$divisor_override + AnyTorchOptionalTensorType:$weights, + Torch_IntType:$minlength ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenBincountOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAvgPool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenBincountOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ +def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, + AnyTorchScalarType:$ord, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, AnyTorchOptionalIntType:$dtype ); let results = (outs @@ -3458,1330 +4498,1344 @@ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenSoftmaxIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [ +def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$dtype + AnyTorchListOfTorchIntType:$pad, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLogSoftmaxIntOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenConstantPadNdOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenLogSoftmaxIntOp::print(OpAsmPrinter &printer) { + void AtenConstantPadNdOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [ +def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_log_softmax : (Tensor, int, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_BoolType:$half_to_float + AnyTorchListOfTorchIntType:$pad, + Torch_StringType:$mode, + AnyTorchOptionalFloatType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_LogSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void Aten_LogSoftmaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenPadOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ +def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$output_size + Torch_IntType:$dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenSqueezeDimOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { + void AtenSqueezeDimOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } -def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ +def Torch_AtenSqueezeOp : Torch_Op<"aten.squeeze", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::squeeze : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$k, - Torch_IntType:$dim, - Torch_BoolType:$largest, - Torch_BoolType:$sorted + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$values, - AnyTorchTensorType:$indices + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTopkOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + ParseResult AtenSqueezeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenTopkOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + void AtenSqueezeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } -def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ +def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::transpose.int : (Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim0, - Torch_IntType:$dim1 + Torch_IntType:$start_dim, + Torch_IntType:$end_dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTransposeIntOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFlattenUsingIntsOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenTransposeIntOp::print(OpAsmPrinter &printer) { + void AtenFlattenUsingIntsOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ +def Torch_AtenDimOp : Torch_Op<"aten.dim", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::dim : (Tensor) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dims + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenPermuteOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } -def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ +def Torch_AtenSizeOp : Torch_Op<"aten.size", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bmm : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::size : (Tensor) -> (int[])`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$mat2 + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + AnyTorchListOfTorchIntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBmmOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBmmOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ +def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cumsum : (Tensor, int, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::Bool.Tensor : (Tensor) -> (bool)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$dtype + AnyTorchTensorType:$a ); let results = (outs - AnyTorchTensorType:$result + Torch_BoolType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCumsumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenBoolTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCumsumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenBoolTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ +def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + Torch_BoolType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloorDivideScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenFloorDivideScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [ +def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, - Torch_BoolType:$keepdim + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLogsumexpOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenOnesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenLogsumexpOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenOnesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [ +def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, - Torch_BoolType:$keepdim, - AnyTorchOptionalIntType:$dtype + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMeanDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenNewOnesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenMeanDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenNewOnesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ +def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten__And__TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenZerosOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void Aten__And__TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenZerosOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ +def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenNewZerosOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenSqrtOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenNewZerosOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ +def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::tensor : (t[], int?, Device?, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_BoolType:$half_to_float + AnyTorchListType:$data, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_SoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void Aten_SoftmaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ +def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mean : (Tensor, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype + Torch_BoolType:$t, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMeanOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenTensorBoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenMeanOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenTensorBoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenStdOp : Torch_Op<"aten.std", [ +def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_BoolType:$unbiased + Torch_IntType:$t, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenStdOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenTensorIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenStdOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenTensorIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenVarOp : Torch_Op<"aten.var", [ +def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::var : (Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::_shape_as_tensor : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_BoolType:$unbiased + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenVarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult Aten_ShapeAsTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenVarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void Aten_ShapeAsTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [ +def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::all : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, - Torch_BoolType:$unbiased, - Torch_BoolType:$keepdim + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenVarDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenAllOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenVarDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenAllOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ +def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::all.bool : (bool[]) -> (bool)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index + AnyTorchListOfTorchBoolType:$self ); let results = (outs - AnyTorchTensorType:$output, - AnyTorchTensorType:$total_weight + Torch_BoolType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLossForwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + ParseResult AtenAllBoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNllLossForwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + void AtenAllBoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ +def Torch_AtenAnyOp : Torch_Op<"aten.any", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::any : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - AnyTorchOptionalTensorType:$weight, - Torch_IntType:$reduction, - Torch_IntType:$ignore_index, - AnyTorchTensorType:$total_weight + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNllLossBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenAnyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenNllLossBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenAnyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ +def Torch_AtenAnyDimOp : Torch_Op<"aten.any.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bincount : (Tensor, Tensor?, int) -> (Tensor)`"; + let summary = "Generated op for `aten::any.dim : (Tensor, int, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$weights, - Torch_IntType:$minlength + Torch_IntType:$dim, + Torch_BoolType:$keepdim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBincountOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAnyDimOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenBincountOp::print(OpAsmPrinter &printer) { + void AtenAnyDimOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ +def Torch_AtenArangeOp : Torch_Op<"aten.arange", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$ord, - AnyTorchOptionalListOfTorchIntType:$dim, - Torch_BoolType:$keepdim, - AnyTorchOptionalIntType:$dtype + AnyTorchScalarType:$end, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenArangeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) { + void AtenArangeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ +def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$pad, - AnyTorchScalarType:$value + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenConstantPadNdOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenArangeStartOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenConstantPadNdOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenArangeStartOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenPadOp : Torch_Op<"aten.pad", [ +def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`"; + let summary = "Generated op for `aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$pad, - Torch_StringType:$mode, - AnyTorchOptionalFloatType:$value + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + AnyTorchScalarType:$step, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenArangeStartStepOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenArangeStartStepOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + AnyTorchScalarType:$step, + AnyTorchTensorType:$out + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenArangeStartOutOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenPadOp::print(OpAsmPrinter &printer) { + void AtenArangeStartOutOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [ +def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqueezeDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenArgmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenSqueezeDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenArgmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; - let hasFolder = 1; } -def Torch_AtenSqueezeOp : Torch_Op<"aten.squeeze", [ +def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::squeeze : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$boundaries, + Torch_BoolType:$out_int32, + Torch_BoolType:$right ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqueezeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBucketizeTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenSqueezeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBucketizeTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; - let hasFolder = 1; } -def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ +def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::clone : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$start_dim, - Torch_IntType:$end_dim + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFlattenUsingIntsOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenCloneOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFlattenUsingIntsOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenCloneOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenDimOp : Torch_Op<"aten.dim", [ +def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::dim : (Tensor) -> (int)`"; + let summary = "Generated op for `aten::contiguous : (Tensor, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + Torch_IntType:$memory_format ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenContiguousOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenContiguousOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } -def Torch_AtenSizeOp : Torch_Op<"aten.size", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenCopy_Op : Torch_Op<"aten.copy_", [ + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::size : (Tensor) -> (int[])`"; + let summary = "Generated op for `aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_BoolType:$non_blocking ); let results = (outs - AnyTorchListOfTorchIntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSizeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenCopy_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenSizeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenCopy_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; - let hasCanonicalizer = 1; } -def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [ +def Torch_Aten_ToCopyOp : Torch_Op<"aten._to_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::Bool.Tensor : (Tensor) -> (bool)`"; + let summary = "Generated op for `aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$a + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + Torch_BoolType:$non_blocking, + AnyTorchOptionalIntType:$memory_format ); let results = (outs - Torch_BoolType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBoolTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult Aten_ToCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); } - void AtenBoolTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void Aten_ToCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); } }]; } -def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [ +def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`"; + let summary = "Generated op for `aten::detach : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); let results = (outs - Torch_BoolType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDetachOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) { + void AtenDetachOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ +def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$weight, + AnyTorchTensorType:$indices, + Torch_IntType:$padding_idx, + Torch_BoolType:$scale_grad_by_freq, + Torch_BoolType:$sparse ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenOnesOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEmbeddingOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenOnesOp::print(OpAsmPrinter &printer) { + void AtenEmbeddingOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ +def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$weight, + AnyTorchTensorType:$indices, + AnyTorchTensorType:$offsets, + Torch_BoolType:$scale_grad_by_freq, + Torch_IntType:$mode, + Torch_BoolType:$sparse, + AnyTorchOptionalTensorType:$per_sample_weights, + Torch_BoolType:$include_last_offset, + AnyTorchOptionalIntType:$padding_idx ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2, + AnyTorchTensorType:$result3 ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNewOnesOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenEmbeddingBagPaddingIdxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 4); } - void AtenNewOnesOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenEmbeddingBagPaddingIdxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 4); } }]; } -def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ +def Torch_Aten_EmbeddingBagOp : Torch_Op<"aten._embedding_bag", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)`"; let arguments = (ins - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$weight, + AnyTorchTensorType:$indices, + AnyTorchTensorType:$offsets, + Torch_BoolType:$scale_grad_by_freq, + Torch_IntType:$mode, + Torch_BoolType:$sparse, + AnyTorchOptionalTensorType:$per_sample_weights, + Torch_BoolType:$include_last_offset, + Torch_IntType:$padding_idx ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2, + AnyTorchTensorType:$result3 ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenZerosOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult Aten_EmbeddingBagOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 4); } - void AtenZerosOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void Aten_EmbeddingBagOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 4); } }]; } -def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ +def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, AnyTorchOptionalIntType:$dtype, AnyTorchOptionalIntType:$layout, AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNewZerosOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEmptyLikeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenNewZerosOp::print(OpAsmPrinter &printer) { + void AtenEmptyLikeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ +def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::tensor : (t[], int?, Device?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchListType:$data, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, AnyTorchOptionalDeviceType:$device, - Torch_BoolType:$requires_grad + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenNewEmptyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenNewEmptyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ +def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins - Torch_BoolType:$t, + AnyTorchTensorType:$self, AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, AnyTorchOptionalDeviceType:$device, - Torch_BoolType:$requires_grad + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTensorBoolOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenZerosLikeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenTensorBoolOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenZerosLikeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [ +def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins - Torch_IntType:$t, + AnyTorchTensorType:$self, AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, AnyTorchOptionalDeviceType:$device, - Torch_BoolType:$requires_grad - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTensorIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenTensorIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - -def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::_shape_as_tensor : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_ShapeAsTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenOnesLikeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void Aten_ShapeAsTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenOnesLikeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenAllOp : Torch_Op<"aten.all", [ +def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + NoSideEffect, AllowsTypeRefinement, HasValueSemantics, - ReadOnly + ReadOnly, ]> { - let summary = "Generated op for `aten::all : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAllOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenEmptyMemoryFormatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenAllOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenEmptyMemoryFormatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ +def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::all.bool : (bool[]) -> (bool)`"; + let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins - AnyTorchListOfTorchBoolType:$self + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + Torch_BoolType:$implicit ); let results = (outs - Torch_BoolType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAllBoolOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenExpandOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenAllBoolOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenExpandOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenAnyOp : Torch_Op<"aten.any", [ +def Torch_AtenExpandAsOp : Torch_Op<"aten.expand_as", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::any : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::expand_as : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAnyOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenExpandAsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAnyOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenExpandAsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAnyDimOp : Torch_Op<"aten.any.dim", [ +def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::any.dim : (Tensor, int, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_BoolType:$keepdim + AnyTorchListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAnyDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenBroadcastToOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAnyDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenBroadcastToOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenArangeOp : Torch_Op<"aten.arange", [ +def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)`"; let arguments = (ins - AnyTorchScalarType:$end, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$self, + AnyTorchListOfOptionalTensorType:$indices ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArangeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenIndexTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenArangeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenIndexTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [ +def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::index_select : (Tensor, int, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchScalarType:$start, - AnyTorchScalarType:$end, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArangeStartOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenIndexSelectOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenArangeStartOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenIndexSelectOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [ + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchScalarType:$start, - AnyTorchScalarType:$end, - AnyTorchScalarType:$step, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$self, + AnyTorchListOfOptionalTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate, + Torch_BoolType:$unsafe ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArangeStartStepOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult Aten_IndexPutImpl_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenArangeStartStepOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void Aten_IndexPutImpl_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [ - AllowsTypeRefinement +def Torch_AtenItemOp : Torch_Op<"aten.item", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::item : (Tensor) -> (Scalar)`"; let arguments = (ins - AnyTorchScalarType:$start, - AnyTorchScalarType:$end, - AnyTorchScalarType:$step, - AnyTorchTensorType:$out + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + AnyTorchScalarType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArangeStartOutOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenItemOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenArangeStartOutOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenItemOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ +def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_select : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dim, - Torch_BoolType:$keepdim + AnyTorchTensorType:$mask ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenArgmaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMaskedSelectOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenArgmaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMaskedSelectOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ +def Torch_AtenNumelOp : Torch_Op<"aten.numel", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::numel : (Tensor) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$boundaries, - Torch_BoolType:$out_int32, - Torch_BoolType:$right + AnyTorchTensorType:$self ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBucketizeTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenNumelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBucketizeTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenNumelOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ +def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clone : (Tensor, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$memory_format + AnyTorchListOfTorchIntType:$repeats ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCloneOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRepeatOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenCloneOp::print(OpAsmPrinter &printer) { + void AtenRepeatOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ +def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::contiguous : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$memory_format + AnyTorchListOfTorchIntType:$shape ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenContiguousOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReshapeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenContiguousOp::print(OpAsmPrinter &printer) { + void AtenReshapeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCopy_Op : Torch_Op<"aten.copy_", [ - AllowsTypeRefinement +def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ + AllowsTypeRefinement, + ReadOnly ]> { - let summary = "Generated op for `aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$src, - Torch_BoolType:$non_blocking + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCopy_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten_ReshapeAliasOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenCopy_Op::print(OpAsmPrinter &printer) { + void Aten_ReshapeAliasOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_Aten_ToCopyOp : Torch_Op<"aten._to_copy", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - Torch_BoolType:$non_blocking, + AnyTorchListOfTorchIntType:$size, AnyTorchOptionalIntType:$memory_format ); let results = (outs @@ -4789,1463 +5843,1464 @@ def Torch_Aten_ToCopyOp : Torch_Op<"aten._to_copy", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_ToCopyOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenResize_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void Aten_ToCopyOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenResize_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ +def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::detach : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::select.int : (Tensor, int, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_IntType:$index ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDetachOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenSelectIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenDetachOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenSelectIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ +def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::size.int : (Tensor, int) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$weight, - AnyTorchTensorType:$indices, - Torch_IntType:$padding_idx, - Torch_BoolType:$scale_grad_by_freq, - Torch_BoolType:$sparse + AnyTorchTensorType:$self, + Torch_IntType:$dim ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEmbeddingOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenSizeIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenEmbeddingOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenSizeIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } -def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [ +def Torch_AtenStackOp : Torch_Op<"aten.stack", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchListOfTensorType:$tensors, + Torch_IntType:$dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEmptyLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenEmptyLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenStackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [ +def Torch_AtenSumOp : Torch_Op<"aten.sum", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::sum : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNewEmptyOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNewEmptyOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [ +def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenZerosLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenSumDimIntListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenZerosLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenSumDimIntListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ +def Torch_AtenMaxOp : Torch_Op<"aten.max", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenOnesLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenOnesLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenMaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ +def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; let arguments = (ins - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEmptyMemoryFormatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); } - void AtenEmptyMemoryFormatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenMaxDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); } }]; } -def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ +def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - Torch_BoolType:$implicit + Torch_IntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExpandOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenToDtypeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenExpandOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenToDtypeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } -def Torch_AtenExpandAsOp : Torch_Op<"aten.expand_as", [ +def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::expand_as : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenExpandAsOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenToDtypeLayoutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); } - void AtenExpandAsOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenToDtypeLayoutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); } }]; + let hasFolder = 1; } -def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ +def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size + AnyTorchTensorType:$other, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBroadcastToOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenToOtherOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenBroadcastToOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenToOtherOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ +def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)`"; + let summary = "Generated op for `aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfOptionalTensorType:$indices + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenToPrimDeviceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenIndexTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenToPrimDeviceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ +def Torch_AtenToDeviceOp : Torch_Op<"aten.to.device", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::index_select : (Tensor, int, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index + Torch_DeviceType:$device, + Torch_IntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIndexSelectOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenToDeviceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenIndexSelectOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenToDeviceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [ - AllowsTypeRefinement +def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::type_as : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfOptionalTensorType:$indices, - AnyTorchTensorType:$values, - Torch_BoolType:$accumulate, - Torch_BoolType:$unsafe + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_IndexPutImpl_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenTypeAsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void Aten_IndexPutImpl_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenTypeAsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenItemOp : Torch_Op<"aten.item", [ +def Torch_AtenViewOp : Torch_Op<"aten.view", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::item : (Tensor) -> (Scalar)`"; + let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size ); let results = (outs - AnyTorchScalarType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenItemOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenViewOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenItemOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenViewOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } -def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ +def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::masked_select : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$mask + AnyTorchListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaskedSelectOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten_UnsafeViewOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaskedSelectOp::print(OpAsmPrinter &printer) { + void Aten_UnsafeViewOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNumelOp : Torch_Op<"aten.numel", [ +def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::numel : (Tensor) -> (int)`"; + let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$condition, + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNumelOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenWhereSelfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenNumelOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenWhereSelfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ +def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$repeats + AnyTorchTensorType:$condition, + AnyTorchScalarType:$self, + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRepeatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenWhereScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenRepeatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenWhereScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ +def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins + AnyTorchTensorType:$condition, AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$shape + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReshapeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenWhereScalarOtherOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenReshapeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenWhereScalarOtherOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ +def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - AnyTorchListOfTorchIntType:$stride + AnyTorchTensorType:$condition, + AnyTorchScalarType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_ReshapeAliasOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenWhereScalarSelfOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void Aten_ReshapeAliasOp::print(OpAsmPrinter &printer) { + void AtenWhereScalarSelfOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ - AllowsTypeRefinement +def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ + AllowsTypeRefinement, + ReadOnly ]> { - let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`"; + let summary = "Generated op for `aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - AnyTorchOptionalIntType:$memory_format + Torch_IntType:$dim, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, + Torch_IntType:$step ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenResize_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenSliceTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenResize_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenSliceTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ +def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::select.int : (Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::len.Tensor : (Tensor) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_IntType:$index + AnyTorchTensorType:$t ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSelectIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLenTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSelectIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLenTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ +def Torch_AtenCpuOp : Torch_Op<"aten.cpu", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::cpu : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$src, - Torch_IntType:$dim, - Torch_IntType:$index + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenCpuOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSelectScatterOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenCpuOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ +def Torch_AtenGatherOp : Torch_Op<"aten.gather", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::size.int : (Tensor, int) -> (int)`"; + let summary = "Generated op for `aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim + Torch_IntType:$dim, + AnyTorchTensorType:$index, + Torch_BoolType:$sparse_grad ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSizeIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenGatherOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenSizeIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenGatherOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; - let hasFolder = 1; } -def Torch_AtenStackOp : Torch_Op<"aten.stack", [ +def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchListOfTensorType:$tensors, - Torch_IntType:$dim + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenScatterAddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenStackOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenScatterAddOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenSumOp : Torch_Op<"aten.sum", [ +def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sum : (Tensor, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::IntImplicit : (Tensor) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype + AnyTorchTensorType:$a ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenIntImplicitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenSumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenIntImplicitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ +def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalListOfTorchIntType:$dim, - Torch_BoolType:$keepdim, - AnyTorchOptionalIntType:$dtype + Torch_FloatType:$t, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + Torch_BoolType:$requires_grad ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSumDimIntListOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTensorFloatOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenSumDimIntListOp::print(OpAsmPrinter &printer) { + void AtenTensorFloatOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenMaxOp : Torch_Op<"aten.max", [ +def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::Int.Tensor : (Tensor) -> (int)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$a ); let results = (outs - AnyTorchTensorType:$result + Torch_IntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenIntTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMaxOp::print(OpAsmPrinter &printer) { + void AtenIntTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } -def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ +def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_BoolType:$keepdim + AnyTorchTensorType:$a ); let results = (outs - AnyTorchTensorType:$values, - AnyTorchTensorType:$indices + Torch_FloatType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 2); + ParseResult AtenFloatTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenMaxDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 2); + void AtenFloatTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } -def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ +def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::dropout : (Tensor, float, bool) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dtype, - Torch_BoolType:$non_blocking, - Torch_BoolType:$copy, - AnyTorchOptionalIntType:$memory_format + AnyTorchTensorType:$input, + Torch_FloatType:$p, + Torch_BoolType:$train ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenToDtypeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenDropoutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenToDtypeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenDropoutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; - let hasFolder = 1; } -def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [ - AllowsTypeRefinement, - ReadOnly +def Torch_AtenDropout_Op : Torch_Op<"aten.dropout_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::dropout_ : (Tensor, float, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - Torch_BoolType:$non_blocking, - Torch_BoolType:$copy, - AnyTorchOptionalIntType:$memory_format + Torch_FloatType:$p, + Torch_BoolType:$train ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenToDtypeLayoutOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 8, 1); + ParseResult AtenDropout_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenToDtypeLayoutOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 8, 1); + void AtenDropout_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; - let hasFolder = 1; } -def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ +def Torch_AtenNativeDropoutOp : Torch_Op<"aten.native_dropout", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - Torch_BoolType:$non_blocking, - Torch_BoolType:$copy, - AnyTorchOptionalIntType:$memory_format + AnyTorchTensorType:$input, + Torch_FloatType:$p, + AnyTorchOptionalBoolType:$train ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenToOtherOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenNativeDropoutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); } - void AtenToOtherOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenNativeDropoutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); } }]; } -def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ +def Torch_AtenTOp : Torch_Op<"aten.t", [ AllowsTypeRefinement, ReadOnly ]> { - let summary = "Generated op for `aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalIntType:$dtype, - Torch_BoolType:$non_blocking, - Torch_BoolType:$copy + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenToPrimDeviceOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenTOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenToPrimDeviceOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenTOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ +def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::type_as : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTypeAsOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenTypeAsOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenNumpyTOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenViewOp : Torch_Op<"aten.view", [ +def Torch_AtenFullOp : Torch_Op<"aten.full", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size + AnyTorchListOfTorchIntType:$size, + AnyTorchScalarType:$fill_value, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenViewOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenFullOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenViewOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenFullOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; - let hasFolder = 1; } -def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ +def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size + AnyTorchScalarType:$fill_value, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_UnsafeViewOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenFullLikeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); } - void Aten_UnsafeViewOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenFullLikeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); } }]; } -def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [ +def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$condition, AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$batch1, + AnyTorchTensorType:$batch2, + AnyTorchScalarType:$beta, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenWhereSelfOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenBaddbmmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenWhereSelfOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenBaddbmmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly +def Torch_AtenBaddbmm_Op : Torch_Op<"aten.baddbmm_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::baddbmm_ : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$condition, - AnyTorchScalarType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$self, + AnyTorchTensorType:$batch1, + AnyTorchTensorType:$batch2, + AnyTorchScalarType:$beta, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenWhereScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenBaddbmm_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenWhereScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenBaddbmm_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ +def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::alias_copy : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$condition, - AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenWhereScalarOtherOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenAliasCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenWhereScalarOtherOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenAliasCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ +def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$condition, - AnyTorchScalarType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenWhereScalarSelfOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenAsStridedCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenWhereScalarSelfOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenAsStridedCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ +def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)`"; + let summary = "Generated op for `aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$start, - AnyTorchOptionalIntType:$end, - Torch_IntType:$step + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSliceTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + ParseResult AtenDiagonalCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenSliceTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + void AtenDiagonalCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [ +def Torch_AtenExpandCopyOp : Torch_Op<"aten.expand_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`"; + let summary = "Generated op for `aten::expand_copy : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$src, - Torch_IntType:$dim, - AnyTorchOptionalIntType:$start, - AnyTorchOptionalIntType:$end, - Torch_IntType:$step + AnyTorchListOfTorchIntType:$size, + Torch_BoolType:$implicit ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenExpandCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenSliceScatterOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenExpandCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [ +def Torch_AtenPermuteCopyOp : Torch_Op<"aten.permute_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::len.Tensor : (Tensor) -> (int)`"; + let summary = "Generated op for `aten::permute_copy : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$t + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dims ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLenTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenPermuteCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLenTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenPermuteCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCpuOp : Torch_Op<"aten.cpu", [ +def Torch_Aten_ReshapeAliasCopyOp : Torch_Op<"aten._reshape_alias_copy", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cpu : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::_reshape_alias_copy : (Tensor, int[], int[]) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCpuOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult Aten_ReshapeAliasCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenCpuOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void Aten_ReshapeAliasCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGatherOp : Torch_Op<"aten.gather", [ +def Torch_AtenSelectCopyIntOp : Torch_Op<"aten.select_copy.int", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::select_copy.int : (Tensor, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, - AnyTorchTensorType:$index, - Torch_BoolType:$sparse_grad + Torch_IntType:$index ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGatherOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenSelectCopyIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGatherOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenSelectCopyIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [ +def Torch_AtenDetachCopyOp : Torch_Op<"aten.detach_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::detach_copy : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchTensorType:$src + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenScatterAddOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenDetachCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenScatterAddOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenDetachCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ +def Torch_AtenSliceCopyTensorOp : Torch_Op<"aten.slice_copy.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::IntImplicit : (Tensor) -> (int)`"; + let summary = "Generated op for `aten::slice_copy.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$a + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, + Torch_IntType:$step ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIntImplicitOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenSliceCopyTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenIntImplicitOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenSliceCopyTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ +def Torch_AtenSqueezeCopyOp : Torch_Op<"aten.squeeze_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::squeeze_copy : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_FloatType:$t, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalDeviceType:$device, - Torch_BoolType:$requires_grad + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTensorFloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenSqueezeCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenTensorFloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenSqueezeCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ +def Torch_AtenSqueezeCopyDimOp : Torch_Op<"aten.squeeze_copy.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::Int.Tensor : (Tensor) -> (int)`"; + let summary = "Generated op for `aten::squeeze_copy.dim : (Tensor, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$a + AnyTorchTensorType:$self, + Torch_IntType:$dim ); let results = (outs - Torch_IntType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenIntTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenSqueezeCopyDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenIntTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenSqueezeCopyDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } -def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [ +def Torch_AtenTCopyOp : Torch_Op<"aten.t_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`"; + let summary = "Generated op for `aten::t_copy : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$a + AnyTorchTensorType:$self ); let results = (outs - Torch_FloatType:$result + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloatTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTCopyOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenFloatTensorOp::print(OpAsmPrinter &printer) { + void AtenTCopyOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasFolder = 1; } -def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [ +def Torch_AtenTransposeCopyIntOp : Torch_Op<"aten.transpose_copy.int", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::dropout : (Tensor, float, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::transpose_copy.int : (Tensor, int, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - Torch_FloatType:$p, - Torch_BoolType:$train + AnyTorchTensorType:$self, + Torch_IntType:$dim0, + Torch_IntType:$dim1 ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDropoutOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenTransposeCopyIntOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenDropoutOp::print(OpAsmPrinter &printer) { + void AtenTransposeCopyIntOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenDropout_Op : Torch_Op<"aten.dropout_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenUnsqueezeCopyOp : Torch_Op<"aten.unsqueeze_copy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::dropout_ : (Tensor, float, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::unsqueeze_copy : (Tensor, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_FloatType:$p, - Torch_BoolType:$train + Torch_IntType:$dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDropout_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenUnsqueezeCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDropout_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenUnsqueezeCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNativeDropoutOp : Torch_Op<"aten.native_dropout", [ +def Torch_AtenViewCopyOp : Torch_Op<"aten.view_copy", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::view_copy : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - Torch_FloatType:$p, - AnyTorchOptionalBoolType:$train + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size ); let results = (outs - AnyTorchTensorType:$result0, - AnyTorchTensorType:$result1 + AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNativeDropoutOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 2); + ParseResult AtenViewCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNativeDropoutOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 2); + void AtenViewCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenTOp : Torch_Op<"aten.t", [ +def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::view_copy.dtype : (Tensor, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + Torch_IntType:$dtype ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenViewCopyDtypeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenTOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenViewCopyDtypeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [ +def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + Torch_IntType:$dimension, + Torch_IntType:$size, + Torch_IntType:$step ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenUnfoldCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenNumpyTOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenUnfoldCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenFullOp : Torch_Op<"aten.full", [ +def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; + let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`"; let arguments = (ins - AnyTorchListOfTorchIntType:$size, - AnyTorchScalarType:$fill_value, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_IntType:$dim, + Torch_IntType:$index ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFullOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); } - void AtenFullOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenSelectScatterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); } }]; } -def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ +def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$fill_value, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format + AnyTorchTensorType:$src, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, + Torch_IntType:$step ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFullLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenFullLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + void AtenSliceScatterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [ +def Torch_AtenDiagonalScatterOp : Torch_Op<"aten.diagonal_scatter", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$batch1, - AnyTorchTensorType:$batch2, - AnyTorchScalarType:$beta, - AnyTorchScalarType:$alpha + AnyTorchTensorType:$src, + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBaddbmmOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDiagonalScatterOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenBaddbmmOp::print(OpAsmPrinter &printer) { + void AtenDiagonalScatterOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenBaddbmm_Op : Torch_Op<"aten.baddbmm_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::baddbmm_ : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$batch1, - AnyTorchTensorType:$batch2, - AnyTorchScalarType:$beta, - AnyTorchScalarType:$alpha + AnyTorchTensorType:$src, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBaddbmm_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAsStridedScatterOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenBaddbmm_Op::print(OpAsmPrinter &printer) { + void AtenAsStridedScatterOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 5, 1); } }]; @@ -6482,6 +7537,7 @@ def Torch_AtenAddTOp : Torch_Op<"aten.add.t", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ @@ -7107,6 +8163,30 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ let hasFolder = 1; } +def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRemainderScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -7906,6 +8986,31 @@ def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ let hasFolder = 1; } +def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::narrow : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_IntType:$start, + Torch_IntType:$length + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNarrowOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNarrowOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 392f9a4c3fa6..f5d36b53e6df 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -187,6 +187,15 @@ m_TorchTensorSizeInt(Value tensor, int64_t *dim) { Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value tensor); +/// Adjusts the static information in the type of `value` to `desiredType`. +/// +/// Returns null if such an adjustment is not possible. +/// +/// If `userAllowsRefinement` is true, then the original value will be returned +/// if it is a subtype of `desiredType`. +Value adjustStaticInformation(OpBuilder &builder, Location loc, Value value, + Type desiredType, bool userAllowsRefinement); + /// Returns true if `list` is potentially mutated. bool isListPotentiallyMutated(Value list); diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 5e6ac9dd0b66..fae78b45ae37 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -228,8 +228,6 @@ def Torch_AttrOp : Torch_Op<"attr", [ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ Symbol, - IsolatedFromAbove, - SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp"> ]> { let summary = "A slot with global storage"; let description = [{ @@ -245,17 +243,66 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ TypeAttr:$typeBound ); let results = (outs); + + let assemblyFormat = [{ + ($sym_visibility^)? $sym_name attr-dict `:` $typeBound + }]; +} + +def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp"> + ]> { + let summary = "Module initializer for all `torch.global_slot` ops"; + let description = [{ + Initializer function that runs once at program startup to initialize + all `torch.global_slot` ops in the module. + + The only ops that should be in the module initializer should be ops + generated by the IValue importer. This set avoids the need to define + the behavior in case of certain kinds of side effects in the initializer + (except for the side effect of updating the torch.global_slot ops with the + `torch.initialize.global_slots` op). + }]; + + let arguments = (ins); + let results = (outs); let regions = (region SizedRegion<1>:$initializer); let assemblyFormat = [{ - ($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)? + $initializer attr-dict + }]; + let hasVerifier = 1; +} + +def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [ + Terminator, + HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> { + let summary = "Terminator for torch.global_slot.module_initializer region"; + let description = [{ + Atomically updates the value of all the global slots named in `slotSymNames` + with the corresponding values provided in `initialValues`. }]; + + let arguments = (ins + SymbolRefArrayAttr:$slotSymNames, + Variadic:$initialValues + ); + let results = (outs); + + // This builder creates an illegal op, but is needed to appease + // ensureTerminator in the default builders for SingleBlockImplicitTerminator + // on the parent op. + // TODO: Have a SingleBlockExplicitTerminator trait. + let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [ Terminator, HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> { - let summary = "yield-like terminator for torch.global_slot initializer region"; + let summary = "yield-like terminator for torch.initialize.global_slotsr region"; let description = [{ The operand to this op becomes the initial value of the parent torch.global_slot. @@ -310,8 +357,10 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> { // See `torch/csrc/jit/runtime/instruction.h`. -def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", - [AllowsTypeRefinement]> { +def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [ + AllowsTypeRefinement, + ReadOnly + ]> { let summary = "TorchScript prim::ListUnpack op"; let arguments = (ins AnyTorchType:$operand); let results = (outs Variadic:$results); @@ -319,6 +368,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results)) }]; + let hasCanonicalizer = 1; } def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index 46771dc72663..59be99c885e7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -54,12 +54,12 @@ class BaseTensorType : public Type { Type getOptionalDtype() const; /// Return true if this type has a list of sizes. - bool hasSizes() const { return getOptionalSizes().hasValue(); } + bool hasSizes() const { return getOptionalSizes().has_value(); } /// Get the list of sizes. Requires `hasSizes()`. ArrayRef getSizes() const { assert(hasSizes() && "must have sizes"); - return getOptionalSizes().getValue(); + return getOptionalSizes().value(); } /// Return true if all sizes of this tensor are known. diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e0780395e7a4..ae9dd4249219 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -49,8 +49,8 @@ class OptionalArrayRefParameter : AttrOrTypeParameter< "::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> { let allocator = [{ - if ($_self.hasValue()) { - $_dst.getValue() = $_allocator.copyInto($_self.getValue()); + if ($_self.has_value()) { + $_dst.value() = $_allocator.copyInto($_self.value()); } }]; } @@ -78,11 +78,8 @@ class AnyTorchTensorType If the type is `!torch.vtensor` then the tensor is restricted to operations that have value semantics ("v" = "value semantics"). This helps to maintain a strict separation between the value-semantic and potentially-mutating - worlds, as one of our main jobs in the compiler is to isolate the mutating - parts as much as possible because most lower levels of the compiler stack - are expected to require value semantics. E.g. many backend contracts - mostly use linalg-on-tensor for compute-heavy ops, which require - a conversion to the builtin `tensor` type which has value semantics. + worlds, as one of our main jobs in Torch-MLIR is to normalize the program + into a form with value semantics. Some notes about value semantics: - Using the type system described in PEP 483 (which TorchScript and other Python systems follow), `!torch.tensor` is a subtype of diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index dda974c32fff..98ee5151ed42 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -28,15 +28,33 @@ createPrepareForGlobalizeObjectGraphPass(); struct TorchLoweringPipelineOptions : public PassPipelineOptions { - // If this option is true, then perform optimizations. - // If this option is false, only do the bare minimum for correctness. - Option optimize{*this, "optimize", llvm::cl::desc("Do optimizations."), - llvm::cl::init(true)}; - - // If this option is false, decompose complex operations. - // If this option is true, skip decomposition of complex operations. - Option decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."), - llvm::cl::init(true)}; + // The maximum number of invocations of the simplification pipeline in + // LowerToBackendContract. + Option maxIterations{ + *this, "max-iterations", + llvm::cl::desc( + "Maximum number of invocations of the simplification pipeline."), + llvm::cl::init(10)}; + // If this option is true, decompose complex operations. + // If this option is false, skip decomposition of complex operations. + Option decompose{*this, "decompose-complex-ops", + llvm::cl::desc("Decompose complex operations."), + llvm::cl::init(true)}; + // A list of ops that should be considered legal for the backend. + // TODO: The meaning of this list should be formalized. + // A sketch of the semantics would be: + // - In torch_ods_gen.py, we mark each op as "legal in backend contract", + // "illegal in backend contract", or "conditionally legal in backend + // contract". + // This option would be a list of ops from the "conditionally legal" set + // which should be considered legal for a particular invocation of the + // lowering pipeline. + // TODO: The "decompose" flag should be expanded into this formulation + // of legality for the backend. Ultimately we will want LowerToBackendContract + // to check for a specific set of legal ops to stop its iteration. + ListOption backendLegalOps{ + *this, "backend-legal-ops", + llvm::cl::desc("List of ops to be considered legal for the backend.")}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -50,10 +68,16 @@ void createTorchScriptModuleToTorchBackendPipeline( void createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); -/// Creates a pipeline that refines shapes of tensor operations in the program. -void createTorchShapeRefinementPipeline( +/// Creates a pipeline that simplifies the computations in the program. +/// This pass does not do any global program restructuring -- it works entirely +/// within a single semantic model of a `builtin.module` with +/// `torch.global_slot` ops and `func.func` ops. +void createTorchSimplificationPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that refines shapes of tensor operations in the program. +void createTorchShapeRefinementPipeline(OpPassManager &pm); + std::unique_ptr> createAdjustCallingConventionsPass(); std::unique_ptr> createRefineTypesPass(); @@ -66,7 +90,8 @@ std::unique_ptr> createMaximizeValueSemanticsPass(); std::unique_ptr> createRefinePublicReturnPass(); -std::unique_ptr> createDecomposeComplexOpsPass(); +std::unique_ptr> +createDecomposeComplexOpsPass(ArrayRef legalOps); std::unique_ptr> createPreprocessShapeLibraryPass(); @@ -78,7 +103,11 @@ createSimplifyShapeCalculationsPass(); std::unique_ptr> createDropShapeCalculationsPass(); std::unique_ptr> -createVerifyConversionToValueSemanticsPass(); +createEraseModuleInitializerPass(); + +std::unique_ptr> +createLowerToBackendContractPass(int maxIterations, bool decompose, + ArrayRef backendLegalOps); StringRef getShapeLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 0afc46521f37..c1ce31aa6611 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -143,7 +143,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> { Note: This pass inlines everything that is safe to inline. That is, it doesn't have a cost model. This is likely to pessimize programs with - significant amounts of computation inside torch.global_slot initializer + significant amounts of computation inside torch.initialize.global_slotsr regions (but this currently doesn't happen due to how TorchScript modules are imported -- the contents are just constants). }]; @@ -217,7 +217,14 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> { let summary = "Decompose complicated torch operations"; - let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; + let constructor = [{ + mlir::torch::Torch::createDecomposeComplexOpsPass(/*legalOps=*/{}) + }]; + let options = [ + ListOption<"legalOps", "legal-ops", "std::string", + "List of operation names that should be considered legal", + "llvm::cl::ZeroOrMore"> + ]; let description = [{ Decompose torch operation that are losslessly represented as combinations of other operations, modulo appropropriate compiler fusion. Note that this pass @@ -253,16 +260,73 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp" }]; } -def VerifyConversionToValueSemantics - : Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> { - let summary = "Verify that all tensors have been converted to value semantics"; +def EraseModuleInitializer + : Pass<"torch-erase-module-initializer", "ModuleOp"> { + let summary = "Erase the `torch.global_slot.module_initializer` op."; let constructor = - "mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()"; + "mlir::torch::Torch::createEraseModuleInitializerPass()"; + let description = [{ + Backends cannot currently handle module initializers, so we omit them from + our backend contract. This pass removes the + `torch.global_slot.module_initializer` op from the module if legal. + }]; +} + +def LowerToBackendContract + : Pass<"torch-lower-to-backend-contract", "ModuleOp"> { + let summary = "Perform simplifications until the backend contract is satisfied."; + let constructor = [{ + mlir::torch::Torch::createLowerToBackendContractPass( + /*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}) + }]; let description = [{ - Prior passes in the pipeline may have missed converting all tensors to value - semantics and we wish to catch such failures early instead of fixing - individual cases downstream. + This pass performs the bulk of the lowering of the program's computations + to the backend contract. This pass does not do any global program + restructuring -- it works entirely within a single semantic model + of a `builtin.module` with `torch.global_slot` ops and `func.func` ops. + + This pass runs a set of simplifications within that semantic model until + the backend contract is satisfied, and fails if it cannot be satisfied. + In particular, the backend contract consists of: + - Tensors + - Have been converted to value semantics. + - Have at least a known rank, though ideally a maximally inferred shape. + - Have a known dtype. + - `torch.global_slot`'s have been eliminated from the program. + - Ops have been decomposed. + + This particular choice of backend contract was born out of a common set of + requirements from backends, along with aligning with long-term PyTorch + direction of being more tracing-based. The set of simplifications performed + here can be thought of as simulating the kinds of simplifications that + happen naturally as part of tracing, but in a way that is applicable + to our TorchScript frontend. For the LazyTensorCore frontend, the backend + contract trivially holds (except for certain decompositions). + + Generally it is not desirable to have a compiler where successful + compilation depends on "optimizing hard enough", but in this case, there + seems to be enough alignment and recognition in the industry that the + Python-based programming model in the source program is too dynamic + to feasibly handle in totality without a tracing approach that has access + to the source program to re-trace in the face of dynamism (e.g. the ability + to do what TorchDynamo calls "graph break"). We are attempting to maintain + a practical compiler that works well given the current set of constraints + of the TorchScript frontend that PyTorch provides us, and are working to + co-design PyTorch's direction so that we land in a place where most of this + "optimizing hard enough" is not necessary. }]; + let options = [ + Option<"maxIterations", "max-iterations", "int", /*default=*/"10", + "Maximum number of invocations of the simplification pipeline.">, + Option<"decompose", "decompose", "bool", /*default=*/"true", + "Decompose ops.">, + ListOption<"backendLegalOps", "backend-legal-ops", "std::string", + "List of ops to be considered legal for the backend."> + + ]; + // TODO: Debug why this is needed, even though the input program has func.func + // ops in it. + let dependentDialects = ["func::FuncDialect"]; } #endif // TORCHMLIR_TORCH_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 0d2d75b7818f..1a87cdabd458 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -153,6 +153,14 @@ enum MemoryFormat { //===----------------------------------------------------------------------===// enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; +//===----------------------------------------------------------------------===// +// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops. +// Source: +// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +//===-----------------------------------------------------------------------===// +enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX }; + +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index 32782186fe6e..f9a8850b4f05 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -42,6 +42,8 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; } def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor"> @@ -60,6 +62,8 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; } def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [ @@ -152,6 +156,7 @@ def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [ @@ -170,6 +175,7 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_I64ToGeneratorOp : TorchConversion_Op<"i64_to_generator", [ diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index ce6ef9da1efd..c4008dec4944 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -34,8 +34,12 @@ void createTorchBackendToTosaBackendPipeline( OpPassManager &pm, const torch::Torch::TorchLoweringPipelineOptions &options); -std::unique_ptr> -createVerifyInvariantsBeforeBackendLoweringPass(); +// Do not register the torch-to-mhlo pipeline if mhlo target is disabled +#ifdef TORCH_MLIR_ENABLE_MHLO +void createTorchBackendToMhloBackendPipeline( + OpPassManager &pm, + const torch::Torch::TorchLoweringPipelineOptions &options); +#endif std::unique_ptr> createFuncBackendTypeConversionPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index cbadd0b92144..8afd9850b7eb 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -12,27 +12,6 @@ include "mlir/Pass/PassBase.td" -def VerifyInvariantsBeforeBackendLowering - : Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> { - let summary = "Verify invariants required by backend lowering"; - let constructor = - "mlir::torch::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()"; - let description = [{ - This pass checks any invariants needed by the process of lowering the - `torch` dialect to the linalg-on-tensors backend contract. - - The most important invariant is that all tensors should be ranked and have - a known dtype. It is useful to catch this early because it usually - represents a simple bug in RefineTypes, but can manifest as many different - kinds of obscure symptoms during lowering. - - TODO: This pass should probably be phrased as checking the - "torch backend contract" and moved to that dialect once we have more - substantial definition definition around what that layer is from an - "allowlist" perspective. - }]; -} - def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> { let summary = "Convert functions to operate on builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionPass()"; diff --git a/lib/CAPI/TorchOps.cpp b/lib/CAPI/TorchOps.cpp index b67e4a8945ed..b3b459ad0833 100644 --- a/lib/CAPI/TorchOps.cpp +++ b/lib/CAPI/TorchOps.cpp @@ -31,45 +31,8 @@ MlirValue torchMlirAdjustStaticInformation(MlirBlock block_, OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_))); builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator() : block->end()); - Value value = unwrap(value_); - Type type = value.getType(); Type desiredType = unwrap(desiredType_); - - // If the value is already of the desired type, we're done. - if (type == desiredType) - return wrap(value); - - // If the type is a tensor, then adjust the static information. - if ((type.isa() && - desiredType.isa()) || - (type.isa() && - desiredType.isa())) { - Value adjusted = builder.create( - value.getLoc(), desiredType, value); - return wrap(adjusted); - } - - // If the type is a subtype of desiredType, then we need to derefine it to - // desiredType, unless the user allows refinement. - if (Torch::isValidSubtype(type, desiredType)) { - if (!userAllowsRefinement) { - Value adjusted = - builder.create(value.getLoc(), desiredType, value); - return wrap(adjusted); - } else { - return wrap(value); - } - } - - // If the desiredType is subtype of type, then we assume that the desiredType - // is dynamically valid, so we do an unchecked cast. - if (Torch::isValidSubtype(desiredType, type)) { - Value adjusted = builder.create( - value.getLoc(), desiredType, value); - return wrap(adjusted); - } - - // No known adjustment. - return {}; + return wrap(Torch::adjustStaticInformation( + builder, value.getLoc(), value, desiredType, userAllowsRefinement)); } diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 0c67453f3421..7465fc06ef08 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::NonValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); @@ -212,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( } MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { - auto attrTensorType = unwrap(attr).getType().cast(); + auto attrTensorType = + unwrap(attr).cast().getType().cast(); return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); @@ -231,7 +233,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::ValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f9cabe41bcab..ec6ee8cee77a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(TorchMLIRInitAll Core LINK_LIBS PUBLIC + MLIRFuncDialect MLIRIR MLIRSupport diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9cc4019bb94a..29318e3b66a3 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,12 +1,23 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) -add_subdirectory(TorchToStd) +add_subdirectory(TorchToArith) add_subdirectory(TorchToTosa) +if(TORCH_MLIR_ENABLE_MHLO) + add_subdirectory(TorchToMhlo) +endif() add_subdirectory(TorchToTMTensor) add_subdirectory(Utils) # TODO: Automate this with add_torch_mlir_conversion_library. -#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS) +set(linked_libs TorchMLIRTorchToLinalg + TorchMLIRTorchToSCF + TorchMLIRTorchToArith + TorchMLIRTorchToTosa + TorchMLIRTorchToTMTensor + TorchMLIRConversionUtils) +if(TORCH_MLIR_ENABLE_MHLO) + list(APPEND linked_libs TorchMLIRTorchToMhlo) +endif() add_mlir_library(TorchMLIRConversionPasses Passes.cpp @@ -18,11 +29,6 @@ add_mlir_library(TorchMLIRConversionPasses Core LINK_LIBS PUBLIC - TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToStd - TorchMLIRTorchToTosa - TorchMLIRTorchToTMTensor - TorchMLIRConversionUtils + ${linked_libs} #${torch_mlir_conversion_libs} ) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 710e91187fd3..98f1acb75e05 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,10 +9,14 @@ #include "torch-mlir/Conversion/Passes.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" //===----------------------------------------------------------------------===// @@ -24,4 +28,11 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { ::registerPasses(); } +void mlir::torch::registerConversionPasses() { + ::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_MHLO + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::mhlo::createLegalizeHloToLinalgPass(); + }); +#endif // TORCH_MLIR_ENABLE_MHLO +} diff --git a/lib/Conversion/TorchToArith/CMakeLists.txt b/lib/Conversion/TorchToArith/CMakeLists.txt new file mode 100644 index 000000000000..4524c3b07ce8 --- /dev/null +++ b/lib/Conversion/TorchToArith/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(TorchMLIRTorchToArith + TorchToArith.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRFuncDialect + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRTorchToArith) diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp similarity index 94% rename from lib/Conversion/TorchToStd/TorchToStd.cpp rename to lib/Conversion/TorchToArith/TorchToArith.cpp index 00b969f79b82..db7155fdd7f0 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -16,6 +16,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -178,16 +179,17 @@ class ConvertTorchTensorLiteralOp })); return success(); } - if (auto elements = op.valueAttr().dyn_cast()) { + if (auto elements = op.valueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = IntegerType::get(context, intType.getIntOrFloatBitWidth()); auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); + AsmResourceBlob *blob = elements.getRawHandle().getBlob(); + assert(blob && "Expecting dense resource with a valid blob"); rewriter.replaceOpWithNewOp( - op, OpaqueElementsAttr::get(elements.getDialect(), shapedType, - elements.getValue())); + op, DenseElementsAttr::get(shapedType, blob->getData())); return success(); } } @@ -294,7 +296,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToStd : public ConvertTorchToStdBase { +class ConvertTorchToArith : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -323,7 +325,7 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns .add>( typeConverter, context); @@ -333,6 +335,9 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenIntComparisonOp>( + typeConverter, context); target.addIllegalOp(); patterns.add< @@ -359,13 +364,17 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { target.addIllegalOp(); patterns.add>(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); patterns.add>( typeConverter, context); @@ -400,6 +409,6 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToStdPass() { - return std::make_unique(); +mlir::torch::createConvertTorchToArithPass() { + return std::make_unique(); } diff --git a/lib/Conversion/TorchToLinalg/CMakeLists.txt b/lib/Conversion/TorchToLinalg/CMakeLists.txt index 7a86e43a61c1..ece92959719c 100644 --- a/lib/Conversion/TorchToLinalg/CMakeLists.txt +++ b/lib/Conversion/TorchToLinalg/CMakeLists.txt @@ -1,6 +1,4 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg -# TODO: Re-enable after MacOS support is fixed for the custom op extension. -# CustomOpExample.cpp DataMovement.cpp IndirectDataMovement.cpp Linear.cpp diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 14b46d9db539..5283774f0026 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -76,6 +76,11 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); + int64_t inputRank = inputType.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value dimSize = inputShape[dim]; @@ -121,6 +126,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, strides[dim] = rewriter.create(loc, strides[dim], stepIndex); return success(); } + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -179,12 +185,84 @@ class ConvertAtenFlattenUsingIntsOp namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to -/// `linalg.TensorExpandShape` op only when one or multiple static dimensions -/// are expanded. All the other cases of `aten.View` op need to be handled. +/// one `linalg.TensorExpandShape` op for all expanded dimensions and one +/// `linalg.TensorCollapseShape` op for all collapsed dimensions. Cases where +/// there is neither an expand or collapse of dimensions (e.g. [2, 3] -> [3, 2]) +/// is not handled. Additionally, certain dynamic dimension cases rely on naive +/// assumptions or aren't supported. /// TODO: Handle all the other cases of `aten.View` op. class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + + // Helper for filling in remaining un-collapsed dims when the + // input/output dim is next to the next boundary dim. Additionally + // computes the size of a collapsed dynamic dim if necessary. + static LogicalResult + collapseToSingleDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, + int64_t collapseDim, int64_t maxCollapseDim, + int64_t startExpandDim, int64_t maxExpandDim, + SmallVector &collapseShape, + const SmallVector &expandShape, + ReassociationIndices &expandIndices) { + int64_t collapseDimSize = 1; + for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { + expandIndices.push_back(i); + if (collapseDimSize == kUnknownSize) + continue; + + int64_t expandedDimSize = expandShape[i]; + if (expandedDimSize == kUnknownSize) { + collapseDimSize = kUnknownSize; + continue; + } + collapseDimSize *= expandedDimSize; + } + int64_t rawCollapseDimSize = collapseShape[collapseDim]; + if (rawCollapseDimSize != kUnknownSize && collapseDimSize != kUnknownSize && + collapseDimSize != rawCollapseDimSize) { + return rewriter.notifyMatchFailure( + op, "desired size is not compatible with the input tensor size"); + } + collapseShape[collapseDim] = collapseDimSize; + return success(); + } + + // Helper to find the minimum set of dims to collapse with the + // same number of elements as that of collapseDim. This function assumes + // the size of the collapsed dim is never dynamic. + static LogicalResult + minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, + int64_t collapseDim, int64_t maxCollapseDim, + int64_t startExpandDim, int64_t maxExpandDim, + const SmallVector &collapseShape, + const SmallVector &expandShape, + ReassociationIndices &expandIndices) { + int64_t collapseDimSize = collapseShape[collapseDim]; + int64_t expandedSize = 1; + + for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { + int64_t expandDimSize = expandShape[i]; + if (expandDimSize == kUnknownSize || + collapseDimSize % (expandedSize *= expandDimSize)) { + return rewriter.notifyMatchFailure( + op, "desired size is not compatible with the input tensor size"); + } + expandIndices.push_back(i); + if (expandedSize == collapseDimSize) + return success(); + + if (expandedSize > collapseDimSize) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only supports expanding and collapsing " + "in view"); + } + } + + return rewriter.notifyMatchFailure( + op, "total number of elements mismatch in the expansion"); + } + LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -208,10 +286,6 @@ class ConvertAtenViewOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unimplemented: input rank 0 is not supported"); - bool isCollapse = inputRank > resultRank ? true : false; - int64_t collapsedRank = isCollapse ? resultRank : inputRank; - int64_t expandedRank = isCollapse ? inputRank : resultRank; - // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. SmallVector outputSizeTorchInt; @@ -227,43 +301,26 @@ class ConvertAtenViewOp : public OpConversionPattern { op, "desired size list length mismatches with the result type rank"); } - SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef expandedShapeInt = - llvm::makeArrayRef(isCollapse ? inputSize : outputSizeInt); - ArrayRef collapsedShapeInt = - llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSize); - - // Currently, we only handle the expanding or collapsing cases or the - // identity cases where the rank and shape of the input and result are - // equal, and the input itself is the result. We do not handle expanding And - // collapsing happening at the same time or cases where it's neither + // Currently, we only handle the cases where each dimension is either + // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. - // TODO: For the expanding And collapsing case, we will need to identify - // which dimensions are collapsing and which are expanding and do it in two - // steps. // TODO: For neither collapsing nor expanding, we could find a intermediate // shape to collapse and then expanded to the target shape. Like [2,3] => // [6] => [3, 2]. - if (inputRank == resultRank) { - for (unsigned i = 0; i < inputRank; i++) - checkDimEqualHelper(rewriter, loc, inputSize[i], outputSizeInt[i]); - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); - } // Iterate through the view op size list to do the following: // // 1. Combine output size list and input tensor type info to get the most // static outputShape. // - // 2. Fill in the reassociation for size list item where the output dim size - // is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively - // assume this means the corresponding dimension is not expanded or + // 2. Mark dims in unchangedDims for size list items where the output dim + // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We + // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption - // is violated. + // is violated for the cases of dynamic dimensions. SmallVector outputShape(resultRank, kUnknownSize); - SmallVector reassociation(collapsedRank); + SmallVector unchangedDims; llvm::Optional inferredDimension; for (auto en : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; @@ -272,9 +329,9 @@ class ConvertAtenViewOp : public OpConversionPattern { // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(en.value(), m_TorchTensorSizeInt(op.self(), &inputDim))) { - auto collapsedDim = isCollapse ? outputDim : inputDim; - auto expandedDim = isCollapse ? inputDim : outputDim; - reassociation[collapsedDim].push_back(expandedDim); + unchangedDims.emplace_back(); + unchangedDims.back().push_back(inputDim); + unchangedDims.back().push_back(outputDim); if (!inputType.isDynamicDim(inputDim)) { outputShape[outputDim] = inputShape[inputDim]; continue; @@ -285,7 +342,7 @@ class ConvertAtenViewOp : public OpConversionPattern { continue; } - if (inferredDimension.hasValue()) { + if (inferredDimension.has_value()) { return rewriter.notifyMatchFailure( op, "at most one element in size list is allowed to be -1"); } @@ -293,6 +350,11 @@ class ConvertAtenViewOp : public OpConversionPattern { } } + // Mark the end of the input/output shapes + unchangedDims.emplace_back(); + unchangedDims.back().push_back(inputRank); + unchangedDims.back().push_back(resultRank); + // Use static information of input tensor to determine size of inferred // dimension in output shape. // @@ -301,7 +363,7 @@ class ConvertAtenViewOp : public OpConversionPattern { // then we don't need to analyze the static information of the input // shape since the reassociation of dimensions only requires rank // information. - if (inferredDimension.hasValue() && outputShape.size() > 1) { + if (inferredDimension.has_value() && outputShape.size() > 1) { if (llvm::count(outputShape, kUnknownSize) != 1 || llvm::count(inputShape, kUnknownSize) != 0) { return rewriter.notifyMatchFailure( @@ -329,139 +391,208 @@ class ConvertAtenViewOp : public OpConversionPattern { numOfElements / outputKnownNumOfElements; } - SmallVector collapsedShape = - isCollapse ? outputShape : llvm::to_vector(inputShape); - SmallVector expandedShape = - isCollapse ? llvm::to_vector(inputShape) : outputShape; - - // The while loop does the following: - // 1. Fill in the reassociation indices for dimensions that are expanded. - // Check the interval dimensions between two unchanged dims in the - // collapsedShape. If the interval is size 1, associate all the dims - // in the expandedShape shape until the next unchanged dim. If the interval - // is larger than size 1, figure out the associations with assumptions that - // dynamic dimensions are not splitted. - // 2. Set collapsedShape and expandedShape following the requirements by + SmallVector inputSize = getTensorSizes(rewriter, loc, input); + ArrayRef outputShapeInt = llvm::makeArrayRef(outputSizeInt); + ArrayRef inputShapeInt = llvm::makeArrayRef(inputSize); + + // Association indices for expand/collapse ops. These two vectors + // are populated such that two entries at the same index corresponds + // to an expand or collapse. For example, + // + // inputAssociations: [[0, 1], [2]] + // outputAssociations: [[0], [1, 2, 3]] + // + // indicates that the first two dims of the input tensor + // are collapsed into the first dim of the output, and the + // third dim of the input is expanded into the last three dims + // of the output. + SmallVector inputAssociations; + SmallVector outputAssociations; + + SmallVector inputShapeVec = llvm::to_vector(inputShape); + + // The for loop does the following: + // 1. Attempt to match the indices from inputDim and outputDim to the next + // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or + // until (inputRank, resultRank) if there is no such op. Look at the first + // dimension of the input and output and collapse the larger one by finding + // a minimal set of opposing indices with the same number of elements. If + // the number of dims to the next boundary is 1, then we assume all + // remaining opposing dims must collapse into it. + // 2. For handling of dynamic dimensions, we first assume they are only + // split if we can easily compute the correct size. + // e.g. [2, -1] -> [2, 3, 4] + // This mainly happens at the edges of boundaries. Otherwise we try to match + // the dynamic dimension with the one across from it and give up if we can't + // reason about how the dimensions are associated. + // e.g. [-1, -1] -> [2, 3, 4] + // 3. Set inputShapeVec and outputShape following the requirements by // tensor.expand_shape verification code: // a. As long as one or more of the related dimensions in the expanded // shape is dynamic the collapsed dimension is dynamic. // b. If all of the related dimensions are static, the collapsed // dimension must be static. In other words, if a collapsed dimension is // dynamic, at least one of the related dimensions need to be dynamic. - int64_t collapsedDim = 0, expandedDim = 0; - while (collapsedDim < collapsedRank && expandedDim < expandedRank) { - // Not empty means the associations has been filled in and the dimension - // is unchanged. - if (!reassociation[collapsedDim].empty()) { - if (expandedDim != reassociation[collapsedDim][0]) - return op.emitOpError("Unsupported: expanded dims are off from the " - "expected dim got from reassociation"); - collapsedDim++; - expandedDim++; - continue; - } - - // Collect the dims that are collapsed until hitting the next dim that's - // unchanged. - SmallVector collapsedDims; - while (collapsedDim < collapsedRank && - reassociation[collapsedDim].empty()) { - collapsedDims.push_back(collapsedDim); - collapsedDim++; - } - // the next reassociation is for a dim that's unchanged. - int64_t expandedDimNext = collapsedDim != collapsedRank - ? reassociation[collapsedDim][0] - : expandedRank; - if (collapsedDims.size() == 1) { - int64_t collapsedDimSize = 1; - int64_t collapsedDim = collapsedDims[0]; - for (auto i : llvm::seq(expandedDim, expandedDimNext)) { - reassociation[collapsedDim].push_back(i); - if (collapsedDimSize == kUnknownSize) - continue; - - int64_t expandedDimSize = expandedShape[i]; - if (expandedDimSize == kUnknownSize) { - collapsedDimSize = kUnknownSize; - continue; + int64_t inputDim = 0, outputDim = 0; + for (auto boundary : unchangedDims) { + // We assume dims specified by AtenSizeInt ops are unchanged + int64_t nextUnchangedInput = boundary[0]; + int64_t nextUnchangedOutput = boundary[1]; + + bool hasDynamic = false; + while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + + // outputDim is next to the boundary + if (outputDim == nextUnchangedOutput - 1) { + if (hasDynamic && inputDim != nextUnchangedInput - 1) { + return rewriter.notifyMatchFailure( + op, "found ambiguous collapse of dynamic input sizes (e.g. " + "[-1, -1, -1] -> [-1, -1])"); } - collapsedDimSize *= expandedShape[i]; + outputAssociations.back().push_back(outputDim); + if (failed(collapseToSingleDimHelper( + op, rewriter, outputDim, nextUnchangedOutput, inputDim, + nextUnchangedInput, outputShape, inputShapeVec, + inputAssociations.back()))) + return failure(); + outputDim = nextUnchangedOutput; + inputDim = nextUnchangedInput; + continue; } - // To meet both requirements from tensor.expand_shape verification code. - collapsedShape[collapsedDim] = collapsedDimSize; - expandedDim = expandedDimNext; - continue; - } - // collpasedDims are expanded to [expandedDim, expandedDimNext) - if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size()) - op.emitError("unimplemented: mixed of expanding and collapsing " - "operations for view"); - for (auto collapsedDim : collapsedDims) { - if (collapsedShape[collapsedDim] == kUnknownSize) { - if (expandedDim >= expandedDimNext) { + // inputDim is next to the boundary + if (inputDim == nextUnchangedInput - 1) { + if (hasDynamic && inputShape[inputDim] == kUnknownSize) { return rewriter.notifyMatchFailure( - op, - "desired size is not compatible with the input tensor size"); - } - checkDimEqualHelper(rewriter, loc, collapsedShapeInt[collapsedDim], - expandedShapeInt[expandedDim]); - // To meet the second requirement from tensor.expand_shape - // verification code. - expandedShape[expandedDim] = kUnknownSize; - reassociation[collapsedDim].push_back(expandedDim++); - } else { - int64_t remainingSizeToExpand = collapsedShape[collapsedDim]; - // A do-while loop is used here to handle the cases where the - // collapsed shape tensor has a dimension of size 1. - do { - int64_t expandedDimSize = expandedShape[expandedDim]; - if (expandedDim >= expandedDimNext || - expandedShape[expandedDim] == kUnknownSize || - remainingSizeToExpand % expandedDimSize != 0) { - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); - } - reassociation[collapsedDim].push_back(expandedDim++); - remainingSizeToExpand /= expandedDimSize; - } while (remainingSizeToExpand != 1); - - // If all dims until `expandedDimNext` are of size 1, then group those - // with the reassociation for the current `collapsedDim`. - auto expandedShapeSlice = - llvm::makeArrayRef(expandedShape) - .slice(expandedDim, expandedDimNext - expandedDim); - if (llvm::all_of(expandedShapeSlice, - [](int64_t val) { return val == 1; })) { - reassociation[collapsedDim].append( - llvm::to_vector(llvm::seq(expandedDim, expandedDimNext))); - expandedDim = expandedDimNext; + op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " + "[-1, -1, -1])"); } + inputAssociations.back().push_back(inputDim); + if (failed(collapseToSingleDimHelper( + op, rewriter, inputDim, nextUnchangedInput, outputDim, + nextUnchangedOutput, inputShapeVec, outputShape, + outputAssociations.back()))) + return failure(); + outputDim = nextUnchangedOutput; + inputDim = nextUnchangedInput; + continue; } + + int64_t inputMatchingDimSize = inputShapeVec[inputDim]; + int64_t outputMatchingDimSize = outputShape[outputDim]; + + // If the input is dynamic, first assume it is not split + if (inputMatchingDimSize == kUnknownSize) { + checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], + outputShapeInt[outputDim]); + outputShape[outputDim] = kUnknownSize; + inputAssociations.back().push_back(inputDim++); + outputAssociations.back().push_back(outputDim++); + hasDynamic = true; + continue; + } + + // inputDim size is larger; try to collapse onto it + if (inputMatchingDimSize >= outputMatchingDimSize) { + inputAssociations.back().push_back(inputDim); + if (failed(minimallyCollapseDimHelper( + op, rewriter, inputDim, nextUnchangedInput, outputDim, + nextUnchangedOutput, inputShapeVec, outputShape, + outputAssociations.back()))) + return failure(); + hasDynamic = false; + outputDim = outputAssociations.back().back() + 1; + inputDim++; + continue; + } + + // outputDim is larger; try to collapse onto it + outputAssociations.back().push_back(outputDim); + if (failed(minimallyCollapseDimHelper( + op, rewriter, outputDim, nextUnchangedOutput, inputDim, + nextUnchangedInput, outputShape, inputShapeVec, + inputAssociations.back()))) + return failure(); + hasDynamic = false; + inputDim = inputAssociations.back().back() + 1; + outputDim++; + continue; + } + + if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) { + return rewriter.notifyMatchFailure( + op, "could not match input tensor shape to output shape; " + "potentially unsupported view shape"); + } + + // Append the associations for the dims matching `aten.size.int` + if (nextUnchangedInput != inputRank && + nextUnchangedOutput != resultRank) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + inputAssociations.back().push_back(inputDim++); + outputAssociations.back().push_back(outputDim++); } } - if (collapsedDim != collapsedRank || expandedDim != expandedRank) - return rewriter.notifyMatchFailure(op, "view shape is not supported"); + // Check if the shapes already match up to dynamic sizes. If so, we can just + // cast as the result type because the previous loop sets up the necessary + // dim checks in case of dynamic sizes. + if (llvm::all_of( + inputAssociations, + [](ReassociationIndices indices) { return indices.size() == 1; }) && + llvm::all_of(outputAssociations, [](ReassociationIndices indices) { + return indices.size() == 1; + })) { + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } + Type adjustedResultType = - RankedTensorType::get(isCollapse ? collapsedShape : expandedShape, - resultType.getElementType()); + RankedTensorType::get(outputShape, resultType.getElementType()); Type adjustedInputType = - RankedTensorType::get(isCollapse ? expandedShape : collapsedShape, - resultType.getElementType()); + RankedTensorType::get(inputShapeVec, resultType.getElementType()); Value castedInput = rewriter.create(loc, adjustedInputType, input); - Value result = - isCollapse - ? rewriter - .create(loc, adjustedResultType, - castedInput, reassociation) - .result() - : rewriter - .create(loc, adjustedResultType, - castedInput, reassociation) - .result(); + llvm::Optional expandedInput; + llvm::Optional collapsedInput; + + if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) { + return indices.size() > 1; + })) { + SmallVector intermediateShape; + for (auto i : llvm::seq(0, (int)inputAssociations.size())) { + if (inputAssociations[i].size() > 1) { + intermediateShape.push_back(outputShape[outputAssociations[i][0]]); + } else { + intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]); + } + } + Type intermediateResultType = + RankedTensorType::get(intermediateShape, resultType.getElementType()); + expandedInput = + rewriter + .create(loc, intermediateResultType, + castedInput, inputAssociations) + .getResult(); + } + + if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) { + return indices.size() > 1; + })) { + collapsedInput = rewriter + .create( + loc, adjustedResultType, + expandedInput.has_value() ? expandedInput.value() + : castedInput, + outputAssociations) + .getResult(); + } + + Value result = collapsedInput.has_value() ? collapsedInput.value() + : expandedInput.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } @@ -893,6 +1024,10 @@ class ConvertAtenCatOp : public OpConversionPattern { for (int i = 0; i < rank; ++i) sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); + dim = toPositiveDim(dim, rank); + if (!isValidDim(dim, rank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + // Calculate the size of the `dim` result dimension by adding the dim size // of each tensor together. Value resultDimSize = sizes[dim]; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index ae6ff2669c17..c246593d200a 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -168,6 +168,276 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern { }; } // namespace +namespace { +// AtenEmbeddingPaddingIdxOp +// SUM mode == integer 0 +// Sums bags of embeddings together from a weight tensor based on an index and +// offset Vector. Example arguments weight = [[1, 3, 5, 3], +// [3, 4, 2, 1], +// [2, 2, 3, 2], +// [0, 4, 2, 1]] +// +// indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1] +// offsets = [0, 3, 5] +// +// output_tensor = initZeroTensor(offsets_length, embedding_size) +// +// for i in range(offsets_length): <- dim0 +// for j in range(indices_length): <- dim1 +// for k in range(embedding_size): <- dim2 +// if(offsets[i] <= j and j < offsets[i+1]): +// output_tensor[i][k] = output_tensor[i][k] + +// weight[indices[j]][k] +// else: +// break +// +// Indexing maps for linalg::Generic ops +// +// +// indices_indexing_map = (d0, d1, d2) -> (d1) +// offset_indexing_map = (d0, d1, d2) -> (d0) +// output_indexing_map = (d0, d1, d2) -> (d0, d2) +// +// TODO: Find an optimal lowering. +// current lowering is not optimal for bags of large embeddings. +// Since it traverses the output tensor multiple times. +// +// + +class ConvertAtenEmbeddingBagPaddingIdxOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + auto context = op->getContext(); + Value weight = adaptor.weight(); + Value indices = adaptor.indices(); + Value offsets = adaptor.offsets(); + Value scaleGradByFreq = op.scale_grad_by_freq(); + Value mode = op.mode(); + Value sparse = op.sparse(); + Value includeLastOffset = op.include_last_offset(); + + bool scaleGradByFreqBool; + if (!matchPattern(scaleGradByFreq, + m_TorchConstantBool(&scaleGradByFreqBool))) { + return rewriter.notifyMatchFailure( + op, "scale_grad_by_freq is expected to be a constant boolean value."); + } + + if (scaleGradByFreqBool) { + return rewriter.notifyMatchFailure( + op, "Unimplemented: scale_grad_by_freq=True."); + } + + int64_t modeInt; + if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { + return rewriter.notifyMatchFailure( + op, "mode is expected to be a constant integer value."); + } + + if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { + return rewriter.notifyMatchFailure( + op, + "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); + } + + bool isSparse; + if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) { + return rewriter.notifyMatchFailure( + op, "sparse is expected to be a constant boolean value."); + } + + if (isSparse) { + return rewriter.notifyMatchFailure( + op, + "Unimplemented: Sparse mode is not supported yet for EmbeddingBag."); + } + + bool discardLastOffset; + if (!matchPattern(includeLastOffset, + m_TorchConstantBool(&discardLastOffset))) { + return rewriter.notifyMatchFailure( + op, + "include_last_offset is expected to be a constant boolean value."); + } + + auto weightTy = weight.getType().cast(); + if (weightTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "weight must be rank 2"); + + auto indicesTy = indices.getType().cast(); + if (indicesTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "indices must be a vector"); + + auto offsetsTy = offsets.getType().cast(); + if (offsetsTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "offsets much be a vector"); + + Type weightElemTy = weightTy.getElementType(); + + int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank(); + SmallVector indicesExpr; + indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); + auto indicesIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + indicesExpr, context); + + SmallVector offsetsExpr; + offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); + + auto offsetIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + offsetsExpr, context); + + SmallVector outputExpr; + outputExpr.push_back(mlir::getAffineDimExpr(0, context)); + outputExpr.push_back(mlir::getAffineDimExpr(2, context)); + + auto outputIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + outputExpr, context); + + SmallVector indexingMaps = { + indicesIndexingMap, + offsetIndexingMap, + outputIndexingMap, + }; + + SmallVector iteratorTypes(iterationMapDimension, + getParallelIteratorTypeName()); + + Value embeddingDim = getDimOp(rewriter, loc, weight, 1); + Value initTensor; + Value offsetsLength; + Value indicesLength; + if (!discardLastOffset) { + SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), + embeddingDim}; + + initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); + offsetsLength = getDimOp(rewriter, loc, offsets, 0); + indicesLength = getDimOp(rewriter, loc, indices, 0); + } else { + return rewriter.notifyMatchFailure( + op, "Unimplemented: include last offset is not yet " + "supported for EmbeddingBag."); + } + + Value embeddingBagResult = + rewriter + .create( + loc, initTensor.getType(), ValueRange{indices, offsets}, + initTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value indexInIndices = args[0]; + Value offsetsI = args[1]; + Value initTensorElem = args[2]; + + Value indexI = b.create(loc, /*value=*/0); + Value indexIToInt = castIndexToInt64(b, loc, indexI); + Value one = getConstant( + b, loc, 1, + mlir::IntegerType::get(getContext(), 64, + IntegerType::Signless)); + Value offsetIndexPlusOneInt = + b.create(loc, indexIToInt, one); + + Value offsetIndexPlusOne = + castIntToIndex(b, loc, offsetIndexPlusOneInt); + Value checkLast = b.create( + loc, arith::CmpIPredicate::eq, + castIndexToInt64(b, loc, offsetsLength), + offsetIndexPlusOneInt); + Value nextOffset = b.create( + loc, checkLast, castIndexToInt64(b, loc, indicesLength), + b.create(loc, offsets, + offsetIndexPlusOne)); + + Value indicesIndex = castIndexToInt64( + b, loc, b.create(loc, /*value=*/1)); + + Value offsetLessThanIndicesIndex = b.create( + loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex); + Value offsetEqualToIndicesIndex = b.create( + loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); + Value offsetLessThanOrEqualToIndicesIndex = + b.create(loc, offsetLessThanIndicesIndex, + offsetEqualToIndicesIndex); + + Value indicesIndexLessThanNextOffset = + b.create(loc, arith::CmpIPredicate::slt, + indicesIndex, nextOffset); + + Value indicesIndexWithinBounds = b.create( + loc, offsetLessThanOrEqualToIndicesIndex, + indicesIndexLessThanNextOffset); + + SmallVector indexIntoWeight; + indexIntoWeight.push_back( + castIntToIndex(b, loc, indexInIndices)); + indexIntoWeight.push_back( + b.create(loc, /*value=*/2)); + Value weightElem = b.create( + loc, weight, indexIntoWeight); + + Value addResult = b.create(loc, weightElem, + initTensorElem); + Value select = + b.create(loc, indicesIndexWithinBounds, + addResult, initTensorElem); + b.create(loc, select); + }) + .getResult(0); + + // cast outputType. + auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); + Value castedEmbeddingBagResult = + rewriter.create(loc, restulType0, embeddingBagResult); + + // offset2 tensor, this should be an empty tensor for the sum mode + SmallVector offsetResultSize; + Type offsetElemTy = offsetsTy.getElementType(); + Value zeroDim = rewriter.create(loc, /*value=*/0); + offsetResultSize.push_back(zeroDim); + Value offsetResult = rewriter.create( + loc, offsetResultSize, offsetElemTy); + auto resultType1 = typeConverter->convertType(op->getResult(1).getType()); + Value castedOffsetResult = + rewriter.create(loc, resultType1, offsetResult); + + SmallVector offsetSize = getTensorSizes(rewriter, loc, offsets); + // bagsize, vector of size offset with zeros, I think this is always just + // a vector of zeros in the sum mode + Value bagSize = + createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); + auto resultType2 = typeConverter->convertType(op->getResult(2).getType()); + Value castedBagSizeResult = + rewriter.create(loc, resultType2, bagSize); + + // max indices, vector of size offset with zeros, this is also always a + // vector of zeros in the sum mode. Its mainly used in the max mode. + Value indicesOut = + createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); + auto resultType3 = typeConverter->convertType(op->getResult(3).getType()); + Value castedMaxIndices = + rewriter.create(loc, resultType3, indicesOut); + + rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, + castedBagSizeResult, castedMaxIndices}); + + return success(); + } +}; +} // namespace + namespace { // Let's say we have an input tensor: initialized with some random values of // size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an @@ -244,6 +514,21 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { }; } // namespace +// IndexTensor for multiple input tensors broadcasts their shapes to a common +// shape and then replaces the indexed dims with the indices given by the +// indexing tensors: +// x[i_1, i_2, ..., i_M] = result +// result[...] = x[i_1[...], i_2[...], ..., i_M[...]] +// +// where the result shape is computed as follows: +// 1. broadcast i_1, i_2, ..., i_M to a common shape +// 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted +// shape to the beginning of the result shape, while removing the +// unchanged dims (marked by None) +// 3. Otherwise replace the indexed dims with the broadcasted shape +// +// e.g. x: [2, 3] +// x[[4], [6, 1]] -> x[6, 4] namespace { class ConvertAtenIndexTensorOp : public OpConversionPattern { public: @@ -251,6 +536,7 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -262,65 +548,198 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unimplemented: the indices list is not from a list construct"); } - if (indicesTuple.size() != 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: only one index tensor is supported"); - } SmallVector indicesVal = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple); - Value indexTensor = indicesVal[0]; - if (failed(checkNotNone(rewriter, op, indexTensor))) { + + // Identify the indices with non-None index tensors and determine if they + // are contiguous within the input list. + SmallVector indexTensorDims; + SmallVector indexTensors; + bool contiguous = true; + for (auto i : llvm::seq(0, (int)indicesVal.size())) { + Value index = indicesVal[i]; + if (!index || failed(checkNotNone(rewriter, op, index))) + continue; + if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1) + contiguous = false; + indexTensorDims.push_back(i); + indexTensors.push_back(index); + } + + if (indexTensors.empty()) { return rewriter.notifyMatchFailure( - op, "unimplemented: index tensor must not be None"); + op, "aten.index.Tensor: index tensor must not be None"); } RankedTensorType inputType = input.getType().cast(); - RankedTensorType indexTensorType = - indexTensor.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); int inputRank = inputType.getRank(); - int indexTensorRank = indexTensorType.getRank(); + int resultRank = resultType.getRank(); + int firstIndexDim = indexTensorDims[0]; + int replacedIndexCount = indexTensorDims.size(); + int64_t startIndex = contiguous ? firstIndexDim : 0; + + // Currently we only support statically sized index tensors or dynamic size + // index tensors without overlapping dynamic dims when there is more than + // one index tensor. + // TODO: Add support for dynamic size index tensors with overlapping + // dynamic dims. + SmallVector broadcastedIndexShape; + if (indexTensors.size() > 1) { + int maxRank = -1; + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + maxRank = std::max(maxRank, (int)indexTensorType.getRank()); + } + + // Because we are assuming static shapes, we can get the shape of the + // broadcasted index tensors from the shape refinement pass + auto refinedResultShape = resultType.getShape(); + for (auto i : llvm::seq(startIndex, startIndex + maxRank)) { + auto resultDimSize = refinedResultShape[i]; + if (ShapedType::isDynamic(resultDimSize)) { + SmallVector dynamicDims; + int64_t staticDimSize = -1; + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + int64_t indexTensorRank = indexTensorType.getRank(); + if ((maxRank - indexTensorRank) > (i - startIndex)) + continue; + int64_t dim = i - startIndex - maxRank + indexTensorRank; + if (ShapedType::isDynamic(indexTensorType.getShape()[dim])) + dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim)); + else + staticDimSize = + std::max(staticDimSize, indexTensorType.getShape()[dim]); + } + if (dynamicDims.size() >= 2) + return rewriter.notifyMatchFailure( + op, + "unimplemented: index tensors with overlapping dynamic dims"); + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } + broadcastedIndexShape.push_back(dynamicDims[0]); + } else { + broadcastedIndexShape.push_back(getConstant( + rewriter, loc, resultDimSize, rewriter.getIndexType())); + } + } + } else { + // For a single indexing tensor we can simply use its (dynamic) sizes + broadcastedIndexShape = + getTensorSizes(rewriter, loc, indexTensors.front()); + } // This result shape calculation assumes that there is only one - // index tensor and that it is indexing the first dimension of the - // input tensor. The calculation for arbitrary inputs is much more complex. + // index tensor, or all of the index tensors are statically shaped. + int broadcastRank = broadcastedIndexShape.size(); + SmallVector resultShape; - for (auto i : llvm::seq(0, indexTensorRank)) { - resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i)); - } - for (auto i : llvm::seq(1, inputRank)) { - resultShape.push_back(getDimOp(rewriter, loc, input, i)); + if (contiguous) { + for (auto i : llvm::seq(0, firstIndexDim)) { + resultShape.push_back(getDimOp(rewriter, loc, input, i)); + } + resultShape.append(broadcastedIndexShape); + for (auto i : llvm::seq((int)resultShape.size(), resultRank)) { + resultShape.push_back(getDimOp(rewriter, loc, input, + i - broadcastRank + replacedIndexCount)); + } + } else { + resultShape.append(broadcastedIndexShape); + int j = 0; + for (auto i : llvm::seq(0, inputRank)) { + if (j < replacedIndexCount && i == indexTensorDims[j]) { + j++; + continue; + } + resultShape.push_back(getDimOp(rewriter, loc, input, i)); + } } - int resultRank = resultShape.size(); + // Initialize the indexing maps for the generic op. Because we are assuming + // static shapes for the indexing tensors when there are more than 1, we can + // safely map all size 1 dims to 0 in the corresponding affine maps. + // TODO: For dynamic shapes, we have to either broadcast the index tensors + // to a common shape or introduce some form of control flow. Value initTensor = rewriter.create(loc, resultShape, elementType); - SmallVector indicesExpr, resultExpr; + SmallVector indexingMaps; SmallVector iteratorTypes; - for (auto i : llvm::seq(0, indexTensorRank)) - indicesExpr.push_back(rewriter.getAffineDimExpr(i)); + for (auto indexTensor : indexTensors) { + RankedTensorType indexTensorType = + indexTensor.getType().cast(); + auto indexTensorShape = indexTensorType.getShape(); + int rank = indexTensorShape.size(); + SmallVector indicesExpr; + for (auto dim : llvm::seq(0, rank)) { + if (indexTensorShape[dim] == 1) { + indicesExpr.push_back(rewriter.getAffineConstantExpr(0)); + continue; + } + indicesExpr.push_back( + rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim)); + } + indexingMaps.push_back( + AffineMap::get(resultRank, 0, indicesExpr, op->getContext())); + } + + SmallVector resultExpr; for (auto i : llvm::seq(0, resultRank)) { resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } - auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); + + indexingMaps.push_back( + AffineMap::get(resultRank, 0, resultExpr, op->getContext())); Value finalRes = rewriter .create( - loc, initTensor.getType(), indexTensor, initTensor, + loc, initTensor.getType(), indexTensors, initTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector extractionIndices{ - castIntToIndex(b, loc, args[0])}; - for (auto i : llvm::seq(1, inputRank)) { - extractionIndices.push_back(b.create( - loc, i + indexTensorRank - 1)); + SmallVector extractionIndices; + if (contiguous) { + for (auto i : llvm::seq(0, firstIndexDim)) { + extractionIndices.push_back( + b.create(loc, i)); + } + for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { + extractionIndices.push_back( + castIntToIndex(b, loc, args[i])); + } + for (auto i : + llvm::seq((int)extractionIndices.size(), inputRank)) { + extractionIndices.push_back(b.create( + loc, i + broadcastRank - replacedIndexCount)); + } + } else { + int indexCount = 0, unchanged = 0; + for (auto i : llvm::seq(0, inputRank)) { + if (indexCount < replacedIndexCount && + i == indexTensorDims[indexCount]) { + extractionIndices.push_back( + castIntToIndex(b, loc, args[indexCount++])); + continue; + } + extractionIndices.push_back(b.create( + loc, broadcastRank + unchanged)); + unchanged++; + } } Value extractedElement = b.create( loc, input, extractionIndices); @@ -347,4 +766,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 550e41d27f69..9482187e5bed 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -635,12 +636,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value input = adaptor.input(); /* in form of N*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */ + bool transposed = true; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant transposed supported"); + Type elementType = input.getType().cast().getElementType(); if (!elementType.isa()) return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); - if (inRank != 4) + size_t numSpacialDims = inRank - 2; + if (numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolution currently supported"); @@ -663,57 +670,161 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - Value N = getDimOp(rewriter, loc, input, 0); + Value inBatch = getDimOp(rewriter, loc, input, 0); + Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; for (size_t i = 2; i < inRank; i++) inDims.push_back(getDimOp(rewriter, loc, input, i)); - Value F = getDimOp(rewriter, loc, weight, 0); + Value weightBatch = getDimOp(rewriter, loc, weight, 0); + Value weightChannels = getDimOp(rewriter, loc, weight, 1); SmallVector weightDims; for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Guard unused values (transposed, groups) - int64_t group_size; - if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) || - group_size != 1) - return rewriter.notifyMatchFailure( - op, "unimplemented: only group size of 1 supported"); - bool transposed = true; - if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) || - transposed) - return rewriter.notifyMatchFailure( - op, "unimplemented: only non-transposed convolution supported"); - - // Pad the input tensor according to padding. - SmallVector paddingIncludingNC = {0, 0}; - paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), - paddingInts.end()); - Value paddedInput = torch_to_linalg::getZeroPaddedTensor( - op, rewriter, input, paddingIncludingNC); - - SmallVector paddingIntValues = - getAsConstantIntValues(rewriter, loc, paddingInts); + // Checks for valid group size + int64_t groupSize; + if (!matchPattern(op.groups(), m_TorchConstantInt(&groupSize))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.groups()); + + auto validate = [&](Value toValidate, std::string err) { + Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value inputValid = rewriter.create( + loc, arith::CmpIPredicate::eq, c0, + rewriter.create(loc, toValidate, groups)); + rewriter.create(loc, inputValid, + rewriter.getStringAttr(err)); + }; + validate(inChannels, + "invalid: groups must divide input channel size evenly."); + validate(weightBatch, + "invalid: groups must divide weight batch size evenly."); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector paddingIntValues = + getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); - SmallVector outDims{N, F}; - for (size_t i = 0; i < inRank - 2; i++) - outDims.push_back(torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], - castIndexToInt(weightDims[i]), strideIntValues[i])); + // Pad the input tensor according to padding. + SmallVector outDims{inBatch, weightBatch}; + Value paddedInput; + if (transposed) { + Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value c2 = + rewriter.create(loc, rewriter.getIndexAttr(2)); + + // Transpose and flip weight + SmallVector weightInitDims = getTensorSizes(rewriter, loc, weight); + std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); + outDims[1] = weightInitDims[0]; + Value weightInitTensor = + createZeroInitTensor(rewriter, loc, weightInitDims, elementType); + SmallVector iteratorTypes(inRank, + getParallelIteratorTypeName()); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(inRank, context)); + weight = rewriter + .create( + loc, weightInitTensor.getType(), weight, + weightInitTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (size_t i = 0; i < inRank; i++) + indices.push_back(b.create(loc, i)); + std::iter_swap(indices.begin(), indices.begin() + 1); + // Flip only the spatial dimensions (from 2 to inRank) + for (size_t flipDim = 2; flipDim < inRank; flipDim++) { + indices[flipDim] = b.create( + loc, + b.create( + loc, weightInitDims[flipDim], c1), + indices[flipDim]); + } + Value res = + b.create(loc, weight, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + + // Calculate padded input size, allocate tensor + SmallVector outerSizes{inBatch, inChannels}; + SmallVector innerSizes{inBatch, inChannels}; + SmallVector offsets{c0, c0}; + for (size_t i = 0; i < numSpacialDims; i++) { + Value innerSize = rewriter.create(loc, inDims[i], c1); + innerSize = rewriter.create( + loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); + innerSize = rewriter.create(loc, innerSize, c1); + + Value offset = rewriter.create(loc, weightDims[i], c1); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); + offset = rewriter.create( + loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); + + Value outerSize = rewriter.create(loc, offset, c2); + outerSize = rewriter.create(loc, outerSize, innerSize); + + outerSizes.push_back(outerSize); + offsets.push_back(offset); + } + + // Allocate padded input tensor + Value initTensor = + createZeroInitTensor(rewriter, loc, outerSizes, elementType); + + // Insert input into allocated tensor + SmallVector strideIndexValues{c1, c1}; + for (auto stride : strideIntValues) + strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride)); + SmallVector insertSizes = getTensorSizes(rewriter, loc, input); + + paddedInput = rewriter.create( + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initTensor, offsets, insertSizes, strideIndexValues); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + + // Set stride to 1 + strideInts.clear(); + strideInts.append(numSpacialDims, 1); + + } else { + // Pad input + SmallVector paddingIncludingNC = {0, 0}; + paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), + paddingInts.end()); + paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input, + paddingIncludingNC); + + // Calculate output dims + for (size_t i = 0; i < numSpacialDims; i++) + outDims.push_back(torch_to_linalg::getOutputDimForConvOps( + rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], + castIndexToInt(weightDims[i]), strideIntValues[i])); + } Value initTensor = rewriter.create(loc, outDims, elementType); Value bias = adaptor.bias(); - Value biasInitTensor; + Value outputTensor; if (bias.getType().isa()) { Value c0float = rewriter.create( loc, FloatAttr::get(elementType, 0.0)); - biasInitTensor = rewriter.create(loc, c0float, initTensor) - .getResult(0); + outputTensor = rewriter.create(loc, c0float, initTensor) + .getResult(0); } else { auto biasType = bias.getType().cast(); if (biasType.getRank() != 1) @@ -727,27 +838,144 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context), rewriter.getMultiDimIdentityMap(resultRank)}; - SmallVector iteratorTypes(resultRank, "parallel"); - biasInitTensor = rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + SmallVector iteratorTypes(resultRank, + getParallelIteratorTypeName()); + outputTensor = rewriter + .create( + loc, initTensor.getType(), bias, initTensor, + indexingMaps, iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); - // TODO: add 1D and 3D case - Value conv = - rewriter - .create( - loc, biasInitTensor.getType(), ValueRange{paddedInput, weight}, - biasInitTensor, stridesAttr, dilationAttr) - .getResult(0); + Value inputStride = + rewriter.create(loc, inChannels, groups); + Value weightStride = + rewriter.create(loc, weightBatch, groups); + + SmallVector zeroOffsets(inRank, rewriter.create( + loc, rewriter.getIndexAttr(0))); + SmallVector unitStrides(inRank, rewriter.create( + loc, rewriter.getIndexAttr(1))); + SmallVector outDimSlice(outDims); + outDimSlice[1] = weightStride; + SmallVector inputSliceSizes{inBatch, inputStride}; + inputSliceSizes.append(inDims); + SmallVector weightSliceSizes{weightStride, weightChannels}; + weightSliceSizes.append(weightDims); + + Value conv; + if (groupSize == 1) { + // TODO: add 1D and 3D case + conv = + rewriter + .create( + loc, outputTensor.getType(), ValueRange{paddedInput, weight}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + } else { + // Special depthwise case + auto inShape = input.getType().cast().getShape(); + auto weightShape = weight.getType().cast().getShape(); + if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && + weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { + // Collapse weight shape + SmallVector collapsedDims = {{0, 1}, {2}, {3}}; + SmallVector collapsedShape{ + (weightShape[0] == kUnknownSize ? kUnknownSize + : weightShape[0] * weightShape[1]), + weightShape[2], weightShape[3]}; + Type collapsedType = RankedTensorType::get(collapsedShape, elementType); + Value collapsedWeight = rewriter.create( + loc, collapsedType, weight, collapsedDims); + + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); + } + + // Grouped case, use the grouped conv linalg op + + auto expandGroups = [&](Value tensor, size_t dim) { + auto inType = tensor.getType().cast(); + auto inShape = inType.getShape(); + + SmallVector outShape; + for (auto i = 0; i < (long)inShape.size(); i++) { + if (i == 1) { + outShape.push_back(groupSize); + } + if (i == (long)dim) { + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / groupSize); + } else { + outShape.push_back(inShape[i]); + } + } + + SmallVector indices; + for (auto i = 0; i <= (long)inShape.size(); i++) { + if (i == (long)dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(outShape); + return rewriter.create(loc, retType, tensor, + indices); + }; + + auto expandWeight = [&](Value tensor) { + auto inType = tensor.getType().cast(); + auto inShape = inType.getShape(); + + SmallVector outShape{ + groupSize, (inShape[0] == kUnknownSize ? kUnknownSize + : inShape[0] / groupSize)}; + outShape.append(inShape.begin() + 1, inShape.end()); + + SmallVector indices{{0, 1}}; + for (auto i = 2; i <= (long)inShape.size(); i++) + indices.push_back({i}); + + auto retType = inType.clone(outShape); + return rewriter.create(loc, retType, tensor, + indices); + }; + + Value paddedInputExpanded = expandGroups(paddedInput, 1); + Value weightExpanded = expandWeight(weight); + Value outputTensorExpanded = expandGroups(outputTensor, 1); + + // TODO: add 1D and 3D case + conv = rewriter + .create( + loc, outputTensorExpanded.getType(), + ValueRange{paddedInputExpanded, weightExpanded}, + outputTensorExpanded, stridesAttr, dilationAttr) + .getResult(0); + + SmallVector indices{{0}, {1, 2}}; + for (auto dim = 3; dim <= (int64_t)inRank; dim++) + indices.push_back({dim}); + conv = rewriter.create( + loc, outputTensor.getType(), conv, indices); + } Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); diff --git a/lib/Conversion/TorchToLinalg/PopulatePatterns.h b/lib/Conversion/TorchToLinalg/PopulatePatterns.h index 384c89d333eb..56691c82c7c1 100644 --- a/lib/Conversion/TorchToLinalg/PopulatePatterns.h +++ b/lib/Conversion/TorchToLinalg/PopulatePatterns.h @@ -63,9 +63,6 @@ void populateIndirectDataMovementPatternsAndLegality( void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); -//void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter, -// RewritePatternSet &patterns, -// ConversionTarget &target); } // namespace torch_to_linalg } // namespace torch diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index cf2d2beee6e3..728c53bf21dd 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -239,7 +239,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - auto abs = b.create(loc, self); + auto abs = b.create(loc, self); AtenLinalgVectorNormOp::Adaptor adaptor(operands); Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType); auto pow = b.create(loc, abs, ord); @@ -270,6 +270,8 @@ class ConvertReductionOp : public ConversionPattern { "`keepdim` must be a constant bool"); SmallVector dimList; + bool isNoneOrEmptyDimList = + op.dim().getType().template isa(); if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) { // Fix negative dimensions, if any, before adding to the list. for (int64_t dim : dimList) { @@ -278,13 +280,16 @@ class ConvertReductionOp : public ConversionPattern { if (isValidDim(dim, inputType.getRank())) opInfo.dimSet.insert(dim); } - } else if (op.dim().getType().template isa()) { + if (dimList.empty()) + isNoneOrEmptyDimList = true; + } else if (!isNoneOrEmptyDimList) { + return rewriter.notifyMatchFailure( + op, "`dim` argument must be a constant int list or None"); + } + if (isNoneOrEmptyDimList) { // If no dimensions were specified, reduce along all dimensions for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); - } else { - return rewriter.notifyMatchFailure( - op, "`dim` argument must be a constant int list or None"); } return opInfo; diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1caa1408f9a0..f8ebc349fad7 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -62,8 +62,6 @@ class ConvertTorchToLinalg RewritePatternSet patterns(context); - //torch_to_linalg::populateCustomOpExamplePatternsAndLegality( - // typeConverter, patterns, target); torch_to_linalg::populateTensorScalarInteropPatternsAndLegality( typeConverter, patterns, target); torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 3ce2b7a2b16f..d1f473990361 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -131,6 +131,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -139,6 +143,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -202,7 +210,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhsTest, rhsTest); } if (isa(op)) - return b.create(loc, payloadArgs[0]); + return b.create(loc, payloadArgs[0]); if (isa(op)) { auto negate = createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -383,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } } + if (auto atan2 = dyn_cast(op)) { + Type dtype = converter->convertType(atan2.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + atan2.emitError("Atan2 requires floating point result type"); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (auto gtTensor = dyn_cast(op)) { AtenGtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); @@ -783,6 +803,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } + if (auto remScalar = dyn_cast(op)) { + Type newResultType = converter->convertType(remScalar.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, operands[1], newResultType); + Value result; + + if (newResultType.isa()) { + result = b.create(loc, self, other); + } else if (newResultType.isa()) { + result = b.create(loc, self, other); + } else { + remScalar.emitError( + "Unsupported type encountered for AtenRemainderScalarOp."); + } + + return result; + } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -844,9 +884,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( threshold); return b.create(loc, predicate, constantZero, grad); } - if (auto maskedFill = dyn_cast(op)) { + if (auto maskedFillScalar = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(maskedFill.getType()) + Type dtype = converter->convertType(maskedFillScalar.getType()) .cast() .getElementType(); @@ -856,6 +896,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, mask, fillValue, input); } + if (auto maskedFillTensor = dyn_cast(op)) { + AtenMaskedFillScalarOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(maskedFillTensor.getType()) + .cast() + .getElementType(); + + Value input = payloadArgs[0]; + Value mask = payloadArgs[1]; + Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); + return b.create(loc, mask, fillValue, input); + } if (auto triu = dyn_cast(op)) { // Check if the rank of the input tensor is valid. @@ -918,18 +969,19 @@ class ConvertElementwiseOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const override { if (!isa(op)) + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1660,14 +1712,15 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, - AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, - AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, - AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, - AtenLogicalOrOp, AtenTriuOp>(); + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, + AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, + AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, + AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0d04dc552640..57a50a688f81 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -97,6 +97,31 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, return castIntToIndex(b, loc, out); } +Value torch_to_linalg::getOutputDimForConvTransposeOps( + OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt) { + Value c1 = b.create(loc, b.getI64IntegerAttr(1)); + Value c2 = b.create(loc, b.getI64IntegerAttr(2)); + + // (in - 1) * stride + Value inStrided = + b.create(loc, castIndexToInt64(b, loc, in), c1); + inStrided = b.create(loc, inStrided, strideInt); + + // 2 * padding + Value doublePadding = b.create(loc, paddingInt, c2); + + // (kernelSize - 1) * dilation + Value kernelDilated = b.create(loc, kernelSizeInt, c1); + kernelDilated = b.create(loc, kernelDilated, dilationInt); + + Value out = b.create(loc, inStrided, doublePadding); + out = b.create(loc, out, kernelDilated); + out = b.create(loc, out, c1); + + return castIntToIndex(b, loc, out); +} + Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { @@ -338,3 +363,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } + +Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, + Value tensor) { + auto tensorType = tensor.getType().cast(); + auto rank = tensorType.getRank(); + SmallVector unknownSizes(rank, kUnknownSize); + return b.create(loc, tensorType.clone(unknownSizes), tensor); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 6279b8c9e802..f57c7eaa376d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -39,6 +39,14 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value kernelSizeInt, Value strideInt, bool ceilMode = false); +// As above but for transposed convolution ops +// Along each dim: +// dim_out = +// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1 +Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in, + Value paddingInt, Value dilationInt, + Value kernelSizeInt, Value strideInt); + // Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions // in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the // same rank as the `opInfo.tensorOperand` and reduced dimensions are set to @@ -61,6 +69,9 @@ LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, SmallVector broadcastToShape, Value &result); +// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> +// +Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp new file mode 100644 index 000000000000..677419f37ec1 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -0,0 +1,1072 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; +using namespace mlir::torch::torch_to_mhlo; + +bool skipMultiplyAlpha(Value alphaValue) { + double doubleValue; + auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); + + int64_t intValue; + auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); + + return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); +} + +// These legalizations are for unary ops with only for floating point datatypes. +// There is no supported quantized integer mode for these. +namespace { +template +class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + + if (!selfTy) + return op.emitError("only Tensor types supported in MHLO"); + + if (selfTy.getElementType().isa()) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + self); + return success(); + } else { + return op.emitError( + "only floating-point datatype legalization supported"); + } + } +}; +} // namespace + +// aten.ones & aten.zeros +// Ref: Error checking based on the Torch to TOSA lowering +namespace { +template +class ConvertAtenConstPatternOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (!outType) + return op.emitError("only Tensor types supported in MHLO"); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) + return op.emitError( + "only floating-point or integer datatype legalization supported"); + + // FIXME: Handle layout, device and pin_memory. Assume dtype has been + // processed to set output type correctly? + if (!op.layout().getType().template isa()) + return op.emitError("only default layout is supported"); + + bool pinMemory; + if (!op.pin_memory().getType().template isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return op.emitError( + "unsupported pin_memory, should be either None or false"); + } + + SmallVector shape; + if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { + return op.emitError("shape must be a list of Scalar constants"); + } + + int64_t size = 1; + for (auto s : shape) + size *= s; + + SmallVector values(size, fillVal); + auto constOp = + mhlo::getConstTensor(rewriter, op, values, shape).value(); + + rewriter.replaceOpWithNewOp(op, outType, constOp); + return success(); + } +}; + +} // namespace + +// These binary op legalizations are specific to add/sub which have an +// alpha multiplier. +namespace { +template +class ConvertAtenAddSubOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + RankedTensorType lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + RankedTensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType) + return op.emitError("only Tensor types supported in MHLO"); + + TensorType outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (!rhsType) { + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); + if (isa(op)) { + std::swap(lhs, rhs); + } + } + + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); + + if (!skipMultiplyAlpha(op.alpha())) { + Value alpha = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy); + DenseIntElementsAttr bcastDimensions; + rhs = rewriter.create(op->getLoc(), rhs, alpha, + bcastDimensions); + } + + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, + bcastDimensions); + return success(); + } +}; +} // namespace + +// Binary op legalizations for Mul/Div variants. +namespace { +template +class ConvertAtenMulDivOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + TensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType) + return op.emitError("only Tensor types supported in MHLO"); + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (std::is_same()) { + rhs = lhs; + } else if (!rhsType) { + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); + } + DenseIntElementsAttr bcastDimensions; + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); + auto loc = op.getLoc(); + Value result = + rewriter.create(loc, outType, lhs, rhs, bcastDimensions); + + if (!isa(op)) { + rewriter.replaceOp(op, result); + return success(); + } + + AtenDivTensorModeOp divTensorModeOp = + llvm::dyn_cast(op.getOperation()); + std::string roundingMode; + if (!matchPattern(divTensorModeOp.rounding_mode(), + m_TorchConstantStr(roundingMode))) + return rewriter.notifyMatchFailure( + op, "only support constant str rounding mode"); + + if (roundingMode == "trunc") { + // "trunc" - rounds the results of the division towards zero. Equivalent + // to C-style integer division. + auto sign = rewriter.create(loc, result); + auto abs = rewriter.create(loc, result); + auto floor = rewriter.create(loc, abs); + result = rewriter.create(loc, sign, floor).getResult(); + } + if (roundingMode == "floor") { + // "floor" - rounds the results of the division down. Equivalent to + // floor division in Python (the // operator) + result = rewriter.create(loc, result).getResult(); + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +// Binary op legalizations for comparator ops. +namespace { +template +class ConvertAtenCompareOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + Value rhs = adaptor.other(); + RankedTensorType lhsTy = lhs.getType().dyn_cast(); + RankedTensorType rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("only Tensor types supported in MHLO"); + + RankedTensorType outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type lhsElemTy = lhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (!rhsTy) { + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy); + } + + // TODO: what is the PyTorch default type promotion? + rhs = mhlo::promoteType(rewriter, rhs, lhsTy); + + chlo::ComparisonTypeAttr compareTypeAttr; + chlo::ComparisonDirectionAttr compareDirectionAttr; + + if (lhsElemTy.isa()) { + compareTypeAttr = chlo::ComparisonTypeAttr::get( + op->getContext(), chlo::ComparisonType::FLOAT); + } else if (lhsElemTy.isa()) { + compareTypeAttr = chlo::ComparisonTypeAttr::get( + op->getContext(), chlo::ComparisonType::SIGNED); + } + + if (std::is_same() || + std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::LT); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::GT); + } else if (std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::GE); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::EQ); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::NE); + } + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, + compareTypeAttr); + return success(); + } +}; + +} // namespace + +// AtenTransposeIntOp +namespace { +class ConvertAtenTransposeIntOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + int64_t dim0; + if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) { + return rewriter.notifyMatchFailure(op, "dim0 must be constant"); + } + int64_t dim1; + if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) { + return rewriter.notifyMatchFailure(op, "dim1 must be constant"); + } + + auto inType = self.getType().cast(); + auto inputRank = inType.getRank(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + dim0 = toPositiveDim(dim0, inputRank); + if (!isValidDim(dim0, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim0 out of range"); + } + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim1 out of range"); + } + + SmallVector permValues(inputRank); + std::iota(std::begin(permValues), std::end(permValues), 0); + std::swap(permValues[dim0], permValues[dim1]); + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(permValues.size())}, + rewriter.getI64Type()), + permValues); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); + return success(); + } +}; +} // namespace + +// AtenBroadcastToOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBroadcastToOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + if (options.enableStaticShape && selfTy.hasStaticShape()) { + Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); + rewriter.replaceOp(op, bcastOp); + return success(); + } + + SmallVector shape; + if (!(getListConstructElements(adaptor.size(), shape))) { + return op->emitError("desired shape must be a list of scalar"); + } + SmallVector bcastShapeVec; + int64_t totalRank = shape.size(); + int64_t selfRank = selfTy.getRank(); + int64_t leadingRank = totalRank - selfRank; + + for (int64_t i = 0; i < totalRank; ++i) { + Value dValue = shape[i]; + Value newD; + int64_t dInt; + if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && + dInt == -1) { + newD = rewriter.create(op->getLoc(), self, + i - leadingRank); + } else { + dValue = rewriter.create(op->getLoc(), + dValue); + newD = rewriter.create( + op->getLoc(), rewriter.getIndexType(), dValue); + } + bcastShapeVec.push_back(newD); + } + + if (options.dimSizeIndexBits == 32) { + for (auto &dsize : bcastShapeVec) { + auto dsizeI64 = rewriter.create( + op->getLoc(), rewriter.getI64Type(), dsize); + dsize = rewriter.create(op->getLoc(), + rewriter.getI32Type(), dsizeI64); + } + } + + Value bcastShapeTensor = rewriter.create( + op->getLoc(), ValueRange{bcastShapeVec}); + auto dimensionNumbers = + llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); + rewriter.replaceOpWithNewOp( + op, outType, self, bcastShapeTensor, + rewriter.getI64TensorAttr(dimensionNumbers)); + return success(); +} + +// AtenPermuteOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + // Not a ranked tensor type + auto inType = self.getType().dyn_cast(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + if (!inType) + return op.emitError("only ranked tensor types with static shapes are " + "currently supported"); + + SmallVector permValues; + if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) + return rewriter.notifyMatchFailure( + op, "only constant dimensions are currently supported"); + + int64_t inRank = inType.getRank(); + for (auto &d : permValues) { + d = toPositiveDim(d, inRank); + if (!isValidDim(d, inRank)) + return op.emitError("not all dims are valid"); + } + + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(permValues.size())}, + rewriter.getI64Type()), + permValues); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); + return success(); +} + +// AtenTanhOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTanhOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + if (selfTy && selfTy.getElementType().isa()) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } else { + return op.emitError( + "only floating-point datatype legalization currently supported"); + } +} + +// ValueTensorLiteralOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + ValueTensorLiteralOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + // Tensors with integer types need to be converted to signless integer + // element type. All tensors with element types other than integer can reuse + // existing elements attribute. + // TODO: what about unsigned integer? + if (auto elements = op.valueAttr().dyn_cast()) { + Type builtinTensorElemTy = resultType.getElementType(); + unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); + + DenseElementsAttr valueAttr = + elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { + return APInt(bitWidth, v.getSExtValue()); + }); + rewriter.replaceOpWithNewOp(op, resultType, valueAttr); + return success(); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.value()); + return success(); +} + + +// AtenReciprocalOp +// Reciprocal(x) = Div(1, x) +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReciprocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + if (!inputTy.getElementType().isa()) { + return op.emitError("only floating-point datatype legalization supported " + "for AtenReciprocalOp"); + } + + Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); + rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); + return success(); +} + +// PrimNumToTensorScalarOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimNumToTensorScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType outputType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + auto outputElemType = outputType.getElementType(); + Value mhloTensor = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType); + rewriter.replaceOp(op, mhloTensor); + return success(); +} + +// AtenContiguousOp +// Ref: TosaToTosa.cpp for implementation details +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenContiguousOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("only tensor types are currently supported"); + + // FIXME: memory_format is not handled. + + rewriter.replaceOp(op, adaptor.self()); + + return success(); +} + + +// AtenReluOp +// Relu(x) = Max(0, x) +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + auto lhsElemTy = lhsTy.getElementType(); + + if (!lhsElemTy.isa()) { + return op->emitError("only float tensor in relu op is supported"); + } + + Value zeroTensor; + zeroTensor = chlo::getConstantLike( + rewriter, op->getLoc(), + APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), + false), + lhs); + auto outType = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + rewriter.replaceOpWithNewOp(op, outType, lhs, zeroTensor); + return success(); +} + + +// Convert a Aten::GELU to HLO +// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGeluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value input = adaptor.self(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + + Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); + Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); + Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); + auto rsqrtTwo = rewriter.create(loc, two); + auto erfElement = rewriter.create(loc, input, rsqrtTwo); + auto erf = rewriter.create(loc, erfElement); + auto erfAdd = rewriter.create(loc, erf, one); + auto halfMul = rewriter.create(loc, erfAdd, half); + auto outType = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + rewriter.replaceOpWithNewOp(op, outType, input, halfMul); + return success(); +} + +// AtenErfOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenErfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputType = input.getType().cast(); + if (!inputType.getElementType().isa()) { + return rewriter.notifyMatchFailure(op, "only float tensor is supported"); + } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), input); + return success(); +} + + +// AtenBatchNormOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBatchNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.input(); + // shape = [N, C, H, W] + auto inputTy = input.getType().cast(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + Value runningMean = adaptor.running_mean(); + Value runningVar = adaptor.running_var(); + // momentum is ignored + Value momentum = adaptor.momentum(); + (void)momentum; + + if (inputTy.getRank() <= 2) { + return rewriter.notifyMatchFailure(op, + "input should have rank larger than 2"); + } + if (!inputTy.getElementType().template isa()) { + return op.emitError("only input tensor of float type is supported"); + } + auto inputElemTy = inputTy.getElementType().cast(); + + Value channelDim = rewriter.create(op->getLoc(), input, 1); + + if (options.dimSizeIndexBits == 32) { + auto channelDimI64 = rewriter.create( + op->getLoc(), rewriter.getI64Type(), channelDim); + channelDim = rewriter.create( + op->getLoc(), rewriter.getI32Type(), channelDimI64); + } + + Value channelShape = rewriter.create( + op->getLoc(), ValueRange{channelDim}); + if (failed(checkNotNone(rewriter, op, weight))) { + weight = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); + } + if (failed(checkNotNone(rewriter, op, bias))) { + bias = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); + } + if (failed(checkNotNone(rewriter, op, runningVar))) { + runningVar = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); + } + if (failed(checkNotNone(rewriter, op, runningMean))) { + runningMean = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); + } + + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + auto runningMeanTy = runningMean.getType().cast(); + auto runningVarTy = runningVar.getType().cast(); + + if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || + runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "expect weight, bias, running_mean and running_var to be rank 1"); + } + if (!weightTy.getElementType().template isa() || + !biasTy.getElementType().template isa() || + !runningMeanTy.getElementType().template isa() || + !runningVarTy.getElementType().template isa()) { + return op.emitError("only float weight/bias/runningMean/runningVar tensor " + "of float type is supported"); + } + + double eps = 0.0; + if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { + return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); + } + bool training = false; + if (!matchPattern(op.training(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); + } + // TODO: handle cudnnEnabled parameter. Here, we just ignore it! + bool cudnnEnabled = false; + if (!matchPattern(op.cudnn_enabled(), m_TorchConstantBool(&cudnnEnabled))) { + return rewriter.notifyMatchFailure(op, + "non-bool cudnn_enabled unsupported"); + } + if (training) { + Type outputTy = getTypeConverter()->convertType(op.getType()); + Type batchMeanOrVarTy = + RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); + auto batchNormTrainingResult = rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); + rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); + return success(); + } else { + Type outputTy = getTypeConverter()->convertType(op.getType()); + SmallVector castShape{inputTy.getShape().begin(), + inputTy.getShape().end()}; + castShape[1] = weightTy.getShape()[0]; + auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); + // feature counts must match among operands of mhlo::BatchNormInferenceOp + Value inputCasted = + rewriter.create(op.getLoc(), castTy, input); + Value output = rewriter.create( + op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, + runningMean, runningVar, + // 'epsilon' must satisfy constraint: 32-bit float attribute + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp(op, outputTy, output); + return success(); + } +} + + +// AtenNativeLayerNormOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNativeLayerNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.input(); + auto inputTy = input.getType().cast(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + + if (!inputTy.hasStaticShape()) { + return op->emitError("dynamic shaped input is not supported"); + } + + SmallVector normalizedShape; + if (!matchPattern(op.normalized_shape(), + m_TorchConstantIntList(normalizedShape))) { + return rewriter.notifyMatchFailure( + op, "normalized_shape must be a list of const int"); + } + double eps = 0; + if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { + return rewriter.notifyMatchFailure(op, + "non const float eps is unsupported"); + } + if (failed(checkNotNone(rewriter, op, weight)) || + failed(checkNotNone(rewriter, op, bias))) { + return op->emitError("none weight or bias is unsupported"); + } + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + + if (!inputTy.getElementType().isa() || + !biasTy.getElementType().isa() || + !weightTy.getElementType().isa()) { + return op->emitError("currently only float data type are supported"); + } + int64_t normalizedShapeRank = normalizedShape.size(); + if (weightTy.getRank() != normalizedShapeRank || + biasTy.getRank() != normalizedShapeRank || + inputRank < normalizedShapeRank || normalizedShapeRank < 1) { + return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" + "normalized shape not compatible"); + } + for (int64_t i = 1; i <= normalizedShapeRank; i++) { + if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || + weightTy.getShape()[normalizedShapeRank - i] != + normalizedShape[normalizedShapeRank - i] || + biasTy.getShape()[normalizedShapeRank - i] != + normalizedShape[normalizedShapeRank - i]) { + return op.emitError("mismatching contracting dimension"); + } + } + + // Flatten dims to fit batch_norm operation. + int64_t numFeatureDimSize = 1; + int64_t numEmbeddingDimSize = 1; + for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { + numFeatureDimSize *= inputShape[i]; + } + for (int64_t i = 0; i < normalizedShapeRank; i++) { + numEmbeddingDimSize *= normalizedShape[i]; + } + SmallVector inputFlattenShape{1, numFeatureDimSize, + numEmbeddingDimSize}; + SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; + + auto mhloBatchNormOutTy = + RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); + auto mhloBathNormOutMeanOrVarTy = + RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); + + // Reshape input + auto mhloInput = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, input, + mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), + {static_cast(inputFlattenShape.size())}) + .value()); + + // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. + SmallVector zeroConstVec( + numFeatureDimSize, APFloat::getZero(inputTy.getElementType() + .cast() + .getFloatSemantics())); + SmallVector oneConstVec( + numFeatureDimSize, + APFloat( + inputTy.getElementType().cast().getFloatSemantics(), + 1)); + auto oneOrZeroConstType = + RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); + + Value scale = rewriter.create( + op->getLoc(), oneOrZeroConstType, + DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); + Value offset = rewriter.create( + op->getLoc(), oneOrZeroConstType, + DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); + auto batchNormTrainingResult = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, + mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + + // Reshape back + auto outputTy = + getTypeConverter()->convertType(op.getType(0)).cast(); + auto outputMeanOrVarTy = + getTypeConverter()->convertType(op.getType(1)).cast(); + + auto output = rewriter.create( + op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), + mhlo::getConstTensor(rewriter, op, outputTy.getShape(), + {static_cast(outputTy.getShape().size())}) + .value()); + auto mean = rewriter.create( + op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), + mhlo::getConstTensor( + rewriter, op, outputMeanOrVarTy.getShape(), + {static_cast(outputMeanOrVarTy.getShape().size())}) + .value()); + auto var = rewriter.create( + op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), + mhlo::getConstTensor( + rewriter, op, outputMeanOrVarTy.getShape(), + {static_cast(outputMeanOrVarTy.getShape().size())}) + .value()); + + // Apply affine transform: output x weight + bias [element-wise] + auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); + auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto outputMulWeight = + rewriter.create(op->getLoc(), output, bcastedWeight); + auto finalOuput = + rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); + rewriter.replaceOp(op, {finalOuput, mean, var}); + return success(); +} + + +// AtenCatOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenCatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, + "only constant dim param is supported"); + } + + SmallVector torchTensors; + if (!getListConstructElements(op.tensors(), torchTensors)) { + return rewriter.notifyMatchFailure( + op, "input should comes from a PrimListConstructOp"); + } + SmallVector builtinTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), torchTensors); + + // Promote type + for (auto &v : builtinTensors) { + v = mhlo::promoteType(rewriter, v, outType); + } + + size_t posDim = toPositiveDim(dim, outType.getRank()); + rewriter.replaceOpWithNewOp( + op, outType, ValueRange(builtinTensors), posDim); + return success(); +} + +// AtenNumelOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNumelOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().dyn_cast(); + size_t rank = selfTy.getRank(); + + Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); + auto loc = op->getLoc(); + Value numel = + rewriter.create(loc, rewriter.getIntegerAttr(intType, 1)); + for (size_t d = 0 ; d < rank; ++ d) { + Value dimSize = rewriter.create( + loc, intType, rewriter.create(loc, self, d)); + numel = rewriter.create(loc, numel, dimSize); + } + + auto outTy = getTypeConverter()->convertType(op.getType()); + if (outTy != numel.getType()) { + rewriter.replaceOpWithNewOp( + op, outTy, numel); + } else { + rewriter.replaceOp(op, numel); + } + return success(); +} + +void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + +#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); + INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); + INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp); +#undef INSERT_UNARY_FPONLY_PATTERN + +#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context) + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); +#undef INSERT_CONSTANT_FILL_PATTERN + +#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); +#undef INSERT_BINARY_ADDSUB_PATTERN + +#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); + INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); +#undef INSERT_BINARY_MULDIV_PATTERN + +#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); +#undef INSERT_BINARY_COMPARE_PATTERN + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + + INSERT_ATENOP_PATTERN(AtenTanhOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReciprocalOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenErfOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenNumelOp); +#undef INSERT_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt new file mode 100644 index 000000000000..39d956fddb17 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -0,0 +1,33 @@ +add_mlir_conversion_library(TorchMLIRTorchToMhlo + TorchToMhlo.cpp + MhloLegalizeUtils.cpp + Basic.cpp + Gather.cpp + Linear.cpp + ViewLike.cpp + Reduction.cpp + Pooling.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo + + DEPENDS + MhloDialect + MhloToLinalg + MLIRMhloPassIncGen + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ChloOps + MLIRIR + MLIRPass + MhloDialect + MhloToLinalg + StablehloBase + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRTorchToMhlo) diff --git a/lib/Conversion/TorchToMhlo/Gather.cpp b/lib/Conversion/TorchToMhlo/Gather.cpp new file mode 100644 index 000000000000..a1185c2c14cf --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Gather.cpp @@ -0,0 +1,180 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; + +namespace { +Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, + Value input, Value indices, int64_t axis, + size_t dimSizeIndexBits) { + auto loc = op->getLoc(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + // sliceSizes + auto inputRankTy = input.getType().dyn_cast(); + auto inputRank = inputRankTy.getRank(); + SmallVector sliceSizes; + sliceSizes.reserve(inputRank); + for (int64_t r = 0; r < inputRank; ++r) { + if (r == axis) { + sliceSizes.push_back(one); + } else { + sliceSizes.push_back(rewriter.create( + loc, intType, rewriter.create(loc, input, r))); + } + } + auto sliceSizesTensor = + rewriter.create(loc, sliceSizes); + + // offsetDims + SmallVector offsetDims; + offsetDims.reserve(inputRank); + for (int64_t r = 0; r < axis; ++r) { + offsetDims.push_back(r); + } + auto indicesRankTy = indices.getType().dyn_cast(); + auto indicesRank = indicesRankTy.getRank(); + for (int64_t r = axis + 1; r < inputRank; ++r) { + offsetDims.push_back(r + indicesRank - 1); + } + + // collapsedSliceDims + SmallVector collapsedSliceDims(1, axis); + // startIndexMap + SmallVector startIndexMap(1, axis); + // indexVecDim + int64_t indexVecDim = indicesRank; + auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedSliceDims, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + // outputShape = input.shape[:axis] + indices.shape + + // input.shape[axis + 1:] + auto inputShape = inputRankTy.getShape(); + auto indicesShape = indicesRankTy.getShape(); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + axis); + outputShape.insert(outputShape.end(), indicesShape.begin(), + indicesShape.end()); + outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, + inputShape.end()); + + // create output tensor type + auto outputTy = + RankedTensorType::get(outputShape, inputRankTy.getElementType()); + return rewriter + .create(loc, outputTy, input, indices, + sliceSizesTensor, dimsAttr) + .getResult(); +} +} // namespace + +// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html +// padding_idx (int, optional) +// – If specified, the entries at padding_idx do not contribute to the gradient; +// therefore, the embedding vector at padding_idx is not updated during training, +// i.e. it remains as a fixed “pad”. +// scale_grad_by_freq (boolean, optional) +// – If given, this will scale gradients by the inverse of frequency of the +// words in the mini-batch. Default False. +// sparse (bool, optional) +// – If True, gradient w.r.t. weight matrix will be a sparse tensor. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenEmbeddingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto weight = adaptor.weight(); + auto weightTy = weight.getType().template cast(); + if (!weightTy) + return op.emitError("only ranked tensor types are supported"); + + int64_t padding_idx; + if (!matchPattern(op.padding_idx(), m_TorchConstantInt(&padding_idx))) + return rewriter.notifyMatchFailure( + op, "only constant padding_idx is currently supported"); + + bool scale_grad_by_freq; + if (!matchPattern(op.scale_grad_by_freq(), + m_TorchConstantBool(&scale_grad_by_freq))) + return rewriter.notifyMatchFailure( + op, "only constant scale_grad_by_freq is currently supported"); + if (scale_grad_by_freq) + return rewriter.notifyMatchFailure( + op, "scale gradients is currently not supported"); + bool sparse; + if (!matchPattern(op.sparse(), m_TorchConstantBool(&sparse))) + return rewriter.notifyMatchFailure( + op, "only constant sparse is currently supported"); + if (sparse) + return rewriter.notifyMatchFailure( + op, "sparse gradients is currently not supported"); + + Value output = gatherTensorAlongSingleAxis( + rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexSelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + Value output = gatherTensorAlongSingleAxis( + rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output); + + return success(); +} + +void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); +#undef INSERT_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp new file mode 100644 index 000000000000..2d446a1f15ba --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -0,0 +1,830 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; + +namespace { +Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef shape, ArrayRef dimSizes, + ArrayRef broadcastDims) { + auto tensorTy = tensor.getType().dyn_cast(); + auto loc = op->getLoc(); + Value mhloShape = rewriter.create(loc, dimSizes); + + RankedTensorType outTy = + RankedTensorType::get(shape, tensorTy.getElementType()); + + RankedTensorType attrTy = + RankedTensorType::get({static_cast(broadcastDims.size())}, + rewriter.getIntegerType(64)); + auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + + auto broadcast = rewriter.create( + loc, outTy, tensor, mhloShape, broadcastAttr); + return broadcast; +} + +Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, + ArrayRef inpTransDims) { + auto inputTy = input.getType().dyn_cast(); + auto rank = inputTy.getRank(); + auto transDims = mhlo::toPositiveDims(inpTransDims, rank); + auto inpShape = inputTy.getShape(); + std::vector newShape; + newShape.reserve(rank); + + for (auto d : transDims) { + newShape.push_back(inpShape[d]); + } + + auto attrTy = RankedTensorType::get({static_cast(transDims.size())}, + rewriter.getIntegerType(64)); + auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); + + auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); + auto result = rewriter.create(op->getLoc(), outTy, input, + permuteAttr); + return result.getResult(); +} + +RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, + Value &lhs, Value &rhs, + int64_t lhsResultDim, int64_t rhsResultDim, + int64_t lhsContractingDim, + int64_t rhsContractingDim) { + auto lhsTy = lhs.getType().dyn_cast(); + auto rhsTy = rhs.getType().dyn_cast(); + + auto oldLhsShape = lhsTy.getShape(); + auto oldRhsShape = rhsTy.getShape(); + SmallVector lhsShape; + SmallVector rhsShape; + lhsShape.append(oldLhsShape.begin(), oldLhsShape.end()); + rhsShape.append(oldRhsShape.begin(), oldRhsShape.end()); + auto lhsContractingDimSize = lhsShape[lhsContractingDim]; + auto rhsContractingDimSize = rhsShape[rhsContractingDim]; + if (lhsContractingDimSize != rhsContractingDimSize) { + if (lhsContractingDimSize == ShapedType::kDynamicSize && + rhsContractingDimSize >= 0) { + lhsShape[lhsContractingDim] = rhsContractingDimSize; + auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); + lhs = rewriter.create(op->getLoc(), newRankTy, lhs); + } else if (rhsContractingDimSize == ShapedType::kDynamicSize && + lhsContractingDimSize >= 0) { + rhsShape[rhsContractingDim] = lhsContractingDimSize; + auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); + rhs = rewriter.create(op->getLoc(), newRankTy, rhs); + } + } + SmallVector outShape; + // set batch dims, will skip invalid dimensions + for (int k = 0; k < lhsShape.size(); ++k) { + if (k == lhsResultDim || k == lhsContractingDim) + continue; + outShape.push_back(lhsShape[k]); + } + for (int k = 0, b = 0; k < rhsShape.size(); ++k) { + if (b >= outShape.size()) + break; + if (k == rhsResultDim || k == rhsContractingDim) + continue; + if (outShape[b] == ShapedType::kDynamicSize && rhsShape[k] >= 0) { + outShape[b] = rhsShape[k]; + } + b++; + } + + // set result dimensions + if (lhsResultDim < lhsShape.size() && lhsResultDim >= 0) { + outShape.push_back(lhsShape[lhsResultDim]); + } + if (rhsResultDim < rhsShape.size() && rhsResultDim >= 0) { + outShape.push_back(rhsShape[rhsResultDim]); + } + return RankedTensorType::get(outShape, lhsTy.getElementType()); +} + +void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, + Value &inpRhs, int64_t leadingRank, + size_t dimSizeIndexBits) { + Value lhs = inpLhs; + Value rhs = inpRhs; + auto lhsRankTy = inpLhs.getType().dyn_cast(); + auto rhsRankTy = inpRhs.getType().dyn_cast(); + + auto lhsRank = lhsRankTy.getRank(); + auto rhsRank = rhsRankTy.getRank(); + + // The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be + // broadcastable). + auto minRank = std::min(lhsRank, rhsRank); + auto leadingDims = llvm::to_vector<4>(llvm::seq(0, leadingRank)); + auto broadcastDims = llvm::to_vector<4>( + llvm::seq(leadingRank, minRank + leadingRank)); + auto lhsShape = lhsRankTy.getShape(); + auto rhsShape = rhsRankTy.getShape(); + if (lhsRank < rhsRank) { + std::vector newShape(rhsShape.begin(), + rhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, rhs, leadingDims, dimSizeIndexBits); + auto lhsDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), + lhsDimSizes.end()); + lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, + broadcastDims); + } else { + std::vector newShape(lhsShape.begin(), + lhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, lhs, leadingDims, dimSizeIndexBits); + auto rhsDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), + rhsDimSizes.end()); + rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, + broadcastDims); + } + + inpLhs = lhs; + inpRhs = rhs; +} + +// Perform the basic n-dim matmul operation encompassing the handling of +// broadcasting and dynamic shape propagation. +// All PyTorch ops that leverage matrix multiplication will derive this and +// implement their specialized input processing (e.g transpose), and output +// processing, e.g. GEMM or fully connected bias handling. +template +class ConvertAtenMatmulBaseOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + // Each variant must implement corresponding parameter parsing options. + // Maintain separate input read functions for each variant because it is not + // necessarily true with all variants that the first two operands are the lhs + // and rhs. + virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const { + return rewriter.notifyMatchFailure( + op, + "unimplemented matrix multiplication variant input parsing function"); + } + LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &lhs, + Value &rhs, Value &output) const { + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("matmul: input datatypes mismatched"); + if (lhsRank < 1 || rhsRank < 1) { + return op.emitError("matmul: inputs can't be 0-rank"); + } + + if (lhsRank <= 2 && rhsRank <= 2) { + output = rewriter.create(op->getLoc(), lhs, rhs, nullptr); + return success(); + } + + const auto &options = ConvertAtenOp::getOptions(); + int64_t nBatchDims; + if (rhsRank <= 2) { + auto leadingRank = lhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); + nBatchDims = leadingRank; + } else if (lhsRank <= 2) { + auto leadingRank = rhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); + nBatchDims = leadingRank; + } else { + assert(rhsRank > 2 && lhsRank > 2); + auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank); + nBatchDims = std::max(lhsRank - 2, rhsRank - 2); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); + } + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + if (lhsRank == 1) { + lhsResultDim = nBatchDims + 1; + lhsContractingDim = nBatchDims; + } + + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + auto outTy = + castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, + lhsContractingDim, rhsContractingDim); + output = rewriter + .create(op->getLoc(), outTy, lhs, rhs, + dotDimensionNumbers, nullptr) + .getResult(); + return success(); + } + + // The default version just reads two inputs, computes output and returns it. + // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc. + virtual LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + Value output; + + if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) + return op.emitError("failed to perform matmul operation"); + + rewriter.replaceOpWithNewOp( + op, + ConvertAtenOp::getTypeConverter() + ->convertType(op.getType()) + .template cast(), + output); + + return success(); + } +}; + +// Legalizes the torch.matmul op for general n-dim matmul. +template +class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + return success(); + } +}; + +// Implements handling of aten.mm and aten.bmm ops. +template +class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.mat2(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (isa(op)) { + // Mm takes two 2D tensors. + if (lhsRank != 2 || rhsRank != 2) + return op.emitError("aten.mm called but matrix rank != 2"); + } else if (isa(op)) { + // Bmm takes two 3D tensors. + if (lhsRank != 3 || rhsRank != 3) + return op.emitError("aten.bmm called but matrix rank != 3"); + } + + return success(); + } +}; + +// Implements handling of aten.linear op. +template +class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.input(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.weight(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (lhsRank != 2 && lhsRank != 3) + return op.emitError("aten.Linear called but input rank not 2 or 3"); + if (rhsRank != 2 && rhsRank != 3) + return op.emitError("aten.Linear called but weight rank not 2 or 3"); + + return success(); + } + // Override the default rewriter to perform RHS transpose and bias addition + // as well. + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + // The aten.Linear op has a bias tensor that is added to the matmul + // output. + auto bias = adaptor.bias(); + auto biasTy = bias.getType(); + + // MHLO does not mandate that elementwise op tensors need to be ranked. + if (!biasTy.template isa() && + !biasTy.template isa()) + return op.emitError("only ranked tensor types are supported in MHLO " + "matmul for bias tensor"); + + // weight.T + rhs = getPermutedTensor(rewriter, op, rhs, {1, 0}); + + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), + rhsTy.getRank() - lhsTy.getRank()); + + const auto &options = ConvertAtenOp::getOptions(); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); + auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); + auto nBatchDims = resultRank - 2; + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + + auto outTy = + castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, + lhsContractingDim, rhsContractingDim); + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + Value matmulOutput = rewriter.create( + op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); + + Value matmulPlusBias = matmulOutput; + if (!biasTy.template isa()) { + // Bias addition broadcasts to the matmul output shape. + matmulPlusBias = rewriter + .create( + op->getLoc(), outTy, matmulOutput, bias, nullptr) + .getResult(); + } + + auto resultTy = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); + return success(); + } +}; + +class ConvertAtenConvolutionOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenConvolutionOp::Adaptor; + + Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, + Value weight, int64_t groups) const { + auto weightTy = weight.getType().cast(); + auto weightElemTy = weightTy.getElementType(); + auto rank = weightTy.getRank(); + const auto &options = getOptions(); + SmallVector weightShapeVec = *mhlo::getDimSizesOfTensor( + rewriter, op, weight, options.dimSizeIndexBits); + auto weightShape = weightTy.getShape(); + SmallVector weightShapeInt(rank); + std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); + + // 1. [IC, OC, H, W, ...] => [G, IC//G, OC, H, W, ...] + Value GValue = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(groups)); + Value ICDivGValue = rewriter.create( + op->getLoc(), weightShapeVec[0], GValue); + Value OCMulGValue = rewriter.create( + op->getLoc(), weightShapeVec[1], GValue); + weightShapeVec[0] = ICDivGValue; + weightShapeVec.insert(weightShapeVec.begin(), GValue); + + if (weightShapeInt[0] == ShapedType::kDynamicSize) { + weightShapeInt.insert(weightShapeInt.begin(), groups); + } else { + weightShapeInt[0] /= groups; + weightShapeInt.insert(weightShapeInt.begin(), groups); + } + Value weightShapeTensor = rewriter.create( + op->getLoc(), weightShapeVec); + weight = rewriter.create( + op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), + weight, weightShapeTensor); + + // 2. [G, IC//G, OC, H, W, ...] => [IC//G, G, OC, H, W, ...] + std::vector transposeDims(rank + 1); + for (int64_t i = 0; i <= rank; i++) + transposeDims[i] = i; + std::swap(transposeDims[1], transposeDims[0]); + weight = rewriter.create( + op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); + + // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] + weightShapeInt.erase(weightShapeInt.begin()); + if (weightShapeInt[1] != ShapedType::kDynamicSize) { + weightShapeInt[1] *= groups; + } + weightShapeVec.erase(weightShapeVec.begin()); + weightShapeVec[1] = OCMulGValue; + weightShapeTensor = rewriter.create( + op->getLoc(), weightShapeVec); + weight = rewriter.create( + op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), + weight, weightShapeTensor); + return weight; + } + + Value convertTransposedConv(AtenConvolutionOp op, + ConversionPatternRewriter &rewriter, + RankedTensorType outType, Value input, + Value weight, ArrayRef stride, + ArrayRef padding, + ArrayRef dilation, + ArrayRef outputPadding, int64_t groups, + bool needHandleOutputPadding) const { + auto inputTy = input.getType().cast(); + auto weightTy = weight.getType().cast(); + auto weightShape = weightTy.getShape(); + + auto nDims = inputTy.getRank(); + auto nSpatialDims = nDims - 2; + auto convOutTy = outType; + + if (needHandleOutputPadding) { + SmallVector outShape(nDims); + auto finalOutShape = outType.getShape(); + std::copy(finalOutShape.begin(), finalOutShape.end(), outShape.begin()); + for (int i = 2; i < nDims; ++i) { + if (finalOutShape[i] == ShapedType::kDynamicSize) + continue; + outShape[i] = finalOutShape[i] - outputPadding[i - 2]; + } + convOutTy = RankedTensorType::get(outShape, outType.getElementType()); + } + + // Prepare for transposed convolution + SmallVector mhloStrideVec(nSpatialDims, 1); + DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec); + SmallVector mhloPaddingVec(nSpatialDims * 2, 0); + for (int i = 0; i < nSpatialDims; ++i) { + int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; + mhloPaddingVec[i * 2] = padInt; + mhloPaddingVec[i * 2 + 1] = padInt; + } + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), + mhloPaddingVec); + SmallVector mhloLhsDilationVec(nSpatialDims); + std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin()); + DenseIntElementsAttr mhloLhsDilation = + rewriter.getI64TensorAttr(mhloLhsDilationVec); + SmallVector mhloRhsDilationVec(nSpatialDims); + std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin()); + DenseIntElementsAttr mhloRhsDilation = + rewriter.getI64TensorAttr(mhloRhsDilationVec); + + DenseElementsAttr windowReversal; + ArrayAttr precisionConfig; + + SmallVector spatialDims; + for (int i = 0; i < nSpatialDims; ++i) { + spatialDims.push_back(i + 2); + } + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( + /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, + /*inputFeatureDimension=*/1, + /*inputSpatialDimensions=*/spatialDims, + /*kernelInputFeatureDimension=*/0, + /*kernelOutputFeatureDimension=*/1, + /*kernelSpatialDimensions=*/spatialDims, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, + /*outputSpatialDimensions=*/spatialDims); + + // Reverse and transpose weight + weight = rewriter.create( + op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); + if (groups != 1) { + weight = reshapeConvWeight(rewriter, op, weight, groups); + } + + // Create transposed convolution + auto transposedConvOp = rewriter.create( + op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + static_cast(groups), 1, precisionConfig); + + // Handle output padding + if (!needHandleOutputPadding) { + return transposedConvOp.getResult(); + } + SmallVector edgePaddingLowVec(nDims, 0); + SmallVector edgePaddingHighVec(nDims, 0); + SmallVector interiorPaddingVec(nDims, 0); + std::copy(outputPadding.begin(), outputPadding.end(), + edgePaddingHighVec.begin() + 2); + Value paddingValue = + mhlo::getConstTensor(rewriter, op, {0.0}, {}).value(); + paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy); + mlir::DenseIntElementsAttr edgePaddingLow = + rewriter.getI64VectorAttr(edgePaddingLowVec); + mlir::DenseIntElementsAttr edgePaddingHigh = + rewriter.getI64VectorAttr(edgePaddingHighVec); + mlir::DenseIntElementsAttr interiorPadding = + rewriter.getI64VectorAttr(interiorPaddingVec); + + auto paddedOutput = rewriter.create( + op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, + edgePaddingHigh, interiorPadding); + + return paddedOutput.getResult(); + } + + Value convertNormalConv(AtenConvolutionOp op, + ConversionPatternRewriter &rewriter, + RankedTensorType outType, Value input, Value weight, + ArrayRef stride, ArrayRef padding, + ArrayRef dilation, int64_t groups) const { + int64_t nDims = outType.getRank(); + + // Get mhlo::ConvolutionOp attributes + DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(stride.size())}, + rewriter.getI64Type()), + stride); + std::vector mhloPaddingVec; + for (size_t i = 0; i < padding.size(); i++) { + mhloPaddingVec.emplace_back(padding[i]); + mhloPaddingVec.emplace_back(padding[i]); + } + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(padding.size()), static_cast(2)}, + rewriter.getI64Type()), + mhloPaddingVec); + DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(dilation.size())}, + rewriter.getI64Type()), + dilation); + SmallVector spatialDimensions; + for (int64_t i = 2; i < nDims; i++) { + spatialDimensions.emplace_back(i); + } + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( + /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, + /*inputFeatureDimension=*/1, + /*inputSpatialDimensions=*/spatialDimensions, + /*kernelInputFeatureDimension=*/1, + /*kernelOutputFeatureDimension=*/0, + /*kernelSpatialDimensions=*/spatialDimensions, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, + /*outputSpatialDimensions=*/spatialDimensions); + + // mhlo::ConvolutionOp's optional attributes, leave them as default + DenseIntElementsAttr mhloLhsDilation; + DenseElementsAttr windowReversal; + ArrayAttr precisionConfig; + + auto mhloConvOp = rewriter.create( + op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + static_cast(groups), 1, precisionConfig); + + return mhloConvOp.getResult(); + } + + LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.input(); + Value weight = adaptor.weight(); + + // The input shape is [N, C, H, W] + auto inputTy = input.getType().template cast(); + // The weight shape is [OC, (IC//G), KH, KW] + // If transposed is set to true, + // the weight shape changes to [IC, (OC//G), KH, KW] + auto weightTy = weight.getType().template cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template cast(); + if (!inputTy || !weightTy || !outTy) { + return op.emitError("input, weight and output must be ranked tensors"); + } + if (inputTy.getRank() < 3) + return op.emitError("only input with at least 3 dims valid"); + SmallVector stride; + if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) { + return rewriter.notifyMatchFailure(op, + "non-const stride list unsupported"); + } + SmallVector padding; + if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) { + return rewriter.notifyMatchFailure(op, + "non-const padding list unsupported"); + } + SmallVector dilation; + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) { + return rewriter.notifyMatchFailure(op, + "non-const dilation list unsupported"); + } + SmallVector outputPadding; + if (!matchPattern(op.output_padding(), + m_TorchConstantIntList(outputPadding))) { + return rewriter.notifyMatchFailure( + op, "non-const output_padding list unsupported"); + } + int64_t groups; + if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) { + return rewriter.notifyMatchFailure(op, "non-int groups unsupported"); + } + bool transposed; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) { + return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported"); + } + // Whether need to handle outputpadding + bool needHandleOutputPadding = false; + for (int64_t i : outputPadding) { + if (i != 0) { + needHandleOutputPadding = true; + break; + } + } + // Op validation check + if (needHandleOutputPadding && !transposed) { + return op->emitError( + "output padding attr is valid only in transposed convolution"); + } + assert(padding.size() == dilation.size() && + padding.size() == stride.size() && + padding.size() == static_cast(inputTy.getRank()) - 2 && + inputTy.getRank() == weightTy.getRank()); + + auto nSpatialDims = padding.size(); + auto nDims = inputTy.getRank(); + + // Kernel size must be constant. + auto weightShape = weightTy.getShape(); + for (int i = 2; i < nDims; ++i) { + if (weightShape[i] == ShapedType::kDynamicSize) { + return rewriter.notifyMatchFailure( + op, "only constant kernel size is supported"); + } + } + + Value mhloConvResult; + if (transposed) { + mhloConvResult = convertTransposedConv( + op, rewriter, outTy, input, weight, stride, padding, dilation, + outputPadding, groups, needHandleOutputPadding); + } else { + mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, + stride, padding, dilation, groups); + } + + auto bias = adaptor.bias(); + + // No bias provided + if (failed(checkNotNone(rewriter, op, op.bias()))) { + rewriter.replaceOp(op, mhloConvResult); + return success(); + } + + // Handle bias + if (!bias.getType().cast()) { + return op.emitError("bias provided but not a ranked tensor"); + } + + auto biasTy = bias.getType().template cast(); + if (!biasTy.getElementType().isIntOrFloat()) { + return op.emitError("only floating-point or integer datatype " + "legalization for bias supported"); + } + + assert(biasTy.getRank() <= 1); + + // Reshape and promote bias + auto inputUnsqzDims = + llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); + + const auto &options = getOptions(); + bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, + options.dimSizeIndexBits); + bias = mhlo::promoteType(rewriter, bias, outTy); + + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp(op, outTy, mhloConvResult, + bias, bcastDimensions); + return success(); + } +}; +} // namespace + +void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN + +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN + +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN + +#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add(typeConverter, context, options) + INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp); +#undef INSERT_CONVOLUTION_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp new file mode 100644 index 000000000000..5cab3e7d17a9 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -0,0 +1,367 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./MhloLegalizeUtils.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace mlir { +namespace mhlo { + +// Create a 32-bit float constant operator from a float +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val) { + auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); + auto const_attr = DenseElementsAttr::get(const_type, val); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Create a 64-bit float constant operator from a double +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val) { + auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, val); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = + RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for APInt +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + auto const_type = RankedTensorType::get( + shape, rewriter.getIntegerType(vec[0].getBitWidth())); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for float +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +template <> +llvm::Optional +getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template instantiation +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); + +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); + +template +static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, + const int64_t &intValue) { + if (isFloat) { + // Do a round-trip check here instead of numeric limits due to + // compiler warnings around double <-> int conversion. + return (doubleValue == static_cast(static_cast(doubleValue))); + } else { + assert(isInt); + return (intValue >= std::numeric_limits::min()) && + (intValue <= std::numeric_limits::max()); + } + return true; +} + +template +Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + T val, Type dtype, llvm::ArrayRef dshape) { + auto const_type = RankedTensorType::get(dshape, dtype); + auto const_attr = SplatElementsAttr::get(const_type, val); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype) { + auto tensor = rewriter.create( + op->getLoc(), ArrayRef{scalarValue}); + auto dtype_tensor = + rewriter.create(op->getLoc(), tensor, dtype); + return rewriter.create( + op->getLoc(), RankedTensorType::get(mlir::ArrayRef{}, dtype), + dtype_tensor); +} + +Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { + Operation *op = input.getDefiningOp(); + TensorType in_type = input.getType().dyn_cast(); + + if (in_type.getElementType() != outType.getElementType()) { + TensorType promotedType = + in_type.cloneWith(in_type.getShape(), outType.getElementType()); + return rewriter.create(op->getLoc(), promotedType, input); + } + return input; +} + +Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, + TensorType outType) { + // Two tensors are “broadcastable” if the following rules hold: + // - Each tensor has at least one dimension. + // - When iterating over the dimension sizes, starting at the trailing + // dimension, the dimension sizes must either be equal, one of them is 1, or + // one of them does not exist. + Operation *op = input.getDefiningOp(); + TensorType in_type = input.getType().dyn_cast(); + + if (in_type.getElementType() != outType.getElementType()) { + TensorType promoted_type = + in_type.cloneWith(in_type.getShape(), outType.getElementType()); + input = + rewriter.create(op->getLoc(), promoted_type, input); + } + + ArrayRef inShape = in_type.getShape(); + ArrayRef outShape = outType.getShape(); + + bool do_bcast = (inShape.size() != outShape.size()); + SmallVector bcastDims; + for (size_t i = 0; i < inShape.size(); ++i) { + // iterating over the dimension sizes, starting at the trailing dimension + size_t outPos = outShape.size() - 1 - i; + size_t inPos = inShape.size() - 1 - i; + int64_t outDim = outShape[outPos]; + int64_t inDim = inShape[inPos]; + if (inDim == outDim) { + bcastDims.push_back(outPos); + } else if (inDim != outDim && inDim == 1) { + bcastDims.push_back(outPos); + do_bcast = true; + } else { + op->emitError("The size of tensor a (") + << inDim << ")" + << "must match the size of tensor b (" << outDim << ")" + << "at non-singleton dimension " << inPos; + } + } + std::reverse(bcastDims.begin(), bcastDims.end()); + if (!do_bcast) { + return input; + } + DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(bcastDims.size())}, + rewriter.getI64Type()), + bcastDims); + auto bcast_op = rewriter.create(op->getLoc(), outType, + input, bcast_attr); + return bcast_op.getResult(); +} + +SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { + SmallVector posDims; + posDims.reserve(rank); + std::transform( + dims.begin(), dims.end(), std::back_inserter(posDims), + [rank](int64_t d) -> size_t { return toPositiveDim(d, rank); }); + return posDims; +} + +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + ArrayRef inpDims, + size_t dimSizeIndexBits) { + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + auto dims = toPositiveDims(inpDims, rank); + SmallVector dimSizes; + dimSizes.reserve(dims.size()); + + auto loc = op->getLoc(); + for (auto d : dims) { + dimSizes.emplace_back(rewriter.create( + loc, rewriter.getIntegerType(dimSizeIndexBits), + rewriter.create(loc, value, d))); + } + return dimSizes; +} + +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + size_t dimSizeIndexBits) { + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + // Get int vector [0, 1, ..., rank-1] + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits); +} + +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, ArrayRef inputUnsqzDims, + size_t dimSizeIndexBits) { + // Returns a new tensor with dims of size 1 inserted at the specified + // position. + // + // The position indices (must be high to low dimension number of the returned + // tensor) are specified with unsqzDims. Indices must be in-order, and in + // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, + // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + auto rank = dimSizes.size(); + size_t newRank = rank + inputUnsqzDims.size(); + auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank); + for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k) + if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) + return rewriter.notifyMatchFailure( + op, "unsqueeze dimensions must be specified in order"); + + auto loc = op->getLoc(); + auto rankTy = tensor.getType().dyn_cast(); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) { + if (j < unsqzDims.size() && unsqzDims[j] == k) { + newDimSizes.push_back(one); + newShape.push_back(1); + j++; + } else { + newDimSizes.push_back(dimSizes[i]); + newShape.push_back(oldShape[i]); + i++; + } + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto mhloShape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, mhloShape) + .getResult(); +} + +Value getConstantOfShape(PatternRewriter &rewriter, Location loc, + const APFloat &constant, Value shape, + TensorType outType) { + auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); + auto constTensor = rewriter.create(loc, constAttr); + return rewriter + .create(loc, outType, constTensor, shape, + rewriter.getI64TensorAttr({})) + .getResult(); +} +} // namespace mhlo +} // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h new file mode 100644 index 000000000000..466cadb81cf7 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -0,0 +1,77 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace mhlo { + +using mlir::ConversionPatternRewriter; + +// Create a 32-bit float constant operator from a float +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val); + +// Create a 64-bit float constant operator from a double +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val); + +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +// To create INT48 MHLO constant, need to pass in llvm::APInt instead. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape); + +template +Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + T val, Type dtype, llvm::ArrayRef dshape); + +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype); + +Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); + +Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, + TensorType outType); + +SmallVector toPositiveDims(ArrayRef dims, int64_t rank); + +// Get the dimension sizes of the input tensor, given the dimension axes +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + ArrayRef inpDims, + size_t dimSizeIndexBits); + +// Get the dimension sizes of the input tensor +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + size_t dimSizeIndexBits); + +// Get a tensor that unsqueezed the specified dimensions of the input tensor +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, ArrayRef inputUnsqzDims, + size_t dimSizeIndexBits); + +Value getConstantOfShape(PatternRewriter &rewriter, Location loc, + const APFloat &constant, Value shape, + TensorType outType); +} // namespace mhlo +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp new file mode 100644 index 000000000000..514f941a434b --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -0,0 +1,542 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; + +static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementTy); + // Avg pooling + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getZero( + elementTy.cast().getFloatSemantics(), + /*negative=*/false)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + // Max pooling + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getLargest( + elementTy.cast().getFloatSemantics(), + /*negative=*/true)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, + {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + op->emitError("unimplemented lowering in AtenPoolingOp"); + return nullptr; +} + +// AtenMaxPool2dOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenMaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + + auto inputRank = inputTy.getRank(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + + if (inputRank <= 2) { + return op.emitError( + "max_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + auto reduceWindowOp = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.body().emplaceBlock(); + + auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArg = block.args_begin(); + auto secondArg = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value result = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), result); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} + +// AtenMaxPool2dWithIndicesOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + auto outValTy = + getTypeConverter()->convertType(op.getType(0)).cast(); + auto outIdxTy = + getTypeConverter()->convertType(op.getType(1)).cast(); + + if (inputRank <= 2) { + return op.emitError( + "max_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + SmallVector initIndexShapeVec; + for (int64_t i = 0; i < inputRank - 2; i++) + initIndexShapeVec.push_back(inputShapeVec[i]); + initIndexShapeVec.push_back(rewriter.create( + op->getLoc(), inputShapeVec[inputRank - 1], + inputShapeVec[inputRank - 2])); + auto initIndexShapeTensor = rewriter.create( + op->getLoc(), initIndexShapeVec); + + SmallVector initIndexShapeForType(inputShape.begin(), + inputShape.end() - 2); + if (inputShape[inputRank - 1] == ShapedType::kDynamicSize || + inputShape[inputRank - 2] == ShapedType::kDynamicSize) { + initIndexShapeForType.push_back(ShapedType::kDynamicSize); + } else { + initIndexShapeForType.push_back(inputShape[inputRank - 1] * + inputShape[inputRank - 2]); + } + + auto initIndexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(initIndexShapeForType, + rewriter.getI64Type()), + initIndexShapeTensor, static_cast(inputRank - 2)) + .getResult(); + + auto indexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + initIndexTensor, inputShapeTensor) + .getResult(); + + Value initIdx = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, + windowDimensions, windowStrides, baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.body().emplaceBlock(); + + // Add bb argument + auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); + auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); + auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + auto *firstValArg = block.args_begin(); + auto *firstIdxArg = std::next(firstValArg); + auto *secondValArg = std::next(firstIdxArg); + auto *secondIdxArg = std::next(secondValArg); + + mhlo::ComparisonTypeAttr compareTypeAttr; + if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::FLOAT); + } else if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::SIGNED); + } + mhlo::ComparisonDirectionAttr compareGeDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonDirectionAttr compareEqDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::EQ); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value compareGeResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + Value retValResult = rewriter.create( + op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + + // Get smaller index if compared values are equal. + Value compareEqResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareEqDirectionAttr, compareTypeAttr); + Value minIdx = + rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = rewriter.create( + op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = rewriter.create( + op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + rewriter.create( + op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} + +// AtenAvgPool2dOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenAvgPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto inputElemTy = inputTy.getElementType(); + auto inputRank = inputTy.getRank(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + auto outShape = outTy.getShape(); + + if (inputRank <= 2) { + return op.emitError( + "avg_pooling2d only supports inputs with rank higher than 2"); + } + SmallVector padding, kernelSize, stride; + bool ceilMode = false; + bool countIncludePad = true; + + if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + if (!(matchPattern(op.count_include_pad(), + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); + } + if (succeeded(checkNotNone(rewriter, op, op.divisor_override()))) { + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank as + // input + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); + + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(kernelSize.begin(), kernelSize.end(), + mhloKernelSize.begin() + inputRank - 2); + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloKernelSize.size())}, + rewriter.getI64Type()), + mhloKernelSize); + DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloStride.size())}, + rewriter.getI64Type()), + mhloStride); + DenseIntElementsAttr baseDilations; + DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(mhloDilation.size())}, + rewriter.getI64Type()), + mhloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + mhloPadding); + + auto reduceWindowSum = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &sumBlock = reduceWindowSum.body().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + sumBlock.addArgument(blockArgumentType, op->getLoc()); + auto *firstArg = sumBlock.args_begin(); + auto secondArg = sumBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); + + Value sumResult = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); + } + + // Use kernel size as the divisor + if (countIncludePad) { + Value divisor = mhlo::getConstTensor( + rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) + .value(); + divisor = mhlo::promoteType(rewriter, divisor, outTy); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); + } + + // Use another mhlo.ReduceWindowOp to get the divisor + Value windowSizeConst = + mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); + windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); + const auto &options = getOptions(); + auto inputShapeVec = + *mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + windowSizeConst = rewriter.create( + op->getLoc(), + RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), + windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + + Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + auto reduceWindowSize = rewriter.create( + op->getLoc(), RankedTensorType::get(outShape, inputElemTy), + windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, + windowDilations, pad); + + Block &sizeBlock = reduceWindowSize.body().emplaceBlock(); + + // Add bb argument + blockArgumentType = RankedTensorType::get({}, inputElemTy); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + sizeBlock.addArgument(blockArgumentType, op->getLoc()); + firstArg = sizeBlock.args_begin(); + secondArg = sizeBlock.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); + + Value sumResult = + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); + } + + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); + return success(); +} + +void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add>(typeConverter, context, options); + target.addIllegalOp(); + patterns.add>(typeConverter, context, options); + target.addIllegalOp(); + patterns.add>(typeConverter, + context, options); +} diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h new file mode 100644 index 000000000000..2e195a87fb77 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -0,0 +1,74 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H +#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace torch { +namespace torch_to_mhlo { + +struct TorchToMhloOptions { + bool enableStaticShape = false; + size_t dimSizeIndexBits = 64; +}; + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpAdaptor = typename AtenOpT::Adaptor; + ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, + const TorchToMhloOptions &options) + : OpConversionPattern(typeConverter, context) { + this->options = options; + } + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure(op, "haven't been implemented"); + } + const TorchToMhloOptions &getOptions() const { return options; } + +private: + TorchToMhloOptions options; +}; + +void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); + +void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); + +} // namespace torch_to_mhlo +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp new file mode 100644 index 000000000000..a9a1ee162e4d --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -0,0 +1,594 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; + +static Value createInitialValueForReduceOp(Operation *op, Type elementTy, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementTy); + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getZero( + elementTy.cast().getFloatSemantics(), + /*negative=*/false)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + if (isa(op)) { + if (elementTy.isa()) { + auto constAttr = DenseElementsAttr::get( + constType, {APFloat::getLargest( + elementTy.cast().getFloatSemantics(), + /*negative=*/true)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } else if (elementTy.isa() && + elementTy.getIntOrFloatBitWidth() != 8) { + auto constAttr = DenseElementsAttr::get( + constType, + {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + } + + op->emitError("unimplemented lowering in " + "createInitialValueForReduceOp"); + return nullptr; +} + +// Util for converting AtenArgmaxOp and AtenMaxDimOp +static llvm::Optional +getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, + ArrayRef inputShapeVec, int64_t dim, + size_t dimSizeIndexBits) { + auto inputTy = input.getType().template cast(); + if (!inputTy) { + return llvm::None; + } + if (!inputTy.getElementType().isIntOrFloat()) { + return llvm::None; + } + auto inputShape = inputTy.getShape(); + auto inputElemTy = inputTy.getElementType(); + + Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); + if (!initValue) return llvm::None; + Value initIndex; + if (dimSizeIndexBits == 32) { + initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + } else { + initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + } + + DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI64Type()), dim); + + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + auto indexTensor = rewriter.create( + op->getLoc(), + RankedTensorType::get(inputShape, + rewriter.getIntegerType(dimSizeIndexBits)), + inputShapeTensor, static_cast(dim)); + + auto mhloReduceOp = rewriter.create( + op->getLoc(), ValueRange{input, indexTensor}, + ValueRange{ + initValue, + initIndex, + }, + dimensions); + + Block &block = mhloReduceOp.body().emplaceBlock(); + + // Add block arguments + auto blockValArgumentType = + RankedTensorType::get({}, inputTy.getElementType()); + auto blockIdxArgumentType = + RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits)); + auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + + auto *firstValArg = block.args_begin(); + auto *firstIdxArg = std::next(firstValArg); + auto *secondValArg = std::next(firstIdxArg); + auto *secondIdxArg = std::next(secondValArg); + + mhlo::ComparisonTypeAttr compareTypeAttr; + if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::FLOAT); + } else if (inputTy.getElementType().isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::SIGNED); + } + mhlo::ComparisonDirectionAttr compareGeDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonDirectionAttr compareEqDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::EQ); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value compareGeResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + Value retValResult = rewriter.create( + op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + + // get smaller index value if compared nums are equal. + Value compareEqResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareEqDirectionAttr, compareTypeAttr); + Value minIdx = + rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = rewriter.create( + op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = rewriter.create( + op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + rewriter.create( + op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + } + return mhloReduceOp.getResults(); +} + +namespace { +template +class ConvertAtenReductionOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +// AtenArgmaxOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenArgmaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().template cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported! + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenArgmaxOp to MHLO"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + if (!isValidDim(dim, inputTy.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + bool keepDim = false; + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + + if (keepDim) { + auto outShapeVec = inputShapeVec; + outShapeVec[dim] = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), mhloReduceResults[1], + outShapeTensor); + return success(); + } + + rewriter.replaceOp(op, mhloReduceResults[1]); + return success(); +} +} // namespace + +// AtenMaxDimOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenMaxDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenMaxDimOp to MHLO"); + } + + RankedTensorType valResultType = getTypeConverter() + ->convertType(op.getResult(0).getType()) + .template cast(); + RankedTensorType idxResultType = getTypeConverter() + ->convertType(op.getResult(1).getType()) + .template cast(); + Type idxElementType = idxResultType.getElementType(); + if (!idxElementType.isa()) { + return op.emitError("Aten.max.dim needs integer-like result"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + if (!isValidDim(dim, inputTy.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + bool keepDim = false; + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + + if (keepDim) { + auto outShapeVec = inputShapeVec; + outShapeVec[dim] = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + + auto mhloReduceValueResult = rewriter.create( + op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor); + auto mhloReduceIndexResult = rewriter.create( + op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor); + rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult}); + return success(); + } + + rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]}); + return success(); +} +} // namespace + +// AtenSumOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + } + auto dtype = adaptor.dtype(); + if (!dtype.getType().isa()) { + auto dstElemTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast() + .getElementType(); + input = rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenSumOp to MHLO"); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) return failure(); + + llvm::sort(dims.begin(), dims.end()); + auto mhloReduceOp = rewriter.create( + op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Block &block = mhloReduceOp.body().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value addResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), addResult); + } + + rewriter.replaceOp(op, mhloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenMaxOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenMaxOp to MHLO"); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); + auto mhloReduceOp = rewriter.create( + op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Block &block = mhloReduceOp.body().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value maxResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), maxResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenSumDimIntListOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenSumDimIntListOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + } + auto dtype = adaptor.dtype(); + if (!dtype.getType().isa()) { + auto dstElemTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast() + .getElementType(); + input = rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenSumDimIntListOp to MHLO"); + } + + SmallVector inputDims; + SmallVector dims; + if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) { + return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } + + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + + bool keepDim = false; + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) return failure(); + + llvm::sort(dims.begin(), dims.end()); + auto mhloReduceOp = rewriter.create( + op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Region ®ion = mhloReduceOp.body(); + Block &block = region.emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value addResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), addResult); + } + + if (keepDim) { + const auto &options = getOptions(); + auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResult(0), outShapeTensor); + return success(); + } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResults()); + return success(); +} +} // namespace + +void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); +#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); +#undef INSERT_ATEN_REDUCTION_OP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp new file mode 100644 index 000000000000..67ff28c39ba3 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertTorchToMhlo : public ConvertTorchToMhloBase { +public: + ConvertTorchToMhlo() = default; + ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) { + this->enableStaticShape = enableStaticShape; + this->enableI32Index = enableI32Index; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + torch_to_mhlo::TorchToMhloOptions options{enableStaticShape, + enableI32Index ? 32u : 64u}; + torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populateViewLikeOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populateReductionOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, + target, options); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToMhloPass() { + return std::make_unique(false, false); +} + +std::unique_ptr> +mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape, + bool enableI32Index) { + return std::make_unique(enableStaticShape, + enableI32Index); +} diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp new file mode 100644 index 000000000000..4d2ea6946883 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -0,0 +1,426 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; +using namespace mlir::torch::torch_to_mhlo; + +namespace { +// A dimension index from torch.dialect might outside the range [0, dimSize]. +// The function is used to normalize the input index into the range. +Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, + Value index, Value dimSize) { + auto loc = op->getLoc(); + Value zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + + // To normalize index into range [-dimSize, dimSize] + // index = min(max(-dimSize, index), dimSize) + auto negDimSize = rewriter.create(loc, zero, dimSize); + index = rewriter.create(loc, negDimSize, index); + index = rewriter.create(loc, dimSize, index); + + auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); + auto indexPositive = rewriter.create( + loc, arith::CmpIPredicate::sge, index, zero); + // get positive index: (index >=0) ? index: index + dimSize + return rewriter.create(loc, indexPositive, index, + dimSizePlusIndex); +} + +Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, + Type outTy, Value input, Value startIndex, + Value endIndex, Value step, size_t dimIndex, + ArrayRef dimSizes, + size_t dimSizeIndexBits) { + auto loc = op->getLoc(); + // startIndex & endIndex has been normailized into range [0, dSize] + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + Value zero = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 0)); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + SmallVector startIndices; + SmallVector endIndices; + SmallVector strides; + + auto inputTy = input.getType().dyn_cast(); + size_t rank = inputTy.getRank(); + startIndices.reserve(rank); + endIndices.reserve(rank); + strides.reserve(rank); + + auto endIndexIsZero = rewriter.create( + loc, arith::CmpIPredicate::eq, endIndex, zero); + endIndex = rewriter.create(loc, endIndexIsZero, + dimSizes[dimIndex], endIndex); + + for (size_t r = 0; r < rank; ++r) { + if (r == dimIndex) { + startIndices.push_back(startIndex); + endIndices.push_back(endIndex); + strides.push_back(step); + } else { + startIndices.push_back(zero); + endIndices.push_back(dimSizes[r]); + strides.push_back(one); + } + } + + auto startTensor = + rewriter.create(loc, startIndices).getResult(); + auto endTensor = + rewriter.create(loc, endIndices).getResult(); + auto stridesTensor = + rewriter.create(loc, strides).getResult(); + + return rewriter.create( + loc, outTy, input, startTensor, endTensor, stridesTensor); +} + +// Get a dynamic slice of the tensor from startIndex to endIndex with stride +// step on the specifed dimension. The input startIndex(default to 0), +// endIndex(default to dimSize), and step(default to 1) can be optional. +FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, + Type outTy, Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim, + size_t dimSizeIndexBits) { + auto loc = op->getLoc(); + auto inputTy = input.getType().dyn_cast(); + auto rank = inputTy.getRank(); + + dim = (dim + rank) % rank; + Value dimSize = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create(loc, input, dim)); + + Value normStartIndex = + startIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + Value normEndIndex = + endIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) + : dimSize; + Value step = + stepOpt ? *stepOpt + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + + if (dimSizeIndexBits == 32) { + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + normStartIndex = + rewriter.create(loc, intType, normStartIndex); + normEndIndex = rewriter.create(loc, intType, normEndIndex); + step = rewriter.create(loc, intType, step); + } + FailureOr> dimSizesInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex, + normEndIndex, step, dim, dimSizes, + dimSizeIndexBits); +} + +// This defines a template to construct ops whose legalizations are +// specialized. +template +class ConvertAtenViewOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rankType = + adaptor.self().getType().template dyn_cast(); + if (!rankType) + return op.emitError("Only ranked tensor types are currently supported"); + + SmallVector dimSizes; + if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { + return op.emitError("Dims size must be a list of Scalar"); + } + + auto loc = op.getLoc(); + auto newRank = dimSizes.size(); + auto outTy = OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + + if (newRank == 0 || rankType.getRank() == 0) { + SmallVector newShape(newRank, 1); + Value output = rewriter.create( + loc, + RankedTensorType::get( + newShape, + outTy.template dyn_cast().getElementType()), + adaptor.self()); + rewriter.replaceOpWithNewOp(op, outTy, output); + return success(); + } + + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { + dSize = rewriter.create(loc, dSize).getResult(); + return dSize; + }); + + const auto &options = ConvertAtenOp::getOptions(); + Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); + if (options.dimSizeIndexBits == 32) { + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are + // unlikely to exceed the range of i32(4GiB) + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { + // dimSize: cast i64 -> i32 + dSize = rewriter.create(loc, intType, dSize); + return dSize; + }); + } + + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + auto rank = rankType.getRank(); + for (size_t d = 0; d < rank; ++d) { + Value dimSize = rewriter.create( + loc, intType, rewriter.create(loc, adaptor.self(), d)); + numel = rewriter.create(loc, numel, dimSize); + } + numel = rewriter.create(loc, rewriter.getIndexType(), + numel); + Value mhloShape = rewriter.create(loc, dimSizes); + Value computedShape = rewriter.create( + loc, mhloShape.getType(), numel, mhloShape); + rewriter.replaceOpWithNewOp( + op, outTy, adaptor.self(), computedShape); + return success(); + } + + bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const; +}; + +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { + return getListConstructElements(adaptor.size(), dimSizes); +} + +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { + return getListConstructElements(adaptor.shape(), dimSizes); +} +} // namespace + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + auto getOptionalVal = [&](Value val) -> llvm::Optional { + if (val.getType().isa()) { + return llvm::None; + } else { + return val; + } + }; + + llvm::Optional start = getOptionalVal(adaptor.start()); + llvm::Optional end = getOptionalVal(adaptor.end()); + llvm::Optional step = getOptionalVal(adaptor.step()); + + FailureOr sliceInfo = + getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim, + options.dimSizeIndexBits); + if (failed(sliceInfo)) + return op.emitError("can not create a dynmaic slice"); + + auto slice = *sliceInfo; + rewriter.replaceOp(op, slice); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + + auto rank = selfTy.getRank(); + if (rank == 0) + return rewriter.notifyMatchFailure( + op, "The rank of tensor must be greater than 0"); + + SmallVector dims; + dims.reserve(rank); + for (int r = 0; r < rank; ++r) { + auto dSize = selfTy.getShape()[r]; + if (dSize == ShapedType::kDynamicSize) + return rewriter.notifyMatchFailure( + op, "the size of the dimension being squeezed can't be unknown"); + if (dSize != 1) + dims.push_back(r); + } + if (dims.size() == 0) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } + + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + auto rank = selfTy.getRank(); + if (rank == 0) + return rewriter.notifyMatchFailure( + op, "the rank of tensor must be greater than 0"); + + dim = toPositiveDim(dim, rank); + if (selfTy.getShape()[dim] != 1) { + if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) + return rewriter.notifyMatchFailure( + op, "the size of the dimension being squeezed is can't be unknown"); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self()); + return success(); + } + + SmallVector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + dims.erase(dims.begin() + dim); + if (dims.size() == 0) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnsqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("dim must be a Scalar constant"); + + auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), + {dim}, options.dimSizeIndexBits); + if (failed(unsqzTensorInfo)) + return rewriter.notifyMatchFailure(op, + "failed to create unsqueezed tensor"); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo); + return success(); +} + +void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToMhloOptions &options) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenSqueezeOp); + INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); +#undef INSERT_ATENOP_PATTERN + +#define INSERT_VIEW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_VIEW_OP_PATTERN(AtenViewOp); + INSERT_VIEW_OP_PATTERN(AtenReshapeOp); +#undef INSERT_VIEW_OP_PATTERN +} diff --git a/lib/Conversion/TorchToStd/CMakeLists.txt b/lib/Conversion/TorchToStd/CMakeLists.txt deleted file mode 100644 index 63f0c8c6aca3..000000000000 --- a/lib/Conversion/TorchToStd/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToStd - TorchToStd.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStd - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRFuncDialect - TorchMLIRTorchDialect -) - -torch_mlir_target_includes(TorchMLIRTorchToStd) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 22182ab204c3..558ac82a1ab9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -43,7 +43,8 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( @@ -53,8 +54,8 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { self); return success(); } else { - return op.emitError( - "Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } } }; @@ -94,13 +95,14 @@ class ConvertAtenBinaryOp : public OpConversionPattern { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) - return op.emitError("Add: input datatypes mismatched"); + return rewriter.notifyMatchFailure(op, "Input datatypes mismatched"); rewriter.replaceOpWithNewOp( op, @@ -140,37 +142,43 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue)); if (!isFloat && !isInt) - return op->emitError("Unable to extract the scalar constant"); + return rewriter.notifyMatchFailure(op, + "Unable to extract the scalar constant"); if (dtype.isa()) { tosaTensor = tosa::getConstTensor( rewriter, op, (isFloat ? doubleValue : intValue), dshape) - .getValue(); + .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); if (w != 32 && w != 64) - return op->emitError("Unsupported integer type") << intType; + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "Unsupported integer type: " << intType; + }); if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); + tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); + tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } - } else - return op->emitError("Usupported element type"); + } else { + return rewriter.notifyMatchFailure(op, "Usupported element type"); + } return success(); } @@ -186,11 +194,13 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter, // `alpha` has not been specified. int64_t alphaValue; if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue))) - return op->emitError("Currently only scalar constants are supported for " - "alpha in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "alpha in TOSA operation"); // When no alpha has been specified, this must be 1. if (checkForUnity && alphaValue != 1) - return op->emitError("Unsupported integer value for alpha"); + return rewriter.notifyMatchFailure(op, + "Unsupported integer value for alpha"); alphaTensor = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue); @@ -214,12 +224,13 @@ class ConvertAtenAddSubOp : public OpConversionPattern { auto rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); if (auto lhsElemTy = lhsType.getElementType().dyn_cast()) { if (lhsElemTy.getWidth() > 32) - return op.emitError( - "Integers with widths greater than 32 are not supported"); + return rewriter.notifyMatchFailure( + op, "Integers with widths greater than 32 are not supported"); } auto outType = OpConversionPattern::getTypeConverter() @@ -228,16 +239,17 @@ class ConvertAtenAddSubOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } Value rhsAsTensor; if (!rhsType) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, outElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsType ? rhs : rhsAsTensor; @@ -246,8 +258,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(), alphaTensor, outElemTy, /*checkForUnity=*/false))) { - return op.emitError("Currently only scalar constants are supported for " - "alpha in conversion to TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "alpha in conversion to TOSA operation"); } auto multTensor = rewriter.create( @@ -262,8 +275,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { return success(); } else { - return op.emitError( - "Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } } }; // namespace @@ -283,26 +296,29 @@ class ConvertAtenCompareOp : public OpConversionPattern { auto rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // For bitwise operators, only integer datatype legalization is supported if (lhsElemTy.isa() && std::is_same()) { - return op.emitError("For bitwise operators, only integer datatype " - "legalization is supported"); + return rewriter.notifyMatchFailure(op, + "For bitwise operators, only integer " + "datatype legalization is supported"); } Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, lhsElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; // There is no Lesser operator in TOSA. @@ -343,7 +359,8 @@ class ConvertAtenMulOp : public OpConversionPattern { auto lhsType = lhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -351,8 +368,8 @@ class ConvertAtenMulOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); Value rhsTensor; if (std::is_same()) { @@ -363,10 +380,11 @@ class ConvertAtenMulOp : public OpConversionPattern { auto rhsType = rhs.getType().dyn_cast(); if (!rhsType) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), - rhsAsTensor, outElemTy, {}))) - return op.emitError( - "Currently only scalar constants are supported for " - "conversion in TOSA operation"); + rhsAsTensor, outElemTy, {}))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } } rhsTensor = rhsType ? rhs : rhsAsTensor; } @@ -383,11 +401,12 @@ class ConvertAtenMulOp : public OpConversionPattern { lhs, rhsTensor, /*shift=*/0); return success(); - } else { - // Quantized multiplication may need to rescale inputs. - return op.emitError("Only floating-point or integer datatype " - "legalization currently supported"); } + + // Quantized multiplication may need to rescale inputs. + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype " + "legalization currently supported"); } }; @@ -405,19 +424,21 @@ class ConvertAtenDivOp : public OpConversionPattern { auto rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor, lhsElemTy, {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; @@ -463,12 +484,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); - } else { - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return op.emitError( - "Only floating-point datatype legalization currently supported"); } + // Sigmoid legalization in TOSA for quantized element-type uses specialized + // tosa.table construct. + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); } template <> @@ -481,12 +501,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); - } else { - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return op.emitError( - "Only floating-point datatype legalization currently supported"); } + // Sigmoid legalization in TOSA for quantized element-type uses + // specialized tosa.table construct. + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); } template <> @@ -499,22 +518,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Maps to tosa.clamp which has both int and fp limits. int64_t clampMin = 0; Value clampIn = self; - if (selfTy) { - // Rescale the clampIn for quantized types. TBD - if (!selfTy.getElementType().isa()) { - return op.emitError( - "Only floating-point datatype legalization currently supported"); - } - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), clampIn, - rewriter.getI64IntegerAttr(clampMin), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); - return success(); - } else { - return op.emitError("Only Tensor types supported in TOSA"); + if (!selfTy) { + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); } + + // Rescale the clampIn for quantized types. TBD + if (!selfTy.getElementType().isa()) { + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); + } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), clampIn, + rewriter.getI64IntegerAttr(clampMin), + rewriter.getI64IntegerAttr(std::numeric_limits::max()), + rewriter.getF32FloatAttr(0.0f), + rewriter.getF32FloatAttr(std::numeric_limits::max())); + return success(); } using ReductionConvFunc = llvm::Optional (*)(PatternRewriter &, @@ -547,14 +567,15 @@ class ConvertAtenReductionOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); auto outputTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); if (!outputTy) - return op.emitError( - "Only ranked tensor type outputs permitted for reduce_mean"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type outputs permitted for reduce_mean"); ElementsAttr reduceDimsAttr; bool keepDims; @@ -571,7 +592,7 @@ class ConvertAtenReductionOp : public OpConversionPattern { // TBD - support dtype casting. - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -677,7 +698,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA argmax"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA argmax"); int64_t reduceDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) { @@ -729,11 +751,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto castToInt64 = [&](Value result) -> LogicalResult { auto resTy = result.getType().cast(); if (!resTy) - return op.emitError("Argmax: Result is not a shaped type"); + return rewriter.notifyMatchFailure(op, + "Argmax: Result is not a shaped type"); auto resShape = resTy.getShape(); - auto outTy = - RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type()); + auto outTy = RankedTensorType::get(resShape, outputETy); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(outTy), result); @@ -779,11 +801,13 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA argmax"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA argmax"); SmallVector newOutputShape; if (failed(generateSqueezedShape(op, selfTy, rewriter, newOutputShape))) - return op.emitError("Squeeze could not compute new shape"); + return rewriter.notifyMatchFailure(op, + "Squeeze could not compute new shape"); auto resultTy = OpConversionPattern::getTypeConverter() ->convertType(op.getResult().getType()) @@ -871,17 +895,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Pow"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); if (!selfTy.getElementType().isa()) - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); Value expTensor; Value expScalar = op.exponent(); if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, selfTy.getElementType(), {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, expTensor); @@ -927,7 +954,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) - return op.emitError("Matmul: input datatypes mismatched"); + return rewriter.notifyMatchFailure(op, + "Matmul: input datatypes mismatched"); // Legalization constructs may offer input shapes but expect output shapes // to be inferred, e.g. @@ -1194,7 +1222,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedLhsType), - rankBroadcastedLhs, transposedLhsDimsConst.getValue()) + rankBroadcastedLhs, transposedLhsDimsConst.value()) .getResult(); } @@ -1273,7 +1301,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedRhsType), - rankBroadcastedRhs, transposedRhsDimsConst.getValue()) + rankBroadcastedRhs, transposedRhsDimsConst.value()) .getResult(); } @@ -1424,14 +1452,13 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto transposedOpType = RankedTensorType::get(transposedOpShape, outputElemTy); - output = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(transposedOpType), - reshapedOp.getResult(), transposedOpShapeConst.getValue()) - .getResult(); + output = rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(transposedOpType), + reshapedOp.getResult(), transposedOpShapeConst.value()) + .getResult(); } else { output = reshapedOp.getResult(); @@ -1451,12 +1478,13 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { Value lhs, rhs; if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) - return op.emitError("Failed to read matmul inputs"); + return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); Value output; if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) - return op.emitError("Failed to perform matmul operation"); + return rewriter.notifyMatchFailure(op, + "Failed to perform matmul operation"); rewriter.replaceOpWithNewOp( op, @@ -1485,7 +1513,8 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); return success(); } @@ -1508,7 +1537,8 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -1544,7 +1574,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -1555,8 +1586,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { return op.emitError("aten.Linear called but weight rank not 2 or 3"); // Protection against crash due to unguarded code in TOSA->LinAlg. + // TODO: This should be handled in TOSA->LinAlg instead. if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape()) - return op.emitError("aten.Linear needs statically shaped input"); + return rewriter.notifyMatchFailure( + op, "aten.Linear needs statically shaped input"); return success(); } @@ -1569,7 +1602,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { Value lhs, rhs; if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) - return op.emitError("Failed to read matmul inputs"); + return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); // The aten.Linear op has a bias tensor that is added to the matmul output. auto bias = adaptor.bias(); @@ -1578,8 +1611,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { // TOSA does not mandate that elementwise op tensors need to be ranked. if (!biasTy.template isa() && !biasTy.template isa()) - return op.emitError("Only tensor types supported in GEMM to " - "TOSA for bias tensor"); + return rewriter.notifyMatchFailure( + op, "Only tensor types supported in GEMM to TOSA for bias tensor"); // RHS must have its last two dims transposed prior to matrix // multiplication. @@ -1612,12 +1645,13 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( transposedRhsType), - rhs, transposedRhsShapeConst.getValue()); + rhs, transposedRhsShapeConst.value()); Value matmulOutput; if (failed( this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput))) - return op.emitError("Failed to perform matmul operation"); + return rewriter.notifyMatchFailure(op, + "Failed to perform matmul operation"); Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { @@ -1651,17 +1685,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Rsub"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Rsub"); if (!selfTy.getElementType().isa()) - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, selfTy.getElementType(), {}))) - return op.emitError("Currently only scalar constants are supported for " - "conversion in TOSA Rsub operation"); + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Rsub operation"); if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, alphaTensor, selfTy.getElementType(), @@ -1694,8 +1731,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .template cast(); if (!inputTy || !weightTy || !outputTy) - return op.emitError( - "Input, weight and output to Convolution must be ranked tensors"); + return rewriter.notifyMatchFailure( + op, "Input, weight and output to Convolution must be ranked tensors"); auto inputElemTy = inputTy.getElementType(); auto weightElemTy = weightTy.getElementType(); @@ -1703,10 +1740,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto weightShape = weightTy.getShape(); if (inputTy.getRank() != 4) - return op.emitError("Unimplemented: only 2D convolutions supported"); + return rewriter.notifyMatchFailure( + op, "Unimplemented: only 2D convolutions supported"); if (!weightTy.hasStaticShape()) - return op.emitError("Unimplemented: TOSA only supports static weight"); + return rewriter.notifyMatchFailure( + op, "Unimplemented: TOSA only supports static weight"); // Bias is optional. TOSA mandates a zero tensor here, so construct one if // required. @@ -1719,16 +1758,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor( rewriter, op, zeroVec, {static_cast(weightShape[0])}) - .getValue(); + .value(); } else { SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor(rewriter, op, zeroVec, {static_cast(weightShape[0])}) - .getValue(); + .value(); } } else { if (!bias.getType().cast()) - return op.emitError("Bias provided but not a ranked tensor"); + return rewriter.notifyMatchFailure( + op, "Bias provided but not a ranked tensor"); } auto biasElemTy = inputElemTy.template isa() ? inputElemTy @@ -1767,7 +1807,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedInputType), input, - nchwToNhwcTransposeConst.getValue()) + nchwToNhwcTransposeConst.value()) .getResult(); SmallVector transposedWeightShape( @@ -1779,7 +1819,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedWeightType), weight, - nchwToNhwcTransposeConst.getValue()) + nchwToNhwcTransposeConst.value()) .getResult(); int64_t outputHDim, outputWDim; @@ -1826,7 +1866,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedOutputType), - convOpResult, nhwcToNchwTransposeConst.getValue()) + convOpResult, nhwcToNchwTransposeConst.value()) .getResult(); Value rescaledResult = transposedOutput; @@ -1850,19 +1890,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in TOSA Reshape"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Reshape"); // Check that at most one dimension is -1 SmallVector newShape; if (!matchPattern(op.shape(), m_TorchConstantIntList(newShape))) - return op.emitError("Only constant shape supported in TOSA Reshape"); + return rewriter.notifyMatchFailure( + op, "Only constant shape supported in TOSA Reshape"); int auto_sz = 0; for (auto s : newShape) auto_sz += (s == -1 ? 1 : 0); if (auto_sz > 1) - op.emitError("At most one dimension may be specified as -1 to " - "automatically calculate its size"); + return rewriter.notifyMatchFailure( + op, "At most one dimension may be specified as -1 to " + "automatically calculate its size"); auto newType = RankedTensorType::get(newShape, selfTy.getElementType()); @@ -1929,7 +1972,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor output if (!adaptor.input().getType().dyn_cast()) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); auto outType = getTypeConverter()->convertType(op.getType()); @@ -1937,12 +1981,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // FIXME: Handle training and momentum. if (op.momentum().getType().isa()) - op.emitError("Unsupported None for momentum"); + return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); auto meanType = adaptor.running_mean().getType().dyn_cast(); auto varianceType = adaptor.running_var().getType().dyn_cast(); if (!varianceType || !meanType) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); // Normalization ops perform elementwise ops of a single mean/stdev value // against the feature map and because input is NCHW, the rank-1 value must be @@ -1954,7 +1999,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType toBcastType = toBcast.getType().dyn_cast(); if (toBcastType.getRank() > 1) - op->emitError("Rank cannot be more than 1"); + return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); RankedTensorType outTensorType = outType.cast(); SmallVector newShape = {toBcastType.getShape()[0]}; @@ -1974,26 +2019,27 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.running_mean(), meanVal))) - op.emitError("Failed to reshape running mean"); + return rewriter.notifyMatchFailure(op, "Failed to reshape running mean"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.running_var(), varianceVal))) - op.emitError("Failed to reshape running variance"); + return rewriter.notifyMatchFailure(op, + "Failed to reshape running variance"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.weight(), weightVal))) - op.emitError("Failed to reshape weight"); + return rewriter.notifyMatchFailure(op, "Failed to reshape weight"); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, getTypeConverter(), outType, adaptor.bias(), biasVal))) - op.emitError("Failed to reshape bias"); + return rewriter.notifyMatchFailure(op, "Failed to reshape bias"); double eps; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) - return op.emitError("eps must be a scalar constant"); + return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); @@ -2021,11 +2067,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor output if (!adaptor.input().getType().dyn_cast()) - return op.emitError("Only ranked tensor types are supported"); + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); auto inputType = adaptor.input().getType().cast(); if (inputType.getRank() > 4) - return op.emitError("Only up to 4D tensors are supported"); + return rewriter.notifyMatchFailure(op, + "Only up to 4D tensors are supported"); auto outType = getTypeConverter()->convertType(op.getType(0)); @@ -2033,9 +2081,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // FIXME: Handle the None cases for the optional parameters. if (adaptor.weight().getType().isa()) - return op.emitError("Unsupported None for weight"); + return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); if (adaptor.bias().getType().isa()) - return op.emitError("Unsupported None for bias"); + return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); auto weightType = adaptor.weight().getType().cast(); auto biasType = adaptor.bias().getType().cast(); @@ -2065,7 +2113,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (inputType.getShape()[index + meanAndVarShapeRank] != value || weightType.getShape()[index] != value || biasType.getShape()[index] != value) - return op.emitError("mismatching contracting dimension"); + return rewriter.notifyMatchFailure(op, + "mismatching contracting dimension"); } // Helper for computing mean and variance. @@ -2096,7 +2145,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto elemCntConst = tosa::getConstTensor(rewriter, op.getOperation(), {static_cast(elemCnt)}, {1}) - .getValue(); + .value(); Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); @@ -2146,7 +2195,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double eps; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) - return op.emitError("eps must be a scalar constant"); + return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); @@ -2181,7 +2230,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType || !selfType.hasStaticShape()) - return op.emitError( + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types with static shapes are currently supported"); int64_t selfRank = selfType.getRank(); @@ -2189,19 +2239,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t start_dim, end_dim; if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim))) - return op.emitError("start_dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, + "start_dim must be a Scalar constant"); start_dim = toPositiveDim(start_dim, selfRank); if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim))) - return op.emitError("end_dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "end_dim must be a Scalar constant"); end_dim = toPositiveDim(end_dim, selfRank); if (selfRank > 0 && !isValidDim(start_dim, selfRank)) - return op.emitError("start_dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "start_dim is statically invalid"); if (selfRank > 0 && !isValidDim(end_dim, selfRank)) - return op.emitError("end_dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "end_dim is statically invalid"); if (end_dim < start_dim) - return op.emitError("end_dim must be larger than start_dim"); + return rewriter.notifyMatchFailure(op, + "end_dim must be larger than start_dim"); SmallVector newShape; for (auto s : llvm::enumerate(selfType.getShape())) { @@ -2238,7 +2290,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError( + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types with static shapes are currently supported"); SmallVector dimListInt; @@ -2247,10 +2300,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only constant dimensions are currently supported"); int64_t selfRank = selfType.getRank(); + // TODO: If this is already verified on the op then we can drop checking here. for (auto &d : dimListInt) { d = toPositiveDim(d, selfRank); if (!isValidDim(d, selfRank)) - return op.emitError("Not all dims are valid"); + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); } auto transposeDimsConst = mlir::tosa::getConstTensor( @@ -2258,7 +2312,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), - transposeDimsConst.getValue()); + transposeDimsConst.value()); return success(); } @@ -2271,13 +2325,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, ln2Shape) - .getValue(); + .value(); auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); @@ -2298,29 +2353,32 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // Integer types with width > 32 are not supported auto selfIntType = selfElemTy.dyn_cast(); if (selfIntType && selfIntType.getWidth() > 32) { - return op.emitError( - "Integer types with width greater than 32 are not supported"); + return rewriter.notifyMatchFailure( + op, "Integer types with width greater than 32 are not supported"); } SmallVector constTypeShape(selfType.getRank(), 1); Value threshold, value; if (failed(torchScalarToTosaTensor(rewriter, op, op.threshold(), threshold, selfElemTy, constTypeShape))) - return op.emitError("Only scalar constant is supported for threshold"); + return rewriter.notifyMatchFailure( + op, "Only scalar constant is supported for threshold"); if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), value, selfElemTy, constTypeShape))) - return op.emitError("Only scalar constant is supported for value"); + return rewriter.notifyMatchFailure( + op, "Only scalar constant is supported for value"); // Threshold only clamps the upper values. tosa::ClampOp has the same // value for both threshold and clamped value so cannot be used. @@ -2345,23 +2403,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) { - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); } auto selfRank = selfType.getRank(); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfRank); if (!isValidDim(dim, selfRank)) - return op.emitError("dim is statically invalid"); + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector outShape; for (auto en : llvm::enumerate(selfType.getShape())) { @@ -2386,7 +2445,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // FIXME: memory_format is not handled. @@ -2403,16 +2463,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.input().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); // FIXME: train and p are not handled. bool train; if (!matchPattern(op.train(), m_TorchConstantBool(&train))) - op.emitError("train must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "train must be a Scalar constant"); if (train) - op.emitError("train must be false"); + return rewriter.notifyMatchFailure(op, "train must be false"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.input()); @@ -2428,17 +2489,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } SmallVector outShape; if (!matchPattern(op.size(), m_TorchConstantIntList(outShape))) - return op.emitError("size must consist of Scalar constants"); + return rewriter.notifyMatchFailure(op, + "size must consist of Scalar constants"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), @@ -2459,24 +2522,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto outType = x.getType().cast(); auto loc = op->getLoc(); auto absX = rewriter.create(loc, outType, x); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).getValue(); + auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).getValue(); + auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).getValue(); + auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).getValue(); + auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2500,8 +2563,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x) { - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one @@ -2510,12 +2573,12 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}).getValue(); + tosa::getConstTensor(rewriter, op, 0.70710678, {}).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).getValue(); + Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); return normalCdf; @@ -2530,18 +2593,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isa()) { - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") { - return op.emitError("Unsupported value of approximate"); + return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self()); @@ -2561,18 +2626,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types are currently supported"); auto selfElemTy = selfType.getElementType(); if (!selfElemTy.isa()) { - return op.emitError("Only floating-point datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } // TODO: Handle approximate. std::string approximate; if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") { - return op.emitError("Unsupported value of approximate"); + return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } auto loc = op->getLoc(); @@ -2583,10 +2650,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const double kAlpha = cstAlpha0 * cstAlpha1; Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}) - .getValue(); + tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}).value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}).getValue(); + tosa::getConstTensor(rewriter, op, -0.5, {}).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( @@ -2620,10 +2686,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesType = indices.getType().dyn_cast(); if (!indicesType || !indicesType.getElementType().isa()) - return op.emitError("Indices must be of integer tensor type"); + return rewriter.notifyMatchFailure( + op, "Indices must be of integer tensor type"); if (indicesType.getRank() != 2) - return op.emitError("indices must be of rank 2"); + return rewriter.notifyMatchFailure(op, "indices must be of rank 2"); auto weightType = weight.getType().cast(); if (weightType.getRank() != 2) @@ -2711,22 +2778,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); // Only statically resolvable values are currently supported int64_t dim0, dim1; if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) - return op->emitError("dim0 must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim0 must be a Scalar constant"); if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) - return op->emitError("dim1 must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim1 must be a Scalar constant"); dim0 = toPositiveDim(dim0, selfType.getRank()); dim1 = toPositiveDim(dim1, selfType.getRank()); auto selfRank = selfType.getRank(); if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank)) - return op->emitError("dim0 and dim1 must be less than tensor rank"); + return rewriter.notifyMatchFailure( + op, "dim0 and dim1 must be less than tensor rank"); SmallVector transposeDims; for (auto i = 0; i < selfType.getRank(); ++i) @@ -2740,7 +2808,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), - transposeDimsConst.getValue()); + transposeDimsConst.value()); return success(); } @@ -2752,12 +2820,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); auto indicesType = getTypeConverter()->convertType(op.getType(1)).dyn_cast(); if (!indicesType) - return op.emitError("Only tensor types are supported"); + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); auto selfElemType = selfType.getElementType(); auto indicesElemType = indicesType.getElementType(); @@ -2765,16 +2833,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfType.getRank()); if (!isValidDim(dim, selfType.getRank())) - return op->emitError("dim must be less than tensor rank"); + return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank"); bool keepDim; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) - return op->emitError("keepdim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant"); SmallVector reducedShape, prunedShape; for (auto en : llvm::enumerate(selfType.getShape())) { @@ -2820,39 +2888,42 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType || !selfType.hasStaticShape()) - return op.emitError("Only tensor types with static shape are supported"); + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfType.getRank()); if (!isValidDim(dim, selfType.getRank())) - return op->emitError("dim must less than tensor rank"); + return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); int64_t start; if (!matchPattern(op.start(), m_TorchConstantInt(&start))) - return op->emitError("start must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); if (start < 0) - return op->emitError("Currently unsupported: start < 0"); + return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); int64_t end; if (!matchPattern(op.end(), m_TorchConstantInt(&end))) - return op->emitError("end must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); // FIXME: add support for start/end < 0 and end < start if (end < start) - return op->emitError("Currently unsupported: end < start"); + return rewriter.notifyMatchFailure(op, + "Currently unsupported: end < start"); int64_t step; if (!matchPattern(op.step(), m_TorchConstantInt(&step))) - return op->emitError("step must be a Scalar constant"); + return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); if (step != 1) - return op->emitError("step value other than 1 is currently unsupported"); + return rewriter.notifyMatchFailure( + op, "step value other than 1 is currently unsupported"); SmallVector startSlice(selfType.getRank(), 0); SmallVector sizeSlice = llvm::to_vector(selfType.getShape()); @@ -2919,7 +2990,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { RankedTensorType::get(transposedInputShape, inputElemTy); return rewriter .create(op->getLoc(), transposedInputType, input, - transposeDimsConst.getValue()) + transposeDimsConst.value()) .getResult(); } @@ -2962,7 +3033,8 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { // case of adaptive pooling. Also performs input CHW->HWC transpose. if (failed(processInputs(op, adaptor, rewriter, input, kernel, stride, pad, outputTy))) - return op.emitError("Failed to process inputs for pooling"); + return rewriter.notifyMatchFailure( + op, "Failed to process inputs for pooling"); auto pooledOutput = rewriter @@ -2996,7 +3068,8 @@ class ConvertAtenAdaptivePoolingOp auto inputXchw = adaptor.self(); auto inputTy = inputXchw.getType().template cast(); if (!inputTy) - return op.emitError("Adaptive avgpool requires ranked tensor input"); + return rewriter.notifyMatchFailure( + op, "Adaptive avgpool requires ranked tensor input"); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); @@ -3004,7 +3077,8 @@ class ConvertAtenAdaptivePoolingOp // Rank sanity check. if (inputTy.getRank() != 4 && inputRank != 3) - return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + return rewriter.notifyMatchFailure( + op, "NCHW->NHWC transpose requires 3D or 4D tensor"); int64_t inputHDim = inputShape[inputRank - 2]; int64_t inputWDim = inputShape[inputRank - 1]; @@ -3020,8 +3094,8 @@ class ConvertAtenAdaptivePoolingOp outputHDim = outputWDim = outputSize[0]; } else { if (outputSize.size() != 2) - return op.emitError( - "Adaptive avgpool output_size not 1 or 2 elements."); + return rewriter.notifyMatchFailure( + op, "Adaptive avgpool output_size not 1 or 2 elements."); // Assumes 'None' (e.g. output_size=(None, 5) ) is expressed as <=0. outputHDim = @@ -3096,12 +3170,14 @@ static LogicalResult getOutputTypeAndPoolingParameters( RankedTensorType inputTy = inputXchw.getType().cast(); if (!inputTy) - return op.emitError("Pooling op requires ranked tensor input"); + return rewriter.notifyMatchFailure( + op, "Pooling op requires ranked tensor input"); auto inputRank = inputTy.getRank(); // Rank sanity check. if (inputTy.getRank() != 4 && inputRank != 3) - return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + return rewriter.notifyMatchFailure( + op, "NCHW->NHWC transpose requires 3D or 4D tensor"); SmallVector kernelSizeInts, strideInts, paddingInts; if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) @@ -3149,7 +3225,8 @@ class ConvertAtenMaxPool2dOp op, "Non-const dilation for pooling op unsupported."); // TOSA pooling only supports unit dilation. if (dilationArray[0] > 1 || dilationArray[1] > 1) - return op.emitError("Cannot process non-unit pooling dilation."); + return rewriter.notifyMatchFailure( + op, "Cannot process non-unit pooling dilation."); if (failed(getOutputTypeAndPoolingParameters( @@ -3206,29 +3283,32 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { .template dyn_cast(); if (!outType) - return op.emitError("Only Tensor types supported in TOSA"); + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? if (!op.layout().getType().template isa()) - return op.emitError("Only default layout is supported"); + return rewriter.notifyMatchFailure(op, + "Only default layout is supported"); bool pinMemory; if (!op.pin_memory().getType().template isa() && (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { - return op.emitError( - "Unsupported pin_memory, should be either None or false"); + return rewriter.notifyMatchFailure( + op, "Unsupported pin_memory, should be either None or false"); } SmallVector shape; if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { - return op.emitError("Shape must be a list of Scalar constants"); + return rewriter.notifyMatchFailure( + op, "Shape must be a list of Scalar constants"); } int64_t size = 1; @@ -3237,7 +3317,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { SmallVector values(size, fillVal); auto constOp = - tosa::getConstTensor(rewriter, op, values, shape).getValue(); + tosa::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); @@ -3258,18 +3338,19 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { .template dyn_cast(); if (!outType || !outType.hasStaticShape()) - return op.emitError( - "Only Tensor types with static shapes are currently supported"); + return rewriter.notifyMatchFailure( + op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); } Value constOp; if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), constOp, outElemTy, outType.getShape()))) - return op.emitError("Supplied value must be a Scalar constant"); + return rewriter.notifyMatchFailure( + op, "Supplied value must be a Scalar constant"); rewriter.replaceOpWithNewOp(op, outType, constOp); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 9b1b61cd5338..bcc29c0195fe 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -297,13 +297,13 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale, output_zp); - if (!val.hasValue()) + if (!val.has_value()) return llvm::None; if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.getValue(), div_const, 0) + val.value(), div_const, 0) .getResult(); } diff --git a/lib/Conversion/Utils/CMakeLists.txt b/lib/Conversion/Utils/CMakeLists.txt index f4baf8634539..6b352bdc5491 100644 --- a/lib/Conversion/Utils/CMakeLists.txt +++ b/lib/Conversion/Utils/CMakeLists.txt @@ -3,4 +3,9 @@ add_mlir_conversion_library(TorchMLIRConversionUtils ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils + + LINK_LIBS PUBLIC + MLIRArithmeticDialect + MLIRLinalgDialect + TorchMLIRTorchDialect ) diff --git a/lib/Dialect/Torch/IR/CMakeLists.txt b/lib/Dialect/Torch/IR/CMakeLists.txt index 989fed0dcbac..cf54afe06c2e 100644 --- a/lib/Dialect/Torch/IR/CMakeLists.txt +++ b/lib/Dialect/Torch/IR/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_library(TorchMLIRTorchDialect Core LINK_LIBS PUBLIC + MLIRFuncDialect MLIRIR MLIRSupport MLIRControlFlowInterfaces diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index bae66e8b07b3..835be031644b 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -65,7 +65,7 @@ Type Torch::parseTorchDialectType(AsmParser &parser) { StringRef mnemonic; Type genType; auto parseResult = generatedTypeParser(parser, &mnemonic, genType); - if (parseResult.hasValue()) + if (parseResult.has_value()) return genType; parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `" << TorchDialect::getDialectNamespace() << "`"; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 47af3796a3db..f34ece3468da 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -28,6 +28,49 @@ using namespace mlir::torch::Torch; // Utilities //===----------------------------------------------------------------------===// +Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, + Location loc, Value value, + Type desiredType, + bool userAllowsRefinement) { + Type type = value.getType(); + + // If the value is already of the desired type, we're done. + if (type == desiredType) + return value; + + // If the type is a tensor, then adjust the static information. + if ((type.isa() && desiredType.isa()) || + (type.isa() && + desiredType.isa())) { + Value adjusted = builder.create(value.getLoc(), + desiredType, value); + return adjusted; + } + + // If the type is a subtype of desiredType, then we need to derefine it to + // desiredType, unless the user allows refinement. + if (isValidSubtype(type, desiredType)) { + if (!userAllowsRefinement) { + Value adjusted = + builder.create(value.getLoc(), desiredType, value); + return adjusted; + } else { + return value; + } + } + + // If the desiredType is subtype of type, then we assume that the desiredType + // is dynamically valid, so we do an unchecked cast. + if (isValidSubtype(desiredType, type)) { + Value adjusted = + builder.create(value.getLoc(), desiredType, value); + return adjusted; + } + + // No known adjustment. + return Value(); +} + Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value tensor) { @@ -86,6 +129,10 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { static Value getScalarValue(Value input, Location loc, PatternRewriter &rewriter) { + auto inputType = input.getType(); + if (inputType.isa()) { + return input; + } Value scalar = nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { if (valueTensorLiteralOp && @@ -246,8 +293,9 @@ LogicalResult ClassTypeOp::verify() { // PrimLoopOp //===----------------------------------------------------------------------===// -OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional index) { - assert(index.hasValue() && index.value() == 0); +OperandRange +PrimLoopOp::getSuccessorEntryOperands(Optional index) { + assert(index.has_value() && index.value() == 0); return iterArgsInit(); } @@ -256,7 +304,7 @@ void PrimLoopOp::getSuccessorRegions( SmallVectorImpl ®ions) { (void)operands; - if (!index.hasValue()) { + if (!index.has_value()) { regions.emplace_back(®ion(), region().getArguments().slice(1)); return; } @@ -328,7 +376,7 @@ void PrimIfOp::getSuccessorRegions(Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (index.hasValue()) { + if (index.has_value()) { regions.push_back(RegionSuccessor(getResults())); return; } @@ -466,8 +514,8 @@ OpFoldResult DerefineOp::fold(ArrayRef operands) { return nullptr; } -void DerefineOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { +void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { bool madeChange = false; for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) { @@ -536,7 +584,7 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef operands) { // r[i] = lo + step*i such that i >= 0 and r[i] < hi // So maximize `i` such that lo + step * i < hi // ==> i == ceildiv(hi - lo, step) - return IntegerAttr::get(lo.getType(), + return IntegerAttr::get(lo.cast().getType(), llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, APInt::Rounding::UP)); } @@ -554,7 +602,8 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef operands) { auto indexInt = index.dyn_cast_or_null().getValue(); auto startInt = start.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); - return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt); + return IntegerAttr::get(index.cast().getType(), + startInt + stepInt * indexInt); } //===----------------------------------------------------------------------===// @@ -777,41 +826,163 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenLenStrOp::fold(ArrayRef operands) { - if(auto stringConstruct = s().getDefiningOp()) - return getI64IntegerAttr(getContext(), stringConstruct.valueAttr().getValue().size()); + if (auto stringConstruct = s().getDefiningOp()) + return getI64IntegerAttr(getContext(), + stringConstruct.valueAttr().getValue().size()); return nullptr; } +LogicalResult rewrite0DBinaryTensorOp(Operation *op, + PatternRewriter &rewriter) { + Location loc = op->getLoc(); + // This canonicalization pattern also includes aten div/mul/add/sub ops + // between tensor and scalar, like aten.add.Scalar op + if (op->getNumOperands() < 2) { + return failure(); + } + auto lhs = getScalarValue(op->getOperand(0), loc, rewriter); + auto rhs = getScalarValue(op->getOperand(1), loc, rewriter); + auto outType = op->getResult(0).getType(); + + if (!lhs || !rhs) { + return rewriter.notifyMatchFailure( + op, "only int scalar lhs or rhs is supported"); + } + if (isa( + op)) { + Value alpha = getScalarValue(op->getOperand(2), loc, rewriter); + if (!alpha) { + return rewriter.notifyMatchFailure(op, + "only int scalar alpha is supported"); + } + rhs = rewriter.create(loc, rhs, alpha); + } + + if (isa(op)) { + // None rounding mode + if (op->getOperand(2).getType().isa()) { + Value quotient = rewriter.create(loc, lhs, rhs); + rewriter.replaceOpWithNewOp(op, outType, + quotient); + return success(); + } + std::string roundingMode; + if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) { + return rewriter.notifyMatchFailure( + op, "only None, 'floor' or 'trunc' rounding mode is supported"); + } + if (roundingMode == "floor") { + Value quotient = rewriter.create(loc, lhs, rhs); + rewriter.replaceOpWithNewOp(op, outType, + quotient); + return success(); + } + // For "trunc" rounding mode, insted of canonicalizing it into + // aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds + // complexity but helps little in optimization (such as constant folding), + // we are trying to fold it. + if (roundingMode == "trunc") { + int64_t lhsInt; + int64_t rhsInt; + if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) { + return failure(); + } + if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) { + return failure(); + } + + int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt); + Value resultScalar = rewriter.create( + loc, rewriter.getI64IntegerAttr(result)); + rewriter.replaceOpWithNewOp(op, outType, + resultScalar); + return success(); + } + + return failure(); + } + + Value result; + // Other Add/Sub/Mul ops + if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } else if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } else if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } + rewriter.replaceOpWithNewOp(op, outType, result); + return success(); +} + //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// - void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) { - // The lhs and rhs of the add.tensor op should be 0d tensors for the - // canonicalization to be carried out. - // `aten.add.tensor(self, other, alpha)` is canonicalized to - // `aten.add.int(self, aten.mul.int(other, alpha))`. - - Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter); - if (!lhs) - return rewriter.notifyMatchFailure(op, "lhs scalar is empyty"); - if (!lhs.getType().isa()) - return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType"); - - Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter); - if (!rhs) - return rewriter.notifyMatchFailure(op, "rhs scalar is empyty"); - if (!rhs.getType().isa()) - return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType"); - - Value mul = rewriter.create(op->getLoc(), rhs, op.alpha()); - Value add = rewriter.create(op->getLoc(), lhs, mul); - rewriter.replaceOpWithNewOp( - op, op.self().getType(), add); - return success(); + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenAddScalarOp +//===----------------------------------------------------------------------===// +void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSubTensorOp +//===----------------------------------------------------------------------===// +void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSubScalarOp +//===----------------------------------------------------------------------===// +void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenMulTensorOp +//===----------------------------------------------------------------------===// +void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenMulScalarOp +//===----------------------------------------------------------------------===// +void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenDivTensorModeOp +//===----------------------------------------------------------------------===// +void AtenDivTensorModeOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); }); } @@ -1479,6 +1650,35 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenAddTOp +//===----------------------------------------------------------------------===// + +void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { + auto lhsListConstruct = op.a().getDefiningOp(); + if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) + return failure(); + + auto rhsListConstruct = op.b().getDefiningOp(); + if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) + return failure(); + + SmallVector concatenatedList; + for (auto a : lhsListConstruct.getOperands()) { + concatenatedList.push_back(a); + } + for (auto b : rhsListConstruct.getOperands()) { + concatenatedList.push_back(b); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), + concatenatedList); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenEqIntListOp //===----------------------------------------------------------------------===// @@ -1574,6 +1774,27 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// PrimListUnpackOp +//===----------------------------------------------------------------------===// + +void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) { + auto torchList = op.operand(); + if (isListPotentiallyMutated(torchList)) { + return failure(); + } + + auto listConstruct = torchList.getDefiningOp(); + if (!listConstruct) + return failure(); + + rewriter.replaceOp(op, listConstruct.elements()); + return success(); + }); +} + static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) { if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) { return isa operands) { static bool isListConstructNotModified(Value torchList) { return llvm::all_of(torchList.getUsers(), [](Operation *op) { - return isa(op); - }); + return isa(op); + }); } OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { @@ -1853,7 +2074,7 @@ void ShapeCalculateOp::getSuccessorRegions( SmallVectorImpl ®ions) { (void)operands; - if (!index.hasValue()) { + if (!index.has_value()) { // First thing the op does is branch into the shape calculation. regions.emplace_back(&shapeCalculation()); return; @@ -1886,3 +2107,153 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return emitOpError("expected number of shapes to match number of results"); return success(); } + +//===----------------------------------------------------------------------===// +// GlobalSlotModuleInitializerOp +//===----------------------------------------------------------------------===// + +LogicalResult GlobalSlotModuleInitializerOp::verify() { + // We centralize all verification of the global slots and the + // InitializeGlobalSlotsOp into here, since it requires processing the whole + // module. + + // TODO: We should really have a `torch.module` and have this initializer be + // a region attached to it. + + ModuleOp module = cast(getOperation()->getParentOp()); + for (auto op : module.getOps()) { + if (op.getOperation() != getOperation()) + return op.emitError("there must be only one global slot initializer"); + } + + // Collect the relevant symbol names we will verify. + DenseSet knownGlobalSlots; + for (auto op : module.getOps()) + knownGlobalSlots.insert(op.sym_nameAttr()); + DenseSet initializedGlobalSlots; + auto initialize = cast(getBody()->getTerminator()); + for (Attribute symName : initialize.slotSymNames()) { + auto wasInserted = initializedGlobalSlots + .insert(symName.cast().getAttr()) + .second; + if (!wasInserted) + return initialize.emitError("duplicate initialization of global slot: ") + << symName; + } + auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) { + return lhs.cast().getValue() < + rhs.cast().getValue(); + }; + auto known = llvm::to_vector(knownGlobalSlots); + llvm::sort(known, lessThanByStringValue); + auto initialized = llvm::to_vector(initializedGlobalSlots); + llvm::sort(initialized, lessThanByStringValue); + + // Check that the global slots in the module are all initialized. + SymbolTable symbolTable(module); + if (initializedGlobalSlots != knownGlobalSlots) { + InFlightDiagnostic diag = initialize.emitOpError( + "must have one initializer for each global slot in the module"); + for (auto knownGlobalSlot : known) { + auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast()); + if (!initializedGlobalSlots.count(knownGlobalSlot)) { + diag.attachNote( + symbolTable.lookup(symName.getAttr()).getLoc()) + .append("missing global slot initializer for ", symName); + } + } + for (auto initializedGlobalSlot : initialized) { + if (!knownGlobalSlots.count(initializedGlobalSlot)) { + diag.attachNote().append( + "unexpected global slot initializer for non-existent global slot ", + FlatSymbolRefAttr::get(initializedGlobalSlot.cast())); + } + } + return diag; + } + + // Check that initial values satisfy type bounds. + for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) { + auto symName = initialize.slotSymNames()[i].cast(); + auto initialValue = initialize.getOperand(i); + auto globalSlotOp = symbolTable.lookup(symName.getValue()); + if (!isValidSubtype(initialValue.getType(), globalSlotOp.typeBound())) { + return initialize.emitOpError().append( + "initial value for global slot ", symName, " has type ", + initialValue.getType(), " which is not within the bound ", + globalSlotOp.typeBound()); + } + } + + auto walkResult = getOperation()->walk([](Operation *op) { + // We only permit a small set of ops in the module initializer. + // These ops are essentially those which can be produced by the IValue + // importer. + if (isa(op)) + return WalkResult::advance(); + op->emitOpError() << "is not allowed in a module initializer"; + return WalkResult::interrupt(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// InitializeGlobalSlotsOp +//===----------------------------------------------------------------------===// + +ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (parser.parseLSquare()) + return failure(); + SmallVector slotSymNames; + while (!succeeded(parser.parseOptionalRSquare())) { + NamedAttrList dummy; + StringAttr slotSymName; + if (parser.parseSymbolName(slotSymName, "dummy", dummy)) + return failure(); + slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName)); + if (parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand initialValue; + if (parser.parseOperand(initialValue)) + return failure(); + Type initialValueType; + if (parser.parseColonType(initialValueType)) + return failure(); + if (parser.parseRParen()) + return failure(); + if (parser.resolveOperand(initialValue, initialValueType, result.operands)) + return failure(); + } + result.addAttribute("slotSymNames", + ArrayAttr::get(parser.getContext(), slotSymNames)); + return success(); +} + +void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict(getOperation()->getAttrs(), + /*elidedAttrs=*/{"slotSymNames"}); + p << " ["; + p.printNewline(); + for (int i = 0, e = getNumOperands(); i < e; ++i) { + p << " " << slotSymNames()[i] << "(" << initialValues()[i] << " : " + << initialValues()[i].getType() << ")"; + p.printNewline(); + } + p << "]"; +} + +LogicalResult InitializeGlobalSlotsOp::verify() { + if (initialValues().size() != slotSymNames().size()) + return emitOpError("expected number of operands to match number of slots"); + return success(); +} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 10e2008adf18..3c790b58f2b0 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -236,7 +236,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, } int64_t size; auto optionalInt = parser.parseOptionalInteger(size); - if (optionalInt.hasValue()) { + if (optionalInt.has_value()) { if (failed(*optionalInt)) return Type(); sizes.push_back(size); @@ -356,6 +356,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (auto floatType = dtype.dyn_cast()) { return dtype; } else if (auto integerType = dtype.dyn_cast()) { + if (integerType.isUnsignedInteger()) { + return IntegerType::get(context, integerType.getWidth(), + IntegerType::Unsigned); + } return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); } diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index bef668f3a289..d9de61f0063d 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -2,9 +2,11 @@ add_mlir_library(TorchMLIRTorchPasses AdjustCallingConventions.cpp DecomposeComplexOps.cpp DropShapeCalculations.cpp + EraseModuleInitializer.cpp Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp + LowerToBackendContract.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp ReduceOpVariants.cpp @@ -13,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses ReifyShapeCalculations.cpp ShapeLibrary.cpp SimplifyShapeCalculations.cpp - VerifyConversionToValueSemantics.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6d3dda97651e..a77a1b9cfa60 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -228,6 +229,32 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNarrowOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNarrowOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value start = op.start(); + Value dim = op.dim(); + Value length = op.length(); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value startPlusLength = + rewriter.create(loc, one.getType(), start, length); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.self(), /*dim=*/dim, /*start=*/start, + /*end=*/startPlusLength, /*step=*/one); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenZeroOp : public OpRewritePattern { @@ -683,6 +710,80 @@ class DecomposeAtenTOp : public OpRewritePattern { }; } // namespace +// Decompose aten.roll into aten.slice and aten.cat ops. +// https://pytorch.org/docs/stable/generated/torch.roll.html +namespace { +class DecomposeAtenRollOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRollOp op, + PatternRewriter &rewriter) const override { + SmallVector shifts; + if (!getListConstructElements(op.shifts(), shifts)) + return rewriter.notifyMatchFailure( + op, "unimplemented: shifts not list of Scalar"); + SmallVector dims; + if (!getListConstructElements(op.dims(), dims)) + return rewriter.notifyMatchFailure( + op, "unimplemented: dims not list of Scalar"); + + if (shifts.size() != dims.size()) + return op.emitError("list sizes of shifts and dims are not the same"); + + auto loc = op.getLoc(); + Value constNone = rewriter.create(loc); + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + auto self = op.self(); + auto selfTy = self.getType().cast(); + // roll(input, shift, dim) = cat({ + // slice(input, dim, (dimSize-shift)%dimSize, none), + // slice(input, dim, 0, (dimSize-shift)%dimSize}, dim) + auto imitateRoll = [&](Value input, Value shift, Value dim, + int64_t cstDim) { + Value dimSize = rewriter.create(loc, input, dim); + Value shiftPlus = rewriter.create(loc, dimSize, shift); + Value splitPos = + rewriter.create(loc, shiftPlus, dimSize); + ArrayRef inputShape = selfTy.getSizes(); + SmallVector sizes; + sizes.append(inputShape.begin(), inputShape.end()); + sizes[cstDim] = ShapedType::kDynamicSize; + Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), + selfTy.getDtype()); + Value slice0 = rewriter.create( + loc, sliceTy, input, dim, splitPos, constNone, constOne); + Value slice1 = rewriter.create( + loc, sliceTy, input, dim, constZero, splitPos, constOne); + + Type listType = Torch::ListType::get(sliceTy); + Value slices = rewriter.create( + loc, listType, llvm::ArrayRef{slice0, slice1}); + return rewriter.create(loc, op.getType(), slices, dim); + }; + int rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); + Value output = self; + auto nShifts = shifts.size(); + for (size_t k = 0; k < nShifts; ++k) { + auto dim = dims[k]; + int64_t cstDim = -1; + if (!matchPattern(dim, m_TorchConstantInt(&cstDim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dim must be constant"); + + cstDim = toPositiveDim(cstDim, rank); + output = imitateRoll(output, shifts[k], dim, cstDim); + } + rewriter.replaceOp(op, output); + return success(); + } +}; +} // namespace + // Decompose aten.repeat into aten.expand and aten.view ops. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html @@ -705,7 +806,8 @@ class DecomposeAtenTOp : public OpRewritePattern { // unsqueezed_sizes += [1, s] // expanded_sizes += [m, s] // reshaped_sizes += [m * s] -// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) +// return +// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) // namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { @@ -754,7 +856,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { assert(leadingRank >= 0 && "leadingRank should greater than 0"); for (size_t i = 0; i < leadingRank; ++i) { insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{repeats[i]}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{repeats[i]}); reshapedSizes.push_back(repeats[i]); } @@ -772,18 +875,20 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { loc, rewriter.getI64IntegerAttr(selfShape[i])); } - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one, dimSize}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{scale, dimSize}); + insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, + ArrayRef{one, dimSize}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{scale, dimSize}); Value scaledSize = rewriter.create(loc, dimSize, scale); reshapedSizes.push_back(scaledSize); } Type dtype = self.getType().cast().getDtype(); - Type unsqueezedType = - ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype); + Type unsqueezedType = ValueTensorType::get( + context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); + Type expandedType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = @@ -792,8 +897,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); - auto reshaped = - rewriter.create(loc, unsqueezedType, op.self(), unsqueezedDims); + auto reshaped = rewriter.create(loc, unsqueezedType, op.self(), + unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); @@ -804,6 +909,68 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { }; } // namespace +// Decompose aten.flatten.using_ints into aten.view op. +namespace { +class DecomposeAtenFlattenUsingIntsOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + MLIRContext *context = op.getContext(); + int64_t rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); + + int64_t start, end; + if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start)) || + !matchPattern(op.end_dim(), m_TorchConstantInt(&end))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires start and end dims to be constants"); + } + + SmallVector newSizes; + if (rank == 0) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + newSizes.push_back(one); + } else { + start = toPositiveDim(start, rank); + end = toPositiveDim(end, rank); + + if (start > end) { + return rewriter.notifyMatchFailure( + op, "expected end dim larger than start dim"); + } + + newSizes.reserve(rank - end + start); + for (int64_t k = 0; k < start; ++k) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + newSizes.push_back( + rewriter.create(loc, self, /*dim=*/dim)); + } + Value flattenDimSize = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + newSizes.push_back(flattenDimSize); + for (int64_t k = end + 1; k < rank; ++k) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + newSizes.push_back( + rewriter.create(loc, self, /*dim=*/dim)); + } + } + Value newSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), newSizes); + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + newSizeList); + return success(); + } +}; +} // namespace + // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { @@ -897,13 +1064,14 @@ class DecomposeAtenConvolutionOverrideableOp }; } // namespace -// Decompose aten.convolution_overrideable to aten.convolution +// Decompose aten._convolution-like to aten.convolution namespace { -class DecomposeAten_ConvolutionOp - : public OpRewritePattern { +template +class DecomposeAten_ConvolutionLikeOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_ConvolutionOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvolutionLikeOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( @@ -938,6 +1106,26 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv_transpose2d to aten.convolution +namespace { +class DecomposeAtenConvTranspose2dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.input(), op.weight(), op.bias(), + op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue, + op.output_padding(), op.groups()); + + return success(); + } +}; +} // namespace + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -1009,6 +1197,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.self(); + unsigned inputRank = getTensorRank(input); Value dimList = op.dim(); Value keepDim = op.keepdim(); Value dtype = op.dtype(); @@ -1022,10 +1211,11 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { op, "only floating-point type is supported"); } - auto dimListConstruct = dimList.getDefiningOp(); - if (!dimListConstruct) { + SmallVector dimListElements; + if (!getListConstructElements(dimList, dimListElements) && + !dimList.getType().isa()) { return rewriter.notifyMatchFailure( - op, "expect dimList to be constructed from list construct"); + op, "expected `dim` to be `None` or constructed from list construct"); } // Compute sum along dimensions specified in `dimList`. @@ -1033,12 +1223,18 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, input, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); + Value productDimSize; + // Case: Reduce along all dims. + if (dimListElements.empty() && inputRank != 0) { + productDimSize = rewriter.create(loc, input); + } else { + productDimSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + for (Value dim : dimListElements) { + Value dimSize = rewriter.create(loc, input, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, productDimSize); @@ -1117,6 +1313,67 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { }; } // namespace +// grad_output * mask * scale +namespace { +class DecomposeAtenNativeDropoutBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value maskedGradOutput = rewriter.create( + loc, op.getType(), op.grad_output(), op.mask()); + rewriter.replaceOpWithNewOp(op, op.getType(), + maskedGradOutput, op.scale()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenNativeDropoutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = op.input(); + Value prob = op.p(); + bool train = false; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); + + BaseTensorType inputType = input.getType().cast(); + if (!train) { + // TODO(yancey.yx): supports inference mode + return op.emitError( + "native_dropout does not support argument train is false"); + } + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + Value noneVal = rewriter.create(loc); + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = + rewriter.create(loc, inputType, maskedInput, oneMinusP); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + boolMask = rewriter.create( + loc, op.getResult(1).getType(), boolMask, one); + rewriter.replaceOp(op, {output, boolMask}); + return success(); + } +}; +} // namespace + // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -1172,6 +1429,64 @@ class DecomposeAtenStdOp : public OpRewritePattern { }; } // namespace +// Softplus(x, beta, threshold) = +// x * beta > threshold ? x : log(1 + exp(x * beta)) / beta +namespace { +class DecomposeAtenSoftplusOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSoftplusOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.self(); + BaseTensorType inputType = input.getType().cast(); + + Value inputTimesBeta = + rewriter.create(loc, inputType, input, op.beta()); + + // out = log1p(exp(input * beta)) / beta + Value exp = rewriter.create(loc, inputType, inputTimesBeta); + Value log1p = rewriter.create(loc, inputType, exp); + Value out = + rewriter.create(loc, inputType, log1p, op.beta()); + + // Select where x * beta > threshold + auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), + rewriter.getI1Type()); + Value condition = rewriter.create( + loc, boolResType, inputTimesBeta, op.threshold()); + + rewriter.replaceOpWithNewOp(op, op.getType(), condition, + input, out); + return success(); + } +}; +} // namespace + +// Decompose aten.std.dim to sqrt(var.dim(x)) +namespace { +class DecomposeAtenStdDimOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenStdDimOp op, + PatternRewriter &rewriter) const override { + Value self = op.self(); + BaseTensorType inputTensorType = self.getType().cast(); + if (!inputTensorType.hasDtype() || + !inputTensorType.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "aten.std.dim expects input tensor of floating-point type"); + } + + Value varDim = + rewriter.create(op->getLoc(), op.getType(), self, + op.dim(), op.unbiased(), op.keepdim()); + rewriter.replaceOpWithNewOp(op, op.getType(), varDim); + return success(); + } +}; +} // namespace + // Hardsigmoid(x) = max(0, min(1, (x+3)/6)) namespace { class DecomposeAtenHardsigmoidOp : public OpRewritePattern { @@ -1275,6 +1590,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Value input, Value prob, Value &output) { auto inputType = input.getType().cast(); + auto inputDtype = inputType.getDtype(); auto probType = prob.getType().cast(); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || @@ -1290,8 +1606,14 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // Since the `aten.rand_like` op expects float-type operand, create a // float-type tensor with the same shape as that of the `input`. + Type floatDtype = rewriter.getF64Type(); + if (inputDtype.isa() && + inputDtype.cast().getWidth() < 64) { + floatDtype = rewriter.getF32Type(); + } + Value floatTensor = - convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); + convertTensorToDtype(rewriter, loc, input, floatDtype); Value none = rewriter.create(loc); Value randomVal = rewriter.create( loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, @@ -1305,7 +1627,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. - output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); + output = convertTensorToDtype(rewriter, loc, lessThanP, inputDtype); return success(); } @@ -1440,6 +1762,187 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNativeLayerNormBackwardOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeLayerNormBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.input().getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); + SmallVector reduceDimInts(normalizedShapeSizesTorchInt.size()); + SmallVector outerDimInts(axis); + std::iota(reduceDimInts.begin(), reduceDimInts.end(), axis); + std::iota(outerDimInts.begin(), outerDimInts.end(), 0); + auto reducedTy = op.getResult(1).getType(); + auto sizeListType = ListType::get(IntType::get(context)); + + auto fromIntsToList = [&](ArrayRef dimInts) -> Value { + SmallVector dimVals; + dimVals.reserve(dimInts.size()); + std::transform(dimInts.begin(), dimInts.end(), + std::back_inserter(dimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value dimList = + rewriter.create(loc, sizeListType, dimVals); + return dimList; + }; + // build reduce & outer dims + auto reduceDimList = fromIntsToList(reduceDimInts); + auto outerDimList = fromIntsToList(outerDimInts); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + // x_hat + Value inputSubMean = rewriter.create( + loc, inputTy, op.input(), op.mean(), one); + Value xHat = + rewriter.create(loc, inputTy, inputSubMean, op.rstd()); + + // grad(x_hat) + Value xHatGrad = op.grad_out(); + Value weight = op.weight(); + Value wGrad = none; + if (!weight.getType().isa()) { + xHatGrad = rewriter.create(loc, xHatGrad.getType(), + xHatGrad, weight); + wGrad = rewriter.create( + loc, weight.getType(), + rewriter.create(loc, inputTy, op.grad_out(), xHat), + outerDimList, cstFalse, none); + } + Value bias = op.bias(); + Value bGrad = none; + if (!bias.getType().isa()) { + bGrad = rewriter.create( + loc, bias.getType(), op.grad_out(), outerDimList, cstFalse, none); + } + + Value cstTrue = rewriter.create(loc, true); + // grad(mean) + Value meanGrad = rewriter.create( + loc, op.mean().getType(), xHatGrad, reduceDimList, cstTrue, none); + // grad(rstd) + Value xHatGradMulXHat = + rewriter.create(loc, inputTy, xHatGrad, xHat); + Value rstdGrad0 = rewriter.create( + loc, op.rstd().getType(), xHatGradMulXHat, reduceDimList, cstTrue, + none); + Value rstdGrad1 = + rewriter.create(loc, inputTy, xHat, rstdGrad0); + + // grad(input) + Value inner = + rewriter.create(loc, inputTy, xHatGrad, meanGrad, one); + inner = + rewriter.create(loc, inputTy, inner, rstdGrad1, one); + Value gradInput = + rewriter.create(loc, inputTy, op.rstd(), inner); + + rewriter.replaceOp(op, {gradInput, wGrad, bGrad}); + + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenNativeLayerNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.input().getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); + SmallVector reduceDimInts(normalizedShapeSizesTorchInt.size()); + std::iota(reduceDimInts.begin(), reduceDimInts.end(), axis); + auto reducedTy = op.getResult(1).getType(); + auto sizeListType = ListType::get(IntType::get(context)); + + // build reduce dims + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.input(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = rewriter.create(loc, inputTy, inputMean, op.input()); + Value inputZeroMean = rewriter.create( + loc, inputTy, op.input(), inputMeanExpanded, one); + // var(x) = mean((x - mean(x))^2) + Value inputZeroMeanSquare = rewriter.create( + loc, inputTy, inputZeroMean, inputZeroMean); + Value inputVar = rewriter.create( + loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.eps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = + rewriter.create(loc, inputTy, inputRsqrtVar, op.input()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + Value out = rewriter.create(loc, op.getResult(0).getType(), + inputNormalized); + + Value weight = op.weight(); + Value bias = op.bias(); + if (!weight.getType().isa()) { + out = rewriter.create(loc, out.getType(), out, weight); + } + if (!bias.getType().isa()) { + out = + rewriter.create(loc, out.getType(), out, bias, one); + } + rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { @@ -1900,6 +2403,25 @@ class DecomposeAtenToDtypeLayoutOp }; } // namespace +namespace { +// Decompose `aten.to.device` op into `aten.to.dtype` op. +class DecomposeAtenToDeviceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToDeviceOp op, + PatternRewriter &rewriter) const override { + + // Device information isn't relevant to torch-mlir, so we can drop that info + // here. + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + op.dtype(), op.non_blocking(), + op.copy(), op.memory_format()); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. // @@ -2071,8 +2593,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { + // https://pytorch.org/docs/stable/generated/torch.floor_divide.html + // PyTorch aten.floor_divide is a misnomer because it actually rounds + // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.self(), op.other(), /*rounding_mode=*/cstStrFloor); @@ -2104,6 +2629,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { }; } // namespace +template +static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, + bool unbiased, int64_t correction) { + Location loc = op.getLoc(); + Value self = op.self(); + Value dimList = op.dim(); + Value keepDim = op.keepdim(); + BaseTensorType inputTensorTy = self.getType().cast(); + Type outputType = op.getType(); + BaseTensorType outputTensorType = outputType.cast(); + Type newOutputType = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type()); + if (!inputTensorTy.hasDtype() || + !inputTensorTy.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "support floating-point type input only"); + } + + // Upcasting the input tensor to `F64` dtype for higher precision during the + // computation of the result. + if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { + self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); + inputTensorTy = self.getType().cast(); + } + + unsigned inputRank = getTensorRank(self); + SmallVector dimListElements; + bool isNoneOrEmpty = true; + if (!dimList.getType().template isa()) { + if (!getListConstructElements(dimList, dimListElements)) + return rewriter.notifyMatchFailure( + op, "expect dimList to be constructed from list construct"); + if (!dimListElements.empty() || inputRank == 0) + isNoneOrEmpty = false; + } + if (isNoneOrEmpty) { + for (unsigned i = 0; i < inputRank; i++) + dimListElements.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimListElements); + } + Type meanDimResultType = inputTensorTy; + for (unsigned i = 0; i < dimListElements.size(); i++) + meanDimResultType = computeReductionType( + rewriter, op, meanDimResultType.cast(), + dimListElements[i], + /*keepDim=*/true); + + Value constantNone = rewriter.create(loc); + Value constantTrue = rewriter.create(loc, true); + Value meanAlongDims = rewriter.create( + loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, + /*dtype=*/constantNone); + Value subMean = + createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); + Value square = rewriter.create(loc, inputTensorTy, subMean); + + if (!unbiased) { + Value result = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + result = convertTensorToDtype(rewriter, loc, result, + outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); + } + // Divide the square sum by productDimSize - correction. + Value squareSum = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + + // `productDimSize` is product of sizes of dimensions to be reduced. + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value productDimSize = constantOne; + for (Value dim : dimListElements) { + Value dimSize = rewriter.create(loc, self, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } + Value cstCorrection = rewriter.create( + loc, rewriter.getI64IntegerAttr(correction)); + // The `correction` value should be less than or equal to `productDimSize + + // 1`. + Value productDimSizePlusOne = + rewriter.create(loc, productDimSize, constantOne); + Value cond = + rewriter.create(loc, productDimSizePlusOne, cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + Value productDimSizeSubCorrection = + rewriter.create(loc, productDimSize, cstCorrection); + Value result = rewriter.create(loc, newOutputType, squareSum, + productDimSizeSubCorrection); + result = + convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); +} + // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) @@ -2117,70 +2743,44 @@ class DecomposeAtenVarDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarDimOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.self(); - Value dimList = op.dim(); - Value keepDim = op.keepdim(); - Type outputType = op.getType(); - BaseTensorType inputTensorTy = self.getType().cast(); - if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { - return rewriter.notifyMatchFailure(op, - "support floating type input only"); - } - - auto dimListConstruct = dimList.getDefiningOp(); - if (!dimListConstruct) { - return rewriter.notifyMatchFailure( - op, "expect dimList to be constructed from list construct"); - } - bool unbiased; if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) { return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } + int64_t correction = unbiased ? 1 : 0; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); + return success(); + } +}; +} // namespace - SmallVector dimListElements = dimListConstruct.elements(); - Type meanDimResultType = inputTensorTy; - for (unsigned i = 0; i < dimListElements.size(); i++) - meanDimResultType = computeReductionType( - rewriter, op, meanDimResultType.cast(), - dimListElements[i], - /*keepDim=*/true); - - Value constantNone = rewriter.create(loc); - Value constantTrue = rewriter.create(loc, true); - Value meanAlongDims = rewriter.create( - loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, - /*dtype=*/constantNone); - Value subMean = - createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); - Value square = rewriter.create(loc, inputTensorTy, subMean); - if (unbiased) { - // Bessel’s correction is used. Divide the square sum by - // productDimSize-1. - Value squareSum = rewriter.create( - loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone); - - // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, self, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); - } - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value productDimSizeSubOne = - rewriter.create(loc, productDimSize, constantOne); - rewriter.replaceOpWithNewOp(op, outputType, squareSum, - productDimSizeSubOne); +// Decompose aten.var(x, dims) into: +// sub = aten.sub(x, aten.mean(x, dims)) +// square = aten.square(sub) +// out = aten.sum(square, dims) / (productDimSize - correction) +namespace { +class DecomposeAtenVarCorrectionOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenVarCorrectionOp op, + PatternRewriter &rewriter) const override { + int64_t correction; + if (!op.correction().getType().isa()) { + if (!matchPattern(op.correction(), m_TorchConstantInt(&correction))) + return rewriter.notifyMatchFailure( + op, "Only support constant int correction for aten.var"); } else { - rewriter.replaceOpWithNewOp( - op, outputType, square, dimList, keepDim, /*dtype=*/constantNone); + // The default value in case of `correction` being None is 1. + correction = 1; } + bool unbiased = correction == 0 ? false : true; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; @@ -2235,9 +2835,48 @@ class DecomposeAtenSelectScatterOp }; } // namespace +namespace { +class DecomposeAten_EmbeddingBagOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op, + PatternRewriter &rewriter) const override { + Value weight = op.weight(); + Value indices = op.indices(); + Value offsets = op.offsets(); + Value scaleGradByFreq = op.scale_grad_by_freq(); + Value mode = op.mode(); + Value sparse = op.sparse(); + Value perSampleWeights = op.per_sample_weights(); + Value includeLastOffset = op.include_last_offset(); + Value paddingIdx = op.padding_idx(); + + auto resultType0 = op->getResult(0).getType(); + auto resultType1 = op->getResult(1).getType(); + auto resultType2 = op->getResult(2).getType(); + auto resultType3 = op->getResult(3).getType(); + + mlir::TypeRange returnTypes{resultType0, resultType1, resultType2, + resultType3}; + + rewriter.replaceOpWithNewOp( + op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, + sparse, perSampleWeights, includeLastOffset, paddingIdx); + + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { +public: + DecomposeComplexOpsPass() = default; + DecomposeComplexOpsPass(ArrayRef legalOps) { + this->legalOps = legalOps; + } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -2260,10 +2899,14 @@ class DecomposeComplexOpsPass patterns.add>( context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -2306,14 +2949,23 @@ class DecomposeComplexOpsPass target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); - target.addIllegalOp(); - patterns.add(context); + target.addIllegalOp(); + patterns.add, + DecomposeAten_ConvolutionLikeOp>( + context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -2344,6 +2996,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add>( @@ -2365,7 +3019,11 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); patterns.add(context); @@ -2374,6 +3032,8 @@ class DecomposeComplexOpsPass patterns.add(context); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -2390,6 +3050,18 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + + for (std::string opName : legalOps) { + target.addLegalOp(OperationName(opName, context)); + } if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -2398,7 +3070,9 @@ class DecomposeComplexOpsPass } }; } // namespace + std::unique_ptr> -mlir::torch::Torch::createDecomposeComplexOpsPass() { - return std::make_unique(); +mlir::torch::Torch::createDecomposeComplexOpsPass( + ArrayRef legalOps) { + return std::make_unique(legalOps); } diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp new file mode 100644 index 000000000000..450d84b22ed3 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -0,0 +1,47 @@ +//===- EraseModuleInitializer.cpp --------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class EraseModuleInitializerPass + : public EraseModuleInitializerBase { + void runOnOperation() override { + for (auto initializer : + getOperation().getOps()) { + auto intialize = + cast(initializer.getBody()->getTerminator()); + if (intialize.getNumOperands() == 0) { + initializer.erase(); + } + // The verifier ensures there is only one GlobalSlotModuleInitializerOp. + break; + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createEraseModuleInitializerPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 8ca07604db82..43eba2dc5650 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -48,12 +48,6 @@ static FailureOr findRootNnModule(ModuleOp module) { return rootNnModule; } -static bool hasMeaningfulObjectIdentity(Type type) { - return !type.isa(); -} - //===----------------------------------------------------------------------===// // Object graph recursive traversal. //===----------------------------------------------------------------------===// @@ -100,6 +94,9 @@ class ObjectGraphInfo { assert(it != slotToGlobalSlot.end() && "didn't create global slot"); return it->second; } + llvm::MapVector &getGlobalSlotInitialValues() { + return globalSlotInitialValues; + } private: LogicalResult collectUsedSlots() { @@ -187,8 +184,7 @@ class ObjectGraphInfo { assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end()); slotToGlobalSlot[slot] = globalSlot; slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()}; - if (failed(populateGlobalSlotInitializer(globalSlot, slot))) - return failure(); + globalSlotInitialValues[globalSlot.sym_nameAttr()] = slot.value(); } nameStack.pop_back(); } @@ -201,44 +197,6 @@ class ObjectGraphInfo { } return success(); } - LogicalResult populateGlobalSlotInitializer(GlobalSlotOp globalSlot, - SlotOp slot) { - OpBuilder builder(globalSlot.getContext()); - builder.createBlock(&globalSlot.getRegion()); - - SmallPtrSet needToClone; - Value initialValue = slot.value(); - SmallVector worklist = {initialValue.getDefiningOp()}; - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - if (!needToClone.insert(op).second) - continue; - for (Value operand : op->getOperands()) { - if (auto def = operand.getDefiningOp()) - worklist.push_back(def); - } - } - worklist.assign(needToClone.begin(), needToClone.end()); - llvm::sort(worklist, [](Operation *lhs, Operation *rhs) { - return lhs->isBeforeInBlock(rhs); - }); - BlockAndValueMapping mapping; - for (Operation *op : worklist) { - builder.clone(*op, mapping); - for (Value result : op->getResults()) { - if (!hasMeaningfulObjectIdentity(result.getType())) - continue; - if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result) - .second) { - return op->emitError() << "potentially-aliased value used to " - "initialize multiple slots"; - } - } - } - builder.create(globalSlot->getLoc(), - mapping.lookup(initialValue)); - return success(); - } // Builder for creating GlobalSlotOp's in the module. OpBuilder globalSlotBuilder; // Symbol table for the module. @@ -262,16 +220,50 @@ class ObjectGraphInfo { DenseMap, LinkageInfo> funcLinkageInfo; // The corresponding GlobalSlotOp for each SlotOp in the program. DenseMap slotToGlobalSlot; - // A set of values that we have copied into torch.global_slot initializers, - // which cannot be used in multiple initializers because their object - // identity is important. - DenseSet objectsWithIdentityAlreadyCopiedIntoInitializers; + // The initializing value for each GlobalSlotOp. + // This is a MapVector to keep the order deterministic. + llvm::MapVector globalSlotInitialValues; // Used to keep track of all the used torch slots so that the restrictions can // be applied to those slots only. DenseSet usedSlots; }; } // namespace +LogicalResult +createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable, + ObjectGraphInfo &objectGraphInfo) { + auto builder = OpBuilder::atBlockBegin(module.getBody()); + auto moduleInitializer = + builder.create(module.getLoc()); + Block *body = builder.createBlock(&moduleInitializer.initializer()); + builder.setInsertionPointToEnd(body); + SmallVector opsToMove; + for (Operation &op : *module.getBody()) { + if (isa(op)) + continue; + opsToMove.push_back(&op); + } + BlockAndValueMapping mapping; + for (Operation *op : opsToMove) { + // The ops are used by `torch.slot` ops in the enclosing module. + // Cloning avoids needing to handle those uses specially. + builder.clone(*op, mapping); + } + SmallVector slotSymNames; + SmallVector initialValues; + for (auto &kv : objectGraphInfo.getGlobalSlotInitialValues()) { + StringAttr symName = kv.first; + Value initializer = kv.second; + slotSymNames.push_back(FlatSymbolRefAttr::get(symName)); + initialValues.push_back(mapping.lookup(initializer)); + } + builder.create( + moduleInitializer.getLoc(), + ArrayAttr::get(module.getContext(), slotSymNames), initialValues); + return success(); +} + //===----------------------------------------------------------------------===// // Monomorphization. //===----------------------------------------------------------------------===// @@ -464,26 +456,17 @@ static LogicalResult verifyNnModuleValueUses(Value value) { // Verify that `func` conforms to the subset of allowable method bodies // that we can convert. static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) { - // TODO: Investingate why WalkResult::interrupt() doesn't propagate properly. - LogicalResult ret = success(); - func.walk([&](Block *block) { - for (Value arg : block->getArguments()) { - if (failed(verifyNnModuleValueUses(arg))) { - ret = failure(); + auto walkResult = func.walk([&](Block *block) { + for (Value arg : block->getArguments()) + if (failed(verifyNnModuleValueUses(arg))) return WalkResult::interrupt(); - } - } - for (Operation &op : *block) { - for (Value result : op.getResults()) { - if (failed(verifyNnModuleValueUses(result))) { - ret = failure(); + for (Operation &op : *block) + for (Value result : op.getResults()) + if (failed(verifyNnModuleValueUses(result))) return WalkResult::interrupt(); - } - } - } return WalkResult::advance(); }); - return ret; + return success(!walkResult.wasInterrupted()); } static LogicalResult @@ -605,7 +588,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { instances[classType].push_back(nnModule); } - // Step 2: Verify all functions are suitable to be analyzed by our later code. + // Step 2: Create the torch.global_slot.module_initializer op. + + if (failed(createGlobalSlotModuleInitializer(module, symbolTable, + objectGraphInfo))) + return failure(); + + // Step 3: Verify all functions are suitable to be analyzed by our later code. // This eliminates special handling / error code later. // // This is important, because in principle, we can perform arbitrarily complex @@ -617,7 +606,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { return failure(); } - // Step 3: Calculate the set of monomorphized functions that need to be + // Step 4: Calculate the set of monomorphized functions that need to be // created. For each call that passes !torch.nn.Module to a function, we need // to create a specialized version of that function just for that instance (or // combination of instances in the case of multiple arguments). @@ -642,7 +631,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { return failure(); } - // Step 4: Clone/rewrite functions to implement the necessary + // Step 5: Clone/rewrite functions to implement the necessary // monomorphizations. DenseMap newFuncs; int uniquifier = 0; @@ -657,7 +646,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { monomorphization.argInstances[0].instance.getDefiningOp(), monomorphization.func); } - if (linkageInfo.hasValue()) { + if (linkageInfo.has_value()) { // It's a method. newFunc.setVisibility(linkageInfo->isPrivate ? SymbolTable::Visibility::Private @@ -681,13 +670,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { return failure(); } - // Step 5: Clean up object graph. + // Step 6: Clean up object graph. DenseSet liveFuncs; for (auto &kv : newFuncs) { liveFuncs.insert(kv.second); } for (auto &op : llvm::make_early_inc_range(module.getOps())) { - if (isa(&op)) + if (isa(&op)) continue; if (auto func = dyn_cast(op)) { if (liveFuncs.contains(func)) diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 64a117c523d9..a166068a3f3c 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -6,83 +6,427 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// +// +// This file implements an optimistic dataflow analysis that proves that values +// used in global slot initializers are "safe" (see definition below). This +// analysis allows us to inline global slot initializers. +// +// One thing to note is that this inlining (as with all inlining) can create +// duplicate ops. That is usually not a problem, except for certain large +// tensor literals. We rely on later CSE passes to deduplicate those literals. +// +// For debugging this pass an effort has been made for +// `-debug-only=dataflow` and `-debug-only=torch-inline-global-slots` to give a +// good experience. When debugging this pass, it is recommended to start with +// `-debug-only=torch-inline-global-slots` to find values that are marked +// unsafe unexpectedly and then `-debug-only=dataflow` to find why. +// +//===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "torch-inline-global-slots" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +/// A program point representing a symbol. +/// +/// In principle we could use the `Operation *` program point of the Symbol op, +/// but that just adds a layer of indirection through a symbol table for the +/// purpose of this analysis. +/// +/// This is easier because we only support FlatSymbolRefAttr's in Torch-MLIR in +/// a single module. If we had to support complex nested symbol references, we +/// would probably want to go through the effort to indirect through the symbol +/// tables to make things clearer. +class FlatSymbolRefProgramPoint + : public GenericProgramPointBase { +public: + using Base::Base; + void print(raw_ostream &os) const override { + os << "FlatSymbolRefProgramPoint(" << getValue() << ")"; + } + Location getLoc() const override { + return UnknownLoc::get(getValue().getContext()); + } +}; + +static bool isTypeTriviallySafe(Type type) { + return type.isa(); +} + +static bool isUseTreatedWithValueSemantics(OpOperand &use) { + Operation *op = use.getOwner(); + // If the op unconditionally has value semantics, then the use has value + // semantics. + if (op->hasTrait()) + return true; + // The condition of the torch.prim.if op is treated with value semantics. + if (isa(op) && use.getOperandNumber() == 0) + return true; + // TODO: Generalize the HasValueSemantics trait to support + // operand/result-granularity. + return false; +} + +/// State tracking if an IR construct is "safe". +/// +/// This state is tracked on Value's and also on global slots (via a +/// FlatSymbolRefProgramPoint). +/// +/// In this context, "safe" means that the object is safe to inline. +/// This covers a few concepts +/// - the value cannot be mutated by the program +/// - the value cannot be potentially aliased, with that alias itself being +/// unsafe +class InlineGlobalSlotsAnalysisState : public AnalysisState { +public: + InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { + setSafe(); + } + + bool isUninitialized() const override { + // We are an optimistic analysis, so we are always default initialized to + // the optimistic "assumed safe" state. + return false; + } + + void print(raw_ostream &os) const override { + os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe") + << ")"; + } + + /// Helper for setting the state with the correct ChangeResult. + ChangeResult setSafe(bool newIsSafe = true) { + // As an optimistic analysis, once we prove that a value is unsafe, nothing + // can prove that it is safe again. This is the monotonicity property of + // the dataflow analysis that guarantees that we reach a fixed-point. + // If that property doesn't hold, then there is a bug in the analysis. + assert(!(isSafe == false && newIsSafe == true) && "non-monotonic update"); + if (isSafe == newIsSafe) + return ChangeResult::NoChange; + isSafe = newIsSafe; + return ChangeResult::Change; + } + + /// Helper for updatating the state with the correct ChangeResult based on the + /// safety of a use. + ChangeResult + incorporateSafetyOfUse(const InlineGlobalSlotsAnalysisState *useState) { + // The use is safe, so no need to change anything. + if (useState->isSafe) + return ChangeResult::NoChange; + return setSafe(false); + } + + /// This is an optimistic analysis. We start assuming everything is safe. + bool isSafe = true; +}; + +class InlineGlobalSlotsAnalysis : public DataFlowAnalysis { +public: + InlineGlobalSlotsAnalysis(DataFlowSolver &solver); + LogicalResult initialize(Operation *top) override; + LogicalResult visit(ProgramPoint point) override; + +private: + /// The local transfer function determining the safety of `value`. + bool isValueSafeTransferFunction(Value value); + /// The InitializeGlobalSlotsOp of the current module we are analyzing. + /// + /// This is used to propagate the analysis from globals into to the module + /// initializer. + InitializeGlobalSlotsOp initializeGlobalSlotsOp; +}; + +InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerPointKind(); +} + +LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { + auto walkResult = top->walk([this](Operation *op) { + if (auto globalSlot = dyn_cast(op)) { + auto *state = getOrCreate( + getProgramPoint( + FlatSymbolRefAttr::get(globalSlot.sym_nameAttr()))); + propagateIfChanged(state, + state->setSafe(globalSlot.getVisibility() != + SymbolTable::Visibility::Public)); + } + if (auto globalSlotSet = dyn_cast(op)) { + auto *state = getOrCreate( + getProgramPoint(globalSlotSet.slotAttr())); + propagateIfChanged(state, state->setSafe(false)); + } + // Save the InitializeGlobalSlotsOp for later referencee + if (auto initialize = dyn_cast(op)) { + initializeGlobalSlotsOp = initialize; + } + for (Value result : op->getResults()) { + if (failed(visit(result))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + return success(); +} + +LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { + if (Value value = point.dyn_cast()) { + bool isSafe = isValueSafeTransferFunction(value); + auto *state = getOrCreate(value); + propagateIfChanged(state, state->setSafe(isSafe)); + + // Handle GlobalSlotGetOp's. + if (auto opResult = value.dyn_cast()) { + if (auto globalSlotGet = + dyn_cast(opResult.getOwner())) { + auto *flatSymbolRefPoint = getProgramPoint( + globalSlotGet.slotAttr()); + auto *valueState = getOrCreateFor( + flatSymbolRefPoint, globalSlotGet.result()); + auto *globalState = + getOrCreate(flatSymbolRefPoint); + propagateIfChanged(globalState, + globalState->incorporateSafetyOfUse(valueState)); + } + } + + return success(); + } + if (auto *genericProgramPoint = point.dyn_cast()) { + if (auto *flatSymbolRefPoint = + dyn_cast(genericProgramPoint)) { + if (initializeGlobalSlotsOp) { + auto it = + llvm::find(initializeGlobalSlotsOp.slotSymNames(), + static_cast(flatSymbolRefPoint->getValue())); + Value value = initializeGlobalSlotsOp->getOperand( + std::distance(initializeGlobalSlotsOp.slotSymNames().begin(), it)); + auto *flatSymbolRefState = + getOrCreateFor(value, + flatSymbolRefPoint); + auto *valueState = getOrCreate(value); + propagateIfChanged(valueState, + valueState->setSafe(flatSymbolRefState->isSafe)); + } + return success(); + } + } + LLVM_DEBUG( + { llvm::dbgs() << "visit failing because of: " << point << "\n"; }); + return failure(); +} + +// This is only a member function to access protected get* functions. +bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) { + if (isTypeTriviallySafe(value.getType())) + return true; + for (OpOperand &use : value.getUses()) { + Operation *op = use.getOwner(); + + if (isUseTreatedWithValueSemantics(use)) + continue; + // If the op is read-only and all results are safe, then this value is + // safe. This covers, for example, view-like ops that create aliases. + if ((op->hasTrait() || + MemoryEffectOpInterface::hasNoEffect(op)) && + llvm::all_of(op->getResults(), [&](Value result) { + auto *state = + getOrCreateFor(value, result); + return state->isSafe; + })) + continue; + if (auto initialize = dyn_cast(op)) { + auto symName = initialize.slotSymNames()[use.getOperandNumber()] + .cast(); + auto *state = getOrCreateFor( + value, getProgramPoint(symName)); + if (state->isSafe) + continue; + } + // We may not create all the dependency edges, but that is ok since at + // this point we have already reached the fixed-point. + return false; + } + return true; +} + +SmallVector getBackwardSliceIncludingRoot(Value initialValue) { + SetVector sliceSet; + getBackwardSlice(initialValue, &sliceSet); + SmallVector slice; + llvm::append_range(slice, sliceSet); + slice.push_back(initialValue.getDefiningOp()); + return slice; +} + +static bool isInitialValueTransitivelySafeToInline(Value initialValue, + DataFlowSolver &solver) { + SmallVector slice = getBackwardSliceIncludingRoot(initialValue); + for (Operation *op : slice) { + for (auto result : op->getResults()) { + auto *state = solver.lookupState(result); + if (!state->isSafe) { + return false; + } + } + } + return true; +} + namespace { class InlineGlobalSlotsPass : public InlineGlobalSlotsBase { + void runOnOperation() override { + ModuleOp module = getOperation(); - SymbolTable symbolTable(module); - auto uses = SymbolTable::getSymbolUses(&module.getBodyRegion()); - if (!uses) { - module.emitError() << "cannot analyze symbol uses"; + DataFlowSolver solver; + solver.load(); + if (failed(solver.initializeAndRun(module))) return signalPassFailure(); - } - // Find all the global slots potentially written from within the module. - // (we handle the case of non-private symbols later). - DenseSet potentiallyWrittenGlobalSlots; - for (const SymbolTable::SymbolUse &use : *uses) { - auto flatSymbolRef = use.getSymbolRef().dyn_cast(); - if (!flatSymbolRef) { - use.getUser()->emitError() << "unimplemented: nested SymbolRef's"; - return signalPassFailure(); - } - auto globalSlot = - symbolTable.lookup(flatSymbolRef.getValue()); - if (!globalSlot) - continue; - if (isa(use.getUser())) - continue; + LLVM_DEBUG({ + module->walk([&](Operation *op) { + if (auto globalSlot = dyn_cast(op)) { + auto *state = solver.lookupState( + solver.getProgramPoint( + FlatSymbolRefAttr::get(globalSlot.sym_nameAttr()))); + state->print(llvm::dbgs()); + llvm::dbgs() << ": " + << FlatSymbolRefAttr::get(globalSlot.sym_nameAttr()) + << "\n"; + return; + } + if (op->getNumResults() != 1) + return; + auto *state = solver.lookupState( + op->getResult(0)); + state->print(llvm::dbgs()); + llvm::dbgs() << ": "; + op->dump(); + }); + }); - potentiallyWrittenGlobalSlots.insert(globalSlot); + Torch::InitializeGlobalSlotsOp initialize; + // TODO: Have a torch.module with an optional initializer region to make + // this tighter. + for (auto moduleInitializer : + module.getOps()) { + initialize = cast( + moduleInitializer.getBody()->getTerminator()); + } + if (!initialize) { + return; } - DenseSet toErase; - // Inline all the global slots that are not potentially written. - for (const SymbolTable::SymbolUse &use : *uses) { - auto flatSymbolRef = use.getSymbolRef().cast(); - auto globalSlot = - symbolTable.lookup(flatSymbolRef.getValue()); - if (!globalSlot) - continue; - // And external user might write to the global slot. - if (!globalSlot.isPrivate()) - continue; - // An internal user exists which might write to the global slot. - if (potentiallyWrittenGlobalSlots.contains(globalSlot)) + DenseSet safeToInline; + for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { + auto slotSymName = initialize.slotSymNames()[i].cast(); + Value operand = initialize.getOperand(i); + auto symbolRefPoint = solver.getProgramPoint( + initialize.slotSymNames()[i].cast()); + auto *state = + solver.lookupState(symbolRefPoint); + // We roll the analysis of whether a slot is set or public into the + // main dataflow analysis, so we need to check the slot's + // FlatSymbolRefProgramPoint itself to see if it is safe to inline. + // For example, a public !torch.int is not safe to inline, even though + // it is a value-semantic type and so the actual initializer value + // itself is conceptually safe to inline. + if (!state->isSafe) { continue; - auto globalSlotGet = cast(use.getUser()); - OpBuilder builder(globalSlotGet); - BlockAndValueMapping mapper; - for (Operation &op : globalSlot.getBody()->without_terminator()) - builder.clone(op, mapper); - Value cloned = mapper.lookup( - cast(globalSlot.getBody()->getTerminator()) - .getOperand()); - globalSlotGet.replaceAllUsesWith(cloned); - toErase.insert(globalSlotGet); - toErase.insert(globalSlot); + } + // Check to see if the initializing value is safe to inline. + // This requires a transitive check of all subobjects. + // TODO: This would really be more logical to do as a forward dataflow + // analyis on the whole module initializer rather than doing the + // transitive check backward for each initial value. But it is just + // too much boilerplate to write that with the dataflow framework and we + // generally don't expect long transitive chains of values here -- most + // initial values are just single tensor literals. + if (isInitialValueTransitivelySafeToInline(operand, solver)) { + safeToInline.insert(slotSymName); + } } + SymbolTable symbolTable(module); + DenseSet toErase; + module.walk([&](Torch::GlobalSlotGetOp op) { + if (!safeToInline.count(op.slotAttr())) + return; + // TODO: Make this more ergonomic. + auto it = llvm::find(initialize.slotSymNames(), op.slotAttr()); + Value initialValue = initialize.getOperand( + std::distance(initialize.slotSymNames().begin(), it)); + // It seems inefficient to get a backward slice again here, but we are + // going to be cloning the whole slice anyway, so it doesn't seem like a + // big deal. + SmallVector slice = + getBackwardSliceIncludingRoot(initialValue); + BlockAndValueMapping mapping; + OpBuilder builder(op); + for (Operation *opInSlice : slice) + builder.clone(*opInSlice, mapping); + auto inlinedInitialValue = mapping.lookup(initialValue); + inlinedInitialValue = Torch::adjustStaticInformation( + builder, op.getLoc(), inlinedInitialValue, op.getType(), + /*userAllowsRefinement=*/false); + op.replaceAllUsesWith(inlinedInitialValue); + toErase.insert(op); + }); + + // Clean up after the transform. + + // Erase any pending ops. for (Operation *op : toErase) op->erase(); + // Erase any global slots that we inlined. + // This could be left to SymbolDCE but it's not hard to do here. + for (FlatSymbolRefAttr symName : + llvm::map_range(safeToInline, [](Attribute attr) { + return attr.cast(); + })) { + auto globalSlot = + symbolTable.lookup(symName.getValue()); + globalSlot.erase(); + } + + // Update the initializer. + SmallVector newSlotSymNames; + SmallVector newInitialValues; + for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { + auto slotSymName = initialize.slotSymNames()[i].cast(); + if (!safeToInline.count(slotSymName)) { + newSlotSymNames.push_back(slotSymName); + newInitialValues.push_back(initialize.getOperand(i)); + } + } + { + OpBuilder builder(initialize); + builder.create( + initialize.getLoc(), + ArrayAttr::get(module.getContext(), newSlotSymNames), + newInitialValues); + } + initialize.erase(); } }; } // namespace diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp new file mode 100644 index 000000000000..c9fd3a4845c4 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -0,0 +1,252 @@ +//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "torch-lower-to-backend-contract" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +//===----------------------------------------------------------------------===// +// Checking the backend contract. +//===----------------------------------------------------------------------===// + +static LogicalResult checkType(Operation *op, Type type, + bool actuallyEmitDiagnostics) { + // Allow various scalar types that backends are expected to be able to handle. + if (type.isa()) + return success(); + + // Backends are not expected to support dynamic computations on these types, + // but they frequently appear as parameters to ops which backends + // can statically pattern match and eliminate from the program. + // For example, a tensor operand might be optional, and the backend + // will pattern-match statically whether it is passed as a tensor or None. + if (type.isa()) + return success(); + + // We blanket prohibit non-value-semantic tensors. + // All of our backends are currently based on value-semantic tensors, so + // we consider it our responsibility to lower all non-value-semantic tensors + // to value-semantic tensors. + if (type.isa()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError("unsupported by backend contract: non-value tensor type") + .attachNote() + .append("this is likely due to a missing case in the " + "MaximizeValueSemantics pass"); + } else { + return failure(); + } + } + + // For value-semantic tensors, we require at least a known rank and dtype. + // We are not aware of a situation where our backends can handle an unranked + // tensor type or a tensor with a dynamic dtype. + // + // There are somewhat fundamental reasons for this. In particular, the problem + // of unranked codegen is completely different from the problem of ranked + // codegen (since ranked corresponds to a fixed loop nest structure). For all + // codegen systems we are aware of, the program must be reduced to operate + // on ranked tensors at some point in compilation, and we are not aware of + // any backend with a general solution to this problem before it reaches + // codegen. So we consider it our responsibility to eliminate unranked tensor + // from the program. + // + // We aren't aware of any backend with any infrastructure to represent dynamic + // dtypes, let alone transform and optimize them. Additionally, it is unlikely + // that any backend, even if it supports dynamic dtypes in some form, will + // have an sufficiently rich system for representing PyTorch type promotion + // rules. So we consider it our responsibility to ensure that all dtypes are + // statically known. + if (auto tensorType = type.dyn_cast()) { + if (!tensorType.hasSizes()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError( + "unsupported by backend contract: tensor with unknown rank") + .attachNote() + .append("this is likely due to a missing shape transfer function " + "in shape_lib_gen.py"); + } else { + return failure(); + } + } + if (!tensorType.hasDtype()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError( + "unsupported by backend contract: tensor with unknown dtype") + .attachNote() + .append("this is likely due to a missing case in RefineTypes"); + } else { + return failure(); + } + } + return success(); + } + + // Optional types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. + if (auto optionalType = type.dyn_cast()) { + // TODO: Be stricter about tensor types. + // See comment below for ListType. + if (optionalType.getContainedType().isa()) + return success(); + return checkType(op, optionalType.getContainedType(), + actuallyEmitDiagnostics); + } + // List types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. For example, the + // strides of a convolution op are represented as a list. + if (auto listType = type.dyn_cast()) { + // TODO: Be stricter about tensor types. + // For the moment, there are cases (such as for torch.cat) where we end + // up with `!torch.list` which doesn't have shape or dtype in + // the contained type information. Somehow this slips through and works. + // We should be stricter about this and properly infer the contained type + // and shape. + if (listType.getContainedType().isa()) + return success(); + return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics); + } + // Tuple types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. + if (auto tupleType = type.dyn_cast()) { + for (auto containedType : tupleType.getContainedTypes()) { + if (failed(checkType(op, containedType, actuallyEmitDiagnostics))) + return failure(); + } + return success(); + } + + // Unsupported type. + if (actuallyEmitDiagnostics) { + return op->emitError("unsupported by backend contract: type ") << type; + } else { + return failure(); + } +} + +static bool satisfiesBackendContract(ModuleOp module, + bool actuallyEmitDiagnostics = false) { + // We do not permit `torch.global_slot`'s in the backend contract, since + // support for them is not widespread, and this does not align with PyTorch's + // more tracing-based direction. + // + // We just check for the GlobalSlotModuleInitializerOp since its verifier + // ensures that the set of global slots matches those initialized by the + // module initializer. + auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) { + if (actuallyEmitDiagnostics) { + // Report the error on the terminator to avoid dumping the whole + // initializer itself, which can have pages of ops in it. + op.getBody() + ->getTerminator() + ->emitError("unsupported by backend contract: module initializers") + .attachNote() + .append("this is likely due to InlineGlobalSlots being unable to " + "inline a global slot"); + } + return WalkResult::interrupt(); + }); + if (walkResult0.wasInterrupted()) + return false; + + // Check all the type of all Value's in the program. + // + // A pre-order walk gives a more intuitive "first error". + // TODO: Should we report more than the first error? + // How do we avoid making it too spammy? + auto walkResult1 = module.walk([&](Block *block) { + for (BlockArgument arg : block->getArguments()) + if (failed(checkType(block->getParentOp(), arg.getType(), + actuallyEmitDiagnostics))) { + return WalkResult::interrupt(); + } + for (Operation &op : *block) + for (OpResult result : op.getResults()) + if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + if (walkResult1.wasInterrupted()) + return false; + return true; +} + +namespace { +class LowerToBackendContractPass + : public LowerToBackendContractBase { +public: + LowerToBackendContractPass() = default; + LowerToBackendContractPass(int maxIterations, bool decompose, + ArrayRef backendLegalOps) { + this->maxIterations = maxIterations; + this->decompose = decompose; + this->backendLegalOps = backendLegalOps; + } + void runOnOperation() override { + ModuleOp module = getOperation(); + + OpPassManager pm(module.getOperationName()); + TorchLoweringPipelineOptions options; + options.decompose = decompose; + options.backendLegalOps = backendLegalOps; + createTorchSimplificationPipeline(pm, options); + + int i = 0; + do { + if (i++ == maxIterations) { + LLVM_DEBUG({ + llvm::dbgs() << "LowerToBackendContractPass: " + << "failed to satisfy backend contract after " + << maxIterations + << " iterations of the simplification pipeline\n"; + }); + // Show the diagnostics. + (void)satisfiesBackendContract(module, + /*actuallyEmitDiagnostics=*/true); + return signalPassFailure(); + } + + if (failed(runPipeline(pm, module))) + return signalPassFailure(); + } while (!satisfiesBackendContract(module)); + LLVM_DEBUG({ + llvm::dbgs() << "LowerToBackendContractPass: " + << "succeeded after " << i + << " iterations of the simplification pipeline\n"; + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createLowerToBackendContractPass( + int maxIterations, bool decompose, ArrayRef backendLegalOps) { + return std::make_unique(maxIterations, decompose, + backendLegalOps); +} diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 6168f3d0b5d1..a4db59642028 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -37,7 +37,8 @@ static bool isViewLikeOp(Operation *op) { Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, - TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp>(op); + TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, + AtenNarrowOp, AtenToDeviceOp>(op); } namespace { @@ -122,8 +123,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock PatternRewriter &rewriter) { DenseMap originalReturnTypes; - if (ops.returnOp.hasValue()) { - auto returnOp = ops.returnOp.getValue(); + if (ops.returnOp.has_value()) { + auto returnOp = ops.returnOp.value(); for (auto operand : llvm::enumerate(returnOp->getOperands())) { auto type = operand.value().getType(); if (!type.isa()) @@ -159,8 +160,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock result.setType(resultType.getWithValueSemantics()); }); } - if (ops.returnOp.hasValue()) { - auto returnOp = ops.returnOp.getValue(); + if (ops.returnOp.has_value()) { + auto returnOp = ops.returnOp.value(); for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) { OpOperand &operand = returnOp->getOpOperand(i); auto it = originalReturnTypes.find(i); diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index fbbdbcb191b5..3681eef27e4d 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -31,6 +31,10 @@ void mlir::torch::registerTorchPasses() { "Pipeline lowering a Torch function to Torch backend form.", mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline); mlir::PassPipelineRegistration( + "torch-simplification-pipeline", + "Pipeline simplifying computations in the program.", + mlir::torch::Torch::createTorchSimplificationPipeline); + mlir::PassPipelineRegistration<>( "torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.", mlir::torch::Torch::createTorchShapeRefinementPipeline); } @@ -66,112 +70,83 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { - // General considerations: As a matter of bring-up, we are simultaneously - // building out the frontend pipeline and also co-developing the backend - // support story as well. This means that sometimes the most expedient way to - // support a given program is to "optimize hard enough" that the parts of the - // program that touch unimplemented backend support go away (constant folded, - // dead-code-eliminated, etc.). In the fullness of time, most of that - // optimization should not be necessary, and we should have an "O0" pipeline - // that runs practically no optimizations. - // However, as a matter of expediency, at the moment we do run those - // optimizations. We guard those passes under the `options.optimize` option - // (which default to true, currently). We leave notes with the `OPT-ONLY` tag - // why we currently need that pass for correctness. - // We should eventually remove those passes from the default pipeline once - // backends have enough support. - // In particular the following features are needed in some form from backends: - // - Error handling (RaiseException + error string formatting) - // - First-class list type - // - torch.global_slot lowering - // - ... - // Please try to keep this list somewhat up to date when adding - // "optimize hard enough that it works" transformations. - // Incorporate user annotations and remove signature Python-isms. pm.addPass(createAdjustCallingConventionsPass()); + // Perform the bulk of lowering to the backend contract. + // See the pass documentation for more information. + pm.addPass(createLowerToBackendContractPass( + options.maxIterations, options.decompose, options.backendLegalOps)); +} - if (options.optimize) { - // Eliminate the PrimTupleIndexOp generated from the - // adjustCallingConventions - pm.addNestedPass(createCanonicalizerPass()); - // Inline global slots, which for most inference scenarios deletes them. - // This also exposes more information to intraprocedural transformations - // below like MaximizeValueSemantics and RefineTypes. - // OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting. - // Also don't rely on this pass to expose constants into the program to - // simplify handling of "optional". - pm.addPass(createInlineGlobalSlotsPass()); - } - +// A simplification pipeline to establish the invariants of the backend +// contract (see `satisfiedBackendContract` in `LowerToBackendContract`). +// +// We structure this so that a single run of this pipeline is enough for +// most models, but it is possible for it to take multiple runs to fully +// clean things up when there are cyclic dependencies between certain +// simplifications, such as a decomposition relying on shape refinement which +// depends on another decomposition. +// +// Although technically this pipeline is an implementation detail of +// LowerToBackendContract, we expose it here to help debugging. +// +// LowerToBackendContract will run this pipeline as many times as necessary, but +// in general, it is costly to re-run this pipeline, since all the passes do +// O(module size) work. We want the number of iterations of this pipeline +// to be bounded by meaningful "always in practice small" program properties, +// such as loop nesting depth, number of sequentially dependent steps of +// constant global slots proving that other global slots are dead, etc. +// +// It is generally always possible to construct a pathological input that will +// exceed the number of iterations. If we do find practical cases with +// O(module size) number of iterations of this simplification pipeline, then +// we may need to adjust the approach, such as to do some of the transformations +// together at finer granularity. +void mlir::torch::Torch::createTorchSimplificationPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + // General cleanup. + pm.addNestedPass(createCanonicalizerPass()); + // Inline global slots to expose a bunch of simplification opportunities + // from constant hyperparameters, weights, etc. + pm.addPass(createInlineGlobalSlotsPass()); + // Erase the module initializer if we have proven that all the global slots + // are gone. + pm.addPass(createEraseModuleInitializerPass()); + // Clean up again to avoid needing to to back around the fixed-point + // iteration. + pm.addNestedPass(createCanonicalizerPass()); // Reduce variants of ops to a smaller set of primitives. pm.addNestedPass(createReduceOpVariantsPass()); - - if (options.optimize) { - // OPT-ONLY: Right now we rely on this to eliminate certain branches that - // guard unreachable code that backends can't handle yet, such as lists, - // RaiseException, unimplemented tensor ops, and only-used-in-training - // operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); - // OPT-ONLY: We may have deleted some `torch.global_slot.get` / - // `torch.global_slot.get` ops, which may have left more - // `torch.global_slot`'s unused. - pm.addPass(createSymbolDCEPass()); - } - - //===--------------------------------------------------------------------===// - // Lowering to ranked !torch.vtensors of known dtype. - //===--------------------------------------------------------------------===// - + pm.addNestedPass(createCanonicalizerPass()); + // Remove dead global slots. + pm.addPass(createSymbolDCEPass()); // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's. pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); - - // Update the return op to return value tensors and remove dead ops. + // Update the return op to return value tensors. pm.addPass(Torch::createRefinePublicReturnPass()); pm.addNestedPass(createCanonicalizerPass()); - - // Ensure that all tensors have been converted to value semantics. - pm.addPass(Torch::createVerifyConversionToValueSemanticsPass()); - // Do shape refinement. - // This must be run before RefineTypes (which primarily does dtype inference), - // because Torch type promotion rules actually depend on the shape of the - // operand. - createTorchShapeRefinementPipeline(pm, options); + // This should be run before RefineTypes (which primarily does dtype + // inference), because Torch type promotion rules actually depend on the shape + // of the operand. + createTorchShapeRefinementPipeline(pm); // Refine types in the program, which mainly means inferring dtypes of ops. pm.addNestedPass(Torch::createRefineTypesPass()); - // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. pm.addPass(Torch::createRefinePublicReturnPass()); - - if (options.optimize) { - // This can fold away some branches given the information got from - // RefineTypes before doing maximize value sematics which only works with - // basic blocks. - pm.addNestedPass(createCanonicalizerPass()); - } - - if (options.optimize) { - // All the type refinement we've done above has exposed new information - // that allows folding away more stuff. - // OPT-ONLY: Right now we rely on this to eliminate certain - // branches that guard unreachable code that backends can't handle yet, such - // as lists, RaiseException, unimplemented aten ops, and - // only-used-in-training operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); - } - + // This can fold away some branches given the information got from + // RefineTypes before doing maximize value sematics which only works with + // basic blocks. + pm.addNestedPass(createCanonicalizerPass()); if (options.decompose) { - pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); pm.addNestedPass(createCanonicalizerPass()); } - - // TODO: VerifyTorchBackendContractPass. } -void mlir::torch::Torch::createTorchShapeRefinementPipeline( - OpPassManager &pm, const TorchLoweringPipelineOptions &options) { +void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) { // Reify the shape functions for each op that is present in the shape library. pm.addPass(Torch::createReifyShapeCalculationsPass()); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 7985ce2b3da6..705ce06608cb 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -55,9 +55,9 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) { namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. -class ConvertToImmutableTensors : public RewritePattern { +class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: - ConvertToImmutableTensors(MLIRContext *context) + ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -260,7 +260,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 4adf613466e0..5109a8c5735e 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -71,7 +71,9 @@ class RefinePublicReturnPass // If the return (or transitively other ops) are not the only users, // then we can't be sure that the tensor hasn't been mutated, so stop // here. - if (!llvm::hasSingleElement(copy->getUsers())) + SetVector users(copy->getUsers().begin(), + copy->getUsers().end()); + if (users.size() != 1) break; newOperand = copy.getOperand(); } else { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 4a8ff9aa9f1d..c8049c567a01 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -310,15 +310,15 @@ struct ValueKnowledge { const ValueKnowledge &rhs) { Optional knowledge = meetTypes(lhs, rhs); - if (!knowledge.hasValue()) + if (!knowledge.has_value()) return None; - ValueKnowledge result = knowledge.getValue(); + ValueKnowledge result = knowledge.value(); Optional optional = meetOptionalKnowledge(lhs.optional, rhs.optional); - if (!optional.hasValue()) + if (!optional.has_value()) return None; - result.optional = optional.getValue(); + result.optional = optional.value(); return result; } @@ -377,6 +377,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< using BaseT = dataflow::SparseDataFlowAnalysis>; using BaseT::SparseDataFlowAnalysis; + void setToEntryState(dataflow::Lattice *lattice) override {} // Compute the knowledge for the results of an op, based on the knowledge of // the operands and any information intrinsic to `op`. @@ -446,6 +447,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< ArrayRef operands); void visitAtenScalarImplicitOp(AtenScalarImplicitOp op, ArrayRef operands); + void visitAtenEmbeddingBagOp(Operation *op); }; } // namespace @@ -517,13 +519,13 @@ updateResultTypeState(const ValueKnowledge *tensor, Optional rankIsNonZero, const torch_upstream::ResultTypeState &inState, bool skipRankCheck = false) { - if (!rankIsNonZero.hasValue() && !skipRankCheck) + if (!rankIsNonZero.has_value() && !skipRankCheck) return torch_upstream::ResultTypeState{}; assert(tensor->dtype && "tensor.dtype must be not none"); torch_upstream::ResultTypeState new_state = inState; torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype); - if (skipRankCheck || rankIsNonZero.getValue()) + if (skipRankCheck || rankIsNonZero.value()) new_state.dimResult = promote_skip_undefined(inState.dimResult, current); else new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current); @@ -651,21 +653,22 @@ void TypeAnalysis::visitOperation(Operation *op, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, - AtenSelectScatterOp, AtenSliceTensorOp, AtenSliceScatterOp, - AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, - AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, - AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, - ValsemVariantAtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, - AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, - AtenTriuOp>(op)) { - incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - return; + AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp, + AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, + AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, + AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, + AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp, + AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, + PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, + AtenRollOp>( + op)) { + return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } // Dtype is always float32, except for bfloat16, float64 and nullptr. - if (isa(op)) { + if (isa(op)) { ValueKnowledge knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Type dtype = operands[0]->getValue().dtype; @@ -711,7 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, + AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( @@ -734,6 +738,23 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + // Dtype is always float32, except for bfloat16, float64 and nullptr after + // promotion and assuming possible-zero rank. + if (isa(op)) { + ValueKnowledge knowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + Type promotedDtype = getPromotedResultType( + op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, + getRankIsNonZeroArray(op->getOperands())); + if (promotedDtype) { + knowledge.dtype = Float32Type::get(op->getContext()); + if (promotedDtype.isa()) + knowledge.dtype = promotedDtype; + } + incorporateKnowledge(op->getResult(0), knowledge); + return; + } + // Promote three dtypes. if (isa(op)) { auto knowledge = @@ -753,7 +774,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote LHS with scalar RHS. if (isa(op)) { + AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) { auto lhs = operands[0]->getValue(); Value scalar = op->getOperand(1); auto knowledge = @@ -931,7 +952,8 @@ void TypeAnalysis::visitOperation(Operation *op, Type dtype = operands[0]->getValue().dtype; visitReductionAlongAllDimsOp(max, dtype, operands); return; - } else if (isa(op)) { + } else if (isa(op)) { auto input = operands[0]->getValue(); visitReductionAlongAllDimsOp(op, input.dtype, operands); return; @@ -1006,6 +1028,11 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + if (auto toDtype = dyn_cast(op)) { + visitAtenToDtypeLikeOp(toDtype, operands); + return; + } + if (auto toOther = dyn_cast(op)) { visitTypeConversionOp(toOther, operands); return; @@ -1035,6 +1062,11 @@ void TypeAnalysis::visitOperation(Operation *op, incorporateKnowledge(embedding.getResult(), knowledge); return; } + + if (isa(op)) { + visitAtenEmbeddingBagOp(op); + return; + } if (auto softmaxIntOp = dyn_cast(op)) { visitAtenSoftmaxLikeOp(softmaxIntOp, operands); @@ -1077,7 +1109,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. - markAllPessimisticFixpoint(results); + setAllToEntryStates(results); return; } @@ -1085,8 +1117,8 @@ void TypeAnalysis::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) { auto updatedKnowledge = ValueKnowledge::meet( knowledge, ValueKnowledge::getPessimisticValueState(v)); - assert(updatedKnowledge.hasValue() && "IR has contradictory type!"); - getLatticeElement(v)->join(updatedKnowledge.getValue()); + assert(updatedKnowledge.has_value() && "IR has contradictory type!"); + getLatticeElement(v)->join(updatedKnowledge.value()); } void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op, @@ -1112,6 +1144,23 @@ void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op, incorporateKnowledge(op->getResult(0), knowledge); } +void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { + auto resultFloatKnowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + resultFloatKnowledge.dtype = Float32Type::get(op->getContext()); + + incorporateKnowledge(op->getResult(0), resultFloatKnowledge); + auto resultIntKnowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + resultIntKnowledge.dtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); + + for (int64_t i = 1, e = op->getNumResults(); i < e; i++) { + incorporateKnowledge(op->getResult(i), resultIntKnowledge); + } + return; +} + // Arange like ops returns a 1-D tensor of size ceil(end - start). void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, llvm::Optional start, @@ -1130,9 +1179,9 @@ void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, // `dtype` is inferred to be the default dtype, see // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to // be `torch.int64` - if ((start.hasValue() && (*start).getType().isa()) || + if ((start.has_value() && (*start).getType().isa()) || end.getType().isa() || - (step.hasValue() && (*step).getType().isa())) { + (step.has_value() && (*step).getType().isa())) { // TODO: Should get the dtype from torch.get_default_dtype(). // For now, use float32 which is the initial default dtype. knowledge.dtype = Float32Type::get(op->getContext()); @@ -1213,6 +1262,12 @@ void TypeAnalysis::visitAtenTensorOp(AtenTensorOp op) { while (auto listType = type.dyn_cast()) { type = listType.getContainedType(); } + // TODO: Support tensor as the contained type of the list. + // These are the only types handled by fillInDTypeGivenDTypeAndDataType below. + if (!type.isa()) { + incorporateKnowledge(op.getResult(), knowledge); + return; + } fillInDTypeGivenDTypeAndDataType(knowledge, dtype, type); incorporateKnowledge(op.getResult(), knowledge); } @@ -1224,7 +1279,7 @@ void TypeAnalysis::visitConstantTensorAllocOp(OpTy op, ValueKnowledge::getTensorPessimisticValueState(op->getContext()); if (!dataType) dataType = Torch::FloatType::get(op->getContext()); - fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue()); + fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.value()); incorporateKnowledge(op.getResult(), knowledge); } @@ -1294,11 +1349,11 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op, })); for (auto tensor : tensors) { auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); - if (!newDtype.hasValue()) { + if (!newDtype.has_value()) { incorporateKnowledge(op.getResult(), knowledge); return; } - knowledge.dtype = newDtype.getValue(); + knowledge.dtype = newDtype.value(); } incorporateKnowledge(op.getResult(), knowledge); } @@ -1372,13 +1427,13 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { }; if (auto tensorType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); return getRefinedTensorType(tensorType, knowledge); } else if (auto optionalType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.optional == OptionalKnowledge::isNone) @@ -1392,7 +1447,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { } } else if (auto scalarType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.kind == torch_upstream::TypeKind::IntType) diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index bda4173631b2..f7a8f69ca355 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3494,6 +3494,142 @@ module { %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> return %6 : !torch.tuple, list, list> } + func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int + %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + %4 = torch.prim.If %arg6 -> (!torch.int) { + torch.prim.If.yield %int1 : !torch.int + } else { + torch.prim.If.yield %int0 : !torch.int + } + %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.append.t %3, %7 : !torch.list, !torch.int -> !torch.list + %9 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %9, %true, init() { + ^bb0(%arg9: !torch.int): + %10 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %11 = torch.prim.If %1 -> (!torch.int) { + %12 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.__getitem__.t %arg5, %12 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %13 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + torch.prim.If %arg6 -> () { + %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int + %16 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %18 = torch.aten.__getitem__.t %arg3, %17 : !torch.list, !torch.int -> !torch.int + %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.__getitem__.t %arg4, %20 : !torch.list, !torch.int -> !torch.int + %22 = torch.aten.mul.int %21, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.sub.int %19, %22 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %14 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } else { + %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int + %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int + %19 = torch.aten.mul.int %18, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int + %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield + } + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %3 : !torch.list + } + func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list { + %true = torch.constant.bool true + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool + %3 = torch.prim.If %2 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool + %5 = torch.prim.If %4 -> (!torch.list) { + %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + torch.prim.If.yield %15 : !torch.list + } else { + %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list + torch.prim.If.yield %15 : !torch.list + } + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %9 = torch.prim.ListConstruct : () -> !torch.list + %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list + %12 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list + %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %14, %true, init() { + ^bb0(%arg8: !torch.int): + %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %16 = torch.prim.If %7 -> (!torch.int) { + %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %33 : !torch.int + } else { + torch.prim.If.yield %int1 : !torch.int + } + %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int + %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int + %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int + %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int + %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int + %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int + %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int + %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int + %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int + %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %9 : !torch.list + } func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { %none = torch.constant.none %str = torch.constant.str "AssertionError: " @@ -4369,70 +4505,92 @@ module { } return %6 : !torch.list } - func.func @__torch__.torch.jit._shape_functions.mean_dim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { - %none = torch.constant.none + func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %false = torch.constant.bool false %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { + %1 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool + } + %3 = torch.prim.If %2 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + torch.prim.Loop %4, %true, init() { ^bb0(%arg4: !torch.int): - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.prim.Loop %2, %true, init(%false) { + %5 = torch.aten.len.t %3 : !torch.list -> !torch.int + %6 = torch.prim.Loop %5, %true, init(%false) { ^bb0(%arg5: !torch.int, %arg6: !torch.bool): - %4 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.int) { + %7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool + %10 = torch.prim.If %9 -> (!torch.int) { torch.prim.If.yield %int1 : !torch.int } else { - torch.prim.If.yield %5 : !torch.int + torch.prim.If.yield %8 : !torch.int } - %8 = torch.aten.neg.int %7 : !torch.int -> !torch.int - %9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.lt.int %4, %8 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { + %11 = torch.aten.neg.int %10 : !torch.int -> !torch.int + %12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool + %14 = torch.prim.If %13 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { - %17 = torch.aten.gt.int %4, %9 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool + %20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %20 : !torch.bool } - %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool - torch.prim.If %12 -> () { + %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool + torch.prim.If %15 -> () { torch.prim.If.yield } else { torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } - %13 = torch.aten.lt.int %4, %int0 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.int) { - %17 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %17 : !torch.int + %16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + %17 = torch.prim.If %16 -> (!torch.int) { + %20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + torch.prim.If.yield %20 : !torch.int } else { - torch.prim.If.yield %4 : !torch.int + torch.prim.If.yield %7 : !torch.int } - %15 = torch.aten.eq.int %arg4, %14 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { + %18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool + %19 = torch.prim.If %18 -> (!torch.bool) { torch.prim.If.yield %true : !torch.bool } else { torch.prim.If.yield %arg6 : !torch.bool } - torch.prim.Loop.condition %true, iter(%16 : !torch.bool) + torch.prim.Loop.condition %true, iter(%19 : !torch.bool) } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - torch.prim.If %3 -> () { + torch.prim.If %6 -> () { torch.prim.If %arg2 -> () { - %4 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } else { torch.prim.If.yield } torch.prim.If.yield } else { - %4 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %0, %4 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } torch.prim.Loop.condition %true, iter() @@ -4442,10 +4600,10 @@ module { func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> { %false = torch.constant.bool false %true = torch.constant.bool true - %int1 = torch.constant.int 1 + %none = torch.constant.none %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %str = torch.constant.str "AssertionError: " - %none = torch.constant.none %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list %1 = torch.prim.ListConstruct : () -> !torch.list %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -5317,6 +5475,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.softplus"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.square"(%arg0: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5333,6 +5495,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.expm1"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5365,6 +5531,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.log1p"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.rsqrt"(%arg0: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5436,6 +5606,10 @@ module { func.func @"__torch_mlir_shape_fn.aten.to.dtype_layout"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.list { return %arg0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.to.device"(%arg0: !torch.list, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.to.other"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5504,6 +5678,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -5552,16 +5730,28 @@ module { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { + func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { + %none = torch.constant.none + %0 = torch.derefine %none : !torch.none to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list + } + func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { %none = torch.constant.none %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { + %none = torch.constant.none + %0 = torch.derefine %none : !torch.none to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { %none = torch.constant.none %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool @@ -5614,25 +5804,14 @@ module { %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list> return %1 : !torch.tuple, list> } - func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { + func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { %0 = torch.derefine %arg3 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %none = torch.constant.none - %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list - } else { - %2 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list - } + %0 = torch.derefine %arg3 : !torch.optional to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { @@ -5820,6 +5999,10 @@ module { } return %6 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.roll"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.expand"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list @@ -6071,6 +6254,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.masked_fill.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list) -> !torch.list { return %arg0 : !torch.list } @@ -6154,6 +6341,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list @@ -6282,14 +6473,26 @@ module { %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.conv_transpose2d.input"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list { + %0 = torch.derefine %arg3 : !torch.list to !torch.optional> + %1 = torch.derefine %arg4 : !torch.list to !torch.optional> + %2 = torch.derefine %arg5 : !torch.list to !torch.optional> + %3 = torch.derefine %arg7 : !torch.list to !torch.optional> + %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list + return %4 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list + %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten._convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list { %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list { + %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { return %arg0 : !torch.list } @@ -6300,6 +6503,14 @@ module { %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.narrow"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int + %1 = torch.derefine %arg2 : !torch.int to !torch.optional + %2 = torch.derefine %0 : !torch.int to !torch.optional + %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list + return %3 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.slice_scatter"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list { return %arg0 : !torch.list } @@ -6322,6 +6533,76 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> { + %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list> + return %0 : !torch.tuple, list, list, list> + } + func.func @__torch__._embedding_bag_helper(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int) -> !torch.tuple, list, list, list> { + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %1 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %3 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int + %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %5 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %6 = torch.prim.ListConstruct : () -> !torch.list + %7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int + %8 = torch.prim.If %arg3 -> (!torch.int) { + %19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int + torch.prim.If.yield %19 : !torch.int + } else { + torch.prim.If.yield %7 : !torch.int + } + %9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int + %10 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list + %11 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list + %12 = torch.prim.ListConstruct : () -> !torch.list + %13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool + %14 = torch.prim.If %13 -> (!torch.list) { + %19 = torch.aten.append.t %12, %int0 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield %12 : !torch.list + } else { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } + %15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list + %16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool + %17 = torch.prim.If %16 -> (!torch.list) { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } else { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } + %18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list, !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list, list> + return %18 : !torch.tuple, list, list, list> + } + func.func @"__torch_mlir_shape_fn.aten._embedding_bag"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple, list, list, list> { + %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list> + return %0 : !torch.tuple, list, list, list> + } func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> { %int-1 = torch.constant.int -1 %true = torch.constant.bool true @@ -6578,30 +6859,30 @@ module { %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list %12 = torch.prim.min.self_int %11 : !torch.list -> !torch.int - %13:3 = torch.prim.Loop %12, %true, init(%true, %int-1, %int-1) { - ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.int): + %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) { + ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int): %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional> %17 = torch.aten.__isnot__ %16, %none : !torch.optional>, !torch.none -> !torch.bool - %18:3 = torch.prim.If %17 -> (!torch.bool, !torch.int, !torch.int) { + %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) { %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool - %20:3 = torch.prim.If %19 -> (!torch.bool, !torch.int, !torch.int) { - torch.prim.If.yield %arg3, %arg2, %arg2 : !torch.bool, !torch.int, !torch.int + %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) { + torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int } else { - %21 = torch.aten.sub.int %arg2, %arg5 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool %23 = torch.prim.If %22 -> (!torch.bool) { torch.prim.If.yield %false : !torch.bool } else { torch.prim.If.yield %arg3 : !torch.bool } - torch.prim.If.yield %23, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int } - torch.prim.If.yield %20#0, %20#1, %20#2 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int } else { - torch.prim.If.yield %arg3, %arg4, %arg5 : !torch.bool, !torch.int, !torch.int + torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int } - torch.prim.Loop.condition %true, iter(%18#0, %18#1, %18#2 : !torch.bool, !torch.int, !torch.int) - } : (!torch.int, !torch.bool, !torch.bool, !torch.int, !torch.int) -> (!torch.bool, !torch.int, !torch.int) + torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int) + } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int) %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool %15 = torch.prim.If %14 -> (!torch.list) { %16 = torch.aten.add.t %6, %4 : !torch.list, !torch.list -> !torch.list @@ -6656,25 +6937,9 @@ module { return %none : !torch.none } func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %0 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %4, %true, init() { - ^bb0(%arg5: !torch.int): - %6 = torch.aten.append.t %5, %arg5 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %5 : !torch.list - } else { - %4 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - torch.prim.If.yield %4 : !torch.list - } - %2 = torch.derefine %arg4 : !torch.optional to !torch.any - %3 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %1, %arg3, %2) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %3 : !torch.list + %0 = torch.derefine %arg4 : !torch.optional to !torch.any + %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list + return %1 : !torch.list } } )mlir"); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index b9a4ea29aede..f8bd58878738 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -418,6 +418,7 @@ class SimplifyShapeCalculationsPass Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); AtenSizeOp::getCanonicalizationPatterns(patterns, context); AtenLenTOp::getCanonicalizationPatterns(patterns, context); + AtenAddTOp::getCanonicalizationPatterns(patterns, context); // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. diff --git a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp b/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp deleted file mode 100644 index 0bbce7320fe3..000000000000 --- a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp +++ /dev/null @@ -1,64 +0,0 @@ -//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" - -#include "mlir/IR/BuiltinOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::torch::Torch; - -static LogicalResult checkValueType(Operation *op, Value value) { - auto isNotValueTensorType = value.getType().isa(); - return isNotValueTensorType - ? op->emitError( - "found a non-value tensor type, this is likely due to a " - "missing case in the MaximizeValueSemantics pass") - : success(); -} - -namespace { -class VerifyConversionToValueSemanticsPass - : public VerifyConversionToValueSemanticsBase< - VerifyConversionToValueSemanticsPass> { - void runOnOperation() override { - bool didFail = false; - auto walkResult = getOperation().walk([&](Block *block) { - for (BlockArgument arg : block->getArguments()) { - if (failed(checkValueType(block->getParentOp(), arg))) { - didFail = true; - return WalkResult::interrupt(); - } - } - - for (Operation &op : *block) { - for (OpResult result : op.getResults()) { - if (failed(checkValueType(&op, result))) { - didFail = true; - return WalkResult::interrupt(); - } - } - } - - return WalkResult::advance(); - }); - - if (didFail || walkResult.wasInterrupted()) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 37ffffabd8fd..6cd6f1e1143d 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char; diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 7e9f9a947ef0..906243e14668 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Bool; if (type.isBF16()) return torch_upstream::ScalarType::BFloat16; + if (type.isF16()) + return torch_upstream::ScalarType::Half; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } @@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType( return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: return mlir::FloatType::getBF16(context); + case torch_upstream::ScalarType::Half: + return mlir::FloatType::getF16(context); default: return Type(); } diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index 3dc36bf53d8b..61dade9408ac 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -21,10 +21,29 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch; +static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) { + if (lhs.hasRank() != rhs.hasRank()) + return false; + bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true; + bool sameElementType = lhs.getElementType() == rhs.getElementType(); + return sameElementType && sameSize; +} + //===----------------------------------------------------------------------===// // ToBuiltinTensorOp //===----------------------------------------------------------------------===// +LogicalResult ToBuiltinTensorOp::verify() { + auto resultType = getResult().getType().cast(); + auto operandType = + getOperand().getType().cast().toBuiltinTensor(); + if (!haveSameSizeAndElementType(resultType, operandType)) { + return emitError() + << "operand and result must have the same size and dtype"; + } + return success(); +} + LogicalResult ToBuiltinTensorOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, @@ -37,6 +56,48 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( return success(); } +void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) { + auto fromBuiltinTensorOp = + op.getOperand().getDefiningOp(); + if (!fromBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not FromBuiltinTensorOp"); + rewriter.replaceOp(op, fromBuiltinTensorOp.getOperand()); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// FromBuiltinTensorOp +//===----------------------------------------------------------------------===// + +LogicalResult FromBuiltinTensorOp::verify() { + auto resultType = + getResult().getType().cast().toBuiltinTensor(); + auto operandType = getOperand().getType().cast(); + if (!haveSameSizeAndElementType(resultType, operandType)) { + return emitError() + << "operand and result must have the same size and dtype"; + } + return success(); +} + +void FromBuiltinTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](FromBuiltinTensorOp op, PatternRewriter &rewriter) { + auto toBuiltinTensorOp = op.getOperand().getDefiningOp(); + if (!toBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not ToBuiltinTensorOp"); + rewriter.replaceOp(op, toBuiltinTensorOp.getOperand()); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// FromI64Op +//===----------------------------------------------------------------------===// + OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { auto attr = operands[0].dyn_cast_or_null(); if (attr) { @@ -46,6 +107,10 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { } } +//===----------------------------------------------------------------------===// +// ToI64Op +//===----------------------------------------------------------------------===// + OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { auto attr = operands[0].dyn_cast_or_null(); if (attr) { @@ -55,5 +120,31 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { } } +//===----------------------------------------------------------------------===// +// ToF64Op +//===----------------------------------------------------------------------===// + +OpFoldResult ToF64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// FromF64Op +//===----------------------------------------------------------------------===// + +OpFoldResult FromF64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + #define GET_OP_CLASSES #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 77f58e53cd4f..f8c3373cbb41 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,8 +1,23 @@ +set(LinkedLibs MLIRIR + MLIRPass + MLIRFuncTransforms + TorchMLIRTorchConversionDialect + TorchMLIRTorchDialect + TorchMLIRTorchPasses + TorchMLIRTorchToLinalg + TorchMLIRTorchToTMTensor + TorchMLIRTorchToArith + TorchMLIRTorchToSCF + MLIRMemRefTransforms) + +if(TORCH_MLIR_ENABLE_MHLO) + list(APPEND LinkedLibs ChloPasses) +endif() + add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp BackendTypeConversionPasses.cpp Passes.cpp - VerifyInvariantsBeforeBackendLowering.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp @@ -17,15 +32,5 @@ add_mlir_library(TorchMLIRTorchConversionPasses Core LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRFuncTransforms - TorchMLIRTorchConversionDialect - TorchMLIRTorchDialect - TorchMLIRTorchPasses - TorchMLIRTorchToLinalg - TorchMLIRTorchToTMTensor - TorchMLIRTorchToStd - TorchMLIRTorchToSCF - MLIRMemRefTransforms + ${LinkedLibs} ) diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 9550e2bba79a..6e1d68a76836 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -17,9 +17,13 @@ #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; @@ -42,39 +46,41 @@ void mlir::torch::registerTorchConversionPasses() { "Pipeline lowering torch backend contract to linalg-on-tensors backend " "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); +#ifdef TORCH_MLIR_ENABLE_MHLO + mlir::PassPipelineRegistration( + "torch-backend-to-mhlo-backend-pipeline", + "Pipeline lowering torch backend contract to MHLO backend " + "contract.", + TorchConversion::createTorchBackendToMhloBackendPipeline); +#endif } void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { - // Check some invariants to catch errors in a clear way. - pm.addPass( - TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); - // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, // (e.g. dimensions which must be constant in a ranked programming model) - // and those constants get somewhat obscured by TorchToStd. + // and those constants get somewhat obscured by TorchToArith. pm.addNestedPass(createConvertTorchToTMTensorPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(createConvertTorchToStdPass()); + pm.addNestedPass(createConvertTorchToArithPass()); pm.addNestedPass(memref::createExpandOpsPass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // Resolve `dim` ops on tensors (which currently live in the `memref` - // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass( - memref::createResolveShapedTypeResultDimsPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // Resolve `dim` ops on tensors (which currently live in the `memref` + // dialect for some reason -- we don't have memrefs at this level). + pm.addNestedPass( + memref::createResolveShapedTypeResultDimsPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Finish the type conversion from `torch` types to the types of the // linalg-on-tensors backend contract. @@ -91,20 +97,14 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( void TorchConversion::createTorchBackendToTosaBackendPipeline( OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { - // Check some invariants to catch errors in a clear way. - pm.addPass( - TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); - pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. @@ -118,3 +118,30 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( // correct form. pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } + +#ifdef TORCH_MLIR_ENABLE_MHLO +void TorchConversion::createTorchBackendToMhloBackendPipeline( + OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + pm.addNestedPass(createConvertTorchToMhloPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(createConvertTorchToArithPass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Convert CHLO ops to MHLO ops + pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // MHLO backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); +} +#endif diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp deleted file mode 100644 index 5e0406ee5c15..000000000000 --- a/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp +++ /dev/null @@ -1,87 +0,0 @@ -//===- VerifyInvariantsBeforeBackendLowering.cpp -----------------*- C++-*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::torch; -using namespace mlir::torch::TorchConversion; -using namespace mlir::torch; - -static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) { - // TODO: Make this an allowlist instead of a denylist. - // TODO: Make this stricter. - auto type = v.getType(); - if (auto valueTensorType = type.dyn_cast()) { - if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes()) - return errorReportOp->emitError() - .append("unsupported by backend lowering: tensor with unknown rank " - "or dtype") - .attachNote() - .append("this is likely due to a missing shape transfer function in " - "shape_lib_gen.py"); - } - return success(); -} - -namespace { - -class VerifyInvariantsBeforeBackendLoweringPass - : public VerifyInvariantsBeforeBackendLoweringBase< - VerifyInvariantsBeforeBackendLoweringPass> { - void runOnOperation() override { - // TODO: It seems that the walkers over blocks are not correctly - // propagating `walkResult.wasInterrupted()` so use a manual `didFail` - // boolean. - bool didFail = false; - getOperation().walk([&](Block *block) { - // Check invariants on all the Value's in the program. - // That is, check all BlockArgument's and OpResult's. - for (BlockArgument arg : block->getArguments()) { - if (failed(checkValueInvariants(block->getParentOp(), arg))) { - didFail = true; - return WalkResult::interrupt(); - } - } - for (Operation &op : *block) { - if (isa(op)) { - op.emitError() - .append("unsupported by backend lowering: `torch.operator` op") - .attachNote() - .append("this is likely due to a missing op that needs to be " - "generated by torch_ods_gen.py"); - didFail = true; - return WalkResult::interrupt(); - } - for (OpResult result : op.getResults()) { - if (failed(checkValueInvariants(&op, result))) { - didFail = true; - return WalkResult::interrupt(); - } - } - } - return WalkResult::advance(); - }); - if (didFail) - return signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> mlir::torch::TorchConversion:: - createVerifyInvariantsBeforeBackendLoweringPass() { - return std::make_unique(); -} diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 9e3fb2298d97..b01d62152b25 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/InitAll.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" @@ -20,6 +21,7 @@ #include "torch-mlir/RefBackend/Passes.h" void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 0a311c38537a..3a218ffe223c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -13,6 +13,30 @@ set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") # We vendor our own MLIR instance in the `torch_mlir` namespace. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") +################################################################################ +# PyTorch +################################################################################ + +option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON) + +if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) + # Source builds + set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) + set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) + set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) + set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) + set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) + set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) + execute_process( + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh + RESULT_VARIABLE _result + ) + if(_result) + message(FATAL_ERROR "Failed to run `build_libtorch.sh`") + endif() + set(TORCH_INSTALL_PREFIX "libtorch") +endif() + ################################################################################ # Sources ################################################################################ @@ -57,26 +81,21 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main ) ################################################################################ -# Optionally handle JIT IR importer. +# Lazy Tensor Core ################################################################################ -option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON) +if(TORCH_MLIR_ENABLE_LTC) + add_subdirectory(torch_mlir/csrc/base_lazy_backend) +endif() +# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it +# generates a dummy Python library when disabled. +add_subdirectory(torch_mlir/csrc/reference_lazy_backend) + +################################################################################ +# Optionally handle JIT IR importer. +################################################################################ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) - # Source builds - set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) - set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) - set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) - execute_process( - COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh - RESULT_VARIABLE _result - ) - if(_result) - message(FATAL_ERROR "Failed to run `build_libtorch.sh`") - endif() - set(TORCH_INSTALL_PREFIX "libtorch") - endif() add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir) add_subdirectory(torch_mlir_e2e_test) endif() @@ -92,8 +111,7 @@ add_subdirectory(torch_mlir/eager_mode) # Required for running the update_torch_ods.sh and update_shape_lib.sh scripts. ################################################################################ -# TODO: renable once it build on macOS Intel / M1 -#add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) +# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) ################################################################################ # Generate packages and shared library @@ -110,8 +128,7 @@ set(_source_components # tree, which seems excessive. MLIRPythonSources MLIRPythonExtension.Core - MLIRPythonExtension.AllPassesRegistration - MLIRPythonExtension.ExecutionEngine + MLIRPythonExtension.RegisterEverything TorchMLIRPythonSources TorchMLIRPythonExtensions ) @@ -137,12 +154,15 @@ add_mlir_python_modules(TorchMLIRPythonModules # Then it would "just work". if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter) + add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind) # Build the E2E Tests (which depend on the JIT IR importer now). add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules) endif() -# TODO: Add after macOS builds are fixed -#add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example) +if(TORCH_MLIR_ENABLE_LTC) + # Add Torch-MLIR LTC backend as dependency + add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) + add_dependencies(TorchMLIRPythonModules reference_lazy_backend) +endif() add_subdirectory(test) - diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index c84f3b6f7e3e..e0b045143366 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Registration.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "torch-mlir-c/Dialects.h" #include "torch-mlir-c/Registration.h" diff --git a/python/test/annotations-sugar.py b/python/test/annotations-sugar.py index 60171daf687d..98cbec74d1c5 100644 --- a/python/test/annotations-sugar.py +++ b/python/test/annotations-sugar.py @@ -7,7 +7,7 @@ import torch -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.annotations import annotate_args, export from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations diff --git a/python/test/compile_api/tracing.py b/python/test/compile_api/tracing.py index 4557e80ed758..bd12f3b2b9ed 100644 --- a/python/test/compile_api/tracing.py +++ b/python/test/compile_api/tracing.py @@ -31,3 +31,24 @@ def forward(self, x): print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> + +# TensorPlaceholder support. +placeholder = torch_mlir.TensorPlaceholder.like( + tanh_example_input, dynamic_axes=[1]) +print(torch_mlir.compile(TanhModule(), [placeholder], + use_tracing=True, ignore_traced_shapes=True)) +# CHECK-LABEL: @forward +# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> + +try: + # CHECK: `ignore_traced_shapes` requires `use_tracing` + torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) +except Exception as e: + print(e) + + +try: + # CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True` + torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) +except Exception as e: + print(e) diff --git a/python/test/lazy_backend/device_data_name.py b/python/test/lazy_backend/device_data_name.py new file mode 100644 index 000000000000..fe596287c6ab --- /dev/null +++ b/python/test/lazy_backend/device_data_name.py @@ -0,0 +1,44 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + + +import torch +import torch._lazy + +import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend + +from run_test import run_test + +lazy_backend._initialize() + +device = "lazy" + + +# CHECK: 0 input tensors found +# ----- +# CHECK: PASS - test_no_device_data_name +@run_test +def test_no_device_data_name(): + x = torch.tensor(1).to(device) + y = torch.tensor(2).to(device) + z = x + y + torch._lazy.mark_step() + + +# CHECK: Input tensor: input_x +# CHECK: 1 input tensors found +# ----- +# CHECK: PASS - test_device_data_name +@run_test +def test_device_data_name(): + x = torch.tensor(1).to(device) + y = torch.tensor(2).to(device) + + lazy_backend.set_parameter_name(x, "input_x") + + z = x + y + torch._lazy.mark_step() diff --git a/python/test/lazy_backend/run_test.py b/python/test/lazy_backend/run_test.py new file mode 100644 index 000000000000..5ef560a475a9 --- /dev/null +++ b/python/test/lazy_backend/run_test.py @@ -0,0 +1,23 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: true + + +def run_test(*args, XPASS=False, XFAIL=False): + def _run_test(test): + test_name = test.__name__ + try: + test() + print(("X" if XPASS else "") + f"PASS - {test_name}") + except Exception as e: + print(("X" if XFAIL else "") + f"FAIL - {test_name}") + print("Errors: ", e) + print(flush=True) + + if len(args): + _run_test(args[0]) + else: + return _run_test diff --git a/python/test/lit.cfg.py b/python/test/lit.cfg.py index 9ec3c7921a52..0bb6760a7a9d 100644 --- a/python/test/lit.cfg.py +++ b/python/test/lit.cfg.py @@ -51,6 +51,9 @@ # directories. config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] +if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))): + config.excludes.append("lazy_backend") + # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/python/test/torchscript_e2e_test/basic.py b/python/test/torchscript_e2e_test/basic.py index 438468c20f31..fa3f6f29729b 100644 --- a/python/test/torchscript_e2e_test/basic.py +++ b/python/test/torchscript_e2e_test/basic.py @@ -7,10 +7,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig class MmModule(torch.nn.Module): diff --git a/python/test/torchscript_e2e_test/compilation_failure.py b/python/test/torchscript_e2e_test/compilation_failure.py index a47750caac70..b53f31c496e3 100644 --- a/python/test/torchscript_e2e_test/compilation_failure.py +++ b/python/test/torchscript_e2e_test/compilation_failure.py @@ -7,10 +7,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig class MmModule(torch.nn.Module): diff --git a/python/test/torchscript_e2e_test/error_reports.py b/python/test/torchscript_e2e_test/error_reports.py index 686522644ede..f3321285999a 100644 --- a/python/test/torchscript_e2e_test/error_reports.py +++ b/python/test/torchscript_e2e_test/error_reports.py @@ -9,10 +9,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig # CHECK: Unexpected outcome summary: # CHECK: FAIL - "ErroneousModule_basic" @@ -118,7 +118,7 @@ def test_recursive(self): # CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch" # CHECK-NEXT: @ output of call to "test_tensor_value_mismatch" - # CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5) + # CHECK-NEXT: ERROR: value (Tensor with shape=[3], dtype=torch.float32, min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3], dtype=torch.float32, min=+1.5, max=+3.5, mean=+2.5) @torch.jit.export def test_tensor_value_mismatch(self): if torch.jit.is_scripting(): diff --git a/python/test/torchscript_e2e_test/non_tensor_values.py b/python/test/torchscript_e2e_test/non_tensor_values.py index b5edc83351a5..a1c8c5adfdf4 100644 --- a/python/test/torchscript_e2e_test/non_tensor_values.py +++ b/python/test/torchscript_e2e_test/non_tensor_values.py @@ -9,10 +9,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig class NonTensorValuesModule(torch.nn.Module): diff --git a/python/test/torchscript_e2e_test/runtime_failure.py b/python/test/torchscript_e2e_test/runtime_failure.py index 9bdd602ad223..3581c1b6d41f 100644 --- a/python/test/torchscript_e2e_test/runtime_failure.py +++ b/python/test/torchscript_e2e_test/runtime_failure.py @@ -7,10 +7,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig class MmModule(torch.nn.Module): diff --git a/python/test/torchscript_e2e_test/submodule.py b/python/test/torchscript_e2e_test/submodule.py index fa157b531614..c88ad53b31b3 100644 --- a/python/test/torchscript_e2e_test/submodule.py +++ b/python/test/torchscript_e2e_test/submodule.py @@ -7,10 +7,10 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import run_tests, TestUtils -from torch_mlir_e2e_test.torchscript.reporting import report_results -from torch_mlir_e2e_test.torchscript.registry import register_test_case, GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.configs import TorchScriptTestConfig +from torch_mlir_e2e_test.framework import run_tests, TestUtils +from torch_mlir_e2e_test.reporting import report_results +from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.configs import TorchScriptTestConfig class Submodule2(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 79b6c4512894..046c2fd44ae0 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -6,11 +6,14 @@ from typing import Sequence, Union, List from enum import Enum +import sys +from io import StringIO + import torch from torch_mlir.passmanager import PassManager from .compiler_utils import run_pipeline_with_repro_report -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder +from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder class OutputType(Enum): @@ -28,21 +31,25 @@ class OutputType(Enum): # This output type consists of `torch` dialect ops that have been converted # maximally to value semantics, decomposed, and shapes have been inferred. - TORCH = 0 - - # This output type consists of `tosa` dialect ops. It can be thought of - # as taking the `TORCH` output type and lowering it to TOSA. - TOSA = 1 + TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and # `arith` ops (and also `math` and `tm_tensor`). It can be thought of # as taking the `TORCH` output type and lowering it so that tensor # computations are done with `linalg`-on-tensors ops. - LINALG_ON_TENSORS = 2 + LINALG_ON_TENSORS = "linalg-on-tensors" + + # This output type consists of `tosa` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to TOSA. + TOSA = "tosa" + + # This output type consists of `mhlo` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to MHLO. + MHLO = "mhlo" # Raw output of the JIT IR importer. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. - RAW = 3 + RAW = "raw" @staticmethod def get(spec: Union[str, "OutputType"]) -> "OutputType": @@ -112,13 +119,28 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): return TensorPlaceholder(shape, tensor.dtype) +# The set of ops that are considered legal for each backend. +# These are currently quite load-bearing, since different backends might be +# missing patterns for decomposed forms of certain ops. +# TODO: Tighten up the definition of these "conditionally legal for backends" +# ops in the backend contract, and move these lists somewhere deeper in the +# compiler where each backend can "own" its set of legal ops. +BACKEND_LEGAL_OPS = { + OutputType.TOSA: ['torch.aten.flatten.using_ints',], + OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], + OutputType.MHLO: [], +} + + _example_arg = Union[TensorPlaceholder, torch.Tensor] def compile(model: torch.nn.Module, example_args: Union[_example_arg, Sequence[_example_arg]], output_type: Union[str, "OutputType"] = OutputType.TORCH, - use_tracing=False): + use_tracing: bool = False, + ignore_traced_shapes = False, + verbose: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -132,12 +154,22 @@ def compile(model: torch.nn.Module, details. use_tracing: If True, use `torch.jit.trace` to convert the model to JIT IR rather than `torch.jit.script`. + ignore_traced_shapes: If True, ignore the shapes that were observed + during tracing. This should only be used if one knows that the + original traced program would result in the same trace (modulo + shapes) for all shape combinations implied by any + `TensorPlaceholder`'s used as `example_args`. Also, + strictly-speaking, this option covers dtypes too, but we just say + "shapes" to be succinct. + verbose: If true, print extra information about the conversion. Returns: An MLIR module that contains the converted model in the specified output type. """ output_type = OutputType.get(output_type) + if ignore_traced_shapes and not use_tracing: + raise Exception("`ignore_traced_shapes` requires `use_tracing`") # Special case -- many models have just one input, so canonicalize a single # tensor to a list of a single tensor to make the API more ergonomic. @@ -147,10 +179,29 @@ def compile(model: torch.nn.Module, # TODO: Don't hardcode "forward". See `torch.onnx.export` and # `torch.jit.trace_module` for API inspiration. if use_tracing: - scripted = torch.jit.trace(model, tuple(example_args)) + example_args_for_trace = [] + for arg in example_args: + if isinstance(arg, TensorPlaceholder): + if not ignore_traced_shapes: + # To avoid accidental footguns, we require + # `ignore_traced_shapes` to be true if we're using + # TensorPlaceholder's, as it falls into the same + # "hopefully the trace works for different inputs" bucket + # of concerns. + raise Exception( + "TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`") + # For any dynamic dimensions, replace them with "7" arbitrarily. + # If a user is using dynamic dimensions with tracing, they are + # walking on thin ice already -- assume they know what they are + # doing. + shape = [s if s != -1 else 7 for s in arg.shape] + example_args_for_trace.append( + torch.ones(*shape, dtype=arg.dtype)) + else: + example_args_for_trace.append(arg) + scripted = torch.jit.trace(model, tuple(example_args_for_trace)) else: scripted = torch.jit.script(model) - # Convert all concrete inputs to TensorPlaceholder's, for consistency. arg_placeholders = [] for arg in example_args: @@ -171,14 +222,38 @@ def compile(model: torch.nn.Module, scripted._c._type(), ["forward"], forward_annotation) mb = ModuleBuilder() - mb.import_module(scripted._c, class_annotator) - + import_options = ImportOptions() + import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes + try: + original_stderr = sys.stderr + sys.stderr = StringIO() + # Import the TorchScript module to MLIR + mb.import_module(scripted._c, class_annotator, import_options) + except Exception as e: + raise Exception(f""" +PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: +### Importer C++ Exception: +{e} +### Importer Diagnostics: +{sys.stderr.getvalue()} +""") from None + finally: + sys.stderr = original_stderr if output_type == OutputType.RAW: return mb.module - run_pipeline_with_repro_report(mb.module, - "torchscript-module-to-torch-backend-pipeline", - "Lowering TorchScript IR -> Torch Backend IR") + backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + run_pipeline_with_repro_report( + mb.module, + f"torchscript-module-to-torch-backend-pipeline{option_string}", + "Lowering TorchScript IR -> Torch Backend IR", + ) + + if verbose: + print("\n====================") + print("Torch Backend IR") + print(mb.module) if output_type == OutputType.TORCH: return mb.module @@ -188,6 +263,10 @@ def compile(model: torch.nn.Module, mb.module, "torch-backend-to-tosa-backend-pipeline", "Lowering Torch Backend IR -> TOSA Backend IR") + if verbose: + print("\n====================") + print("TOSA Backend IR") + print(mb.module) return mb.module if output_type == OutputType.LINALG_ON_TENSORS: @@ -195,6 +274,20 @@ def compile(model: torch.nn.Module, mb.module, "torch-backend-to-linalg-on-tensors-backend-pipeline", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + if verbose: + print("\n====================") + print("LINALG Backend IR") + print(mb.module) return mb.module + elif output_type == OutputType.MHLO: + run_pipeline_with_repro_report( + mb.module, + "torch-backend-to-mhlo-backend-pipeline", + "Lowering Torch Backend IR -> MHLO Backend IR") + if verbose: + print("\n====================") + print("MHLO Backend IR") + print(mb.module) + return mb.module raise Exception(f"Unknown OutputType: {output_type}") diff --git a/python/torch_mlir/_torch_mlir_custom_op_example/CMakeLists.txt b/python/torch_mlir/_torch_mlir_custom_op_example/CMakeLists.txt index fdd5997f4828..a5a011b5543a 100644 --- a/python/torch_mlir/_torch_mlir_custom_op_example/CMakeLists.txt +++ b/python/torch_mlir/_torch_mlir_custom_op_example/CMakeLists.txt @@ -1,5 +1,5 @@ # Setup PyTorch -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../dialects/torch/importer/jit_ir/cmake/modules") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules") include(TorchMLIRPyTorch) TorchMLIRProbeForPyTorchInstall() find_package(Torch 1.8 REQUIRED) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/cmake/modules/TorchMLIRPyTorch.cmake b/python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake similarity index 100% rename from python/torch_mlir/dialects/torch/importer/jit_ir/cmake/modules/TorchMLIRPyTorch.cmake rename to python/torch_mlir/cmake/modules/TorchMLIRPyTorch.cmake diff --git a/python/torch_mlir/csrc/.clang-format b/python/torch_mlir/csrc/.clang-format new file mode 100644 index 000000000000..e71b6c2c8771 --- /dev/null +++ b/python/torch_mlir/csrc/.clang-format @@ -0,0 +1,4 @@ +BasedOnStyle: LLVM +AlignAfterOpenBracket: AlwaysBreak # BlockIndent +PointerAlignment: Left +ReflowComments: false diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt new file mode 100644 index 000000000000..a872c802ea8e --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -0,0 +1,137 @@ +#------------------------------------------------------------------------------- +# Setup PyTorch/LTC +#------------------------------------------------------------------------------- + + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") +include(TorchMLIRPyTorch) + +TorchMLIRProbeForPyTorchInstall() +if(TORCH_MLIR_USE_INSTALLED_PYTORCH) + TorchMLIRConfigurePyTorch() +else() + set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch") +endif() + +find_package(Torch 1.11 REQUIRED) + +set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen) + +include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${Python3_INCLUDE_DIRS} + ${PROJECT_SOURCE_DIR}/python +) +link_directories("${TORCH_INSTALL_PREFIX}/lib") + +set(LTC_GENERATED + generated/LazyNativeFunctions.cpp + generated/RegisterLazy.cpp + generated/shape_inference.cpp +) +set(LTC_BACKEND_DEPENDS + mlir_lowering_context.cpp + mlir_native_functions.cpp + mlir_node_lowering.cpp + shape_inference.cpp +) + +# Generate Lazy IR Nodes + +add_custom_command( + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/build_tools/autogen_ltc_backend.py -b ${TORCH_MLIR_BINARY_DIR} + OUTPUT + ${TORCH_MLIR_BINARY_DIR}/generated_backend.hash + ${LTC_GENERATED} + DEPENDS + ${PROJECT_SOURCE_DIR}/build_tools/autogen_ltc_backend.py + ${PROJECT_SOURCE_DIR}/build_tools/autogen_ltc_backend.yaml + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td + ${LTC_BACKEND_DEPENDS} + ${TORCHGEN_DIR}/gen_backend_stubs.py + ${TORCHGEN_DIR}/gen_lazy_tensor.py + ${TORCHGEN_DIR}/api/lazy.py + ${TORCHGEN_DIR}/dest/lazy_ir.py + COMMENT "Generating Lazy Tensor Core IR Nodes" +) +add_custom_target( + torch_mlir_ltc_backend_generated ALL + DEPENDS + ${TORCH_MLIR_BINARY_DIR}/generated_backend.hash + ${LTC_GENERATED} +) + +add_library(torch_mlir_ltc_backend SHARED + ${LTC_GENERATED} + ${LTC_BACKEND_DEPENDS} + backend_impl.cpp + dynamic_ir.cpp + mlir_node.cpp + ops/device_data.cpp + ops/generic.cpp +) +target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) + +add_dependencies(torch_mlir_ltc_backend + TorchMLIRJITIRImporter + torch_mlir_ltc_backend_generated +) +target_link_libraries(torch_mlir_ltc_backend + TorchMLIRAggregateCAPI + TorchMLIRJITIRImporter + ${TORCH_LIBRARIES} +) + +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") +set_target_properties(torch_mlir_ltc_backend PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + OUTPUT_NAME lib_torch_mlir_ltc + PREFIX "" + SUFFIX ".so" + CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" + LINK_FLAGS "-rdynamic" +) + +# Copy header files into python package + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND mkdir -p + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND cp + ${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/*.h + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND cp + ${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/generated/*.h + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND mkdir -p + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND cp + ${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/ops/*.h + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND mkdir -p + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/) + +add_custom_command( + TARGET torch_mlir_ltc_backend POST_BUILD + COMMAND cp + ${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/utils/*.h + ${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/) diff --git a/python/torch_mlir/csrc/base_lazy_backend/README.md b/python/torch_mlir/csrc/base_lazy_backend/README.md new file mode 100644 index 000000000000..5041986691ce --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/README.md @@ -0,0 +1,26 @@ +# Torch-MLIR Lazy Tensor Core Backend + +## Detailed Documentation + +Detailed documentation about the architecture of this LTC backend is available [here](../../../../docs/ltc_backend.md). + +## Summary + +Contained within this directory are the components that implements the +Torch-MLIR LTC backend. Note that the code style for LTC components is +consistent with that of LTC itself, rather than the rest of Torch-MLIR. + +The components are subclasses of the backend API interface classes found under +[torch/csrc/lazy/backend](https://github.com/pytorch/pytorch/tree/master/torch/csrc/lazy/backend). + +Importantly, the subclasses are still abstract classes. Pure virtual methods +such as `Compile` were purposefully not overriden as Torch-MLIR does not know +how to compile the model for the target hardware. + +The intent is that vendor hardware specific plugins will subclass the Torch-MLIR +backend classes and override the remaining pure virtual functions to complete +the backend. + +The Torch-MLIR LTC backend's job is to perform the lowering from ATen to MLIR. A +hardware vendor's backend job is to take care of the actual compile and +execution of the lowered MLIR. diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp new file mode 100644 index 000000000000..38b804c22421 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp @@ -0,0 +1,204 @@ +//===- backend_impl.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include "backend_impl.h" +#include "ir_builder.h" +#include "mlir_lowering_context.h" +#include "ops/device_data.h" +#include "utils/debug.h" +#include "utils/exception.h" + +namespace torch { +namespace lazy { + +TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) + : BackendData(device, shape), + info_(std::make_unique()) { + PRINT_FUNCTION(); +} +TorchMlirBackendData::TorchMlirBackendData( + const at::Scalar& scalar, BackendDevice device) + : BackendData(device, Shape(scalar.type(), {})), + info_(std::make_unique(scalar)) { + PRINT_FUNCTION(); +} +TorchMlirBackendData::TorchMlirBackendData( + const at::Tensor& tensor, BackendDevice device, Shape shape) + : BackendData(device, shape), + info_(std::make_unique(tensor)) { + PRINT_FUNCTION(); +} + +BackendData::Handle TorchMlirBackendData::GetHandle() { + return reinterpret_cast(this); +} + +void TorchMlirBackendData::Assign(const BackendData& data) { + const TorchMlirBackendData* torch_mlir_data = + dynamic_cast(&data); + TORCH_CHECK( + torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + + TorchMlirBackendData::Info* info = + dynamic_cast(torch_mlir_data->mlir_info()); + TORCH_CHECK( + info, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + + info_ = std::make_unique(*info); +} + +bool TorchMlirBackendData::HasValue() const { return bool(info_); } + +TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const { + return info_.get(); +} + +/** + * Initialization/Teardown + * */ +void TorchMlirBackendImpl::PrepareToExit() const {} + +/** + * IR Tracing + * */ + +const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { + static const IrBuilder* builder = new TorchMlirIrBuilder(); + return builder; +} + +/** + * Data Transfer + * */ + +BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor( + const at::Tensor& tensor, const Shape& shape, + const BackendDevice& device) const { + PRINT_FUNCTION(); + return std::make_shared(tensor, device, shape); +} + +BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar( + const at::Scalar& scalar, const BackendDevice& device) const { + PRINT_FUNCTION(); + return std::make_shared(scalar, device); +} + +BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder( + const BackendDevice& device, const Shape& shape) const { + PRINT_FUNCTION(); + return std::make_shared(device, shape); +} + +BackendDataPtr +TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const { + PRINT_FUNCTION(); + auto* device_data_node = dynamic_cast(node); + if (!device_data_node) { + return nullptr; + } + return device_data_node->data(); +} + +at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( + const BackendDataPtr data, + c10::optional logical_scalar_type) const { + PRINT_FUNCTION(); + + TorchMlirBackendData* torch_mlir_data = + dynamic_cast(data.get()); + TORCH_CHECK( + torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + + TorchMlirBackendData::Info* info = + dynamic_cast(torch_mlir_data->mlir_info()); + TORCH_CHECK( + info, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + + return info->tensor; +} + +/** + * Lowering, Compilation, Execution + * */ + +std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( + const std::string& name, BackendDevice device, + c10::ArrayRef post_order, Util::EmissionMap emit_status) const { + PRINT_FUNCTION(); + return std::make_unique( + name, std::forward(device), + std::forward>(post_order), + std::forward(emit_status)); +} + +std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( + const std::string& name, BackendDevice device) const { + PRINT_FUNCTION(); + return std::make_unique( + name, std::forward(device)); +} + +/** + * Device Configuration + * */ + +// Set or get the default device type. +// For backends used with virtual c10:: Devices, this configures what real +// device type the backend should use, and matters if the backend supports +// more than one type of real device. + +// Specify which aten device should be used for eager fallback +// may change depending on current 'Default' DeviceType +at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const { + PRINT_FUNCTION(); + return at::DeviceType::CPU; +} + +// Query all available backend devices +std::vector TorchMlirBackendImpl::GetBackendDevices() const { + PRINT_FUNCTION(); + return { + GetBackendDevice(c10::Device(c10::kLazy, 0)), + GetBackendDevice(c10::Device(c10::kCPU, 0))}; +} + +// Map a particular c10:: device to a concrete backend device +// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are +// virtual devices, meaning they may map to a gpu, tpu, etc. behind the +// scenes. In the future, non-virtual c10:: devices may also use lazy tensors +// through a mode, in which case these APIs should still work, but should be +// identity mappings. +BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const { + PRINT_FUNCTION(); + return BackendDevice(GetDefaultDeviceType(), device.index()); +} + +int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const { + return default_device_ordinal; +} + +void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) { + default_device_ordinal = ordinal; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h new file mode 100644 index 000000000000..70008bf216e8 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h @@ -0,0 +1,186 @@ +//===- backend_impl.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// The Torch-MLIR backend class API that handles lowering LTC ATen ops to MLIR +// using the Torch-MLIR ATen dialect +// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include + +namespace torch { +namespace lazy { + +class TORCH_API TorchMlirBackendData : public BackendData { +public: + struct Info : public BackendData::Info { + at::Tensor tensor; + c10::optional scalar; + bool requires_grad; + std::string name; + + Info() {} + Info(const Info& other) + : tensor{other.tensor}, scalar{other.scalar}, + requires_grad{other.requires_grad}, name{other.name} {} + Info(const at::Tensor& tensor) + : tensor{tensor}, requires_grad{tensor.requires_grad()} { + static int num_tensors = 0; + std::ostringstream oss; + oss << "tensor" << num_tensors; + this->name = oss.str(); + ++num_tensors; + } + Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {} + }; + + TorchMlirBackendData(BackendDevice device, Shape shape); + TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); + TorchMlirBackendData( + const at::Tensor& tensor, BackendDevice device, Shape shape); + + virtual BackendData::Handle GetHandle() override; + + virtual void Assign(const BackendData& data) override; + + virtual bool HasValue() const override; + + TorchMlirBackendData::Info* mlir_info() const; + +private: + std::unique_ptr info_; +}; + +class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { +public: + virtual ~TorchMlirBackendImpl() = default; + + /** + * Initialization/Teardown + * */ + virtual void PrepareToExit() const override; + + /** + * IR Tracing + * */ + + const IrBuilder* GetIrBuilder() const override; + + /** + * Configuration + * */ + // virtual void SetRngSeed(size_t seed) const = 0; + + /** + * Data Transfer + * */ + + virtual BackendDataPtr MakeComputationDataFromTensor( + const at::Tensor& tensor, const Shape& shape, + const BackendDevice& device) const override; + + virtual BackendDataPtr MakeComputationDataFromScalar( + const at::Scalar& scalar, const BackendDevice& device) const override; + + virtual BackendDataPtr CreateDataPlaceholder( + const BackendDevice& device, const Shape& shape) const override; + + // Gets backend data if the node is a device data node. Otherwise returns + // nullptr. + virtual BackendDataPtr GetComputationDataFromNode(Node*) const override; + + virtual at::Tensor MakeTensorFromComputationData( + const BackendDataPtr data, + c10::optional logical_scalar_type) const override; + + /** + * Lowering, Compilation, Execution + * */ + + virtual std::unique_ptr CreateLoweringContext( + const std::string& name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const override; + + virtual std::unique_ptr CreateLoweringContext( + const std::string& name, BackendDevice device) const override; + + // TODO(whc) need to keep this? + // virtual std::vector GetCompilationDevices( + // const std::string& device, c10::ArrayRef devices + // ) const = 0; + + // virtual std::vector Compile( + // std::vector instances + // ) const = 0; + + // virtual std::vector ExecuteComputation( + // Computation& computation, + // c10::ArrayRef arguments, + // const BackendDevice& device + // ) const = 0; + + /** + * Device Configuration + * */ + + // Set or get the default device type. + // For backends used with virtual c10:: Devices, this configures what real + // device type the backend should use, and matters if the backend supports + // more than one type of real device. + + // virtual std::shared_ptr GetDefaultDeviceType() const = + // 0; + // virtual void SetDefaultDeviceType(std::string device_type) = 0; + + // Specify which aten device should be used for eager fallback + // may change depending on current 'Default' DeviceType + virtual at::DeviceType EagerFallbackDeviceType() const override; + + // Query all available backend devices + virtual std::vector GetBackendDevices() const override; + + // Map a particular c10:: device to a concrete backend device + // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are + // virtual devices, meaning they may map to a gpu, tpu, etc. behind the + // scenes. In the future, non-virtual c10:: devices may also use lazy tensors + // through a mode, in which case these APIs should still work, but should be + // identity mappings. + virtual BackendDevice GetBackendDevice(c10::Device device) const override; + + virtual int64_t GetDefaultDeviceOrdinal() const override; + + virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override; + + /** + * Debug/Metrics + * */ + + // virtual std::map GetMetrics() const = 0; + + // virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; + + // virtual std::string GetComputationBackendText( + // const ComputationPtr computation + // ) const = 0; + +protected: + int64_t default_device_ordinal = 0; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp b/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp new file mode 100644 index 000000000000..ca6d80f1f419 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.cpp @@ -0,0 +1,74 @@ +//===- dynamic_ir.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.cpp +//===----------------------------------------------------------------------===// + +#include "dynamic_ir.h" + +namespace torch { +namespace lazy { + +DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed) + : TorchMlirNode( + op, operands, /*num_outputs=*/1, + /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} + +std::string DimensionNode::ToString() const { return "DimensionNode"; } + +SizeNode::SizeNode(Value input, size_t dim) + : DimensionNode( + OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, + MHash(dim)), + dim_(dim){}; + +int64_t SizeNode::getStaticValue() const { + return dynamic_cast(operand(0).node) + ->shape(0) + .size(dim_); +} + +std::string SizeNode::ToString() const { return "SizeNode"; } + +SizeAdd::SizeAdd(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; + +int64_t SizeAdd::getStaticValue() const { + return dynamic_cast(operand(0).node)->getStaticValue() + + dynamic_cast(operand(1).node)->getStaticValue(); +} + +std::string SizeAdd::ToString() const { return "SizeAdd"; } + +SizeMul::SizeMul(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; + +int64_t SizeMul::getStaticValue() const { + return dynamic_cast(operand(0).node)->getStaticValue() * + dynamic_cast(operand(1).node)->getStaticValue(); +} + +std::string SizeMul::ToString() const { return "SizeMul"; } + +SizeDiv::SizeDiv(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){}; + +int64_t SizeDiv::getStaticValue() const { + TORCH_CHECK( + dynamic_cast(operand(1).node)->getStaticValue() != + 0, + "Can't divide a dimension by zero"); + return dynamic_cast(operand(0).node)->getStaticValue() / + dynamic_cast(operand(1).node)->getStaticValue(); +} + +std::string SizeDiv::ToString() const { return "SizeDiv"; } + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.h b/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.h new file mode 100644 index 000000000000..97bad3c15b2c --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/dynamic_ir.h @@ -0,0 +1,99 @@ +//===- dynamic_ir.h -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/dynamic_ir.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlir_node.h" +#include +#include +#include +#include +#include + +C10_DECLARE_bool(ltc_enable_dynamic_shapes); + +namespace torch { +namespace lazy { + +/** + * The goal of "dynamic" Nodes is to patch a hole in our tracing. + * Previously, if a user called `sizes` on a Tensor, it would leak out + * of our tracing system, as `sizes` returns a torch.Size or an int. To + * prevent this from happening, we introduce DimensionNode, a new type + * of Node that abstracts the operation of getting the dimensions of a + * Tensor. + * + * Consider the following example: + * ``` + * numel = x.shape()[0] * x.shape()[1] + * ``` + * + * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode), + * and the multiplication of the two SizeNodes will be represented by + * a SizeMul (also a subclass of DimensionNode). Through this, we can + * prevent `numel` from being represented as a Python int and thus + * burned into the Graph. + */ + +class TORCH_API DimensionNode : public lazy::TorchMlirNode { +public: + DimensionNode(OpKind op, OpList operands, hash_t hash_seed = kHashSeed); + bool isDynamic() { return false; } + + std::string ToString() const override; + + virtual int64_t getStaticValue() const = 0; +}; + +// Represents the result of calling `size` on a Tensor +class TORCH_API SizeNode : public DimensionNode { +public: + SizeNode(Value input, size_t dim); + int64_t getStaticValue() const override; + std::string ToString() const override; + size_t dim_ = 0; +}; + +class TORCH_API SizeAdd : public DimensionNode { +public: + SizeAdd(Value a, Value b); + int64_t getStaticValue() const override; + std::string ToString() const override; +}; + +class TORCH_API SizeMul : public DimensionNode { +public: + SizeMul(Value a, Value b); + int64_t getStaticValue() const override; + std::string ToString() const override; +}; + +class TORCH_API SizeDiv : public DimensionNode { +public: + SizeDiv(Value a, Value b); + int64_t getStaticValue() const override; + std::string ToString() const override; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h b/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h new file mode 100644 index 000000000000..d1b8dd08a51a --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h @@ -0,0 +1,66 @@ +//===- ir_builder.h -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ir_builder.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include "dynamic_ir.h" +#include "generated/LazyNonNativeIr.h" +#include "mlir_node.h" +#include "ops/device_data.h" +#include "ops/generic.h" +#include "utils/exception.h" + +// This file contains the TorchMlir IrBuilder + +namespace torch { +namespace lazy { + +// clang-format off + +struct TorchMlirIrBuilder : IrBuilder { + NodePtr MakeDeviceData(const std::shared_ptr& data) const override { return MakeNode(data); } + NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode(value, type); } + NodePtr MakeExpand(const Value& input0, const std::vector& size, const bool& is_scalar_expand) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeView(const Value& input0, const std::vector& output_size) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional& stype = c10::nullopt) const override { return MakeNode(input0, dtype, stype); } + NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode(inputs); } + NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const override { return MakeNode(op, operands, shape, num_outputs, hash_seed); } + + // view ops + NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeAsStrided(const Value& input0, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector& base_indices) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeNarrow(const Value& input0, const std::vector& base_indices, const std::vector& sizes) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakePermute(const Value& input0, const std::vector& dims) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeResize(const Value& input0, const std::vector& size) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeSqueeze(const Value& input0, const int& dim) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override { UNIMPLEMENTED_FUNCTION_ERROR(); } + + // dynamic ir nodes + NodePtr MakeSizeNode(const Value& input, size_t dim) const override { return MakeNode(input, dim); } + NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { return MakeNode(a, b); } + NodePtr MakeSizeMul(const Value& a, const Value& b) const override { return MakeNode(a, b); } + NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { return MakeNode(a, b); } +}; + +// clang-format on + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp new file mode 100644 index 000000000000..fdef6271965a --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -0,0 +1,361 @@ +//===- mlir_lowering_context.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +//===----------------------------------------------------------------------===// + +#include + +#include +#include +#include + +#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" +#include "backend_impl.h" +#include "mlir_lowering_context.h" +#include "mlir_node.h" +#include "torch-mlir-c/Registration.h" +#include "utils/debug.h" +#include "utils/exception.h" + +namespace torch { +namespace lazy { + +/////////////////////////////////////////////////////////////////////////////// +// TorchMlir Lowering Context +/////////////////////////////////////////////////////////////////////////////// + +TorchMlirLoweringContext::TorchMlirLoweringContext( + const std::string& name, BackendDevice device) + : LoweringContext(name, std::forward(device)), + graph_(std::make_shared()), + function_( + std::make_shared(name, graph_, nullptr)), + mlir_context_(mlirContextCreate()) { + RegisterMlirDialects(); +} + +TorchMlirLoweringContext::TorchMlirLoweringContext( + const std::string& name, BackendDevice device, + c10::ArrayRef post_order, Util::EmissionMap emit_status) + : LoweringContext( + name, std::forward(device), + std::forward>(post_order), + std::forward(emit_status)), + graph_(std::make_shared()), + function_( + std::make_shared(name, graph_, nullptr)), + mlir_context_(mlirContextCreate()) { + RegisterMlirDialects(); + + for (auto node : post_order) { + Lower(node); + } +} + +void TorchMlirLoweringContext::Lower(const Node* node) { + if (auto* torch_mlir_node = + dynamic_cast(node)) { + TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); + CHECK(!ops.empty()) << "Failed to lower: " << *node; + TORCH_CHECK_EQ(node->num_outputs(), ops.size()); + for (size_t i = 0; i < ops.size(); ++i) { + AssignOutputOp(torch::lazy::Output(node, i), ops[i]); + } + } else { + throw std::runtime_error( + "Expected torch::lazy::TorchMlirNode but could not dynamic cast"); + } +} + +void TorchMlirLoweringContext::SetUpAlias( + const std::vector& output_index, int64_t param_number, + const std::vector& param_index, bool must_alias) { + input_output_aliases_.push_back( + {output_index, param_number, param_index, must_alias}); +} + +bool TorchMlirLoweringContext::CheckResultShape( + const BackendDataPtr& parameter_data, size_t result_idx) { + TORCH_CHECK( + result_idx < root_tuple_.size(), "Tried getting result shape at index ", + result_idx, " which is out of bounds!"); + + torch::jit::Value* output = root_tuple_[result_idx]; + + if (c10::TensorTypePtr tensor_type = + output->type()->cast()) { + auto scalar_type = tensor_type->scalarType(); + auto sizes = tensor_type->sizes().concrete_sizes(); + + // Not guaranteed to have concrete size, so we need to check it exists. + if (scalar_type && sizes) { + return Shape(parameter_data->shape()) == + Shape(scalar_type.value(), c10::ArrayRef(sizes.value())); + } + } + + return false; +} + +size_t TorchMlirLoweringContext::AddResult(const Output& output) { + PRINT_FUNCTION(); + + return AddResult(GetOutputOp(output)); +} + +// Associates the given output with the input parameter of the given index and +// shape. Only used for the operator-by-operator execution, mostly for +// debugging purposes. +void TorchMlirLoweringContext::AddParameter( + const torch::lazy::Output& output, size_t index, + const torch::lazy::Shape& shape, const std::string& name) { + UNIMPLEMENTED_FUNCTION_ERROR(); +} + +// Build the computation capturing all the operations created with the +// embedded builder (returned by the builder() API). +ComputationPtr TorchMlirLoweringContext::Build() { + PRINT_FUNCTION(); + + // Since we mutated the types of some nodes to insert shape information, we + // must perform this pass to ensure tuples have up to date output types. + torch::jit::RefineTupleTypes(graph_); + + // Insert return values into graph. + for (torch::jit::Value* output : root_tuple_) { + graph_->block()->registerOutput(output); + } + + // Generate MLIR. + MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp( + /*context=*/mlir_context_, + /*function=*/generate_jit_fn().get(), + /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, + /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); + + return std::make_shared( + func_op, mlir_context_, graph_, input_output_aliases_); +} + +torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { + PRINT_FUNCTION(); + + auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { + auto post_order = Util::ComputePostOrder(output.node, &emit_status_); + for (auto node : post_order) { + Lower(node); + } + // At this point the output better be present, otherwise there is an issue + // with the lowering code. + it = emitted_outputs_.find(output); + TORCH_CHECK( + it != emitted_outputs_.end(), + "No MLIR operation emitted for output: ", output.ToString()); + } + return it->second; +} + +void TorchMlirLoweringContext::AssignOutputOp( + const Output& output, torch::jit::Value* op) { + PRINT_FUNCTION(); + + // TODO (antoniojkim): Do we need this? + // auto torch_mlir_node = + // NodeCast(output.node, output.node->op()); + // if (!torch_mlir_node->getPythonStacktrace().empty()) { + // op->node()->s_( + // c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace()); + // } + emitted_outputs_[output] = std::move(op); +} + +torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { + PRINT_FUNCTION(); + + if (!dynamic_cast(data.get())) { + TORCH_CHECK( + false, + "Expected TorchMlirBackendData. Got some other BackendData type"); + } + const auto mlir_data = std::static_pointer_cast(data); + + BackendData::Handle handle = mlir_data->GetHandle(); + auto it = parameters_map_.find(handle); + + if (it == parameters_map_.end()) { + torch::jit::Value* param = + graph_->addInput(c10::str("p", parameters_.size())); + + auto info = mlir_data->mlir_info(); + if (info->scalar.has_value()) { + auto& scalar = info->scalar.value(); + if (scalar.isFloatingPoint()) { + param->setType(c10::FloatType::get()); + } else if (scalar.isIntegral(true)) { + param->setType(c10::IntType::get()); + } else { + TORCH_CHECK( + false, "Unhandled scalar type: ", c10::toString(scalar.type())); + } + } else { + // Save parameter shape information. + param->setType(torch::jit::TensorType::create( + /*scalar_type=*/data->shape().scalar_type(), + /*device=*/c10::nullopt, + /*sizes=*/c10::VaryingShape(data->shape().sizes()), + /*strides=*/c10::VaryingShape(), + /*requires_grad=*/c10::nullopt)); + } + + it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) + .first; + parameters_.push_back(mlir_data); + } + + parameter_sequence_.push_back(it->second.index); + return it->second.param; +} + +std::shared_ptr TorchMlirLoweringContext::graph() const { + return graph_; +} + +size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { + PRINT_FUNCTION(); + root_tuple_.push_back(std::move(op)); + return root_tuple_.size() - 1; +} + +// Sync vector of c10::Argument with type specified from parallel list of +// jit::Value. There must be a 1:1 map between elements of args and values. +std::vector sync_argument_types( + const std::vector& args, + c10::ArrayRef values) { + TORCH_CHECK( + args.size() == values.size(), + "Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ", + args.size(), ":", values.size(), " instead!"); + + std::vector updated_args; + for (unsigned i = 0; i < args.size(); i++) { + updated_args.push_back(args[i].cloneWithType(values[i]->type())); + } + + return updated_args; +} + +std::unique_ptr +TorchMlirLoweringContext::generate_jit_fn() const { + // IMPORTANT: We pass in a COPY of the graph into create_function, since it + // may get mutated in the process. + auto fn = std::make_unique( + c10::QualifiedName("graph"), graph_->copy(), nullptr); + + c10::FunctionSchema schema = fn->getSchema(); + + // When constructing the default schema of a jit::GraphFunction, input and + // output shapes are stripped (via call to unshapedType(...)); however, + // since we want to have shape information in our MLIR, we'll add it back. + std::vector arguments = + sync_argument_types(schema.arguments(), graph_->inputs()); + std::vector returns = + sync_argument_types(schema.returns(), graph_->outputs()); + + fn->setSchema(schema.cloneWithArguments(arguments).cloneWithReturns(returns)); + + return fn; +} + +void TorchMlirLoweringContext::RegisterMlirDialects() { + // https://reviews.llvm.org/D88162 + torchMlirRegisterAllDialects(mlir_context_); +} + +/////////////////////////////////////////////////////////////////////////////// +// TorchMlir Computation +/////////////////////////////////////////////////////////////////////////////// + +TorchMlirComputation::TorchMlirComputation( + MlirOperation func_op, MlirContext mlir_context, + const std::shared_ptr& graph, + InputOutputAliases input_output_aliases) + : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), + graph_(graph), input_output_aliases_(input_output_aliases) { + for (torch::jit::Value* input : graph_->inputs()) { + parameter_names_.push_back(input->debugName()); + } +} + +int TorchMlirComputation::parameters_size() const { + return parameter_names_.size(); +} + +const std::vector& +TorchMlirComputation::parameter_shapes() const { + throw std::runtime_error( + "todo(whc) implement ts computation shapes or change interface"); + return parameter_shapes_; +} + +const std::vector& TorchMlirComputation::parameter_names() const { + return parameter_names_; +} + +const torch::lazy::Shape& TorchMlirComputation::result_shape() const { + throw std::runtime_error( + "todo(whc) implement ts computation shapes or change interface"); + return result_shape_; +} + +std::shared_ptr TorchMlirComputation::graph() const { + return graph_; +} + +MlirOperation TorchMlirComputation::func_op() const { return func_op_; } + +const std::string TorchMlirComputation::debug_string() const { + std::stringstream ss; + + // JIT Graph + ss << "JIT Graph: \n" << graph_->toString() << "\n\n"; + + // MLIR + ss << "MLIR: \n" << to_string() << "\n"; + + // Input/Output Mapping + ss << "Input/Output Alias Mapping: \n"; + for (InputOutputAlias input_output_alias : input_output_aliases_) { + ss << "Output: " << input_output_alias.output_index + << " -> Input param: " << input_output_alias.param_number << "\n"; + } + ss << "\n"; + + // Mark Step + ss << "In Mark Step: " << (in_mark_step ? "true" : "false") << "\n"; + + return ss.str(); +} + +const std::string TorchMlirComputation::to_string() const { + // Since we use the C-MLIR API, we need to use a callback to print. + MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + // user_data is a void ptr to some data structure of our choice -- in this + // case, the string stream where we'll be accumulating the strings. + std::stringstream* ss_ptr = static_cast(user_data); + *ss_ptr << std::string(part.data, part.length); + }; + std::stringstream ss; + mlirOperationPrint(func_op_, print_callback, &ss); + return ss.str(); +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h new file mode 100644 index 000000000000..b6e6b1ceffdd --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -0,0 +1,154 @@ +//===- mlir_lowering_context.h --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include + +#include "mlir-c/IR.h" +#include "mlir_node_lowering.h" + +namespace torch { +namespace lazy { + +class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { +public: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + // Specifies the index of the aliased buffer in the result tuple. + std::vector output_index; + // Specifies the parameter containing the buffer to be aliased. + int64_t param_number; + // Specifies the index of the aliased buffer in the parameter + std::vector param_index; + // Specifies if the alias is a must alias or may alias. + bool must_alias; + }; + using InputOutputAliases = std::vector; + + TorchMlirLoweringContext( + const std::string& name, torch::lazy::BackendDevice device); + TorchMlirLoweringContext( + const std::string& name, torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); + + void Lower(const Node* node); + + // Adds a new input/output alias. + void SetUpAlias( + const std::vector& output_index, int64_t param_number, + const std::vector& param_index, + bool must_alias = false) override; + + // Check if parameter shape matches result at index. + bool CheckResultShape( + const BackendDataPtr& parameter_data, size_t result_idx) override; + + // Adds the given output as a component of the result tuple and returns its + // assigned position within the tuple. + size_t AddResult(const torch::lazy::Output& output) override; + + // Associates the given output with the input parameter of the given index and + // shape. Only used for the operator-by-operator execution, mostly for + // debugging purposes. + void AddParameter( + const torch::lazy::Output& output, size_t index, + const torch::lazy::Shape& shape, const std::string& name) override; + + // Build the computation capturing all the operations created with the + // embedded builder (returned by the builder() API). + torch::lazy::ComputationPtr Build() override; + + // Retrieves the lowered operation for an output. If the requested output is + // not available yet, the graph behind the output's Node is lowered, and the + // corresponding TS operation returned. + torch::jit::Value* GetOutputOp(const Output& output); + + // Assigns the given TS operation to the specified output. As outputs are + // lowered in a post-order fashion, later nodes should always find their + // operands among the emitted outputs. + void AssignOutputOp(const Output& output, torch::jit::Value* op); + + // If a parameter associated with data has already been declared, it will be + // returned. Otherwise a new one will be created, associated with the tensor + // held in data. + torch::jit::Value* GetParameter(BackendDataPtr data); + + std::shared_ptr graph() const; + +private: + struct Parameter { + torch::jit::Value* param; + size_t index = 0; + }; + + size_t AddResult(torch::jit::Value* op); + + // Creates a jit::Function from the current jit::Graph. Input and output + // type information is patched to include shape. + std::unique_ptr generate_jit_fn() const; + + void RegisterMlirDialects(); + + // Holds the input/output alias information populated by the SetUpAlias() API. + InputOutputAliases input_output_aliases_; + std::shared_ptr graph_; + std::shared_ptr function_; + MlirContext mlir_context_; + std::unordered_map parameters_map_; + std::vector root_tuple_; + OutputMap emitted_outputs_; +}; + +class TORCH_API TorchMlirComputation : public torch::lazy::Computation { +public: + using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases; + using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; + + TorchMlirComputation( + MlirOperation func_op, MlirContext mlir_context, + const std::shared_ptr& graph, + InputOutputAliases input_output_aliases); + + int parameters_size() const override; + + const std::vector& parameter_shapes() const override; + + const std::vector& parameter_names() const override; + + const torch::lazy::Shape& result_shape() const override; + + std::shared_ptr graph() const; + + MlirOperation func_op() const; + + virtual const std::string debug_string() const; + + virtual const std::string to_string() const override; + +protected: + std::vector parameter_names_; + std::vector parameter_shapes_; + Shape result_shape_; + + MlirOperation func_op_; + MlirContext mlir_context_; + std::shared_ptr graph_; + InputOutputAliases input_output_aliases_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp new file mode 100644 index 000000000000..32cba4fdf63f --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -0,0 +1,461 @@ +//===- aten_ltc_mlir_type.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "generated/LazyNativeFunctions.h" +#include "generated/shape_inference.h" +#include "ops/to_copy.h" +#include "utils/exception.h" +#include "utils/sys_utils.h" + +namespace torch { +namespace lazy { + +namespace { + +at::Tensor CreateLtcTensor( + const at::Tensor& tensor, + const c10::optional& device) { + if (tensor.defined() && device) { + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(tensor, *device)); + } + return tensor; +} + +c10::optional +GetLtcDevice(const c10::optional& device) { + if (!device) { + return c10::nullopt; + } + if (device->type() != at::kLazy) { + return c10::nullopt; + } + return torch::lazy::atenDeviceToBackendDevice(*device); +} + +torch::lazy::Value MaybeExpand( + const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) { + if (input.shape().sizes() == target_shape.sizes()) { + return input; + } + return torch::lazy::MakeExpand( + input, target_shape.sizes().vec(), + /*is_scalar_expand=*/false); +} + +void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { + if (input->GetDevice() == src->GetDevice()) { + torch::lazy::Value copy_value; + if (input->dtype() == src->dtype()) { + copy_value = src->GetIrValue(); + } else { + copy_value = torch::lazy::MakeCast( + src->GetIrValue(), input->dtype(), src->dtype()); + } + input->SetIrValue(MaybeExpand(copy_value, input->shape())); + } else { + auto input_shape = input->shape(); + at::Tensor src_tensor = src->ToTensor(/*detached=*/true); + if (src_tensor.sizes() != input_shape.Get().sizes()) { + src_tensor = src_tensor.expand(input_shape.Get().sizes().vec()); + } + input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false); + } +} + +} // namespace + +// at::Tensor LazyNativeFunctions::bernoulli( +// const at::Tensor& self, c10::optional generator) { +// TORCH_LAZY_FN_COUNTER("lazy::"); +// if (generator.has_value() && generator->defined()) { +// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value"); +// } +// auto self_tensor = torch::lazy::TryGetLtcTensor(self); + +// UNIMPLEMENTED_FUNCTION_ERROR(); +// // return torch::lazy::CreateAtenFromLtcTensor( +// // torch::lazy::bernoulli(self_tensor)); +// } + +// at::Tensor& LazyNativeFunctions::bernoulli_( +// at::Tensor& self, double p, c10::optional generator) { +// TORCH_LAZY_FN_COUNTER("lazy::"); +// if (generator.has_value() && generator->defined()) { +// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value"); +// } +// auto self_tensor = torch::lazy::TryGetLtcTensor(self); + +// UNIMPLEMENTED_FUNCTION_ERROR(); +// // torch::lazy::bernoulli_(self_tensor, p); +// // return self; +// } + +// clone is special in LT because we make it a no-op. +// This should be safe to do, because every operator in the LT is functional. +at::Tensor LazyNativeFunctions::clone( + const at::Tensor& self, c10::optional memory_format) { + auto self_lt = torch::lazy::TryGetLtcTensor(self); + return torch::lazy::CreateAtenFromLtcTensor( + self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); +} + +at::Tensor LazyNativeFunctions::_copy_from( + const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); + auto self_tensor = torch::lazy::TryGetLtcTensor(self); + if (!self_tensor) { + // providing a new 'eager' value (self) for an existing lazy tensor (dst) + static bool sync_update = + sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true); + CHECK(dst_tensor); + dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); + } else if (!dst_tensor) { + // materializing a lazy tensor (self) and copying its value into eager + // tensor (dst) + // detached=false lets us skip a copy in `ToTensor`, which should be safe + // becuase we are only going to use the tensor for dst.copy_() + CHECK(self_tensor); + at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false); + at::Tensor typed_tensor = + torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); + dst.resize_as_(typed_tensor).copy_(typed_tensor); + } else { + // Copying one lazy tensor to another + if (!dst_tensor->CurrentIrValue()) { + // if dest is not backed by IR (e.g. result of some lazy operation), + // then it should have at::Tensor data backing it instead + auto dst_tensor_data = dst_tensor->CurrentTensorData(); + CHECK(dst_tensor_data); + auto src_tensor_data = self_tensor->CurrentTensorData(); + if (src_tensor_data) { + // both src/dst are simply backed by at::Tensor data, no IR- do a + // straightforward copy + dst_tensor_data->copy_(*src_tensor_data); + } else { + // src needs to be materialized before its result can be used for a copy + // into dst + // since we use the src tensor only for making a copy, we don't need to + // detach it + // note: it would be even more efficient if we could cause ToTensor to + // materialize the + // value directly into dst's buffer (that would need to be detached + // though). + dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false)); + } + } else { + copy_(dst_tensor, self_tensor); + auto* impl = + dynamic_cast(dst.unsafeGetTensorImpl()); + impl->set_tensor(dst_tensor); + } + } + return dst; +} + +at::Tensor LazyNativeFunctions::_copy_from_and_resize( + const at::Tensor& self, const at::Tensor& dst) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); + auto self_tensor = torch::lazy::TryGetLtcTensor(self); + if (!self_tensor) { + CHECK(dst_tensor); + dst_tensor->UpdateFromTensorOut(self); + } else if (!dst_tensor) { + CHECK(self_tensor); + at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); + at::Tensor typed_tensor = + torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); + dst.resize_as_(typed_tensor).copy_(typed_tensor); + } else { + // at this point we know dst is a lazy tensor + auto* dest_impl = + dynamic_cast(dst.unsafeGetTensorImpl()); + dest_impl->tensor()->UpdateFromTensorOut(self_tensor); + dest_impl->force_refresh_sizes(); + } + return dst; +} + +at::Tensor LazyNativeFunctions::_to_copy( + const at::Tensor& self, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, bool non_blocking, + c10::optional memory_format) { + PRINT_FUNCTION(); + auto options = self.options(); + if (dtype) { + // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... + // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it + options = options.dtype(dtype); + } + if (layout) { + options = options.layout(layout); + } + if (memory_format) { + options = options.memory_format(memory_format); + } + if (pin_memory) { + // TODO(whc) can we honor 'pin_memory' in some/all cases? + options = options.pinned_memory(pin_memory); + TORCH_WARN_ONCE("Pinned memory used in lazy _to_copy, check if the " + "behavior is as intended"); + } + + TORCH_LAZY_FN_COUNTER("lazy::"); + auto lazy_self = torch::lazy::TryGetLtcTensor(self); + if (!lazy_self && device && device->type() == c10::kLazy) { + // Case 1: eager->lazy (we create a new lazy tensor) + // See Note [Lazy Tensor Functionalization] + // Invariant: if the functionalization key is in the exclude set, then we're expected + // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. + bool functionalize_output = + !c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::Functionalize); + return torch::lazy::to_lazy_tensor( + self, options, *device, /*non_blocking=*/non_blocking, + /*functionalize_output=*/functionalize_output); + } else if (device && device->type() != c10::kLazy) { + // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + options = options.device(device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + return moved_eager_tensor; + } else if ( + device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { + // Case 3: lazy:0 -> lazy:1 + + // TODO(whc) what do we actually want to do here? + // option 1: materialize, move eager tensor, create new lazy tensor + // - this should be our default, as it is what would happen before we implemented _to_copy + // - actually combines case 1 + case 2 + // option 2: support multiple devices inside one lazy/TS executor (case 4) + // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + // we move the eager tensor to the 'eager' equivalent of our lazy device + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use + auto eager_device = c10::Device( + torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); + options = options.device(eager_device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true); + lazy_self = torch::lazy::GetOrCreateLtcTensor( + moved_eager_tensor, + torch::lazy::atenDeviceToBackendDevice(eager_device)); + return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + + } else { + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) + + // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. + // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to + // convert an eager tensor back to a lazy one inside the torchscript executor + // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument + device = c10::nullopt; + + auto shapes = torch::lazy::compute_shape__to_copy( + self, dtype, layout, device, pin_memory, non_blocking, memory_format); + TORCH_INTERNAL_ASSERT(shapes.size() == 1); + auto node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), dtype, layout, device, pin_memory, + non_blocking, memory_format, std::move(shapes)); + + auto result = + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + std::move(node), lazy_self->GetDevice())); + return result; + } +}; + +at::Tensor LazyNativeFunctions::empty( + at::SymIntArrayRef sym_size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { + // TODO: support this directly + auto size = c10::asIntArrayRefSlow(sym_size); + const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); + at::TensorOptions options = at::TensorOptions() + .device(c10::Device(device_type)) + .layout(layout) + .pinned_memory(pin_memory) + .dtype(dtype); + auto x_result = at::empty(size, options, memory_format); + auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); + // See Note [Lazy Tensor Functionalization] + if (c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::Functionalize)) { + // Invariant: if the functionalization key is in the exclude set, then we're expected + // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. + return tensor; + } else { + auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); + return wrapped; + } +} + +at::Tensor LazyNativeFunctions::empty_strided( + at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER("lazy::"); + at::Tensor t = empty( + c10::SymIntArrayRef::fromIntArrayRef(size), + dtype, layout, device, pin_memory, c10::nullopt); + return t.as_strided(size, stride, /*storage_offset=*/0); +} + +at::Tensor& +LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto self_tensor = torch::lazy::TryGetLtcTensor(self); + + torch::lazy::Value constant = + torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( + value, self_tensor->shape(), self_tensor->GetDevice()); + self_tensor->SetInPlaceIrValue(std::move(constant)); + return self; +} + +at::Tensor LazyNativeFunctions::_unsafe_view( + const at::Tensor& self, at::IntArrayRef size) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return LazyNativeFunctions::view_copy(self, c10::SymIntArrayRef::fromIntArrayRef(size)); +} + +// This is needed by the torch.tensor constructor. +// LazyTensor always opts into functionalization. +// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object. +at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); +} + +at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); +} + +// All of the below ops correspond to CompositeExplicitAutograd kernels from core +// that call into view operators internally. +// These are all composite ops that LTC can technically re-use / get for free, +// but we need to "functionalize" them to remove the view ops before we can use them. +at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { + return at::functionalization::functionalize_aten_op::call(tensors); +} +at::Tensor LazyNativeFunctions::new_empty_strided( + const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return at::functionalization:: + functionalize_aten_op::call( + self, size, stride, dtype, layout, device, pin_memory); +} + +at::Tensor LazyNativeFunctions::pixel_shuffle( + const at::Tensor& self, int64_t upscale_factor) { + return at::functionalization::functionalize_aten_op::call(self, upscale_factor); +} +at::Tensor LazyNativeFunctions::pixel_unshuffle( + const at::Tensor& self, int64_t downscale_factor) { + return at::functionalization::functionalize_aten_op::call(self, downscale_factor); +} +at::Tensor LazyNativeFunctions::select_backward( + const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, + int64_t index) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, index); +} +at::Tensor LazyNativeFunctions::slice_backward( + const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, + int64_t start, int64_t end, int64_t step) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, start, end, step); +} +at::Tensor LazyNativeFunctions::diagonal_backward( + const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset, + int64_t dim1, int64_t dim2) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, offset, dim1, dim2); +} +at::Tensor LazyNativeFunctions::_trilinear( + const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, + at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, + at::IntArrayRef sumdim, int64_t unroll_dim) { + return at::functionalization::functionalize_aten_op:: + call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); +} +at::Tensor LazyNativeFunctions::linalg_pinv( + const at::Tensor& self, const c10::optional& atol, + const c10::optional& rtol, bool hermitian) { + return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); +} + +// functionalize_aten_op can't handle out= ops directly. +// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs. +at::Tensor& LazyNativeFunctions::logsumexp_out( + const at::Tensor& self, at::IntArrayRef dim, bool keepdim, + at::Tensor& out) { + auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); + auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); + // directly call the composite kernel from core. + // Make sure to re-enable functionalization first. + auto curr_tls = c10::impl::tls_local_dispatch_key_set(); + auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); + tls_reenable_functionalize.set_included(curr_tls.included_); + tls_reenable_functionalize.set_excluded( + curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); + c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); + at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped); + auto out_unwrapped = + at::functionalization::impl::from_functional_tensor(out_wrapped); + // propagate mutations back to the inputs (including resizing) + out.resize_(out_unwrapped.sizes()); + out.copy_(out_unwrapped); + return out; +} + +void InitializeAtenBindings() {} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp new file mode 100644 index 000000000000..51907c9b4c11 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -0,0 +1,104 @@ +//===- mlir_node.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp +//===----------------------------------------------------------------------===// + +#include "mlir_node.h" +#include "utils/exception.h" + +namespace torch { +namespace lazy { + +namespace { + +hash_t OperandHashes( + const OpList& operands, const c10::ArrayRef& shapes, + const hash_t& seed, bool bakeInSizes) { + hash_t hash = seed; + for (auto& operand : operands) { + if (!operand) { + hash = HashCombine(hash, static_cast(kNullOpt)); + continue; + } + auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); + hash = HashCombine(hash, operand_hash); + } + for (auto& shape : shapes) { + hash = HashCombine(hash, shape.hash(bakeInSizes)); + } + return hash; +} + +} // namespace + +TorchMlirNode::TorchMlirNode( + OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, + hash_t hash_seed) + : Node(op, operands, std::move(shapes), num_outputs) { + hash_seed = HashCombine(op.hash(), hash_seed); + shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); + dag_hash_ = + (enableDynamicShape() + ? OperandHashes(operands, this->shapes(), hash_seed, false) + : shape_hash_); +} + +TorchMlirNode::TorchMlirNode( + OpKind op, OpList operands, const std::function& shape_fn, + size_t num_outputs, hash_t hash_seed) + : TorchMlirNode( + op, operands, std::vector{}, num_outputs, hash_seed) { + addComputedShape(shape_fn); +} + +TorchMlirNode::TorchMlirNode( + OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) + : TorchMlirNode( + op, operands, std::vector{}, num_outputs, hash_seed) {} + +TorchMlirNode::TorchMlirNode( + OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) + : TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} + +hash_t TorchMlirNode::hash() const { return dag_hash_; } + +hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } + +OpKind TorchMlirTensorList::ClassOpKind() { + // Note: this OpKind is separate from ltc_ops.h since it would be a circular + // import otherwise + static const OpKind tensor_list_opkind = + OpKind::Get("lazy_tensors::tensor_list"); + return tensor_list_opkind; +} + +TorchMlirTensorList::TorchMlirTensorList(OpList values) + : TorchMlirNode( + /*op=*/TorchMlirTensorList::ClassOpKind(), + /*operands=*/values, + /*shapes=*/std::vector(), + /*num_outputs=*/1, + /*hash_seed=*/kHashSeed) {} + +torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector tensor_list; + CHECK(!operands().empty()); + for (const torch::lazy::Output& operand : operands()) { + tensor_list.emplace_back(loctx->GetOutputOp(operand)); + } + auto graph = function->graph(); + auto listnode = + graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list)); + return {listnode->output()}; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h new file mode 100644 index 000000000000..c7e10d8cf6ed --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -0,0 +1,90 @@ +//===- mlir_node.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include "mlir_lowering_context.h" +#include "utils/debug.h" +#include "utils/exception.h" + +namespace torch { +namespace lazy { + +class TORCH_API TorchMlirNode : public torch::lazy::Node { +public: + TorchMlirNode( + OpKind op, OpList operands, std::vector&& shapes, + size_t num_outputs, hash_t hash_seed = kHashSeed); + + TorchMlirNode( + OpKind op, OpList operands, const std::function& shape_fn, + size_t num_outputs, hash_t hash_seed = kHashSeed); + + TorchMlirNode( + OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed = kHashSeed); + + TorchMlirNode( + OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + + ~TorchMlirNode() override = default; + + hash_t hash() const override; + + hash_t shapeHash() const override; + + virtual TorchMlirOpVector + Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; + +private: + // The hash of the dag WITH size info. Used for shape caching + hash_t shape_hash_; + // The hash of the dag used to look up the compiled graph by a hash + // in this case, we will use the dag hash WITHOUT size info if dynamic shape + // is enabled and use the dag hash WITH size info otherwise. + hash_t dag_hash_; +}; + +// TensorList represents an at::TensorList which is a vector[Tensor] but is also +// a first-class IValue and can be fed as a single input to a TS program. It is +// much easier to handle TensorLists in Lazy Tensor code if they are represented +// as a single Node so there can be more than one TensorList and more than one +// Tensor side-by-side as operands to an op. +// +// Note: shape is undefined for TensorList. We assert in some places that +// #shapes matches #outputs and this stems from +// the fact that currently all IR nodes represent tensors (there is no +// type system for this IR). Becuase of this, TensorList is a bit of a +// hack. +// +// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and +// then implement it as NotImplemented for TensorList, also fixing the assertion +// that would fail. +struct TORCH_API TorchMlirTensorList : public TorchMlirNode { + static OpKind ClassOpKind(); + + TorchMlirTensorList() = delete; + TorchMlirTensorList(OpList values); + + torch::lazy::TorchMlirOpVector Lower( + TorchMlirFunction function, + TorchMlirLoweringContext* loctx) const override; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp new file mode 100644 index 000000000000..e3d4fab862f4 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -0,0 +1,288 @@ +//===- mlir_node_lowering.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp +//===----------------------------------------------------------------------===// + +#include "mlir_node_lowering.h" +#include "generated/LazyNonNativeIr.h" +#include "mlir_lowering_context.h" +#include "mlir_node.h" +#include "ops/device_data.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace lazy { + +TorchMlirOpVector LowerTorchMlirBuiltin( + TorchMlirFunction function, c10::Symbol sym, + const std::vector tensor_types, + const std::vector& arguments, + const std::vector& kwarguments) { + auto builtin = + std::make_shared(sym, at::nullopt); + auto magic_method = std::make_shared("", builtin); + auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); + auto sv = dynamic_cast(ret.get()); + CHECK(sv); + + TorchMlirOpVector results; + if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { + // Op returns multiple values. + const auto tuple_call_result = sv->asTuple({}, *function); + for (const auto& tuple_component : tuple_call_result) { + auto tuple_component_sv = + dynamic_cast(tuple_component.get()); + results.push_back(tuple_component_sv->getValue()); + } + } else { + // Op returns single value. + results.push_back(sv->getValue()); + } + + // Insert known tensor type information. + unsigned tensor_type_idx = 0; + for (jit::Value* value : results) { + if (value->type()->kind() == c10::TypeKind::TensorType) { + TORCH_CHECK( + tensor_type_idx < tensor_types.size(), function->graph()->toString(), + "\nTensor corresponding to JIT SSA value %", value->debugName(), + " corresponds to result #", tensor_type_idx, ", but we only have ", + tensor_types.size(), " known types!"); + + value->setType(tensor_types[tensor_type_idx++]); + } + } + + // Ensure that we use up all the known tensor type information available. + TORCH_CHECK( + tensor_type_idx == tensor_types.size(), tensor_type_idx, + " known types were injected into jit::Value, but ", tensor_types.size(), + " were provided from lazy::Node!"); + + return results; +} + +TorchMlirOpVector LowerTorchMlirBuiltin( + TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector& arguments, + const std::vector& kwarguments) { + std::vector tensor_types; + + // Generate types with fixed tensor shape information. + for (const Shape& shape : result_shapes) { + tensor_types.push_back(torch::jit::TensorType::create( + /*scalar_type=*/shape.scalar_type(), + /*device=*/c10::nullopt, + /*sizes=*/c10::VaryingShape(shape.sizes()), + /*strides=*/c10::VaryingShape(), + /*requires_grad=*/c10::nullopt)); + } + + return LowerTorchMlirBuiltin( + function, sym, tensor_types, arguments, kwarguments); +} + +TorchMlirOpVector LowerBuiltin( + const torch::lazy::Node* node, TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin( + function, node->op().op, node->shapes(), arguments, kwarguments); +} +TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const c10::ArrayRef result_shapes, + TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin( + function, sym, result_shapes, arguments, kwarguments); +} +TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const std::vector types, + TorchMlirFunction function, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments); +} + +c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { + auto tensor_type = value_type->cast(); + TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); + + return *tensor_type.get(); +} + +c10::optional> +get_tensor_type_shape(c10::TensorType& tensor_type) { + auto& symbolic_shape = tensor_type.symbolic_sizes(); + if (!symbolic_shape.rank()) { + return c10::nullopt; + } + + // Get current tensor shape. + std::vector dims; + dims.resize(*symbolic_shape.rank()); + for (size_t i = 0; i < dims.size(); ++i) { + auto shape_symbol = symbolic_shape[i]; + dims[i] = shape_symbol.is_static() ? shape_symbol.static_size() : -1; + } + + return dims; +} + +std::vector compute_shape_copy(c10::TypePtr value_type) { + c10::TensorType& tensor_type = cast_tensor_type(value_type); + + auto maybe_dims = get_tensor_type_shape(tensor_type); + TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!"); + + auto scalar_type = tensor_type.scalarType(); + TORCH_CHECK( + scalar_type.has_value(), "Unable to copy due to lack of scalar type!"); + return {Shape(scalar_type.value(), maybe_dims.value())}; +} + +std::vector compute_shape_slice( + c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end, + int64_t step) { + c10::TensorType& tensor_type = cast_tensor_type(value_type); + + auto maybe_dims = get_tensor_type_shape(tensor_type); + TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!"); + + std::vector dims = maybe_dims.value(); + int64_t num_dims = dims[dim]; + + // Index may be negative, so we must normalize it. + auto normalize_index = [](int64_t index, unsigned num_dims) { + return index < 0 ? (int64_t)num_dims + index : index; + }; + start = normalize_index(start, num_dims); + end = normalize_index(end, num_dims); + + if (start >= end || start >= num_dims || end <= 0) { + // Slice is out of bounds, nothing in range. + dims[dim] = 0; + } else { + // Clamp upper and lower bound to valid indices. + start = std::max((int64_t)0, start); + end = std::min(num_dims, end); + + // Final size is determined by step and interval size. + dims[dim] = std::ceil((double)(end - start) / (double)step); + } + + auto scalar_type = tensor_type.scalarType(); + TORCH_CHECK( + scalar_type.has_value(), "Unable to slice due to lack of scalar type!"); + return {Shape(scalar_type.value(), dims)}; +} + +torch::jit::Value* +GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { + std::vector clone_arguments; + clone_arguments.emplace_back(val); + + // Type of cloned value should be identical to the original one. + TorchMlirOpVector cloned = + LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments); + TORCH_CHECK_EQ(cloned.size(), 1); + return cloned.front(); +} + +void GenerateCopy( + torch::jit::Value* destination, torch::jit::Value* source, + TorchMlirFunction function) { + std::vector arguments; + arguments.emplace_back(destination); + arguments.emplace_back(source); + LowerBuiltin( + at::aten::copy_, c10::ArrayRef(compute_shape_copy(source->type())), + function, arguments); +} + +torch::jit::Value* GenerateSlice( + torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, + int64_t step, TorchMlirFunction function) { + std::vector arguments; + arguments.emplace_back(base); + arguments.emplace_back(dim); + arguments.emplace_back(start); + arguments.emplace_back(end); + arguments.emplace_back(step); + + TorchMlirOpVector selected = LowerBuiltin( + at::aten::slice, + c10::ArrayRef( + compute_shape_slice(base->type(), dim, start, end, step)), + function, arguments); + TORCH_CHECK_EQ(selected.size(), 1); + return selected.front(); +} + +// Node Lowerings + +// Default Node Lowering +TorchMlirOpVector TorchMlirNode::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + for (const torch::lazy::Output& output : operands()) { + arguments.emplace_back(loctx->GetOutputOp(output)); + } + return LowerBuiltin(this, function, arguments); +} + +// TorchMlir specific nodes + +// Non-native nodes + +TorchMlirOpVector +Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + std::vector arguments; + arguments.emplace_back(loctx->GetOutputOp(operand(0))); + arguments.emplace_back(dtype); + return LowerBuiltin(at::aten::to, shapes(), function, arguments); +} + +TorchMlirOpVector DeviceData::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + auto infoptr = data_->info(); + auto deviceDataInfoPtr = + (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + if (GRAPH_DUMP_ENABLED) { + LOG(ERROR) << "Lowering device data node, tensor id " + << deviceDataInfoPtr->tensor_id << std::endl; + } + return {loctx->GetParameter(data_)}; +} + +TorchMlirOpVector Scalar::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + auto options = + at::TensorOptions() + .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) + .dtype(shape().scalar_type()); + return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h new file mode 100644 index 000000000000..f9e028a5cc15 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h @@ -0,0 +1,31 @@ +//===- mlir_node_lowering.h -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace torch { +namespace lazy { + +typedef std::vector TorchMlirOpVector; +typedef std::shared_ptr TorchMlirFunction; + +TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( + TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector& arguments, + const std::vector& kwarguments = {}); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp new file mode 100644 index 000000000000..653211e8d9c3 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp @@ -0,0 +1,68 @@ +#include + +#include + +#include "device_data.h" +#include "../backend_impl.h" + +namespace torch { +namespace lazy { + +DeviceData::DeviceData(std::shared_ptr data) + : TorchMlirNode( + ClassOpKind(), + data->shape(), + /*num_outputs=*/1, + /*hash_seed=*/static_cast(101)), + data_(std::move(data)) { + propagate_name(); +} + +void DeviceData::propagate_name() { + if (data_ && name_ != "") { + // Add device data name to backend data + TorchMlirBackendData* mlir_data = dynamic_cast(data_.get()); + TORCH_CHECK(mlir_data); + TorchMlirBackendData::Info* info = mlir_data->mlir_info(); + TORCH_CHECK(info); + info->name = name_; + } +} + +void DeviceData::SetData(std::shared_ptr data) { + data_ = data; + propagate_name(); +} + +void DeviceData::SetName(const std::string& name) { + name_ = name; + propagate_name(); +} + +std::string DeviceData::ToString() const { + std::stringstream ss; + ss << TorchMlirNode::ToString() << ", device=" << data_->device(); + if (name_ != "") { + ss << ", name=" << name_; + } + return ss.str(); +} + +const DeviceData* DeviceData::Cast(const Node* node) { + return NodeCast(node); +} + +NodePtr DeviceData::Create(std::shared_ptr data) { + NodePtr node = ReuseOrMakeNode(data); + // ReuseOrMakeNode may return a reused node which has the same shape, + // however, we need to replace the old data_ with the new one. + // Ditching the old data_ is safe because tracing is done iteration + // by iteration, and after we lauch the async device execution for the + // previous iteration, data_ in DeviceData nodes are not needed anymore. + DeviceData* device_data = static_cast(node.get()); + device_data->SetData(data); + return node; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h new file mode 100644 index 000000000000..ad9d9d0eb94b --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h @@ -0,0 +1,53 @@ +#pragma once + +#include "../mlir_lowering_context.h" +#include "../mlir_node.h" + +#include +#include + + +namespace torch { +namespace lazy { + +class TORCH_API DeviceData : public TorchMlirNode { + public: + static OpKind ClassOpKind() { + return ltc_device_data; + } + + explicit DeviceData(std::shared_ptr data); + + // A DeviceData node can be reused if the shape matches, + // but we will substitute the actual data_ pointer under + // the hood. + bool CanBeReused(std::shared_ptr data) const { + return data_->shape() == data->shape(); + } + + std::string ToString() const override; + + const std::shared_ptr& data() const { return data_; } + + void SetData(std::shared_ptr data); + + TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override; + + static const DeviceData* Cast(const Node* node); + + // To reuse IR nodes, use this method to create DeviceData nodes + // instead of calling the constructor directly. + static NodePtr Create(std::shared_ptr data); + + const std::string& GetName() const { return name_; } + void SetName(const std::string& name); + + private: + void propagate_name(); + + std::shared_ptr data_; + std::string name_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/generic.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/generic.cpp new file mode 100644 index 000000000000..1df8be231023 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/generic.cpp @@ -0,0 +1,28 @@ +//===- generic.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.cpp +//===----------------------------------------------------------------------===// + +#include "generic.h" + +namespace torch { +namespace lazy { + +Generic::Generic( + OpKind op, + OpList operands, + Shape shape, + size_t num_outputs, + hash_t hash_seed) + : TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed), + hash_seed_(hash_seed) {} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/generic.h b/python/torch_mlir/csrc/base_lazy_backend/ops/generic.h new file mode 100644 index 000000000000..f294b1cfaed2 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/generic.h @@ -0,0 +1,39 @@ +//===- generic.h ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + +// Generic IR Node implementation for nodes which can simply be described by a +// specific OpKind and a lowering function. IR nodes carrying +// metadata should not be using this class TORCH_API (and have the metadata +// captured by the LowerFn), but they should instead create a dedicated IR node. +// Doing the former would limit IR introspection. +class TORCH_API Generic : public TorchMlirNode { + public: + Generic( + OpKind op, + OpList operands, + Shape shape, + size_t num_outputs = 1, + hash_t hash_seed = static_cast(0x5a2d296e9)); + + private: + hash_t hash_seed_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h new file mode 100644 index 000000000000..c6b75baaf8f3 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h @@ -0,0 +1,101 @@ +//===- to_copy.h ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// this file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ops/to_copy.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + + +// This IR was copied from code-generated output, but the entire _to_copy operator +// cannot be trivially code genereated since it is only desirable to capture IR for +// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke +// the aten/eager fallback necessitating directly implementing the right to(device) behavior +class ToCopy : public torch::lazy::TorchMlirNode { + public: + ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) + : torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy), + {self}, std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), + + dtype(dtype), + layout(layout), + device(device), + pin_memory(pin_memory), + non_blocking(non_blocking), + memory_format(memory_format) {} + + std::string ToString() const override { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + if (dtype.has_value()) { + ss << ", dtype=" << dtype.value(); + } else { + ss << ", dtype=null"; + } + if (layout.has_value()) { + ss << ", layout=" << layout.value(); + } else { + ss << ", layout=null"; + } + if (device.has_value()) { + ss << ", device=" << device.value(); + } else { + ss << ", device=null"; + } + if (pin_memory.has_value()) { + ss << ", pin_memory=" << pin_memory.value(); + } else { + ss << ", pin_memory=null"; + } + ss << ", non_blocking=" << non_blocking; + if (memory_format.has_value()) { + ss << ", memory_format=" << memory_format.value(); + } else { + ss << ", memory_format=null"; + } + return ss.str(); + } + + torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function, + torch::lazy::TorchMlirLoweringContext* loctx) const override { + std::vector arguments; + std::vector kwarguments; + arguments.reserve(1); + kwarguments.reserve(6); + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + kwarguments.emplace_back("dtype", dtype); + kwarguments.emplace_back("layout", layout); + kwarguments.emplace_back("device", device); + kwarguments.emplace_back("pin_memory", pin_memory); + kwarguments.emplace_back("non_blocking", non_blocking); + kwarguments.emplace_back("memory_format", memory_format); + torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + TORCH_CHECK_EQ(_to_copy_out.size(), 1); + + return _to_copy_out; + + } + + c10::optional dtype; + c10::optional layout; + c10::optional device; + c10::optional pin_memory; + bool non_blocking; + c10::optional memory_format; +}; +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp new file mode 100644 index 000000000000..48004d9d34eb --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -0,0 +1,40 @@ +//===- LazyShapeInference.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "generated/shape_inference.h" +#include "utils/exception.h" + +namespace torch { +namespace lazy { + +// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future. + +std::vector +compute_shape_div(const at::Tensor& self, const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector +compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_var( + const at::Tensor& self, at::OptionalIntArrayRef dim, + c10::optional correction, bool keepdim) { + // Result of variance is scalar tensor. + return {Shape(self.scalar_type(), {})}; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/debug.h b/python/torch_mlir/csrc/base_lazy_backend/utils/debug.h new file mode 100644 index 000000000000..98a86b3d7c88 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/debug.h @@ -0,0 +1,27 @@ +//===- debug.h ------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "sys_utils.h" + +#define PRINT_DEBUG(msg) \ + std::cout << msg << " (" << __FILE__ << ":" << __LINE__ << ")" \ + << std::endl; + +#define PRINT_FUNCTION() \ + if (verbose_print_function) { \ + std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \ + << ")" << std::endl; \ + } + +static const bool verbose_print_function = + sys_util::GetEnvBool("VERBOSE_PRINT_FUNCTION", false); diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/exception.h b/python/torch_mlir/csrc/base_lazy_backend/utils/exception.h new file mode 100644 index 000000000000..96510d830aef --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/exception.h @@ -0,0 +1,32 @@ +//===- exception.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#define UNIMPLEMENTED_ERROR(msg) \ + { \ + std::ostringstream err; \ + err << "Unimplemented Error: " << msg; \ + throw std::runtime_error(err.str()); \ + } + +#define UNIMPLEMENTED_FUNCTION_ERROR() \ + UNIMPLEMENTED_ERROR( \ + "\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__) + +#define UNSUPPORTED_ERROR(msg) \ + { \ + std::ostringstream err; \ + err << "Unsupported Error: " << msg; \ + throw std::runtime_error(err.str()); \ + } diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h new file mode 100644 index 000000000000..c4c2ea79d6ab --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/string_utils.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + + +template +std::ostream& string_join(std::ostream& out, const std::vector& v, const std::string& delimiter) { + size_t i = 0; + for (const T& e : v) { + if ((i++) > 0) { out << delimiter; } + out << e; + } + return out; +} + +template +std::string string_join(const std::vector& v, const std::string& delimiter) { + std::ostringstream joined; + string_join(joined, v, delimiter); + return joined.str(); +} + + +/* + * Returns true if str starts with prefix + */ +inline bool startswith(const std::string& str, const std::string& prefix) { + return str.rfind(prefix, 0) == 0; +} diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h new file mode 100644 index 000000000000..6cb47895af92 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +namespace sys_util { + +template +static T GetEnv(const std::string& name, const T& default_value = T(0)) { + const char* env = std::getenv(name.c_str()); + if (!env) { + return default_value; + } + return T(std::atoi(env)); +} + +static bool GetEnvBool(const char* name, bool defval) { + const char* env = std::getenv(name); + if (env == nullptr) { + return defval; + } + if (std::strcmp(env, "true") == 0) { + return true; + } + if (std::strcmp(env, "false") == 0) { + return false; + } + return std::atoi(env) != 0; +} + +} // namespace sys_util diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h new file mode 100644 index 000000000000..75a900abd776 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h @@ -0,0 +1,30 @@ +#pragma once + +#include "torch/csrc/lazy/backend/backend_device.h" +#include "torch/csrc/lazy/core/tensor.h" + +#include "../ops/device_data.h" + + +namespace torch { +namespace lazy { + +inline torch::lazy::DeviceData* device_data_cast( + const at::Tensor& tensor, c10::optional device = c10::nullopt +) { + if (!device) { + device = torch::lazy::GetBackendDevice(tensor); + } + TORCH_CHECK(device); + torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); + if (lazy_tensor) { + torch::lazy::Value param_value = lazy_tensor->GetIrValue(); + if (param_value && param_value->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(param_value.node.get()); + } + } + return nullptr; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt new file mode 100644 index 000000000000..dcfce9e89570 --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt @@ -0,0 +1,79 @@ +########################################################################### +# Setup PyTorch +########################################################################### + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") +include(TorchMLIRPyTorch) + +TorchMLIRProbeForPyTorchInstall() +if(TORCH_MLIR_USE_INSTALLED_PYTORCH) + TorchMLIRConfigurePyTorch() +else() + set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch") +endif() + +find_package(Torch 1.11 REQUIRED) + +########################################################################### +# Setup Python development +########################################################################### + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/externals/llvm-project/mlir/cmake/modules") +include(MLIRDetectPythonEnv) +mlir_configure_python_dev_packages() + +########################################################################### +# Library definition +########################################################################### + +# We piggyback on the shared library setup/infra used by Torch-MLIR Python bindings for consistency. +# https://github.com/llvm/torch-mlir/pull/1283 +set(LIBRARY_OUTPUT_PATH "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs") +set(OUTPUT_NAME "_REFERENCE_LAZY_BACKEND") + +if(TORCH_MLIR_ENABLE_LTC) + include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${Python3_INCLUDE_DIRS} + ${PYTHON_H_DIR} + ${PROJECT_SOURCE_DIR}/python + ) + link_directories("${TORCH_INSTALL_PREFIX}/lib") + link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib) + add_link_options(-Wl,-rpath,$ORIGIN/lib) + + add_library(reference_lazy_backend MODULE + backend_impl.cpp + reference_lazy_backend_pybind.cpp + ) + add_dependencies(reference_lazy_backend + torch_mlir_ltc_backend + ) + target_link_libraries(reference_lazy_backend + ${TORCH_LIBRARIES} + torch_python + torch_mlir_ltc_backend + ) + + message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") + set_target_properties(reference_lazy_backend PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_PATH} + OUTPUT_NAME ${OUTPUT_NAME} + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" + ) + mlir_python_setup_extension_rpath(reference_lazy_backend) + + torch_mlir_python_target_compile_options(reference_lazy_backend) + mlir_check_all_link_libraries(reference_lazy_backend) +else() + # To avoid import errors when LTC is disabled (and a bunch of checks + # associated with that), we will generate a dummy placeholder library. + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/gen_dummy_lib.py ${LIBRARY_OUTPUT_PATH} ${OUTPUT_NAME} + ) +endif() diff --git a/python/torch_mlir_e2e_test/torchscript/__init__.py b/python/torch_mlir/csrc/reference_lazy_backend/__init__.py similarity index 100% rename from python/torch_mlir_e2e_test/torchscript/__init__.py rename to python/torch_mlir/csrc/reference_lazy_backend/__init__.py diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp new file mode 100644 index 000000000000..b6e37e74b43d --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -0,0 +1,191 @@ +//===- backend_impl.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "backend_impl.h" + +using namespace torch::lazy; + +namespace torch { +namespace lazy { + +struct ReferenceLazyBackendDeviceType : public BackendDeviceType { + ReferenceLazyBackendDeviceType(c10::DeviceType device_type) + : device_type_(device_type) {} + ReferenceLazyBackendDeviceType(int8_t device_type) + : device_type_(static_cast(device_type)) {} + + std::string toString() const override { + return c10::DeviceTypeName(device_type_); + } + + c10::DeviceType device_type_; +}; + +class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { +public: + ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {} + + /** + * Configuration + * */ + void SetRngSeed(size_t seed) const override { + std::cout << "RNG Seed Set to: " << seed << std::endl; + } + + /** + * Lowering, Compilation, Execution + * */ + std::vector GetCompilationDevices( + const std::string& device, + c10::ArrayRef devices) const override { + return std::vector(devices.begin(), devices.end()); + }; + + std::vector + Compile(std::vector instances) const override { + PRINT_FUNCTION(); + + // Vendor backend specific lowering can be exec here before returning. + for (const auto& instance : instances) { + // Store computation instance for external access after compilation. + GetLatestComputation() = instance; + } + + std::cout << "Received " << instances.size() + << " computation instances at Compile!" << std::endl; + + return instances; + } + + std::vector ExecuteComputation( + torch::lazy::ComputationPtr computation, + c10::ArrayRef arguments, + const BackendDevice& device) const override { + PRINT_FUNCTION(); + + // `arguments` maps 1:1 with the parameters in the generated MLIR. In this + // function, we will generate a list of BackendData that corresponds to the + // return values in the MLIR. + + auto mlir_computation = + static_cast(computation.get()); + + int num_inputs = 0; + + // Vendor backend specific execution can be inserted here. + // + // We don't have a way to execute a computation based on the generated MLIR, + // so we'll fallback to the implementation used by the TS LTC backend. + // + // JIT Execution adopted from: + // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp + torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), ""); + std::vector stack; + for (const auto& argument : arguments) { + const auto mlir_data = + std::static_pointer_cast(argument); + if (mlir_data->mlir_info()->scalar.has_value()) { + stack.emplace_back(mlir_data->mlir_info()->scalar.value()); + } else { + at::Tensor tensor = mlir_data->mlir_info()->tensor; + stack.emplace_back(tensor); + } + + // count number of inputs + auto name = mlir_data->mlir_info()->name; + if (startswith(name, "input_")) { + // Printing tensor name for testing purposes + std::cout << "Input tensor: " << name << std::endl; + ++num_inputs; + } + } + // Printing number of input tensors for testing purposes + std::cout << num_inputs << " input tensors found" << std::endl; + graph_executor.run(stack); + std::vector results; + for (torch::jit::IValue component : stack) { + at::Tensor result = component.toTensor(); + at::IntArrayRef result_sizes = result.sizes(); + torch::lazy::Shape shape( + result.scalar_type(), + std::vector(result_sizes.begin(), result_sizes.end())); + results.push_back( + std::make_shared(result, device, shape)); + } + + std::cout << "Received " << arguments.size() << " arguments, and returned " + << results.size() << " results during ExecuteCompile!" + << std::endl; + + return results; + } + + /** + * Device Configuration + * */ + std::shared_ptr + GetDefaultDeviceType() const override { + return std::make_shared(default_device_type_); + } + + void SetDefaultDeviceType(int8_t device_type) override { + default_device_type_ = ReferenceLazyBackendDeviceType(device_type); + } + + /** + * Debug/Metrics + * */ + std::string + GetComputationBackendText(const ComputationPtr computation) const override { + // Store computation instance for external access after compilation. + // We do this in GetComputationBackendText since there may be instances + // where a ComputationPtr does not pass through Compile (e.g. when using + // DumpUtil::ToBackend.) + GetLatestComputation() = computation; + + return computation->to_string(); + } + +private: + ReferenceLazyBackendDeviceType default_device_type_; +}; + +BackendImplInterface* GetReferenceLazyBackendImpl() { + static ReferenceLazyBackendImpl* reference_lazy_backend_impl = + new ReferenceLazyBackendImpl(); + return reference_lazy_backend_impl; +} + +void InitReferenceLazyBackend() { + at::RegisterTorchMlirLazyNativeFunctions(); + static std::unique_ptr g_registrar; + g_registrar.reset(new BackendRegistrar(GetReferenceLazyBackendImpl())); +} + +ComputationPtr& GetLatestComputation() { + // Store the computation from the most recent compile. + static ComputationPtr computation; + return computation; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.h b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.h new file mode 100644 index 000000000000..6366fe9fd90f --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.h @@ -0,0 +1,29 @@ +//===- backend_impl.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace at { +// This function is defined in the codegenerated RegisterLazy.cpp file. +TORCH_API void RegisterTorchMlirLazyNativeFunctions(); +} // namespace at + +namespace torch { +namespace lazy { + +torch::lazy::BackendImplInterface* GetReferenceLazyBackendImpl(); + +void InitReferenceLazyBackend(); + +ComputationPtr& GetLatestComputation(); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py b/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py new file mode 100755 index 000000000000..34c9e61907b6 --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py @@ -0,0 +1,23 @@ +# When LTC is disabled in Torch-MLIR build, we will generate a dummy module to +# ensure that no import errors occur. + +import sys +import os + +if __name__ == '__main__': + path = sys.argv[1] # dummy script path + file_name = sys.argv[2] # dummy script + + contents = ''' +# This file was automatically generated due to LTC being disabled in build. + +class LazyTensorCoreTestConfig: + def __init__(self): + assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`" + ''' + + if not os.path.exists(path): + os.makedirs(path) + + with open(os.path.join(path, file_name + '.py'), 'w') as file: + file.write(contents) diff --git a/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp new file mode 100644 index 000000000000..b2ff81c67a22 --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -0,0 +1,95 @@ +//===- reference_lazy_backend_pybind.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch/csrc/jit/python/pybind.h" +#include "torch/csrc/lazy/backend/backend_interface.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "backend_impl.h" + +namespace py = pybind11; + +namespace { +bool verbose = sys_util::GetEnv("VERBOSE", false); + +struct NoGilSection { + NoGilSection() : state(PyEval_SaveThread()) {} + ~NoGilSection() { PyEval_RestoreThread(state); } + PyThreadState* state = nullptr; +}; + +/** + * @brief Install the plugin + */ +void Initialize() { + // Initialize the Reference Lazy Backend + torch::lazy::InitReferenceLazyBackend(); + + // sanity check + const torch::lazy::BackendImplInterface* mlir_backend = + torch::lazy::GetReferenceLazyBackendImpl(); + const torch::lazy::BackendImplInterface* lazy_backend = + torch::lazy::getBackend(); + if (lazy_backend != mlir_backend) { + std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl; + throw std::runtime_error("Failed to initialize MLIR Lazy Backend"); + } + + if (verbose) { + std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; + } +} + +/** + * @brief Uninstall the plugin + */ +void Shutdown() { + if (verbose) { + std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl; + } +} +} // anonymous namespace + +PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) { + py::class_(m, "TorchMlirComputation") + .def("to_string", &torch::lazy::TorchMlirComputation::to_string) + .def("debug_string", &torch::lazy::TorchMlirComputation::debug_string); + + m.doc() = ("pybind11 for the Reference Lazy backend."); + m.def("get_latest_computation", []() { + auto computation = static_cast( + torch::lazy::GetLatestComputation().get()); + return py::cast(computation); + }); + m.def("set_parameter_name", + [](const at::Tensor& tensor, const std::string& name) -> bool { + torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor); + if (ir_node) { + ir_node->SetName(name); + return true; + } + return false; + }); + m.def("_initialize", []() { + NoGilSection gil; + Initialize(); + }); + m.def("_shutdown", []() { + NoGilSection gil; + Shutdown(); + }); +} diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt index b0076762687a..72fde4f56b8b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/CMakeLists.txt @@ -2,7 +2,7 @@ # Setup PyTorch #------------------------------------------------------------------------------- -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") include(TorchMLIRPyTorch) TorchMLIRProbeForPyTorchInstall() diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index aa0031324143..3f547720a40a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -18,7 +18,6 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder from torch_mlir.passmanager import PassManager -import torch_mlir.all_passes_registration from .registry import Registry @@ -82,6 +81,34 @@ def __repr__(self): else: return f"TensorOfShape({args_str}, dtype={self.dtype})" +def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): + assert len(weight) == 2 + assert len(indices) == 1 + assert len(offsets) == 1 + output_bag_shape: List[int] = [] + out_dim0 = offsets[0] + if (include_last_offset): + out_dim0 = out_dim0 - 1 + out_dim1 = weight[1] + output_bag_shape.append(out_dim0) + output_bag_shape.append(out_dim1) + + offset2bag_shape: List[int] = [] + if mode == 1: + offset2bag_shape.append(0) + else: + offset2bag_shape = upstream_shape_functions._copy(indices) + + bag_size_shape = upstream_shape_functions._copy(offsets) + + max_indices_shape: List[int] = [] + if mode == 2: + max_indices_shape = upstream_shape_functions._copy(output_bag_shape) + else: + max_indices_shape = upstream_shape_functions._copy(offsets) + + return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape + def LongTensorOfShape(*args, **kwargs): """Helper for indicating a TensorOfShape with integer type.""" return TensorOfShape(*args, **kwargs, dtype=torch.long) @@ -301,6 +328,9 @@ def aten〇sigmoid(self: List[int]) -> List[int]: def aten〇hardsigmoid(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇softplus(self: List[int], beta: float = 1, threshold: float = 20) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇square(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -313,6 +343,9 @@ def aten〇silu(self: List[int]) -> List[int]: def aten〇exp(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇expm1(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇sin(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -337,6 +370,9 @@ def aten〇detach(self: List[int]) -> List[int]: def aten〇log2(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇log1p(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇rsqrt(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -391,6 +427,9 @@ def aten〇to〇dtype(self: List[int], dtype: int, non_blocking: bool = False, c def aten〇to〇dtype_layout(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return self +def aten〇to〇device(self: List[int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇to〇other(self: List[int], other: List[int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -442,6 +481,9 @@ def aten〇mul〇Scalar(self: List[int], other: float) -> List[int]: def aten〇div〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇remainder〇Scalar(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -481,12 +523,18 @@ def aten〇mean(self: List[int], dtype: Optional[int] = None) -> List[int]: def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: return [] -def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: - return upstream_shape_functions.mean_dim(self, dim, keepdim, None) +def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + +def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: return [] +def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self)) out: List[int] = [] @@ -519,14 +567,11 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ reduced_shape = _reduce_along_dim(self, dim, keepdim) return reduced_shape, reduced_shape -def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) +def aten〇mean〇dim(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None: - return upstream_shape_functions.mean_dim(self, [], keepdim, dtype) - else: - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇permute(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -590,6 +635,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]: out.append(self[i] * repeats[i + leading_rank]) return out +def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_functions.expand(self, size) @@ -722,6 +770,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇masked_fill〇Tensor(self: List[int], mask: List[int], value: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇zero(self: List[int]) -> List[int]: return self @@ -755,7 +806,7 @@ def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]: def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self -def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: +def aten〇arange〇start_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) def aten〇arange〇start(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: @@ -788,6 +839,9 @@ def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: O def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇atan2(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -884,12 +938,18 @@ def aten〇topk(self: List[int], k: int, dim: int = -1, largest: bool = True, so def aten〇conv2d(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv_transpose2d〇input(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> List[int]: + return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) + def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: - return upstream_shape_functions.conv_output_size(input, weight, bias, stride, padding, dilation, groups) + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) - + +def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: + return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + def aten〇flip(self: List[int], dims: List[int]) -> List[int]: return self @@ -905,6 +965,9 @@ def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio def aten〇slice〇Tensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) +def aten〇narrow(self: List[int], dim: int, start: int, length: int) -> List[int]: + return upstream_shape_functions.slice(self, dim, start, start + length, 1) + def aten〇slice_scatter(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return self @@ -926,6 +989,12 @@ def aten〇index_put〇hacked_twin(self: List[int], indices: List[List[int]], va def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) +def aten〇embedding_bag〇padding_idx(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: + return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode) + +def aten〇_embedding_bag(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights: Optional[List[int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[List[int], List[int], List[int], List[int]]: + return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode) + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -1008,6 +1077,7 @@ def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: O Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. @@ -1029,15 +1099,13 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) - if len(unused_dim_sizes) == 0: return broadcasted_shape - prev_index_tensor_location = -1 first_index_tensor_location = -1 index_tensors_are_together = True for e, index_tensor_shape in enumerate(indices): if index_tensor_shape is not None: if first_index_tensor_location == -1: first_index_tensor_location = e - prev_index_tensor_location = e - elif e - prev_index_tensor_location != 1: + elif e - first_index_tensor_location != 1: index_tensors_are_together = False if not index_tensors_are_together: @@ -1100,13 +1168,7 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen return [hacky_get_unknown_dimension_size()] def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None: - dim = list(range(len(self))) - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) - -# TODO: Re-enable after MacOS support is fixed for the extension. -#def _torch_mlir_custom_op_example〇identity(t: List[int]) -> List[int]: -# return upstream_shape_functions.unary(t) + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) # ============================================================================== # Shape library generator main(). diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 86cdf5fd9c0e..e2ff4146bc16 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -250,24 +250,20 @@ def emit_with_mutating_variants(key, **kwargs): "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", + "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", - "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", - "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", @@ -277,10 +273,13 @@ def emit_with_mutating_variants(key, **kwargs): "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", + "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", + "aten::sqrt : (Tensor) -> (Tensor)", + "aten::log1p : (Tensor) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", @@ -293,7 +292,14 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants(key) # Elementwise tensor compute ops that don't have the standard mutating # variants. - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") @@ -303,6 +309,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") + emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") # Ops without value semantics but the corresponding without trailing # underscore variant doesn't exist. @@ -327,9 +334,14 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") + emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") + emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"), emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" @@ -372,14 +384,15 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") - emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)") + emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") - emit("aten::sqrt : (Tensor) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") + emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") - emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)") + emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") + emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") @@ -419,6 +432,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") + emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") + emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)") emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") @@ -438,7 +453,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") - emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::sum : (Tensor, int?) -> (Tensor)") @@ -449,6 +463,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") + emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") @@ -457,7 +472,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") - emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") @@ -474,6 +488,30 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") + # Functionalization ops + emit("aten::alias_copy : (Tensor) -> (Tensor)") + emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") + emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)") + emit("aten::_reshape_alias_copy : (Tensor, int[], int[]) -> (Tensor)") + emit("aten::select_copy.int : (Tensor, int, int) -> (Tensor)") + emit("aten::detach_copy : (Tensor) -> (Tensor)") + emit("aten::slice_copy.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") + emit("aten::squeeze_copy : (Tensor) -> (Tensor)") + emit("aten::squeeze_copy.dim : (Tensor, int) -> (Tensor)") + emit("aten::t_copy : (Tensor) -> (Tensor)") + emit("aten::transpose_copy.int : (Tensor, int, int) -> (Tensor)") + emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)") + emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") + emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") + emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") + emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") + emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") + emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") + + # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) emit("aten::__contains__.int_list : (int[], int) -> (bool)", has_folder=True) @@ -486,7 +524,7 @@ def emit_with_mutating_variants(key, **kwargs): # List ops. emit("aten::cat : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") - emit("aten::add.t : (t[], t[]) -> (t[])") + emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) emit("aten::list.t : (t[]) -> (t[])") emit("aten::slice.t : (t[], int?, int?, int) -> (t[])") @@ -519,6 +557,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) + emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) @@ -555,6 +594,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) + emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)") # backprop ops @@ -598,16 +638,6 @@ def emit_with_mutating_variants(key, **kwargs): "quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)", traits=["HasValueSemantics"]) - # ========================================================================== - # `_torch_mlir_custom_op_example::` namespace. - # - # This is a demonstration of supporting an operation defined in a PyTorch - # extension. - # ========================================================================== - - # TODO: Re-enable after MacOS support is fixed for the extension. - #emit("_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)") - def dump_registered_ops(outfile: TextIO, registry: Registry): for _, v in sorted(registry.by_unique_key.items()): diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt index 5545a32bdd7b..287e9a20c87b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt @@ -10,36 +10,60 @@ include_directories(BEFORE ) link_directories("${TORCH_INSTALL_PREFIX}/lib") -add_library(TorchMLIRJITIRImporter MODULE +# Static library with core functionality. +# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) +# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 +add_library(TorchMLIRJITIRImporter STATIC class_annotator.cpp - class_annotator_pybind.cpp - get_registered_ops.cpp function_importer.cpp - module_builder.cpp node_importer.cpp ivalue_importer.cpp - init_python_bindings.cpp torch_to_mlir_utils.cpp ) - target_link_libraries(TorchMLIRJITIRImporter TorchMLIRAggregateCAPI + ${TORCH_LIBRARIES} + ) +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}") +set_target_properties(TorchMLIRJITIRImporter PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + OUTPUT_NAME lib_jit_ir_importer + PREFIX "" + SUFFIX ".a" + CXX_VISIBILITY_PRESET "default" + COMPILE_FLAGS "${TORCH_CXXFLAGS}" + ) + +# Separate Pybind MODULE due to issues with a SHARED library. +# https://github.com/llvm/torch-mlir/issues/1154 +add_library(TorchMLIRJITIRImporterPybind MODULE + class_annotator_pybind.cpp + get_registered_ops.cpp + import_options_pybind.cpp + init_python_bindings.cpp + module_builder.cpp + ) +add_dependencies(TorchMLIRJITIRImporterPybind + TorchMLIRJITIRImporter + ) +target_link_libraries(TorchMLIRJITIRImporterPybind ${TORCH_LIBRARIES} torch_python -) + TorchMLIRJITIRImporter + ) # On static Python builds, there may not be Python libraries to link against # (they will late bind at runtime from the executable). We have to condition # this because in that case it is set to NOTFOUND and CMake will consider # this an error. if(Python3_LIBRARIES) - target_link_libraries(TorchMLIRJITIRImporter + target_link_libraries(TorchMLIRJITIRImporterPybind ${Python3_LIBRARIES} ) endif() message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}") -set_target_properties(TorchMLIRJITIRImporter PROPERTIES +set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" OUTPUT_NAME _jit_ir_importer PREFIX "${PYTHON_MODULE_PREFIX}" @@ -47,7 +71,7 @@ set_target_properties(TorchMLIRJITIRImporter PROPERTIES CXX_VISIBILITY_PRESET "hidden" COMPILE_FLAGS "${TORCH_CXXFLAGS}" ) -mlir_python_setup_extension_rpath(TorchMLIRJITIRImporter) +mlir_python_setup_extension_rpath(TorchMLIRJITIRImporterPybind) -torch_mlir_python_target_compile_options(TorchMLIRJITIRImporter) -mlir_check_all_link_libraries(TorchMLIRJITIRImporter) +torch_mlir_python_target_compile_options(TorchMLIRJITIRImporterPybind) +mlir_check_all_link_libraries(TorchMLIRJITIRImporterPybind) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp index c02f014c774b..c499448e9bbd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp @@ -8,9 +8,13 @@ //===----------------------------------------------------------------------===// #include "class_annotator.h" - +#include "torch_to_mlir_utils.h" #include +#if TORCH_VERSION_LT(1, 8) +#include "ATen/core/function.h" +#endif + using namespace torch_mlir; //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index 0ec368903ac9..b9eb261965ba 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -22,12 +22,13 @@ using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, - std::function getArgAttribute) { + std::function getArgAttribute, + const ImportOptions &importOptions) { // Useful for debugging: // graph->dump(); MlirLocation loc = mlirLocationUnknownGet(context); MlirType functionType = - getFunctionTypeFromSchema(context, function->getSchema()); + getFunctionTypeFromSchema(context, function->getSchema(), importOptions); // Use the function's qualified name from the compilation unit. // This is a stable linkage name that matches Python module lookup // conventions (see compilation unit import in IValueImporter for more details @@ -68,8 +69,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( /*userAllowsRefinement=*/false)); }; MlirBlock block = importBlock( - context, torch::jit::toGraphFunction(*function).graph()->block(), - createTerminator, inputTypes); + context, getGraphFromFunction(function)->block(), + createTerminator, inputTypes, importOptions); mlirRegionAppendOwnedBlock(bodyRegion, block); return func; } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h index f9f5b10f7357..9b8ae065f4af 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h @@ -12,6 +12,7 @@ #include +#include "import_options.h" #include "node_importer.h" #include "mlir-c/IR.h" @@ -38,10 +39,11 @@ namespace torch_mlir { /// will be attached as an argument attribute to the func op's argument. If a /// null MlirAttribute is returned, no attribute will be attached to that /// argument. -MlirOperation importJitFunctionAsFuncOp( +TORCH_API MlirOperation importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, std::function getArgAttribute = - [](int) -> MlirAttribute { return {nullptr}; }); + [](int) -> MlirAttribute { return {nullptr}; }, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h new file mode 100644 index 000000000000..b620c784ac24 --- /dev/null +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h @@ -0,0 +1,39 @@ +//===- import_options.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H +#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H + +namespace torch_mlir { +// Common import options across importers. We define this as a struct to avoid +// an unstructured proliferation of different kinds of ways to control different +// parts of the import process. +struct ImportOptions { + // If this is set to true, then all tensors in the program can be assumed to + // have value semantics. This can happen, for example, when coming from + // LazyTensorCore since conversion to value semantics has already happened at + // a higher level there before we see the program. For + // calling-convention-impacting decisions, this flag should be interpreted as + // a requirement to use a value-semantic tensor type (!torch.vtensor) in + // signatures. + bool assumeTensorsHaveValueSemantics = false; + + // If this is set to true, then the shape and dtype information in the + // JIT IR graph should be ignored. This can be useful when importing from + // torch.jit.trace'd graphs, since those will have shapes burned into them. + // In certain scenarios, users know that their trace will be correct for + // a variety of shapes, and this option allows them to use such traced graphs. + // + // In that case, the appropriate shape information is provided via the type + // bound annotations on the function arguments instead. + bool ignoreExistingTensorShapesAndDtypes = false; +}; +} // namespace torch_mlir + +#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp new file mode 100644 index 000000000000..b072b0ed922c --- /dev/null +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp @@ -0,0 +1,24 @@ +//===- import_options_pybind.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "import_options_pybind.h" +#include "import_options.h" + +namespace py = pybind11; + +using namespace torch_mlir; + +void torch_mlir::initImportOptionsBindings(py::module &m) { + py::class_(m, "ImportOptions") + .def(py::init<>()) + .def_readwrite("assumeTensorsHaveValueSemantics", + &ImportOptions::assumeTensorsHaveValueSemantics) + .def_readwrite("ignoreExistingTensorShapesAndDtypes", + &ImportOptions::ignoreExistingTensorShapesAndDtypes); +} diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h new file mode 100644 index 000000000000..6e8e1389ca3a --- /dev/null +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.h @@ -0,0 +1,19 @@ +//===- import_options_pybind.h ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H +#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H + +#include + +namespace torch_mlir { +void initImportOptionsBindings(pybind11::module &m); +} // namespace torch_mlir + +#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp index 7506e4c9d6ed..e6d01ca9abcf 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/init_python_bindings.cpp @@ -13,6 +13,7 @@ #include "class_annotator_pybind.h" #include "get_registered_ops.h" +#include "import_options_pybind.h" #include "module_builder.h" using namespace torch_mlir; @@ -21,4 +22,5 @@ PYBIND11_MODULE(_jit_ir_importer, m) { ModuleBuilder::bind(m); initClassAnnotatorBindings(m); initGetRegisteredOpsBindings(m); + initImportOptionsBindings(m); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp index cea471c3d68d..e9529f9ee92e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp @@ -21,7 +21,11 @@ #include "mlir-c/Diagnostics.h" #include "torch-mlir-c/TorchTypes.h" -#include "ATen/native/quantized/PackedParams.h" +#if TORCH_VERSION_LT(1, 12) +// do nothing +#else +#include "ATen/native/quantized/packed_params.h" +#endif #include "caffe2/core/scope_guard.h" using namespace torch_mlir; @@ -49,14 +53,50 @@ using namespace torch_mlir; // which is compatible with the semantics we want (for the subset it doesn't // throw an error on). namespace { +#if TORCH_VERSION_LT(1, 7) +#include "torch/csrc/utils/hash.h" +#endif + +#if TORCH_VERSION_LT(1, 8) +inline size_t IValueHash(const c10::IValue &v) { + using namespace torch; + using namespace c10; + if (v.isNone()) { + return 0; + } else if (v.isInt()) { + return get_hash(v.toInt()); + } else if (v.isBool()) { + return get_hash(v.toBool()); + } else if (v.isDouble()) { + return get_hash(v.toDouble()); + } else if (v.isTensor()) { + // Tensor __hash__ is equivalent to `id()`, so take the pointer value of + // the tensor to emulate it + return get_hash(v.toTensor().unsafeGetTensorImpl()); + } else if (v.isString()) { + return get_hash(v.toStringRef()); + } else if (v.isTuple()) { + return get_hash(v.toTuple()); + } else if (v.isDevice()) { + return get_hash(v.toDevice()); + } else { + return std::hash()( + static_cast(v.internalToPointer())); + } +} +#else +inline size_t IValueHash(const c10::IValue &v) { + return c10::IValue::hash(v); +} +#endif + struct IValueHasher { size_t operator()(const c10::IValue &ivalue) const { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { return std::hash()( static_cast(ivalue.internalToPointer())); } - - return c10::IValue::hash(ivalue); + return IValueHash(ivalue); } }; } // namespace @@ -101,8 +141,9 @@ namespace { class IValueImporter { public: IValueImporter(MlirBlock importBlock, MlirContext context, - ClassAnnotator &annotator) - : importBlock(importBlock), context(context), annotator(annotator) {} + ClassAnnotator &annotator, const ImportOptions &importOptions) + : importBlock(importBlock), context(context), annotator(annotator), + importOptions(importOptions) {} MlirValue importIValue(c10::IValue ivalue); @@ -118,6 +159,7 @@ class IValueImporter { MlirBlock importBlock; MlirContext context; ClassAnnotator &annotator; + const ImportOptions &importOptions; // Map tracking already-imported values. std::unordered_map valueMap; @@ -331,6 +373,9 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { torchMlirTorchNoneTypeGet(context)); return mlirOperationGetResult(operation, 0); } +#if TORCH_VERSION_LT(1, 12) + // do nothing +#else if (ivalue.isCustomClass()) { if (ivalue.type().get() == c10::getCustomClassType>() @@ -351,6 +396,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { return mlirOperationGetResult(operation, 0); } } +#endif std::stringstream msg; msg << "Unsupported ivalue: " << ivalue; throw std::invalid_argument(msg.str()); @@ -503,7 +549,8 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { MethodAnnotation *annotation = annotator.getMethodAnnotationForFunction(function); MlirOperation func = importJitFunctionAsFuncOp( - context, function, [&](int argIndex) -> MlirAttribute { + context, function, + [&](int argIndex) -> MlirAttribute { if (!annotation || !annotation->argAnnotations.has_value()) { return {nullptr}; } @@ -541,7 +588,8 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute( "torch.type_bound", mlirTypeAttrGet(typeBound)); return mlirDictionaryAttrGet(context, 1, &typeBoundAttr); - }); + }, + importOptions); // For IValue importing, the logical linkage structure of the module // is determined by the object graph. // @@ -560,10 +608,12 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { } MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, - MlirContext context, ClassAnnotator &annotator) { + MlirContext context, + ClassAnnotator &annotator, + const ImportOptions &importOptions) { // When debugging module importing, it can be useful to dump as so: // if (ivalue.isModule()) // ivalue.toModule().dump(true, false, false); - IValueImporter importer(block, context, annotator); + IValueImporter importer(block, context, annotator, importOptions); return importer.importIValue(ivalue); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h index 36e07cc2c0f9..7cbc7ece8488 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.h @@ -13,6 +13,7 @@ #include #include "class_annotator.h" +#include "import_options.h" #include "mlir-c/IR.h" @@ -25,7 +26,8 @@ namespace torch_mlir { /// Main entry-point for importing torch IValue's . /// Recursively imports `ivalue`, inserting operations at the end of `block`. MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, - ClassAnnotator &annotator); + ClassAnnotator &annotator, + const ImportOptions &importOptions); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp index 70ac687b978b..da85a9b24828 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp @@ -17,7 +17,6 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "mlir-c/Registration.h" #include "torch-mlir-c/Registration.h" namespace py = pybind11; @@ -134,12 +133,18 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { } void ModuleBuilder::importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator) { + py::object maybeClassAnnotator, + py::object maybeImportOptions) { ClassAnnotator dummyAnnotator; ClassAnnotator *classAnnotator = &dummyAnnotator; if (!maybeClassAnnotator.is_none()) { classAnnotator = py::cast(maybeClassAnnotator); } + ImportOptions importOptions; + if (!maybeImportOptions.is_none()) { + importOptions = py::cast(maybeImportOptions); + } + // Set a debugging name for the MLIR Module based on the jitModule's class // name. // This is a bit hacky, because we are mutating the enclosing ModuleOp @@ -164,7 +169,7 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule, toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr); importIValue(jitModule._ivalue(), mlirModuleGetBody(module), - mlirModuleGetContext(module), *classAnnotator); + mlirModuleGetContext(module), *classAnnotator, importOptions); } MlirBlock ModuleBuilder::getBodyBlock() { @@ -179,5 +184,6 @@ void ModuleBuilder::bind(py::module &m) { .def_property_readonly("module", &ModuleBuilder::getModuleObj) .def("import_function", &ModuleBuilder::importFunction) .def("import_module", &ModuleBuilder::importModule, py::arg("module"), - py::arg("classAnnotator") = py::none()); + py::arg("classAnnotator") = py::none(), + py::arg("importOptions") = py::none()); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h index 95c85536da88..6e1e0beead94 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h @@ -45,7 +45,8 @@ class ModuleBuilder { // annotations, if not none, provided in `maybeClassAnnotator` which should be // a ClassAnnotator. void importModule(torch::jit::Module jitModule, - py::object maybeClassAnnotator); + py::object maybeClassAnnotator, + py::object maybeImportOptions); private: MlirBlock getBodyBlock(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index b7f563851d43..81b3e17413b6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -33,15 +33,18 @@ class NodeImporter { public: NodeImporter(MlirContext context) : context(context) {} - void importNode(Node *node, MlirBlock appendToBlock); + void importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions = {}); MlirBlock importBlock( Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt); + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); private: MlirBlock createBlockFor(Block *jitBlock, - c10::optional> blockArgTypes); + c10::optional> blockArgTypes, + const ImportOptions &importOptions = {}); void mapValue(Value *jitValue, MlirValue value); void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value *jitValue); @@ -76,39 +79,39 @@ rearrangeDictConstructInputs(std::vector &inputs) { return rearranged; } -void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { +void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions) { MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); - MlirOperation operation = - createMlirOperationAtEnd(appendToBlock, opName, loc, - getMlirTypesFromValues(loc, node->outputs()), - t ? t(mappedInputs) : mappedInputs); + MlirOperation operation = createMlirOperationAtEnd( + appendToBlock, opName, loc, + getMlirTypesFromValues(loc, node->outputs(), importOptions), + t ? t(mappedInputs) : mappedInputs); mapResults(node, operation); }; - auto createAndMapNodeWithAttribute = [&](Node *node, - const std::string &opName, - const std::string &attrName, - MlirAttribute attr) { - MlirOperation operation = - createMlirOperationAtEnd(appendToBlock, opName, loc, - getMlirTypesFromValues(loc, node->outputs()), - lookupMappedValues(node->inputs()), - toMlirNamedAttribute(attrName.c_str(), attr)); - mapResults(node, operation); - }; + auto createAndMapNodeWithAttribute = + [&](Node *node, const std::string &opName, const std::string &attrName, + MlirAttribute attr) { + MlirOperation operation = createMlirOperationAtEnd( + appendToBlock, opName, loc, + getMlirTypesFromValues(loc, node->outputs(), importOptions), + lookupMappedValues(node->inputs()), + toMlirNamedAttribute(attrName.c_str(), attr)); + mapResults(node, operation); + }; // Trivial ops with schema. auto maybeSchema = node->maybeSchema(); if (maybeSchema) { - MlirOperation operation = - createOperationFromSchema(appendToBlock, loc, node->schema(), - getMlirTypesFromValues(loc, node->outputs()), - lookupMappedValues(node->inputs())); + MlirOperation operation = createOperationFromSchema( + appendToBlock, loc, node->schema(), + getMlirTypesFromValues(loc, node->outputs(), importOptions), + lookupMappedValues(node->inputs())); mapResults(node, operation); return; } @@ -130,7 +133,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { auto containedTypes = c10::fmap( node->output()->type()->cast()->containedTypes(), [&](const c10::TypePtr &t) { - MlirType type = getMlirTypeFromTorchType(loc, t); + MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); if (mlirTypeIsNull(type)) { throw mlir_diagnostic_emitted(); } @@ -178,13 +181,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.int", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.float", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { @@ -202,7 +205,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.device", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute( "value", mlirStringAttrGet(context, toMlirStringRef(node->s( c10::attr::value))))); @@ -211,16 +214,16 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { const std::string &symName = function->qualname().qualifiedName(); op = createMlirOperation( "func.constant", loc, - getFunctionTypeFromSchema(context, function->getSchema()), + getFunctionTypeFromSchema(context, function->getSchema(), + importOptions), toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); } else if (output->type()->cast()) { ClassAnnotator dummyAnnotator; - MlirValue listValue = importIValue(node->ival(c10::attr::value), - appendToBlock, - context, - dummyAnnotator); + MlirValue listValue = + importIValue(node->ival(c10::attr::value), appendToBlock, context, + dummyAnnotator, importOptions); mapResults(node, mlirOpResultGetOwner(listValue)); return; // Early return, since `importIValue` already added op to block. } else { @@ -237,7 +240,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { if (kind == c10::prim::Loop) { std::vector resultTypes = - getMlirTypesFromValues(loc, node->outputs()); + getMlirTypesFromValues(loc, node->outputs(), importOptions); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop", loc, resultTypes, lookupMappedValues(node->inputs().slice(0, 2)), @@ -260,13 +263,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator)); + importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); return; } if (kind == c10::prim::If) { std::vector resultTypes = - getMlirTypesFromValues(loc, node->outputs()); + getMlirTypesFromValues(loc, node->outputs(), importOptions); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()), resultTypes, mlirRegionCreate(), mlirRegionCreate()); @@ -281,10 +284,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator)); + importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 1), - importBlock(node->blocks()[1], createTerminator)); + importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions)); return; } @@ -293,14 +296,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { auto methodName = node->s(c10::attr::name); torch::jit::Function *function = classType->findMethod(methodName); MlirType calleeType = - getFunctionTypeFromSchema(context, function->getSchema()); + getFunctionTypeFromSchema(context, function->getSchema(), importOptions); std::vector expectedTypes; for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); } MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.CallMethod", loc, - getMlirTypesFromValues(loc, node->outputs()), + getMlirTypesFromValues(loc, node->outputs(), importOptions), adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs()), expectedTypes, /*userAllowsRefinement=*/false), @@ -313,13 +316,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { if (kind == c10::prim::CallFunction) { auto functionType = node->input(0)->type()->cast(); torch::jit::Block *calleeEntryBlock = - torch::jit::toGraphFunction(*functionType->function()).graph()->block(); + getGraphFromFunction(functionType->function())->block(); auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { - return getMlirTypeFromTorchType(loc, v->type()); + return getMlirTypeFromTorchType(loc, v->type(), importOptions); }); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "func.call_indirect", loc, - getMlirTypesFromValues(loc, node->outputs()), + getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValue(node->input(0)), adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)), @@ -339,10 +342,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { MlirBlock NodeImporter::importBlock( Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes) { - MlirBlock block = createBlockFor(jitBlock, blockArgTypes); + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { + MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); for (Node *node : jitBlock->nodes()) { - importNode(node, block); + importNode(node, block, importOptions); } Node *returnNode = jitBlock->return_node(); createTerminator(lookupMappedValues(returnNode->inputs()), block); @@ -350,11 +354,12 @@ MlirBlock NodeImporter::importBlock( } MlirBlock NodeImporter::createBlockFor( - Block *jitBlock, c10::optional> blockArgTypes) { + Block *jitBlock, c10::optional> blockArgTypes, + const ImportOptions &importOptions) { Node *paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); std::vector paramNodeTypes = - getMlirTypesFromValues(loc, paramNode->outputs()); + getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); if (!blockArgTypes) blockArgTypes = paramNodeTypes; else @@ -405,7 +410,8 @@ NodeImporter::lookupMappedValues(c10::ArrayRef values) { MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes) { + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { NodeImporter importer(context); - return importer.importBlock(jitBlock, createTerminator, blockArgTypes); + return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h index b05a37cfd9b7..dd01444f415a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h @@ -10,6 +10,8 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H #define TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H +#include "import_options.h" + #include #include "mlir-c/IR.h" @@ -37,7 +39,8 @@ using CreateTerminatorFn = MlirBlock importBlock( MlirContext context, torch::jit::Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt); + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 289b95c45c70..4e65a7a9ae22 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -22,6 +22,19 @@ #include "torch-mlir-c/TorchOps.h" #include "torch-mlir-c/TorchTypes.h" +#if TORCH_VERSION_LT(1, 8) +#include "torch/custom_class.h" +#endif + +std::shared_ptr +torch_mlir::getGraphFromFunction(torch::jit::Function *function) { +#if TORCH_VERSION_LT(1, 11) + return function->graph(); +#else + return toGraphFunction(*function).graph(); +#endif +} + using namespace torch_mlir; static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, @@ -117,14 +130,27 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, throw mlir_diagnostic_emitted(); } -MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType) { +MlirType +torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, + const c10::TypePtr &torchType, + const ImportOptions &importOptions) { MlirContext context = mlirLocationGetContext(loc); using c10::TypeKind; auto kind = torchType->kind(); switch (kind) { case TypeKind::TensorType: { auto tensorType = torchType->cast(); + auto getMlirTensorType = importOptions.assumeTensorsHaveValueSemantics + ? torchMlirTorchValueTensorTypeGet + : torchMlirTorchNonValueTensorTypeGet; + + if (importOptions.ignoreExistingTensorShapesAndDtypes) { + return getMlirTensorType(context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/{nullptr}); + } + // Element type. MlirType elementType = {nullptr}; if (tensorType->scalarType()) { @@ -137,18 +163,30 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, auto &sizes = tensorType->symbolic_sizes(); if (!sizes.rank()) { // Unranked. - return torchMlirTorchNonValueTensorTypeGet(context, - /*numSizes=*/0, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/ - elementType); + return getMlirTensorType(context, + /*numSizes=*/-1, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/ + elementType); } + // Ranked with possibly dynamic dims. auto &symbolicShape = tensorType->symbolic_sizes(); +#if TORCH_VERSION_LT(1, 8) + auto getSymbolicShape = [&](size_t d) { + const auto &dims = symbolicShape.sizes(); + if (!dims) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims).at(d); + }; +#else + auto getSymbolicShape = [&](size_t d) { return symbolicShape[d]; }; +#endif std::vector dims; dims.resize(*sizes.rank()); for (size_t i = 0; i < dims.size(); ++i) { - auto shapeSymbol = symbolicShape[i]; + auto shapeSymbol = getSymbolicShape(i); dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; } @@ -158,10 +196,10 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, // case. So use a dummy data pointer. int64_t dummy; int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data(); - return torchMlirTorchNonValueTensorTypeGet(context, dims.size(), - /*optionalSizes=*/dimsData, - /*optionalDtype=*/ - elementType); + return getMlirTensorType(context, dims.size(), + /*optionalSizes=*/dimsData, + /*optionalDtype=*/ + elementType); } case TypeKind::IntType: { return torchMlirTorchIntTypeGet(context); @@ -180,17 +218,22 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, } case TypeKind::OptionalType: { return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType( - loc, torchType->cast()->getElementType())); + loc, torchType->cast()->getElementType(), + importOptions)); } case TypeKind::TupleType: { std::vector containedTypes; for (const c10::TypePtr &type : torchType->cast()->containedTypes()) { - containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); + containedTypes.push_back( + getMlirTypeFromTorchType(loc, type, importOptions)); } return torchMlirTorchTupleTypeGet(context, containedTypes.size(), containedTypes.data()); } +#if TORCH_VERSION_LT(1, 10) +// do nothing +#else case TypeKind::UnionType: { std::vector containedTypes; for (const c10::TypePtr &type : @@ -200,15 +243,17 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return torchMlirTorchUnionTypeGet(context, containedTypes.size(), containedTypes.data()); } +#endif case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( - loc, torchType->cast()->getElementType())); + loc, torchType->cast()->getElementType(), + importOptions)); } case TypeKind::DictType: { auto dictType = torchType->cast(); return torchMlirTorchDictTypeGet( - getMlirTypeFromTorchType(loc, dictType->getKeyType()), - getMlirTypeFromTorchType(loc, dictType->getValueType())); + getMlirTypeFromTorchType(loc, dictType->getKeyType(), importOptions), + getMlirTypeFromTorchType(loc, dictType->getValueType(), importOptions)); } case TypeKind::NoneType: { return torchMlirTorchNoneTypeGet(context); @@ -243,10 +288,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, MlirType torch_mlir::getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema) { + const c10::FunctionSchema &schema, + const ImportOptions &importOptions) { MlirLocation loc = mlirLocationUnknownGet(context); auto mapType = [&](const c10::TypePtr &torchType) { - MlirType type = getMlirTypeFromTorchType(loc, torchType); + MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); if (mlirTypeIsNull(type)) { std::stringstream msg; msg << "unsupported type in function schema: '" @@ -336,6 +382,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, case ScalarType::BFloat16: return mlirDenseElementsAttrBFloat16Get( shapedType, numElements, static_cast(tensorData)); + case ScalarType::Half: + return mlirDenseElementsAttrFloat16Get( + shapedType, numElements, static_cast(tensorData)); + default: throwUnsupportedTensorError(); } @@ -383,10 +433,11 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, std::vector torch_mlir::getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values) { + c10::ArrayRef values, + const ImportOptions &importOptions) { std::vector ret; for (auto value : values) { - MlirType t = getMlirTypeFromTorchType(loc, value->type()); + MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); if (mlirTypeIsNull(t)) throw mlir_diagnostic_emitted("unsupported type"); ret.push_back(t); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index e3ff4f45d314..b5eace32b862 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -10,6 +10,8 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H #define TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H +#include "import_options.h" + #include #include "mlir-c/IR.h" @@ -20,6 +22,13 @@ namespace torch_mlir { +#define TORCH_VERSION_LT(major, minor) \ + (defined(PYTORCH_MAJOR_VERSION) && defined(PYTORCH_MINOR_VERSION) && \ + (PYTORCH_MAJOR_VERSION == major && PYTORCH_MINOR_VERSION < minor)) + +std::shared_ptr +getGraphFromFunction(torch::jit::Function *function); + /// Thrown on failure when details are in MLIR emitted diagnostics. class mlir_diagnostic_emitted : public std::runtime_error { public: @@ -42,14 +51,16 @@ MlirType getMlirTypeForTorchScalarType(MlirLocation loc, /// Maps a torch type to a corresponding MlirType. Returns a null type /// on failure and emits a diagnostic. MlirType getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType); + const c10::TypePtr &torchType, + const ImportOptions &importOptions = {}); /// Creates a FunctionType suitable for expressing the signature of `schema`. /// /// This can differ from the type inferred from the block of a /// torch::jit::Function due to derefinement and refinement of tensor types. MlirType getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema); + const c10::FunctionSchema &schema, + const ImportOptions &importOptions = {}); /// Creates an appropriate MlirAttribute that holds the same values as `tensor`. MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor, @@ -63,7 +74,8 @@ MlirLocation getMlirLocationFromNode(MlirContext context, std::vector getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values); + c10::ArrayRef values, + const ImportOptions &importOptions = {}); std::vector adjustStaticInformationForValues( MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py b/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py index 66006319527a..d495dda4836f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/torchscript_annotations.py @@ -18,7 +18,7 @@ # to be expressed conveniently and gives clearer error reports when # the annotations aren't acceptable. -# This module is kept separate from torch_mlir_e2e_test.torchscript.annotations so that +# This module is kept separate from torch_mlir_e2e_test.annotations so that # we can use that module from code without C++ dependencies, which prevent us # from interfacing the test framework across environments. diff --git a/python/torch_mlir_e2e_test/torchscript/annotations.py b/python/torch_mlir_e2e_test/annotations.py similarity index 100% rename from python/torch_mlir_e2e_test/torchscript/annotations.py rename to python/torch_mlir_e2e_test/annotations.py diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/configs/__init__.py similarity index 83% rename from python/torch_mlir_e2e_test/torchscript/configs/__init__.py rename to python/torch_mlir_e2e_test/configs/__init__.py index 14c2f48c36cb..a7118c0eff98 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/configs/__init__.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig +from .mhlo_backend import MhloBackendTestConfig from .tosa_backend import TosaBackendTestConfig from .eager_mode import EagerModeTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py b/python/torch_mlir_e2e_test/configs/eager_mode.py similarity index 96% rename from python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py rename to python/torch_mlir_e2e_test/configs/eager_mode.py index 2a81702b88c2..157ef0f36acc 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py +++ b/python/torch_mlir_e2e_test/configs/eager_mode.py @@ -7,7 +7,7 @@ from torch.utils._pytree import tree_map from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem def wrap(e): diff --git a/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py b/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py new file mode 100644 index 000000000000..29842ccfc1bb --- /dev/null +++ b/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend +import torch +from torch.utils._pytree import tree_map + +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem + + +def to_device(device): + """Returns a lambda that maps `torch.Tensor` objects to `device`, and ignores other types""" + return lambda e: e.to(device) if isinstance(e, torch.Tensor) else e + + +class LazyTensorCoreTestConfig(TestConfig): + """TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR""" + + def __init__(self): + super().__init__() + lazy_backend._initialize() + + def compile(self, program: torch.nn.Module) -> torch.nn.Module: + return program.to('lazy') + + def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + result: Trace = [] + + for item in trace: + # We need to move all the inputs to the lazy device before running in LTC. + lazy_inputs = tree_map(to_device('lazy'), item.inputs) + output = getattr(artifact, item.symbol)(*lazy_inputs) + cpu_outputs = tree_map(to_device('cpu'), output) + + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=cpu_outputs)) + + return result diff --git a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py b/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py similarity index 71% rename from python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py rename to python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index adb4e3cca3ed..6ad41dd6dccb 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py +++ b/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -3,23 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import sys from typing import Any -from io import StringIO -import os -import tempfile -import numpy as np import torch +import torch_mlir from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem -from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, - convert_torchscript_module_to_torch_backend_contract_mlir, ) @@ -34,14 +29,9 @@ def __init__(self, backend: LinalgOnTensorsBackend): self.backend = backend def compile(self, program: torch.nn.Module) -> Any: - - module = convert_torchscript_module_to_torch_backend_contract_mlir( - program) - - run_pipeline_with_repro_report( - module, - "torch-backend-to-linalg-on-tensors-backend-pipeline", - "Lower Torch Backend IR -> Linalg-on-Tensors Backend IR") + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="linalg-on-tensors") return self.backend.compile(module) diff --git a/python/torch_mlir_e2e_test/configs/mhlo_backend.py b/python/torch_mlir_e2e_test/configs/mhlo_backend.py new file mode 100644 index 000000000000..0b7b3253499a --- /dev/null +++ b/python/torch_mlir_e2e_test/configs/mhlo_backend.py @@ -0,0 +1,52 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Any + +import torch +import torch_mlir + +from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend +from torch_mlir_e2e_test.framework import ( + TestConfig, + Trace, + TraceItem +) +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, +) + + +class MhloBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with linalg-on-tensors. + + This class handles all the common lowering that torch-mlir does before + reaching the linalg-on-tensors abstraction level. + """ + def __init__(self, backend: MhloBackend): + super().__init__() + self.backend = backend + + def compile(self, program: torch.nn.Module) -> Any: + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="mhlo") + + return self.backend.compile(module) + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, item.symbol)(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result diff --git a/python/torch_mlir_e2e_test/torchscript/configs/native_torch.py b/python/torch_mlir_e2e_test/configs/native_torch.py similarity index 92% rename from python/torch_mlir_e2e_test/torchscript/configs/native_torch.py rename to python/torch_mlir_e2e_test/configs/native_torch.py index 06115b34a7aa..b85353f65cef 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/native_torch.py +++ b/python/torch_mlir_e2e_test/configs/native_torch.py @@ -8,7 +8,7 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem class NativeTorchTestConfig(TestConfig): diff --git a/python/torch_mlir_e2e_test/torchscript/configs/torchscript.py b/python/torch_mlir_e2e_test/configs/torchscript.py similarity index 93% rename from python/torch_mlir_e2e_test/torchscript/configs/torchscript.py rename to python/torch_mlir_e2e_test/configs/torchscript.py index f79aa92cc40f..9d105557ccb9 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/torchscript.py +++ b/python/torch_mlir_e2e_test/configs/torchscript.py @@ -8,7 +8,7 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem class TorchScriptTestConfig(TestConfig): diff --git a/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py b/python/torch_mlir_e2e_test/configs/tosa_backend.py similarity index 72% rename from python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py rename to python/torch_mlir_e2e_test/configs/tosa_backend.py index f157433c7366..8b41cfeda535 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py +++ b/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -3,22 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import sys from typing import Any -from io import StringIO -import os -import tempfile -import numpy as np import torch +import torch_mlir from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem -from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, - convert_torchscript_module_to_torch_backend_contract_mlir, ) @@ -33,14 +28,9 @@ def __init__(self, backend: TosaBackend): self.backend = backend def compile(self, program: torch.nn.Module) -> Any: - - module = convert_torchscript_module_to_torch_backend_contract_mlir( - program) - - run_pipeline_with_repro_report( - module, - "torch-backend-to-tosa-backend-pipeline", - "Lower Torch Backend IR -> TOSA Backend IR") + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="tosa") return self.backend.compile(module) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/utils.py b/python/torch_mlir_e2e_test/configs/utils.py similarity index 53% rename from python/torch_mlir_e2e_test/torchscript/configs/utils.py rename to python/torch_mlir_e2e_test/configs/utils.py index 6e5da13650a3..c8f912f43aac 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/utils.py +++ b/python/torch_mlir_e2e_test/configs/utils.py @@ -3,17 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import sys from typing import Any -from io import StringIO import numpy as np import torch -from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder -from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations -from torch_mlir.compiler_utils import run_pipeline_with_repro_report - def recursively_convert_to_numpy(o: Any): if isinstance(o, torch.Tensor): @@ -50,41 +44,3 @@ def recursively_convert_from_numpy(o: Any): if isinstance(o, int): return o raise Exception(f"Unexpected Python function output: {o}") - - -def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module): - """Perform common lowering from TorchScript to Torch MLIR - - Returns an MLIR module that satisfies the Torch backend contract. - """ - mb = ModuleBuilder() - scripted = torch.jit.script(program) - class_annotator = ClassAnnotator() - - extract_annotations(program, scripted, class_annotator) - - - # TODO: Find a way to make each of these calls own its own - # "debuggable error report" situation. - try: - original_stderr = sys.stderr - sys.stderr = StringIO() - # Import the TorchScript module to MLIR - mb.import_module(scripted._c, class_annotator) - except Exception as e: - raise Exception(f""" -PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: -Exception: -{e} -Diagnostics: -{sys.stderr.getvalue()} -""") from None - finally: - sys.stderr = original_stderr - - run_pipeline_with_repro_report( - mb.module, - "torchscript-module-to-torch-backend-pipeline", - "Lowering TorchScript Object Graph IR -> Torch Backend IR") - - return mb.module diff --git a/python/torch_mlir_e2e_test/torchscript/framework.py b/python/torch_mlir_e2e_test/framework.py similarity index 91% rename from python/torch_mlir_e2e_test/torchscript/framework.py rename to python/torch_mlir_e2e_test/framework.py index fdaa46084ffa..38ba064c9548 100644 --- a/python/torch_mlir_e2e_test/torchscript/framework.py +++ b/python/torch_mlir_e2e_test/framework.py @@ -23,6 +23,7 @@ import abc from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict +import sys import traceback import torch @@ -182,6 +183,9 @@ def __init__(self): def rand(self, *sizes, low=0.0, high=1.0): return torch.empty(sizes).uniform_(low, high) + def randint(self, *sizes, low=0, high=10): + return torch.randint(low, high, sizes) + def nans(self, *sizes): vals = torch.empty(sizes) vals[...] = torch.nan @@ -277,9 +281,11 @@ def generate_golden_trace(test: Test) -> Trace: return trace -def compile_and_run_test(test: Test, config: TestConfig) -> Any: +def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: try: golden_trace = generate_golden_trace(test) + if verbose: + print(f"Compiling {test.unique_name}...", file=sys.stderr) compiled = config.compile(test.program_factory()) except Exception as e: return TestResult(unique_name=test.unique_name, @@ -290,6 +296,8 @@ def compile_and_run_test(test: Test, config: TestConfig) -> Any: trace=None, golden_trace=None) try: + if verbose: + print(f"Running {test.unique_name}...", file=sys.stderr) trace = config.run(compiled, golden_trace) except Exception as e: return TestResult(unique_name=test.unique_name, @@ -309,31 +317,39 @@ def compile_and_run_test(test: Test, config: TestConfig) -> Any: queue_sentinel = "QUEUE_SENTINEL" -def run_workers_in_parallel(task_queue: mp.Queue, worker): - NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize()) - - # TODO: We've noticed that on certain 2 core machine parallelizing the tests - # makes the llvm backend legacy pass manager 20x slower than using a - # single process. Need to investigate the root cause eventually. This is a - # hack to work around this issue. - if mp.cpu_count() == 2: - NUMBER_OF_PROCESSES = 1 - +def run_workers_in_parallel(task_queue: mp.Queue, worker, num_processes: int): processes = [] - for i in range(NUMBER_OF_PROCESSES): + for i in range(num_processes): p = mp.get_context("fork").Process(target=worker, args=(task_queue, )) p.start() processes.append(p) - for i in range(NUMBER_OF_PROCESSES): + for i in range(num_processes): task_queue.put(queue_sentinel) for p in processes: p.join() -def run_tests(tests: List[Test], config: TestConfig, sequential = False) -> List[TestResult]: +def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: """Invoke the given `Test`'s with the provided `TestConfig`.""" - if sequential: - return [compile_and_run_test(test, config) for test in tests] + num_processes = min(int(mp.cpu_count() * 1.1), len(tests)) + # TODO: We've noticed that on certain 2 core machine parallelizing the tests + # makes the llvm backend legacy pass manager 20x slower than using a + # single process. Need to investigate the root cause eventually. This is a + # hack to work around this issue. + # Also our multiprocessing implementation is not the most efficient, so + # the benefit at core count 2 is probably not worth it anyway. + if mp.cpu_count() == 2: + num_processes = 1 + + # Sort the tests to make output nicer. + tests = list(sorted(tests, key=lambda t: t.unique_name)) + + # TODO: If num_processes == 1, then run without any of the multiprocessing + # machinery. In theory it should work, but any crash in the testing process + # seems to cause a cascade of failures resulting in undecipherable error + # messages. + if num_processes == 1 or sequential: + return [compile_and_run_test(test, config, verbose) for test in tests] # To run e2e tests in parallel: # The tests are put into a synchronized queue. Multiple worker processes are @@ -357,7 +373,7 @@ def worker(tests_queue: mp.Queue): sync_results.append( compile_and_run_test(tests_dict[test_name], config)) - run_workers_in_parallel(tests_queue, worker) + run_workers_in_parallel(tests_queue, worker, num_processes) tests_with_results = {result.unique_name for result in sync_results} all_tests = {test.unique_name for test in tests} # For processes that are crashed due to compile time or runtime error, diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index fd4110afe05a..09ab07c4dc07 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -10,8 +10,6 @@ from torch_mlir.passmanager import * from torch_mlir.execution_engine import * from torch_mlir.runtime import * -# Imported for side effects. -import torch_mlir.all_passes_registration import torch_mlir.dialects.torch from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -144,6 +142,8 @@ def invoke(*args): "func.func(refback-expand-ops-for-llvm)", "func.func(arith-expand)", "func.func(convert-math-to-llvm)", + # Handle some complex mlir::math ops (e.g. atan2) + "convert-math-to-libm", "convert-linalg-to-llvm", "convert-memref-to-llvm", "func.func(convert-arith-to-llvm)", diff --git a/python/torch_mlir_e2e_test/mhlo_backends/__init__.py b/python/torch_mlir_e2e_test/mhlo_backends/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/torch_mlir_e2e_test/mhlo_backends/abc.py b/python/torch_mlir_e2e_test/mhlo_backends/abc.py new file mode 100644 index 000000000000..8fc51ac00f7a --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/abc.py @@ -0,0 +1,49 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `MhloBackend.compile` and the +# input to `MhloBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class MhloBackend(abc.ABC): + """The interface to an MHLO backend. + + Backends are recommended to raise meaningful exceptions in case of error, + ideally with easy reproduction instructions. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the MHLO backend contract + (see the VerifyMhloBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py new file mode 100644 index 000000000000..25896e0a0043 --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -0,0 +1,45 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import MhloBackend + +__all__ = [ + "LinalgOnTensorsMhloBackend", +] + +class LinalgOnTensorsMhloBackend(MhloBackend): + """Main entry-point for the linalg-on-tensors based MHLO backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the MHLO backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the MHLO + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + "func.func(hlo-legalize-to-linalg)", + "Lowering MLIR-HLO to Linalg-on-Tensors") + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/python/torch_mlir_e2e_test/torchscript/registry.py b/python/torch_mlir_e2e_test/registry.py similarity index 67% rename from python/torch_mlir_e2e_test/torchscript/registry.py rename to python/torch_mlir_e2e_test/registry.py index 7c60226059b0..2f6cab581749 100644 --- a/python/torch_mlir_e2e_test/torchscript/registry.py +++ b/python/torch_mlir_e2e_test/registry.py @@ -11,6 +11,8 @@ # The global registry of tests. GLOBAL_TEST_REGISTRY = [] +# Ensure that there are no duplicate names in the global test registry. +_SEEN_UNIQUE_NAMES = set() def register_test_case(module_factory: Callable[[], torch.nn.Module]): @@ -22,6 +24,13 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]): `program_invoker` is the decorated function. """ def decorator(f): + # Ensure that there are no duplicate names in the global test registry. + if f.__name__ in _SEEN_UNIQUE_NAMES: + raise Exception( + f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.") + _SEEN_UNIQUE_NAMES.add(f.__name__) + + # Store the test in the registry. GLOBAL_TEST_REGISTRY.append( Test(unique_name=f.__name__, program_factory=module_factory, diff --git a/python/torch_mlir_e2e_test/torchscript/reporting.py b/python/torch_mlir_e2e_test/reporting.py similarity index 98% rename from python/torch_mlir_e2e_test/torchscript/reporting.py rename to python/torch_mlir_e2e_test/reporting.py index b3764a5bfe22..bb95d3523ab1 100644 --- a/python/torch_mlir_e2e_test/torchscript/reporting.py +++ b/python/torch_mlir_e2e_test/reporting.py @@ -24,9 +24,10 @@ def __init__(self, tensor): self.max = torch.max(tensor.type(torch.float64)) self.mean = torch.mean(tensor.type(torch.float64)) self.shape = list(tensor.shape) + self.dtype = tensor.dtype def __str__(self): - return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}' + return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}' class ErrorContext: diff --git a/python/torch_mlir_e2e_test/torchscript/serialization.py b/python/torch_mlir_e2e_test/serialization.py similarity index 100% rename from python/torch_mlir_e2e_test/torchscript/serialization.py rename to python/torch_mlir_e2e_test/serialization.py diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 36e9c4868ab5..af77f919b306 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -9,10 +9,10 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "TableBatchEmbeddingModule_basic", - "MobilenetV2Module_basic", - "MobilenetV3Module_basic", "Convolution3DModule_basic", "Convolution1DModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose1dModule_basic", "MaxPool2dWith3dInputModule_basic", "MaxPool2dWithIndicesWith3dInputModule_basic", } @@ -53,5 +53,3 @@ def register_all_tests(): from . import return_types from . import control_flow from . import stats - # TODO: Re-enable after MacOS support is fixed for the extension. - #from . import custom_op_example diff --git a/python/torch_mlir_e2e_test/test_suite/arange.py b/python/torch_mlir_e2e_test/test_suite/arange.py index a22f5efa7377..d7ca3b6e2bac 100644 --- a/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/python/torch_mlir_e2e_test/test_suite/arange.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/argmax.py b/python/torch_mlir_e2e_test/test_suite/argmax.py index 575af604c058..098ed508b63c 100644 --- a/python/torch_mlir_e2e_test/test_suite/argmax.py +++ b/python/torch_mlir_e2e_test/test_suite/argmax.py @@ -4,9 +4,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -63,4 +63,3 @@ def forward(self, a): @register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) def ArgmaxModule_keepDim(module, tu: TestUtils): module.forward(tu.rand(4, 6)) - diff --git a/python/torch_mlir_e2e_test/test_suite/backprop.py b/python/torch_mlir_e2e_test/test_suite/backprop.py index 5a27bc7d377b..bd5a01e590cd 100644 --- a/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1bb6ae8f016e..70981ead2e2e 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -80,7 +80,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: IsFloatingPointInt()) def IsFloatingPointInt_False(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3))) + module.forward(tu.randint(3, 3, high=100)) # ============================================================================== @@ -104,7 +104,7 @@ def forward(self, x): def IsFloatingPointFloat_True(module, tu: TestUtils): module.forward(tu.rand(3)) - + # ============================================================================== @@ -137,7 +137,7 @@ def forward(self): @register_test_case(module_factory=lambda: ContainsIntListFalse()) def ContainsIntList_False(module, tu: TestUtils): module.forward() - + # ============================================================================== @@ -596,6 +596,30 @@ def TensorsConcatModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y, z): + return torch.cat([x, y, z], dim=-2) + + +@register_test_case(module_factory=lambda: TensorsConcatNegativeDimModule()) +def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4)) + + +# ============================================================================== + + class GatherModule(torch.nn.Module): def __init__(self): @@ -712,7 +736,7 @@ def forward(self, indices): @register_test_case(module_factory=lambda: EmbeddingModuleI64()) def EmbeddingModuleI64_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3))) + module.forward(tu.randint(3, 3, high=100)) # ============================================================================== @@ -738,7 +762,7 @@ def forward(self, indices): @register_test_case(module_factory=lambda: EmbeddingModuleI32()) def EmbeddingModuleI32_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3)).to(torch.int32)) + module.forward(tu.randint(3, 3, high=100).to(torch.int32)) # ============================================================================== @@ -763,7 +787,7 @@ def forward(self, indices): @register_test_case(module_factory=lambda: EmbeddingModuleI32Static()) def EmbeddingModuleI32Static_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3)).to(torch.int32)) + module.forward(tu.randint(3, 3, high=100).to(torch.int32)) # ============================================================================== @@ -789,7 +813,7 @@ def forward(self, indices): @register_test_case(module_factory=lambda: EmbeddingModule1DIndices()) def EmbeddingModule1DIndices_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3,)).to(torch.int32)) + module.forward(tu.randint(3, high=100).to(torch.int32)) # ============================================================================== @@ -935,6 +959,28 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils): # ============================================================================== +class SoftplusModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.softplus(x) + + +@register_test_case(module_factory=lambda: SoftplusModule()) +def SoftplusModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + + +# ============================================================================== + + class HardsigmoidModule(torch.nn.Module): def __init__(self): @@ -1001,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class RollModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, -1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.roll([2, -1], [0, 2]) + + +@register_test_case(module_factory=lambda: RollModule()) +def RollModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) + +# ============================================================================== + + class RepeatModule(torch.nn.Module): def __init__(self): @@ -1019,7 +1086,6 @@ def forward(self, x): def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) - # ============================================================================== @@ -1265,7 +1331,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: DropoutEvalIntModule()) def DropoutEvalIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(5, 10, (3, 4))) + module.forward(tu.randint(3, 4, low=5, high=10)) # ============================================================================== @@ -1354,7 +1420,7 @@ def forward(self, input): @register_test_case(module_factory=lambda: NumelZeroRankModule()) def NumelZeroRankModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [])) + module.forward(tu.randint(high=10)) # ============================================================================== @@ -1580,7 +1646,7 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: ReturnTwoTensorF32I64()) def ReturnTwoTensorF32I64_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(5, (2, 3))) + module.forward(tu.rand(2, 3), tu.randint(2, 3, high=5)) # ============================================================================== @@ -1603,7 +1669,7 @@ def forward(self, x, index): @register_test_case(module_factory=lambda: IndexTensorModule()) def IndexTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5), torch.randint(4, (2, 3))) + module.forward(tu.rand(5), tu.randint(2, 3, high=4)) # ============================================================================== @@ -1626,7 +1692,272 @@ def forward(self, x, index): @register_test_case(module_factory=lambda: IndexTensorModule3dInput()) def IndexTensorModule3dInput_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(3, (2, 3))) + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) + + +# ============================================================================== + + +class IndexTensorSelectDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, a, ind): + return torch.ops.aten.index(a, (None, ind, None)) + + +@register_test_case(module_factory=lambda: IndexTensorSelectDimModule()) +def IndexTensorSelectDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3)) + +# ============================================================================== + + +class IndexTensorMultiInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([3, 3], torch.int64, True), + ([3], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, index2,)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInput()) +def IndexTensorMultiInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3), tu.randint(3, high=3)) + + +# ============================================================================== + + +class IndexTensorMultiInputOneDim(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([6, 1], torch.int64, True), + ([3], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, index2,)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim()) +def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3)) + + +# ============================================================================== + + +class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + None, + index1, + index2, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic()) +def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), + tu.randint(3, high=3)) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + index1, + None, + index2, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic()) +def IndexTensorMultiInputNonContiguousOneDimDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), + tu.randint(3, high=3)) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, 2], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, ( + index2, + None, + index1, + )) + + +@register_test_case( + module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic()) +def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(6, 2, high=2), + tu.randint(2, high=3)) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, (index1, index2, index3)) + + +@register_test_case(module_factory=lambda: + IndexTensorMultiInputNonContiguousMultipleStaticDims()) +def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) + + +# ============================================================================== + + +class IndexTensorMultiInputNonContiguous(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 2], torch.int64, True), + ([4, 2], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, None, index2)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous()) +def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1)) + + +# ============================================================================== + + +class IndexTensorMultiInputThreeIndexers(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([8, 4, 2], torch.int64, True), + ([8, 1, 1], torch.int64, True), + ([4, 2], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, (None, None, index1, None, index2, index3)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers()) +def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 4, 4, 5, 3), + tu.randint(8, 4, 2, high=3), + tu.randint(8, 1, 1, high=4), + tu.randint(4, 2, high=2)) + + +# ============================================================================== + + +class IndexTensorMultiInputContiguousCenter(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([2, 2], torch.int64, True), + ([2], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (None, index1, index2, None)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter()) +def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3), tu.randint(2, high=2)) # ============================================================================== @@ -1758,7 +2089,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: HardTanhIntModule()) def HardTanhIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-5, 5, (100, 100))) + module.forward(tu.randint(100, 100, low=-5, high=5)) # ============================================================================== @@ -1780,7 +2111,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: BincountModule()) def BincountModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (1000, ))) + module.forward(tu.randint(1000, high=10)) # ============================================================================== @@ -1802,7 +2133,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: BincountStaticSizeModule()) def BincountStaticSizeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (200, ))) + module.forward(tu.randint(200, high=100)) # ============================================================================== @@ -1824,7 +2155,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: BincountMinlengthModule()) def BincountMinlengthModule_basic(module, tu: TestUtils): - module.forward(torch.randint(5, (20, ))) + module.forward(tu.randint(20, high=5)) # ============================================================================== @@ -1867,8 +2198,8 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ExpandAsIntModule()) def ExpandAsIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (1, 1, 1)), - torch.randint(200, (4, 5, 6))) + module.forward(tu.randint(1, 1, 1, high=100), + tu.randint(4, 5, 6, high=200)) # ============================================================================== @@ -1931,7 +2262,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule()) def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4)) + module.forward(tu.randint(3, 2, 4, high=100), tu.rand(3, 2, 4)) class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module): @@ -1952,7 +2283,7 @@ def forward(self, x, y): @register_test_case( module_factory=lambda: CopyWithDifferentDTypesAndSizesModule()) def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1))) + module.forward(tu.rand(3, 2, 4), tu.randint(3, 2, 1, high=1000)) # ============================================================================== @@ -2120,7 +2451,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ScalarImplicitIntModule()) def ScalarImplicitIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100)) # ============================================================================== @@ -2186,7 +2517,7 @@ def forward(self, input, batch1, batch2): @register_test_case(module_factory=lambda: BaddbmmDifferentDtypesModule()) def BaddbmmDifferentDtypesModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4, 5)), tu.rand(3, 4, 6), + module.forward(tu.randint(3, 4, 5, high=10), tu.rand(3, 4, 6), tu.rand(3, 6, 5)) @@ -2391,3 +2722,66 @@ def forward(self, lhs): @register_test_case(module_factory=lambda: NumpyTRank0Module()) def NumpyTRank0Module_basic(module, tu: TestUtils): module.forward(torch.tensor(7, dtype=torch.float32)) + +class AtenEmbeddingBagSumExample(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, weight, indices, offsets): + return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None) + +@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample()) +def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils): + weight = torch.rand(100, 10) + indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) + module.forward(weight, indices, offsets) + +class Aten_EmbeddingBagExample(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, weight, indices, offsets): + return torch.ops.aten._embedding_bag(weight, indices, offsets) + +@register_test_case(module_factory=lambda: Aten_EmbeddingBagExample()) +def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): + weight = torch.rand(100, 10) + indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) + module.forward(weight, indices, offsets) + +# ============================================================================== + +class AtenToDeviceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ]) + + def forward(self, val): + return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False) + +@register_test_case(module_factory=lambda: AtenToDeviceModule()) +def AtenToDeviceModule_basic(module, tu: TestUtils): + module.forward(torch.randn(2, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/cast.py b/python/torch_mlir_e2e_test/test_suite/cast.py index 83406694959e..b6867ad319a6 100644 --- a/python/torch_mlir_e2e_test/test_suite/cast.py +++ b/python/torch_mlir_e2e_test/test_suite/cast.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -26,7 +26,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TensorToIntZeroRank()) def TensorToIntZeroRank_basic(module, tu: TestUtils): - module.forward(torch.randint(10, ())) + module.forward(tu.randint(high=10)) # ============================================================================== @@ -45,7 +45,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TensorToInt()) def TensorToInt_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (1, 1))) + module.forward(tu.randint(1, 1, high=10)) # ============================================================================== @@ -122,4 +122,3 @@ def forward(self, x): @register_test_case(module_factory=lambda: TensorToBool()) def TensorToBool_basic(module, tu: TestUtils): module.forward(torch.tensor([[1]], dtype=torch.bool)) - diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 44282a87921b..292c0aacf08e 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -195,6 +195,24 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward() +class OnesModuleCPUDevice(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ones(3, 4, device="cpu") + + +@register_test_case(module_factory=lambda: OnesModuleCPUDevice()) +def OnesModuleCPUDevice_basic(module, tu: TestUtils): + module.forward() + + # ============================================================================== @@ -328,7 +346,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: EmptyLikeIntModule()) def EmptyLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class EmptyLikeMemoryFormatModule(torch.nn.Module): @@ -342,7 +360,8 @@ def __init__(self): ([-1, -1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0) + return torch.empty_like(a, + memory_format=torch.preserve_format).fill_(0) @register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule()) @@ -427,7 +446,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ZerosLikeIntModule()) def ZerosLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class ZerosLikeFloatModule(torch.nn.Module): @@ -506,7 +525,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: OnesLikeIntModule()) def OnesLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class OnesLikeFloatModule(torch.nn.Module): @@ -683,7 +702,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewZerosModuleFloat2D()) def NewZerosModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewZerosModuleFloat3D(torch.nn.Module): @@ -702,7 +721,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewZerosModuleFloat3D()) def NewZerosModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewZerosModuleFalsePinMemory(torch.nn.Module): @@ -723,7 +742,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory()) def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) # ============================================================================== @@ -802,7 +821,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewOnesModuleFloat2D()) def NewOnesModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewOnesModuleFloat3D(torch.nn.Module): @@ -821,7 +840,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewOnesModuleFloat3D()) def NewOnesModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewOnesModuleFalsePinMemory(torch.nn.Module): @@ -842,7 +861,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory()) def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) # ============================================================================== @@ -997,7 +1016,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: FullLikeModuleInt2D()) def FullLikeModuleInt2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class FullLikeModuleInt3D(torch.nn.Module): @@ -1016,7 +1035,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: FullLikeModuleInt3D()) def FullLikeModuleInt3D_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4, 5)).to(torch.int32)) + module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32)) class FullLikeModuleInt2DStatic(torch.nn.Module): @@ -1035,7 +1054,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: FullLikeModuleInt2DStatic()) def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class FullLikeModuleFloat2D(torch.nn.Module): @@ -1114,7 +1133,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory()) def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4))) + module.forward(tu.randint(10, 4, high=100)) # ============================================================================== @@ -1155,7 +1174,7 @@ def forward(self, tensor): @register_test_case(module_factory=lambda: ZeroInt32Module()) def ZeroInt32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4), dtype=torch.int32)) + module.forward(tu.randint(10, 4, high=100).to(dtype=torch.int32)) class ZeroInt64Module(torch.nn.Module): @@ -1174,7 +1193,7 @@ def forward(self, tensor): @register_test_case(module_factory=lambda: ZeroInt64Module()) def ZeroInt64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4))) + module.forward(tu.randint(10, 4, high=100)) # ============================================================================== @@ -1255,7 +1274,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyModuleFloat2D()) def NewEmptyModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewEmptyModuleFloat3D(torch.nn.Module): @@ -1275,7 +1294,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyModuleFloat3D()) def NewEmptyModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewEmptyModuleFalsePinMemory(torch.nn.Module): @@ -1296,7 +1315,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory()) def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module): @@ -1335,7 +1354,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype()) def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) class NewEmptyModuleLayoutIntDtype(torch.nn.Module): @@ -1354,7 +1373,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype()) def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) # ============================================================================== @@ -1378,7 +1397,7 @@ def forward(self, x, mask): @register_test_case(module_factory=lambda: MaskedFillScalarDefaultModule()) def MaskedFillScalarDefaultModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarIntValueModule(torch.nn.Module): @@ -1399,7 +1418,7 @@ def forward(self, x, mask): @register_test_case(module_factory=lambda: MaskedFillScalarIntValueModule()) def MaskedFillScalarIntValueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarFloatValueModule(torch.nn.Module): @@ -1419,5 +1438,27 @@ def forward(self, x, mask): @register_test_case(module_factory=lambda: MaskedFillScalarFloatValueModule()) def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 10, (2, 3)), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + module.forward(tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool)) + + +class MaskedFillTensorFloatValueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ([], torch.float32, True), + ]) + def forward(self, x, mask, value): + return torch.ops.aten.masked_fill(x, mask, value=value) + + +@register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule()) +def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand()) diff --git a/python/torch_mlir_e2e_test/test_suite/control_flow.py b/python/torch_mlir_e2e_test/test_suite/control_flow.py index df1912e3e892..5c00a75e06da 100644 --- a/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -7,9 +7,9 @@ import torch import random -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -32,7 +32,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule()) def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 10, (6, 8))) + module.forward(tu.randint(6, 8, high=10)) # ============================================================================== class TorchPrimLoopWhileLikeModule(torch.nn.Module): @@ -54,4 +54,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule()) def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 10, (6, 8))) + module.forward(tu.randint(6, 8, high=10)) diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 57deeef8fcf5..a098db07aa39 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -4,9 +4,9 @@ # Also available under a BSD-style license. See LICENSE. import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -405,3 +405,303 @@ def forward(self, inputVec, weight): @register_test_case(module_factory=lambda: _Convolution2DTF32Module()) def _Convolution2DTF32Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule()) +def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=True, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule()) +def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=True, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule()) +def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=True) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule()) +def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class ConvolutionModule2DGroups(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=4) + +@register_test_case(module_factory=lambda: ConvolutionModule2DGroups()) +def ConvolutionModule2DGroups_basic(module, tu: TestUtils): + module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3)) + +# ============================================================================== + +class ConvolutionModule2DTranspose(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTranspose()) +def ConvolutionModule2DTranspose_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2)) + +class ConvolutionModule2DTransposeStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided()) +def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic()) +def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + + +class Conv_Transpose1dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d(inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5), torch.randn(2, 5, 2)) + + +class Conv_Transpose2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dModule()) +def Conv_Transpose2dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) + +class Conv_Transpose3dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d(inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2)) diff --git a/python/torch_mlir_e2e_test/test_suite/custom_op_example.py b/python/torch_mlir_e2e_test/test_suite/custom_op_example.py index c67b7c8d8fb1..3d08708d7cd1 100644 --- a/python/torch_mlir_e2e_test/test_suite/custom_op_example.py +++ b/python/torch_mlir_e2e_test/test_suite/custom_op_example.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -33,4 +33,3 @@ def forward(self, a): @register_test_case(module_factory=lambda: CustomOpExampleModule()) def CustomOpExampleModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) - diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 56114e9012c0..98f0a94fd913 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # TODO: Support scalar !torch.int/!torch.float variants. Add support to # ReduceOpVariants to implement them in terms of the tensor-only variants + @@ -57,7 +57,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseUnaryIntModule()) def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -428,7 +428,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule()) def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32)) + module.forward(tu.randint(3, 5, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -474,7 +474,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseMinimumIntModule()) def ElementwiseMinimumIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) # ============================================================================== @@ -520,7 +520,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseMaximumIntModule()) def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) # ============================================================================== @@ -663,7 +663,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: RsubIntModule()) def RsubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4))) + module.forward(tu.randint(3, 4, high=100)) # ============================================================================== @@ -685,7 +685,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: RsubIntModule_noalpha()) def RsubIntModule_noalpha_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4))) + module.forward(tu.randint(3, 4, high=100)) # ============================================================================== @@ -707,7 +707,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule()) def ElementwiseMulScalarModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4))) + module.forward(tu.randint(3, 4, high=10)) # ============================================================================== @@ -751,7 +751,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) def ElementwiseMulScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, high=10).to(torch.int32)) # ============================================================================== @@ -798,7 +798,78 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: ElementwiseMulTensorIntModule()) def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): module.forward( - torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4])) + tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10)) + + +# ============================================================================== + + +class ElementwiseAtan2TensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule()) +def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + + +# ============================================================================== + + +class ElementwiseAtan2TensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule()) +def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10)) + + +# ============================================================================== + + +class ElementwiseAtan2FloatIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule()) +def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32), + tu.rand(4, 4).double()) # ============================================================================== @@ -842,7 +913,28 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseLogIntModule()) def ElementwiseLogIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + +# ============================================================================== + + +class ElementwiseLog1pModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.log1p(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog1pModule()) +def ElementwiseLog1pModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) # ============================================================================== @@ -886,7 +978,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseErfIntModule()) def ElementwiseErfIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -930,7 +1022,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseSqrtIntModule()) def ElementwiseSqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1078,7 +1170,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseLog2IntModule()) def ElementwiseLog2IntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1122,7 +1214,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule()) def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1188,7 +1280,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule()) def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (4,), dtype=torch.int32)) + module.forward(tu.randint(4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1212,6 +1304,90 @@ def forward(self, x): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + def forward(self, x): + return torch.remainder(x, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float()) +def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): + module.forward(tu.randint(3, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.remainder(x, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) +def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): + module.forward(torch.rand(10, 3)) + + +# ============================================================================== + +class ElementwiseRemainderScalarModule_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.remainder(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) +def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 2, high=10).to(torch.int32)) + +# ============================================================================== + +class ElementwiseRemainderScalarModule_Bool(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.bool, True), + ]) + def forward(self, x): + return torch.remainder(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool()) +def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False, True, True, True])) + # ============================================================================== @@ -1302,8 +1478,8 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseAndIntegerModule()) def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): module.forward( - torch.randint(-10, 10, (3, 4)).to(torch.int32), - torch.randint(-10, 10, (3, 4))) + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(3, 4, low=-10, high=10)) # ============================================================================== @@ -1325,7 +1501,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule()) def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, high=10).to(dtype=torch.int32)) # ============================================================================== @@ -1369,7 +1545,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module()) def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4))) + module.forward(tu.randint(3, 4, high=10)) # ============================================================================== @@ -1391,7 +1567,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule()) def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3), dtype=torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) # ============================================================================== @@ -1501,7 +1677,51 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseExpIntModule()) def ElementwiseExpIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseExpm1Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpm1Module()) +def ElementwiseExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseExpm1IntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) +def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1545,7 +1765,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseSinIntModule()) def ElementwiseSinIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1589,7 +1809,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ElementwiseCosIntModule()) def ElementwiseCosIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1704,7 +1924,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule()) def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils): - module.forward(torch.randint(3, 10, (2, 3, 4, 5)), torch.randint(10, 100, (2, 3, 4, 5))) + module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100)) # ============================================================================== @@ -1742,7 +1962,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule()) def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): - module.forward(torch.neg(torch.randint(3, 10, (2, 3, 4, 5))), torch.neg(torch.randint(10, 100, (2, 3, 4, 5)))) + module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100))) # ============================================================================== @@ -1761,7 +1981,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule()) def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils): - module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3))) + module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 3e7f8a79ad3e..89d334628b21 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -45,7 +45,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule()) def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -64,7 +64,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule()) def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -102,7 +102,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule()) def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -121,7 +121,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule()) def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -180,7 +180,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule()) def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) + module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) # ============================================================================== @@ -218,7 +218,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule()) def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -238,7 +238,7 @@ def forward(self, x): @register_test_case( module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -276,7 +276,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule()) def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -295,7 +295,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule()) def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -354,7 +354,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule()) def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) + module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) # ============================================================================== @@ -393,7 +393,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule()) def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (5, 8))) + module.forward(tu.randint(5, 8, low=2, high=4)) # ============================================================================== @@ -413,7 +413,7 @@ def forward(self, x): @register_test_case( module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32)) + module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32)) # ============================================================================== @@ -455,7 +455,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule()) def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, ))) + module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4)) # ============================================================================== @@ -494,7 +494,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule()) def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (8, 5))) + module.forward(tu.randint(8, 5, low=2, high=4)) # ============================================================================== @@ -571,4 +571,3 @@ def forward(self): @register_test_case(module_factory=lambda: AllBoolFalseModule()) def AllBoolFalseModule_basic(module, tu: TestUtils): module.forward() - diff --git a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py index 8a176ce00ec1..ca95e36a3844 100644 --- a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py +++ b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -38,7 +38,7 @@ def __init__(self): torch.empty([_num_interval], dtype=torch.float64).fill_(0.0), ) self.register_buffer("_bin_ids", torch.arange(_num_interval)) - self.positive_weight = torch.tensor([0.4]) + self.register_buffer("positive_weight", torch.tensor([0.4])) self.bin_ctr_in_use_after = 0 self.bin_ctr_weight_value = 0.9995 self.oneminusbin_ctr_weight_value = 0.0005 @@ -54,6 +54,9 @@ def __init__(self): def forward(self, segment_value, segment_lengths, logit): origin_prediction = torch.sigmoid( logit + torch.log(self.positive_weight)) + # TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related + # issues below when new tensors are created on the CPU, which is currently the default behaviour. + # The solution would be to move these tensors to ensure they are on the same device as the existing ones. dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32) validoffsets = torch.gt( segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits]) @@ -87,17 +90,12 @@ def forward(self, segment_value, segment_lengths, logit): @register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature()) def HBC_basic(module, tu: TestUtils): logits = torch.rand(NUM_LOGITS, dtype=torch.float) - segment_lengths: Tensor = torch.randint( - 0, 2, (NUM_LOGITS,), dtype=torch.int) + segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int) segment_offsets: Tensor = torch.cumsum(segment_lengths, 0) segment_offsets: Tensor = torch.cat( (torch.tensor([0]), segment_offsets), 0) num_values: int = int(torch.sum(segment_lengths).item()) - segment_values: Tensor = torch.randint( - 0, - NUM_SEGMENTS, - (num_values,), - ) + segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS) segment_values = torch.cat( (segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0) module.forward(segment_values.int(), segment_offsets.int(), logits) diff --git a/python/torch_mlir_e2e_test/test_suite/index_put.py b/python/torch_mlir_e2e_test/test_suite/index_put.py index e12b2627caef..4fb74511b7e1 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_put.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -34,7 +34,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule()) def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module): @@ -59,7 +59,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule()) def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): @@ -84,7 +84,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule()) def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -113,8 +113,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule()) def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) # ============================================================================== @@ -142,7 +142,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl1DFloatAccumulateModule()) def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module): @@ -167,7 +167,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl2DFloatAccumulateModule()) def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module): @@ -192,7 +192,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutImpl3DFloatAccumulateModule()) def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -220,8 +220,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) # ============================================================================== @@ -248,7 +248,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPut1DFloatNonAccumulateModule()) def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPut2DFloatNonAccumulateModule(torch.nn.Module): @@ -272,7 +272,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPut2DFloatNonAccumulateModule()) def IndexPut2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPut3DFloatNonAccumulateModule(torch.nn.Module): @@ -296,7 +296,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPut3DFloatNonAccumulateModule()) def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -323,8 +323,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule()) def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) class IndexPut2DIntNonAccumulateModule(torch.nn.Module): @@ -347,8 +347,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut2DIntNonAccumulateModule()) def IndexPut2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPut3DIntNonAccumulateModule(torch.nn.Module): @@ -371,8 +371,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut3DIntNonAccumulateModule()) def IndexPut3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -398,7 +398,7 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule()) def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPut2DFloatAccumulateModule(torch.nn.Module): @@ -421,7 +421,7 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut2DFloatAccumulateModule()) def IndexPut2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPut3DFloatAccumulateModule(torch.nn.Module): @@ -444,7 +444,7 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule()) def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -471,8 +471,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule()) def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) class IndexPut2DIntAccumulateModule(torch.nn.Module): @@ -495,8 +495,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut2DIntAccumulateModule()) def IndexPut2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPut3DIntAccumulateModule(torch.nn.Module): @@ -519,8 +519,8 @@ def forward(self, input, index, value): @register_test_case(module_factory=lambda: IndexPut3DIntAccumulateModule()) def IndexPut3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -548,7 +548,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DFloatNonAccumulateModule()) def IndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module): @@ -572,7 +572,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DFloatNonAccumulateModule()) def IndexPutHackedTwin2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module): @@ -596,7 +596,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule()) def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -624,8 +624,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DIntNonAccumulateModule()) def IndexPutHackedTwin1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module): @@ -649,8 +649,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DIntNonAccumulateModule()) def IndexPutHackedTwin2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module): @@ -674,8 +674,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DIntNonAccumulateModule()) def IndexPutHackedTwin3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -700,7 +700,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DFloatAccumulateModule()) def IndexPutHackedTwin1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module): @@ -722,7 +722,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DFloatAccumulateModule()) def IndexPutHackedTwin2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module): @@ -744,7 +744,7 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule()) def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -770,8 +770,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DIntAccumulateModule()) def IndexPutHackedTwin1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module): @@ -793,8 +793,8 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DIntAccumulateModule()) def IndexPutHackedTwin2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module): @@ -816,5 +816,5 @@ def forward(self, input, index, value): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DIntAccumulateModule()) def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) diff --git a/python/torch_mlir_e2e_test/test_suite/index_select.py b/python/torch_mlir_e2e_test/test_suite/index_select.py index c4693cc1a42c..1bb575ccb283 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/matmul.py b/python/torch_mlir_e2e_test/test_suite/matmul.py index 5b9502caec11..e1ecfa6a3b0f 100644 --- a/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/mlp.py b/python/torch_mlir_e2e_test/test_suite/mlp.py index 153f357590f2..faebcadf30f6 100644 --- a/python/torch_mlir_e2e_test/test_suite/mlp.py +++ b/python/torch_mlir_e2e_test/test_suite/mlp.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/python/torch_mlir_e2e_test/test_suite/nll_loss.py index edbb8f444fda..f5eeb1f2c009 100644 --- a/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -34,7 +34,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule()) def NllLossModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_mean(torch.nn.Module): @@ -58,7 +58,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule_mean()) def NllLossModule_mean_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_sum(torch.nn.Module): @@ -82,7 +82,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule_sum()) def NllLossModule_sum_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_1D(torch.nn.Module): @@ -106,7 +106,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule_1D()) def NllLossModule_1D_basic(module, tu: TestUtils): - module.forward(tu.rand(3), torch.randint(0, 3, ())) + module.forward(tu.rand(3), tu.randint(high=3)) class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): @@ -131,7 +131,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_backward(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/norm_like.py b/python/torch_mlir_e2e_test/test_suite/norm_like.py index f8670654a3ed..e8b006c30efb 100644 --- a/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 26a18b0dfe14..36ad293605a2 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -489,7 +489,7 @@ def forward(self, output, input, indices): module_factory=lambda: MaxPool2dWithIndicesBackwardStatic4DModule()) def MaxPool2dWithIndicesBackwardStatic4DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - torch.randint(16, (2, 4, 7, 6))) + tu.randint(2, 4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module): @@ -519,7 +519,7 @@ def forward(self, output, input, indices): module_factory=lambda: MaxPool2dWithIndicesBackwardStatic3DModule()) def MaxPool2dWithIndicesBackwardStatic3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 7, 6), tu.rand(4, 6, 5), - torch.randint(16, (4, 7, 6))) + tu.randint(4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module): @@ -549,7 +549,7 @@ def forward(self, output, input, indices): module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic4DModule()) def MaxPool2dWithIndicesBackwardDynamic4DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - torch.randint(16, (2, 4, 7, 6))) + tu.randint(2, 4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module): @@ -579,7 +579,7 @@ def forward(self, output, input, indices): module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic3DModule()) def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5), - torch.randint(16, (2, 7, 6))) + tu.randint(2, 7, 6, high=16)) # ============================================================================== @@ -632,7 +632,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: AvgPool2dIntModule()) def AvgPool2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (2, 4, 20, 20))) + module.forward(tu.randint(2, 4, 20, 20, high=100)) class AvgPool2dStaticModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 7080a25fea16..e4a118700aa1 100644 --- a/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -6,9 +6,9 @@ import torch from torch import nn -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index be1ed5078485..1eecb5186fd7 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -106,6 +106,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListEmptyDimModule()) +def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -121,7 +140,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumUnsignedIntModule()) def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=0, high=100)) # ============================================================================== @@ -140,7 +159,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumSignedIntModule()) def ReduceSumSignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) # ============================================================================== @@ -159,7 +178,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumDtypeIntModule()) def ReduceSumDtypeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32)) + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) # ============================================================================== @@ -178,7 +197,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumDimIntListIntModule()) def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== @@ -197,7 +216,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeIntModule()) def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32)) + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) # ============================================================================== @@ -216,7 +235,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimIntModule()) def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== @@ -364,7 +383,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceMaxSignedIntModule()) def ReduceMaxSignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) # ============================================================================== @@ -382,7 +401,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule()) def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index c99f32a937ec..a8bdc5859604 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -4,9 +4,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -55,15 +55,15 @@ def __init__(self): @export @annotate_args([ None, - ([1, 3], torch.float32, True), + ([2, 1, 16, 1, 1], torch.float32, True), ]) def forward(self, a): - return a.view(1, 1, 3, 1, 1) + return a.view(1, 2, 1, 16, 1, 1, 1, 1) @register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule()) def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 3)) + module.forward(tu.rand(2, 1, 16, 1, 1)) # ============================================================================== @@ -164,6 +164,82 @@ def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewExpandCollapseWithOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 8, 8], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, 1, 4, 64) + +@register_test_case(module_factory=lambda: ViewExpandCollapseWithOnesModule()) +def ViewExpandCollapseWithOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + +class ViewExpandCollapseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 8, 16, 4], torch.float32, True), + ]) + + def forward(self, a): + return a.view(8, 2, 4, 16, 2, 2) + +@register_test_case(module_factory=lambda: ViewExpandCollapseModule()) +def ViewExpandCollapseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 16, 4)) + +# ============================================================================== + +class ViewDynamicExpandCollapseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, 4, 64) + +@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseModule()) +def ViewDynamicExpandCollapseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + +class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, a.size(1), 64) + +@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule()) +def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + class View1DFoldModule(torch.nn.Module): def __init__(self): super().__init__() @@ -485,4 +561,3 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) - diff --git a/python/torch_mlir_e2e_test/test_suite/return_types.py b/python/torch_mlir_e2e_test/test_suite/return_types.py index 942f9bbd9cdd..2acfacf9850e 100644 --- a/python/torch_mlir_e2e_test/test_suite/return_types.py +++ b/python/torch_mlir_e2e_test/test_suite/return_types.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index ea13eadaf95f..dcbea55dd89a 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -1,8 +1,8 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index f79ebc206c30..cbf17a09ac9d 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -29,7 +29,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: AddIntModule()) def AddIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -52,7 +52,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: SubIntModule()) def SubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -98,7 +98,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MulIntModule()) def MulIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -172,7 +172,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: SqrtIntModule()) def SqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, ())) + module.forward(tu.randint(high=10)) class SqrtIntConstantModule(torch.nn.Module): @@ -273,7 +273,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: BoolIntFalseModule()) def BoolIntFalseModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 100, ())) + module.forward(tu.randint(low=1, high=100)) class BoolIntTrueModule(torch.nn.Module): @@ -292,7 +292,7 @@ def forward(self, a): @register_test_case(module_factory=lambda: BoolIntTrueModule()) def BoolIntTrueModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 100, ())) + module.forward(tu.randint(low=1, high=100)) class BoolIntConstantModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 8a626d9625a0..d9d0bd121c64 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -29,7 +29,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: NeIntModule()) def NeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -52,7 +52,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: EqIntModule()) def EqIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -75,7 +75,30 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: GtIntModule()) def GtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) + + +# ============================================================================== + + +class GeIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.ops.aten.ge(int(lhs), int(rhs)) + + +@register_test_case(module_factory=lambda: GeIntModule()) +def GeIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -121,7 +144,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: GeFloatIntModule()) def GeFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -144,7 +167,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: NeFloatIntModule()) def NeFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -167,4 +190,4 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: GtFloatIntModule()) def GtFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 3a56826ff3bc..d4dc44ea9209 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -46,7 +46,7 @@ def forward(self, x): result = x[:8, :5, 8:] cat_tensor = torch.ones((6,4,1), dtype=torch.float32) return torch.cat((result,cat_tensor), dim=2) - + @register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule()) def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils): @@ -229,7 +229,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: SelectIntModule()) def SelectIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (5,5))) + module.forward(tu.randint(5,5, high=10)) # ============================================================================== @@ -270,6 +270,31 @@ def forward(self, x, src): def SliceScatterZeroDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(1, 8)) + +class SliceScatterNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, + src, + dim=-2, + start=0, + end=1, + step=1) + + +@register_test_case(module_factory=lambda: SliceScatterNegativeDimModule()) +def SliceScatterNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(1, 8)) + class SliceScatterStepVariationModule(torch.nn.Module): def __init__(self): super().__init__() @@ -341,3 +366,81 @@ def forward(self, x, src): @register_test_case(module_factory=lambda: SelectScatterStaticModule()) def SelectScattertStaticModule_basic(module, tu: TestUtils): module.forward(torch.rand(6, 8, 5), torch.rand(6, 5)) + +# ============================================================================== + +class NarrowHorizontalTest(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.narrow(x, dim=0, start=0, length=2) + + +@register_test_case(module_factory=lambda: NarrowHorizontalTest()) +def NarrowHorizontalTest_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,3)) + +# ============================================================================== + + +class NarrowVerticalTest(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=0, length=2) + + +@register_test_case(module_factory=lambda: NarrowVerticalTest()) +def NarrowVerticalTest_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,3)) + +# ============================================================================== + +class NarrowHorizontalTest2(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.narrow(x, dim=0, start=0, length=2) + + +@register_test_case(module_factory=lambda: NarrowHorizontalTest2()) +def NarrowHorizontalTest2_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) + +# ============================================================================== + + +class NarrowVerticalTest2(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=0, length=2) + + +@register_test_case(module_factory=lambda: NarrowVerticalTest2()) +def NarrowVerticalTest2_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) diff --git a/python/torch_mlir_e2e_test/test_suite/squeeze.py b/python/torch_mlir_e2e_test/test_suite/squeeze.py index 0fa63601da5e..04bf6e97cbd7 100644 --- a/python/torch_mlir_e2e_test/test_suite/squeeze.py +++ b/python/torch_mlir_e2e_test/test_suite/squeeze.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 550b96fe0547..157b0c3f6974 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -97,7 +97,7 @@ def __init__(self): ([-1, -1, -1], torch.float64, True), ]) def forward(self, x): - return torch.ops.aten.mean(x, 0, dtype=torch.float32) + return torch.ops.aten.mean(x, (0,), dtype=torch.float32) @register_test_case(module_factory=lambda: MeanDimDtypeModule()) @@ -180,6 +180,45 @@ def forward(self, x): def MeanDimNegativeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, dim=[]) + + +@register_test_case(module_factory=lambda: MeanDimEmptyDimModule()) +def MeanDimEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimNoneDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, dim=None) + + +@register_test_case(module_factory=lambda: MeanDimNoneDimModule()) +def MeanDimNoneDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + # ============================================================================== class VarUnbiasedModule(torch.nn.Module): @@ -256,6 +295,116 @@ def StdBiasedModule_basic(module, tu: TestUtils): # ============================================================================== +class StdDimKeepDimFalseModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=(1, 2), keepdim=False) + + +@register_test_case(module_factory=lambda: StdDimKeepDimFalseModule()) +def StdDimKeepDimFalseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class StdDimKeepDimTrueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True) + + +@register_test_case(module_factory=lambda: StdDimKeepDimFalseModule()) +def StdDimKeepDimTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class StdDimBiasedModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=(0, 2), unbiased=False) + + +@register_test_case(module_factory=lambda: StdDimBiasedModule()) +def StdDimBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class StdDimEmptyDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=[], keepdim=False) + + +@register_test_case(module_factory=lambda: StdDimEmptyDimModule()) +def StdDimEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class StdDimNoneDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=None, keepdim=False) + + +@register_test_case(module_factory=lambda: StdDimNoneDimModule()) +def StdDimNoneDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + class VarDimModule(torch.nn.Module): def __init__(self): @@ -311,7 +460,7 @@ def __init__(self): ([-1, -1, -1], torch.float64, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=0, unbiased=False, keepdim=True) + return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True) @register_test_case(module_factory=lambda: VarDimBiasedModule()) @@ -333,7 +482,7 @@ def __init__(self): ([-1, -1, -1], torch.float64, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=0, keepdim=True) + return torch.ops.aten.var(x, dim=(0,), keepdim=True) @register_test_case(module_factory=lambda: VarDimSingleDimModule()) @@ -410,7 +559,7 @@ def VarDimNegativeModule_basic(module, tu: TestUtils): # ============================================================================== -class VarDimKeepDimFalseModule(torch.nn.Module): +class VarDimEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -421,9 +570,188 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=False) + return torch.ops.aten.var(x, dim=[], keepdim=False) -@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule()) -def VarDimKeepDimFalseModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: VarDimEmptyDimModule()) +def VarDimEmptyDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class VarDimNoneDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, keepdim=False) + + +@register_test_case(module_factory=lambda: VarDimNoneDimModule()) +def VarDimNoneDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class VarCorrectionModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionModule()) +def VarCorrectionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionSingleDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[1], correction=1) + + +@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule()) +def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionAllDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, + dim=[0, 1, 2], + correction=10, + keepdim=False) + + +@register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule()) +def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionKeepDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True) + + +@register_test_case(module_factory=lambda: VarCorrectionKeepDimModule()) +def VarCorrectionKeepDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=None) + + +@register_test_case(module_factory=lambda: VarCorrectionNoneModule()) +def VarCorrectionNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionEmptyDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionEmptyDimModule()) +def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionLargeInputModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[2, 3], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule()) +def VarCorrectionLargeInputModule_basic(module, tu: TestUtils): + module.forward(100 + tu.rand(3, 4, 1024, 8192)) diff --git a/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py b/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py index 1f74c9dc8d9c..1ed41ffc165a 100644 --- a/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py +++ b/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -54,8 +54,7 @@ def forward(self, indices, offsets): @register_test_case(module_factory=lambda: TableBatchEmbeddingModule()) def TableBatchEmbeddingModule_basic(module, tu: TestUtils): - indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,)) + indices = tu.randint(NUM_TABLES * BATCH_SIZE * BAG_SIZE, high=NUM_EMBEDDINGS) offsets = torch.cumsum( torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0) module.forward(indices, offsets) - diff --git a/python/torch_mlir_e2e_test/test_suite/threshold.py b/python/torch_mlir_e2e_test/test_suite/threshold.py index d7a34a89abd3..8efa7e7e2731 100644 --- a/python/torch_mlir_e2e_test/test_suite/threshold.py +++ b/python/torch_mlir_e2e_test/test_suite/threshold.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -27,7 +27,7 @@ def forward(self, input): @register_test_case(module_factory=lambda: Threshold1dIntI32Module()) def Threshold1dIntI32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,), dtype=torch.int32)) + module.forward(tu.randint(4, high=10).to(torch.int32)) class Threshold1dIntModule(torch.nn.Module): @@ -45,7 +45,7 @@ def forward(self, input): @register_test_case(module_factory=lambda: Threshold1dIntModule()) def Threshold1dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,))) + module.forward(tu.randint(4, high=10)) class Threshold2dIntModule(torch.nn.Module): @@ -63,7 +63,7 @@ def forward(self, input): @register_test_case(module_factory=lambda: Threshold2dIntModule()) def Threshold2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class Threshold3dIntModule(torch.nn.Module): @@ -81,7 +81,7 @@ def forward(self, input): @register_test_case(module_factory=lambda: Threshold3dIntModule()) def Threshold3dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5, 6))) + module.forward(tu.randint(4, 5, 6, high=10)) class Threshold1dFloatModule(torch.nn.Module): @@ -154,7 +154,7 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward1dIntModule()) def ThresholdBackward1dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,)), torch.randint(8, (4,))) + module.forward(tu.randint(4, high=10), tu.randint(4, high=8)) class ThresholdBackward2dIntModule(torch.nn.Module): @@ -173,7 +173,7 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward2dIntModule()) def ThresholdBackward2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5))) + module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8)) class ThresholdBackward3dIntModule(torch.nn.Module): @@ -192,7 +192,7 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward3dIntModule()) def ThresholdBackward3dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6))) + module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8)) class ThresholdBackward1dFloatModule(torch.nn.Module): @@ -268,7 +268,7 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule()) def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randn(4), torch.randint(10, (4,))) + module.forward(torch.randn(4), tu.randint(4, high=10)) class ThresholdBackward2dMixedModule(torch.nn.Module): @@ -287,7 +287,7 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule()) def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5)) + module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5)) class ThresholdBackward3dMixedModule(torch.nn.Module): @@ -306,4 +306,4 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule()) def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6))) + module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10)) diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 0d455c3b4a6c..cf41631f56b8 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -57,7 +57,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI32ToI64Module()) def TypeConversionI32ToI64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(5, [2, 3]).type(torch.int32)) + module.forward(tu.randint(2, 3, high=5).type(torch.int32)) class TypeConversionI64ToI32Module(torch.nn.Module): @@ -73,7 +73,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI64ToI32Module()) def TypeConversionI64ToI32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(5, [2, 3])) + module.forward(tu.randint(2, 3, high=5)) class TypeConversionI1ToI32Module(torch.nn.Module): @@ -89,7 +89,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI1ToI32Module()) def TypeConversionI1ToI32Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -106,7 +106,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI1ToI64Module()) def TypeConversionI1ToI64Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -123,7 +123,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI1ToF32Module()) def TypeConversionI1ToF32Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -140,7 +140,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: TypeConversionI1ToF64Module()) def TypeConversionI1ToF64Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) diff --git a/python/torch_mlir_e2e_test/test_suite/type_promotion.py b/python/torch_mlir_e2e_test/test_suite/type_promotion.py index a7a5491c5b2e..f2ff36fd8b42 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_promotion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_promotion.py @@ -5,9 +5,9 @@ import torch -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== @@ -30,8 +30,8 @@ def forward(self, a, b): module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()) def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils): module.forward( - torch.randint(10, [4]).type(torch.int32), - torch.randint(10, [4])) + tu.randint(4, high=10).type(torch.int32), + tu.randint(4, high=10)) class TypePromotionDifferentCategoryModule(torch.nn.Module): @@ -51,7 +51,7 @@ def forward(self, a, b): @register_test_case( module_factory=lambda: TypePromotionDifferentCategoryModule()) def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [4]), torch.randn(4)) + module.forward(tu.randint(4, high=10), torch.randn(4)) class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module): @@ -91,7 +91,7 @@ def forward(self, a, b): @register_test_case( module_factory=lambda: TypePromotionZeroRankHigherCategoryModule()) def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [4]), tu.rand()) + module.forward(tu.randint(4, high=10), tu.rand()) class TypePromotionAlphaWiderModule(torch.nn.Module): @@ -111,5 +111,3 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule()) def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) - - diff --git a/python/torch_mlir_e2e_test/test_suite/vision_models.py b/python/torch_mlir_e2e_test/test_suite/vision_models.py index c06f1d9188fd..43e81847e207 100644 --- a/python/torch_mlir_e2e_test/test_suite/vision_models.py +++ b/python/torch_mlir_e2e_test/test_suite/vision_models.py @@ -6,9 +6,9 @@ import torch import torchvision.models as models -from torch_mlir_e2e_test.torchscript.framework import TestUtils -from torch_mlir_e2e_test.torchscript.registry import register_test_case -from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== diff --git a/python/torch_mlir_e2e_test/torchscript/CMakeLists.txt b/python/torch_mlir_e2e_test/torchscript/CMakeLists.txt deleted file mode 100644 index 9316fdb0944a..000000000000 --- a/python/torch_mlir_e2e_test/torchscript/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -## Declare the sources of the Python module. - -declare_mlir_python_sources(TorchMLIRPythonSources.TorchScriptE2ETest - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES_GLOB - dialects/torch/e2e_test/torchscript/*.py -) diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index eb98371fb754..eefb492f0a45 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -5,8 +5,6 @@ from torch_mlir.ir import * from torch_mlir.passmanager import * -# Imported for side effects. -import torch_mlir.all_passes_registration from torch_mlir.compiler_utils import run_pipeline_with_repro_report from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend diff --git a/python/torch_mlir_e2e_test/utils.py b/python/torch_mlir_e2e_test/utils.py index e69de29bb2d1..403c455cba64 100644 --- a/python/torch_mlir_e2e_test/utils.py +++ b/python/torch_mlir_e2e_test/utils.py @@ -0,0 +1,22 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir import TensorPlaceholder +from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME + +def convert_annotations_to_placeholders(forward_method): + """Converts the annotations on a forward method into tensor placeholders. + + These placeholders are suitable for being passed to `torch_mlir.compile`. + """ + annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) + placeholders = [] + # Skip the "self" annotation. + for annotation in annotations[1:]: + if not annotation[2]: + raise ValueError( + "Can only compile inputs annotated as having value semantics.") + placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) + return placeholders diff --git a/requirements.txt b/requirements.txt index dab7bce187b4..07a76ab6f12a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,9 @@ setuptools cmake ninja +# Workaround for what should be a torch dep +# See discussion in #1174 +pyyaml + # Test Requirements pillow diff --git a/setup.py b/setup.py index 849cbfb61dd7..929c27e8ac72 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,6 @@ def run(self): python_package_dir = os.path.join(cmake_build_dir, "tools", "torch-mlir", "python_packages", "torch_mlir") - if not os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR_ALREADY_BUILT"): src_dir = os.path.abspath(os.path.dirname(__file__)) llvm_dir = os.path.join( @@ -82,7 +81,9 @@ def run(self): f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", + f"-DTORCH_MLIR_ENABLE_LTC={'OFF' if int(os.environ.get('TORCH_MLIR_ENABLE_LTC', 1)) else 'OFF'}", ] + os.makedirs(cmake_build_dir, exist_ok=True) cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") if os.path.exists(cmake_cache_file): diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 628bee11226d..a5098d8aac46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON TORCH_MLIR_ENABLE_JIT_IR_IMPORTER + TORCH_MLIR_ENABLE_MHLO ) configure_lit_site_cfg( diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToArith/basic.mlir similarity index 95% rename from test/Conversion/TorchToStd/basic.mlir rename to test/Conversion/TorchToArith/basic.mlir index f45388e4e08c..c6b2d429d838 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-std | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-arith | FileCheck %s // CHECK-LABEL: func.func @torch.aten.dim( @@ -66,6 +66,19 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.ge.int( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ge.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32> diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 1e0f66582930..fdb6742b1b93 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -220,3 +220,51 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16> return %0 : !torch.vtensor<[?,?],bf16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.neg.f16 +// CHECK: linalg.generic {{.*}} { +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %{{.*}}: f16): +// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f16 +// CHECK-NEXT: linalg.yield %[[NEG]] : f16 +// CHECK-NEXT: } -> tensor +func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { + %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> + return %0 : !torch.vtensor<[?,?],f16> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor +// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> +// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor +// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor +// CHECK: %[[CST0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor +// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor +// CHECK: %[[CST1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor +// CHECK: %[[OUT_T:.*]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] : tensor +// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor, tensor) outs(%[[OUT_T]] : tensor) { +// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32): +// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index +// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index +// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index +// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor +// CHECK: linalg.yield %[[RESULT]] : f32 +// CHECK: } -> tensor +// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor to tensor +// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> { + %none = torch.constant.none + %1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list> + %2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list> -> !torch.vtensor<[?,?,?],f32> + return %2 : !torch.vtensor<[?,?,?],f32> +} diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir new file mode 100644 index 000000000000..ae505146d5b7 --- /dev/null +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -0,0 +1,297 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.clone$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = mhlo.copy %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %none = torch.constant.none + %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> +func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { + %0 = torch.vtensor.literal(dense<0.0> : tensor) : !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> +func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { + %0 = torch.vtensor.literal(dense<1> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic( +// CHECK-SAME: ) -> !torch.vtensor<[],si64> { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> +// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[],si64> +func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> + return %0 : !torch.vtensor<[], si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.contiguous( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reciprocal( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.transpose$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$dynamic_implicit( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int-1 = torch.constant.int -1 +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int8 = torch.constant.int 8 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int8, %int4, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %int8 +// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_3:.*]] : i64 to index +// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %int4 +// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : i64 to index +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex> +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32> +func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { + %int-1 = torch.constant.int -1 + %int4 = torch.constant.int 4 + %int8 = torch.constant.int 8 + %0 = torch.prim.ListConstruct %int8, %int4, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.broadcast_to %arg0, %0 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[8,4,?],f32> + return %1 : !torch.vtensor<[8,4,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$training( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$training( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %false = torch.constant.bool false +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]] = "mhlo.batch_norm_inference"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %false = torch.constant.bool false + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$no_bias_weight( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %2 = torch.aten.batch_norm %arg0, %none, %none, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.none, !torch.none, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> +} + + +// CHECK-LABEL: func @torch.aten.native_layer_norm( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32> +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int5 = torch.constant.int 5 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %true = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64> +// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> +// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32> +func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32> + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %true = torch.constant.bool true + %2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32> + return %result0 : !torch.vtensor<[3,7,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat$convert( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_3]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir new file mode 100644 index 000000000000..b61a61b3bf83 --- /dev/null +++ b/test/Conversion/TorchToMhlo/dropout.mlir @@ -0,0 +1,47 @@ +// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.native_dropout.train( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { +// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[CST_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 +// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 +// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor +// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor +// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor +// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> +// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> +// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> +// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { +// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor +// CHECK: shape.assuming_yield %[[T19]] : tensor +// CHECK: } +// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> +// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor +// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor +// CHECK: return %[[T24]], %[[T26]] : tensor, tensor +func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { + %bool_true = torch.constant.bool true + %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> + return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> +} \ No newline at end of file diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir new file mode 100644 index 000000000000..ae41c3fd65dc --- /dev/null +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -0,0 +1,597 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.gelu( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "none" +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor +// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor +// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor +// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor +// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor +// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %str = torch.constant.str "none" + %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tanh$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.exp$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.neg$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_add %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T8:.*]] = mhlo.reshape %[[T7]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = chlo.broadcast_add %[[T0]], %[[T9]] : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = chlo.broadcast_add %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$promote( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor) -> tensor +// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T8:.*]] = mhlo.reshape %[[T7]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = chlo.broadcast_subtract %[[T0]], %[[T9]] : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = chlo.broadcast_subtract %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_subtract %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$promote( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor) -> tensor +// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.multensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_multiply %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.scalar( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int3 = torch.constant.int 3 + %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.lt.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.eq.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.ne.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.permute$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32> +func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> + return %1 : !torch.vtensor<[64,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.relu( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> +// CHECK: %[[T6:.*]] = mhlo.convert(%[[T5]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T7:.*]] = mhlo.reshape %[[T6]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T8:.*]] = chlo.broadcast_multiply %[[T4]], %[[T7]] : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_add %[[T0]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.add.Scalar %arg0, %arg1, %arg1: !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG2:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.float) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mulscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mul.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.div.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.scalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "trunc" +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor +// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor +// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "floor" +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} diff --git a/test/Conversion/TorchToMhlo/gather.mlir b/test/Conversion/TorchToMhlo/gather.mlir new file mode 100644 index 000000000000..a20b32d4994d --- /dev/null +++ b/test/Conversion/TorchToMhlo/gather.mlir @@ -0,0 +1,66 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.index_select$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32> +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> +func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.index_select %arg0, %int0, %arg1 : !torch.vtensor<[?,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// CHECK-LABEL: func.func @torch.aten.embedding$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],si64> -> tensor +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,?],f32> + return %ret: !torch.vtensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.embedding$rank_two_indices( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> +func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,1], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,1,?],f32> + return %ret: !torch.vtensor<[?,1,?],f32> +} + diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir new file mode 100644 index 000000000000..bad66a84dbd9 --- /dev/null +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -0,0 +1,501 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.mm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32> +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> +func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32> -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> +func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[10,4,5],f32> -> !torch.vtensor<[10,3,5],f32> + return %0 : !torch.vtensor<[10,3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,4],f32>, !torch.vtensor<[?,4,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> +func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256,120],f32>, !torch.vtensor<[4,120,256],f32> -> !torch.vtensor<[4,256,256],f32> + return %0 : !torch.vtensor<[4,256,256],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> +func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[4,?,?],f32> + return %0 : !torch.vtensor<[4,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$3dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32> +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> +func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[1,?],f32> + return %0 : !torch.vtensor<[1,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx3d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[?,256,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$2dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx2d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[],f32> +func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> +func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[?,?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,?,256],f32> + return %1 : !torch.vtensor<[?,?,256],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mm$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> +func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,256],f32> + return %1 : !torch.vtensor<[?,256],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor +// CHECK: %[[T_2:.*]] = torch.constant.none +// CHECK: %[[T_4:.*]] = torch.constant.int 2 +// CHECK: %[[T_5:.*]] = torch.constant.int 1 +// CHECK: %[[T_6:.*]] = torch.constant.int 4 +// CHECK: %[[T_7:.*]] = torch.constant.int 3 +// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]] +// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[T_13:.*]] = torch.constant.bool false +// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %none = torch.constant.none + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %false = torch.constant.bool false + %5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$bias( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, +// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor +// CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3 +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %false = torch.constant.bool false +// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor +// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 +// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> +// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor +// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %false = torch.constant.bool false + %5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.vtensor<[?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> +// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> +// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> +func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.convolution %arg0, %arg1, %none, %1, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,9,9],f32> + return %2 : !torch.vtensor<[1,4,9,9],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> +func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> + return %3 : !torch.vtensor<[1,4,15,15],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> +// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> +// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> +func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %1, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,16,16],f32> + return %3 : !torch.vtensor<[1,4,16,16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32> +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[T_9:.*]] = tensor.dim %[[T_6]], %[[IDX_1]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[T_11:.*]] = tensor.dim %[[T_6]], %[[IDX_2]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_12:.*]] = arith.index_cast %[[T_11]] : index to i64 +// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index +// CHECK: %[[T_13:.*]] = tensor.dim %[[T_6]], %[[IDX_3]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_14:.*]] = arith.index_cast %[[T_13]] : index to i64 +// CHECK: %[[T_24:.*]] = arith.constant 2 : i64 +// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 +// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 +// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> +// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> +// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> +// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> +// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> +// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> +func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> + return %3 : !torch.vtensor<[1,4,15,15],f32> +} diff --git a/test/Conversion/TorchToMhlo/lit.local.cfg b/test/Conversion/TorchToMhlo/lit.local.cfg new file mode 100644 index 000000000000..829a5662f6e6 --- /dev/null +++ b/test/Conversion/TorchToMhlo/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_mhlo: + config.unsupported = True diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToMhlo/pooling.mlir new file mode 100644 index 000000000000..00c918af5b0b --- /dev/null +++ b/test/Conversion/TorchToMhlo/pooling.mlir @@ -0,0 +1,218 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %false = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: mhlo.return %[[VAL_10]] : tensor +// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d$padding( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: mhlo.return %[[VAL_10]] : tensor +// CHECK: }) +// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.max_pool2d %arg0, %0, %1, %2, %2, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool2d_with_indices( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %false = torch.constant.bool false +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor +// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i64 +// CHECK: %[[VAL_13:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]] : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i64 +// CHECK: %[[VAL_15:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_14]] : tensor<2xi64> +// CHECK: %[[VAL_16:.*]] = "mhlo.dynamic_iota"(%[[VAL_15]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_16]], %[[VAL_13]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor, %[[IVAL_2:.*]]: tensor, %[[IVAL_3:.*]]: tensor): +// CHECK: %[[IVAL_4:.*]] = mhlo.compare GE, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[IVAL_6:.*]] = mhlo.compare EQ, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor +// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor, tensor, tensor) -> tensor +// CHECK: mhlo.return %[[IVAL_5]], %[[IVAL_9]] : tensor, tensor +// CHECK{LITERAL}: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> +// CHECK: return %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> +func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %false = torch.constant.bool false + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> + return %result0, %result1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool2d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): +// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: mhlo.return %[[IVAL_2]] : tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64 +// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor +// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 +// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): +// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor +// CHECK: mhlo.return %[[IVAL_5]] : tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool2d$count_include_pad( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): +// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: mhlo.return %[[IVAL_2]] : tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<9> : tensor +// CHECK: %[[VAL_8:.*]] = mhlo.convert(%[[VAL_7]]) : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = chlo.broadcast_divide %[[VAL_6]], %[[VAL_8]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %true = torch.constant.bool true + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?,?],f32> +} diff --git a/test/Conversion/TorchToMhlo/reduction.mlir b/test/Conversion/TorchToMhlo/reduction.mlir new file mode 100644 index 000000000000..fb8545b78193 --- /dev/null +++ b/test/Conversion/TorchToMhlo/reduction.mlir @@ -0,0 +1,243 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.max.dim$keepdim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %true = torch.constant.bool true +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64> +// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_16]], %[[VAL_20]] : tensor, tensor +// CHECK: } +// CHECK: %[[VAL_21:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_21]] : tensor<2xi64> +// CHECK: %[[VAL_23:.*]] = mhlo.dynamic_reshape %[[VAL_10]]#0, %[[VAL_22]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[VAL_24:.*]] = mhlo.dynamic_reshape %[[VAL_10]]#1, %[[VAL_22]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor -> !torch.vtensor<[?,1],f32> +// CHECK: %[[VAL_26:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor -> !torch.vtensor<[?,1],si64> +// CHECK: return %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> + +func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> + return %values, %indices : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.max.dim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %false = torch.constant.bool false +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64> +// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_16]], %[[VAL_20]] : tensor, tensor +// CHECK: } +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#0 : tensor -> !torch.vtensor<[?],f32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#1 : tensor -> !torch.vtensor<[?],si64> +// CHECK: return %[[VAL_21]], %[[VAL_22]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> +func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %values, %indices = torch.aten.max.dim %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> + return %values, %indices : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.argmax$keepdim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %true = torch.constant.bool true +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64> +// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_16]], %[[VAL_20]] : tensor, tensor +// CHECK: } +// CHECK: %[[VAL_21:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_21]] : tensor<2xi64> +// CHECK: %[[VAL_23:.*]] = mhlo.dynamic_reshape %[[VAL_10]]#1, %[[VAL_22]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor -> !torch.vtensor<[?,1],si64> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[?,1],si64> +func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> { + %int1 = torch.constant.int 1 + %true = torch.constant.bool true + %indices = torch.aten.argmax %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],si64> + return %indices : !torch.vtensor<[?,1],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.argmax( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %false = torch.constant.bool false +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[VAL_8:.*]] = tensor.from_elements %[[VAL_3]], %[[VAL_5]] : tensor<2xi64> +// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_16]], %[[VAL_20]] : tensor, tensor +// CHECK: } +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]]#1 : tensor -> !torch.vtensor<[?],si64> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?],si64> +func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> { + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %indices = torch.aten.argmax %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],si64> + return %indices : !torch.vtensor<[?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist$keepdim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = mhlo.reduce(%[[VAL_1:.*]] init: %[[VAL_3:.*]]) applies mhlo.add across dimensions = [0, 1] : (tensor, tensor) -> tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor +// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64 +// CHECK: %[[ONE_0:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_11:.*]] = tensor.from_elements %[[ONE_0]], %[[ONE_0]], %[[VAL_10]] : tensor<3xi64> +// CHECK: %[[VAL_12:.*]] = mhlo.dynamic_reshape %[[VAL_4]], %[[VAL_11]] : (tensor, tensor<3xi64>) -> tensor<1x1x?xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x?xf32> -> !torch.vtensor<[1,1,?],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,?],f32> +func.func @torch.aten.sum.dim_Intlist$keepdim(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.sum.dim_IntList %arg0, %0, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,?],f32> + return %1 : !torch.vtensor<[1,1,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[VAL_01:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %false = torch.constant.bool false +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_3]]) applies mhlo.add across dimensions = [0, 1] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> +func.func @torch.aten.sum.dim_Intlist(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[?],f32> + return %1 : !torch.vtensor<[?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sum( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[],f32> +func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.reduce(%[[VAL_1]] init: %[[VAL_2]]) applies mhlo.maximum across dimensions = [0, 1, 2] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[],f32> +func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir new file mode 100644 index 000000000000..41d84c76208e --- /dev/null +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -0,0 +1,583 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor<4x65x256xf32> +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[2,65,256],f32> +func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,65,256],f32> + return %0 : !torch.vtensor<[2,65,256],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor -> !torch.vtensor<[?,1,?],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[?,1,?],f32> +func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1,?],f32> + return %0 : !torch.vtensor<[?,1,?],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor<4x65x256xf32> +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[4,1,256],f32> +func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,256],f32> + return %0 : !torch.vtensor<[4,1,256],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.none$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1_0:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 +// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> +// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %none = torch.constant.none + %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.none.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1_0:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor<4x65x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 +// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> +// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[4,33,256],f32> +func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %none = torch.constant.none + %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[4,33,256],f32> + return %0 : !torch.vtensor<[4,33,256],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT224:.*]] = torch.constant.int 224 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT224]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T8:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[T7]] : index, tensor<2xi64> -> tensor<2xi64> +// CHECK: %[[T9:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,224],f32> +func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { + %int-1 = torch.constant.int -1 + %int224 = torch.constant.int 224 + %0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,224],f32> + return %1 : !torch.vtensor<[?,224],f32> +} + +// CHECK-LABEL: func.func @torch.aten.reshape$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT120:.*]] = torch.constant.int 120 +// CHECK: %[[INT4:.*]] = torch.constant.int 4 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT120]], %[[INT4]], %[[INT64]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]] +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]] +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<4xi64> -> tensor<4xi64> +// CHECK: %[[T13:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T12]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[?,120,4,64],f32> +func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { + %int-1 = torch.constant.int -1 + %int120 = torch.constant.int 120 + %int4 = torch.constant.int 4 + %int64 = torch.constant.int 64 + %0 = torch.prim.ListConstruct %int-1, %int120, %int4, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reshape %arg0, %0 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,120,4,64],f32> + return %1 : !torch.vtensor<[?,120,4,64],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$minus1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]] +// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64 +// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]] +// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]] +// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[C1_I64]], %[[T4]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T5]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T6]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T4]], %[[T5]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<3xi64> -> tensor<3xi64> +// CHECK: %[[T13:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T12]] : (tensor<2x3x?x?xf32>, tensor<3xi64>) -> tensor<2x3x?xf32> +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[2,3,?],f32> +func.func @torch.aten.view$minus1(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list -> !torch.vtensor<[2,3,?],f32> + return %3 : !torch.vtensor<[2,3,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$to_rank1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[1],f32> +func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[1],f32> + return %1 : !torch.vtensor<[1],f32> +} +// CHECK-LABEL: func.func @torch.aten.view$to_rank0( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[],f32> +func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + return %1 : !torch.vtensor<[],f32> +} +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> +// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> +func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> + return %0 : !torch.vtensor<[2,1,2,1,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T9]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,1,?],f32> +func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,1,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T9]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?],f32> +func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.squeeze.dim %arg0, %int-2 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> +// CHECK: return %[[T9]] : !torch.vtensor<[2,2,2],f32> +func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { + %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[2,1,2,1,2],f32> -> !torch.vtensor<[2,2,2],f32> + return %0 : !torch.vtensor<[2,2,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$0( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T9]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[1,?,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[1,?,?,?,?],f32> + return %0 : !torch.vtensor<[1,?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T9]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[C1_I64]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T9]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,?,1,?],f32> +func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,?,1,?],f32> +} diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir index 33fd2ee85fef..df68492b5e7f 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir @@ -2,26 +2,22 @@ // Basic case. -// CHECK-LABEL: torch.global_slot @b : !torch.bool { -// CHECK: %[[INIT:.*]] = torch.constant.bool true -// CHECK: torch.global_slot.init %[[INIT]] : !torch.bool +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[FLOAT4:.*]] = torch.constant.float 4.250000e+01 +// CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +// CHECK: torch.initialize.global_slots [ +// CHECK: @b(%[[TRUE]] : !torch.bool) +// CHECK: @i(%[[INT3]] : !torch.int) +// CHECK: @f(%[[FLOAT4]] : !torch.float) +// CHECK: @t(%[[TENSOR]] : !torch.tensor) +// CHECK: ] // CHECK: } - -// CHECK-LABEL: torch.global_slot @i : !torch.int { -// CHECK: %[[INIT:.*]] = torch.constant.int 3 -// CHECK: torch.global_slot.init %[[INIT]] : !torch.int -// CHECK: } - -// CHECK-LABEL: torch.global_slot @f : !torch.float { -// CHECK: %[[INIT:.*]] = torch.constant.float 4.250000e+01 -// CHECK: torch.global_slot.init %[[INIT]] : !torch.float -// CHECK: } - -// CHECK-LABEL: torch.global_slot @t : !torch.tensor { -// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -// CHECK: torch.global_slot.init %[[T]] : !torch.tensor -// CHECK: } - +// CHECK-LABEL: torch.global_slot @b : !torch.bool +// CHECK-LABEL: torch.global_slot @i : !torch.int +// CHECK-LABEL: torch.global_slot @f : !torch.float +// CHECK-LABEL: torch.global_slot @t : !torch.tensor torch.class_type @c { torch.attr "b" : !torch.bool torch.attr "i" : !torch.int diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir index 8bf281c180c0..1e7d0adf2317 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir @@ -35,43 +35,3 @@ func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">, %2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float return } - -// ----- - -torch.class_type @c { - torch.attr "t1" : !torch.tensor - torch.attr "t2" : !torch.tensor -} - -// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}} -%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -torch.nn_module { - torch.slot "t1", %t : !torch.tensor - torch.slot "t2", %t : !torch.tensor -} : !torch.nn.Module<"c"> -func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor { - %t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor - %t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor - %cst = torch.constant.int 1 - %ret = torch.aten.add.Tensor %t1, %t2, %cst : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- - -torch.class_type @c { - torch.attr "t1" : !torch.tensor - torch.attr "t2" : !torch.tensor -} - -// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}} -%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -torch.nn_module { - torch.slot "t1", %t : !torch.tensor - torch.slot "t2", %t : !torch.tensor -} : !torch.nn.Module<"c"> -func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) { - torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor - torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor - return -} diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir index c48097316644..a13fa5d6b2f1 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir @@ -1,13 +1,16 @@ // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s -// CHECK that multiple nested initialization ops are properly handled. +// Check that multiple nested initialization ops are properly handled. -// CHECK-LABEL: torch.global_slot @l : !torch.list>> { +// CHECK-LABEL: torch.global_slot.module_initializer { // CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list, !torch.list) -> !torch.list> // CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list>, !torch.list>) -> !torch.list>> -// CHECK: torch.global_slot.init %[[L2]] : !torch.list>> +// CHECK: torch.initialize.global_slots [ +// CHECK: @l(%[[L2]] : !torch.list>>) +// CHECK: ] // CHECK: } +// CHECK-LABEL: torch.global_slot @l : !torch.list>> torch.class_type @c { torch.attr "l" : !torch.list>> diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir index 6d5e94cdf01b..6fc4ac1988e5 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir @@ -12,20 +12,22 @@ torch.class_type @__torch__.Submodule { torch.method private "forward", @__torch__.Submodule.forward } +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: torch.initialize.global_slots [ +// CHECK: @s1.n(%[[INT1]] : !torch.int) +// CHECK: @s2.n(%[[INT2]] : !torch.int) +// CHECK: ] +// CHECK: } +// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int +// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int %int1 = torch.constant.int 1 %s1 = torch.nn_module { - // CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int { - // CHECK: %[[C1:.*]] = torch.constant.int 1 - // CHECK: torch.global_slot.init %[[C1]] : !torch.int - // CHECK: } torch.slot "n", %int1 : !torch.int } : !torch.nn.Module<"__torch__.Submodule"> %int2 = torch.constant.int 2 %s2 = torch.nn_module { - // CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int { - // CHECK: %[[C2:.*]] = torch.constant.int 2 - // CHECK: torch.global_slot.init %[[C2]] : !torch.int - // CHECK: } torch.slot "n", %int2 : !torch.int } : !torch.nn.Module<"__torch__.Submodule"> %3 = torch.nn_module { diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances.mlir index 149e4e5288f9..a5fbdfad24d5 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances.mlir @@ -10,20 +10,23 @@ torch.class_type @__torch__.Submodule { torch.method private "forward", @__torch__.Submodule.forward } +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: torch.initialize.global_slots [ +// CHECK: @s1.n(%[[INT1]] : !torch.int) +// CHECK: @s2.n(%[[INT2]] : !torch.int) +// CHECK: ] +// CHECK: } +// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int +// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int + %int1 = torch.constant.int 1 %s1 = torch.nn_module { - // CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int { - // CHECK: %[[C1:.*]] = torch.constant.int 1 - // CHECK: torch.global_slot.init %[[C1]] : !torch.int - // CHECK: } torch.slot "n", %int1 : !torch.int } : !torch.nn.Module<"__torch__.Submodule"> %int2 = torch.constant.int 2 %s2 = torch.nn_module { - // CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int { - // CHECK: %[[C2:.*]] = torch.constant.int 2 - // CHECK: torch.global_slot.init %[[C2]] : !torch.int - // CHECK: } torch.slot "n", %int2 : !torch.int } : !torch.nn.Module<"__torch__.Submodule"> %3 = torch.nn_module { diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir index 0b35e3cb22c6..eacd36493791 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir @@ -2,10 +2,13 @@ // Check that linkage names consist of the dotted path from the root. -// CHECK-LABEL: torch.global_slot @m.float : !torch.float { -// CHECK: %[[INIT:.*]] = torch.constant.float 4.200000e+01 -// CHECK: torch.global_slot.init %[[INIT]] : !torch.float +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01 +// CHECK: torch.initialize.global_slots [ +// CHECK: @m.float(%[[FLOAT]] : !torch.float) +// CHECK: ] // CHECK: } +// CHECK-LABEL: torch.global_slot @m.float : !torch.float torch.class_type @child { diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index d6e68129ec50..36aa26f1058a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -635,6 +635,53 @@ func.func @torch.aten.__getitem__.t$invalid_index() -> !torch.int { return %1 : !torch.int } +// Not canonicalized because of mutated lhs list +// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() +func.func @torch.aten.add.t$no_canonicalize_lhs_mutated() -> !torch.list { + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.append.t %0, %int4 : !torch.list, !torch.int -> !torch.list + // CHECK: torch.aten.add.t + %3 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %3 : !torch.list +} + +// Not canonicalized because of mutated rhs list +// CHECK-LABEL: func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() +func.func @torch.aten.add.t$no_canonicalize_rhs_mutated() -> !torch.list { + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.append.t %1, %int4 : !torch.list, !torch.int -> !torch.list + // CHECK: torch.aten.add.t + %3 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %3 : !torch.list +} + +// CHECK-LABEL: func.func @torch.aten.add.t$concat( +// CHECK-SAME: %[[ARG0:.*]]: !torch.int, +// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.list { +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.add.t$concat(%arg0: !torch.int, %arg1: !torch.int) -> !torch.list { + %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %2 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %2 : !torch.list +} + +// CHECK-LABEL: func.func @torch.aten.add.t$concat_empty( +// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.list { +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG0]] : (!torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.add.t$concat_empty(%arg0: !torch.int) -> !torch.list { + %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct : () -> !torch.list + %2 = torch.aten.add.t %0, %1 : !torch.list, !torch.list -> !torch.list + return %2 : !torch.list +} + // CHECK-LABEL: func.func @torch.aten.eq.int_list$fold$literals_of_different_sizes // CHECK: %[[RET:.*]] = torch.constant.bool false // CHECK: return %[[RET]] : !torch.bool @@ -1357,3 +1404,220 @@ func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch. %size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int return %size : !torch.int } + +// CHECK-LABEL: func.func @prim.ListUnpack$fold_list( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) { +// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> +func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) { + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list + %1:2 = torch.prim.ListUnpack %0 : !torch.list -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> + return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %str = torch.constant.str "floor" + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %int2 = torch.constant.int 2 + %str = torch.constant.str "floor" + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.add.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.add.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.sub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.mul.Scalar %0, %int3 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.mul.Scalar %0, %int3 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.mul.Tensor %0, %1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int3 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.mul.Tensor %0, %1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %int2 = torch.constant.int 2 + %str = torch.constant.str "trunc" + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %str = torch.constant.str "trunc" + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} \ No newline at end of file diff --git a/test/Dialect/Torch/decompose-complex-ops-legal.mlir b/test/Dialect/Torch/decompose-complex-ops-legal.mlir new file mode 100644 index 000000000000..261ae8c96ba2 --- /dev/null +++ b/test/Dialect/Torch/decompose-complex-ops-legal.mlir @@ -0,0 +1,10 @@ +// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim +func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { + %none = torch.constant.none + %dim = torch.constant.int 1 + // CHECK: torch.aten.softmax.int + %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> + return %ret : !torch.tensor<[2,3],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index d7f058874475..9cd2d18538bc 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -202,8 +202,10 @@ func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : -// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: %[[CST:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[CST]]-1 : (!torch.int) -> !torch.list +// CHECK: %[[FLATTEN:.*]] = torch.aten.view %[[INP]], %[[T0]] : +// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> // CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : // CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64> // CHECK: return %[[IND]] : !torch.vtensor<[],si64> @@ -233,28 +235,36 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true %0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -269,26 +279,34 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false %0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -303,28 +321,36 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true @@ -340,26 +366,34 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false @@ -1081,8 +1115,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch. // CHECK-LABEL: func @torch.aten.floor_divide( // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor" -// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> +// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc" +// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> @@ -1165,22 +1199,30 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list // CHECK: %[[UNBIASED:.*]] = torch.constant.bool false // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none // CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,1],f32>, !torch.float -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,7],f32> -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> -// CHECK: return %[[VAR]] : !torch.vtensor<[3,4,1],f32> +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { %int2 = torch.constant.int 2 %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list @@ -1189,3 +1231,154 @@ func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vten %0 = torch.aten.var.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32> return %0 : !torch.vtensor<[3,4,1],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.softplus( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> { +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.aten.mul.Scalar %[[VAL_0]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32> +// CHECK: %[[VAL_4:.*]] = torch.aten.exp %[[VAL_3]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32> +// CHECK: %[[VAL_5:.*]] = torch.aten.log1p %[[VAL_4]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32> +// CHECK: %[[VAL_6:.*]] = torch.aten.div.Scalar %[[VAL_5]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32> +// CHECK: %[[VAL_7:.*]] = torch.aten.gt.Scalar %[[VAL_3]], %[[VAL_2]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],i1> +// CHECK: %[[VAL_8:.*]] = torch.aten.where.self %[[VAL_7]], %[[VAL_0]], %[[VAL_6]] : !torch.tensor<[2,3],i1>, !torch.tensor<[2,3],f32>, !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32> +// CHECK: return %[[VAL_8]] : !torch.tensor<[2,3],f32> +// CHECK: } +func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> { + %int0 = torch.constant.int 0 + %ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32> + return %ret : !torch.tensor<[2,3],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.var.correction( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1_0:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[CST2_0:.*]] = torch.constant.int 2 +// CHECK: %[[NUM_ELEMENTS_PLUS_ONE:.*]] = torch.aten.add.int %[[NUM_ELEMENTS_0]], %[[CST1_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[PRED:.*]] = torch.aten.ge.int %[[NUM_ELEMENTS_PLUS_ONE]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED]], "correction value should be less than or equal to productDimSize + 1" +// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> +func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { + %int2 = torch.constant.int 2 + %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %keepdim = torch.constant.bool true + %0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32> + return %0 : !torch.vtensor<[3,4,1],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.std.dim( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list +// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,5],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1_0:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[STD:.*]] = torch.aten.sqrt %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[STD]] : !torch.vtensor<[3,4,1],f32> +func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> { + %int2 = torch.constant.int 2 + %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %unbiased = torch.constant.bool false + %keepdim = torch.constant.bool true + %0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32> + return %0 : !torch.vtensor<[3,4,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.flatten.using_ints( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[INT]]-1 : (!torch.int) -> !torch.list +// CHECK: %[[T1:.*]] = torch.aten.view %[[ARG0]], %[[T0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +// CHECK: return %[[T1]] : !torch.vtensor<[?],f32> +func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + return %1 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.roll( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1_0:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int +// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int +// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> +// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list + %int1 = torch.constant.int 1 + %int-2 = torch.constant.int -2 + %1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> + return %2 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Dialect/Torch/erase-module-initializer.mlir b/test/Dialect/Torch/erase-module-initializer.mlir new file mode 100644 index 000000000000..c0bbbdbbddb0 --- /dev/null +++ b/test/Dialect/Torch/erase-module-initializer.mlir @@ -0,0 +1,8 @@ +// RUN: torch-mlir-opt -torch-erase-module-initializer -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: } +torch.global_slot.module_initializer { + torch.initialize.global_slots [ + ] +} diff --git a/test/Dialect/Torch/inline-global-slots-analysis.mlir b/test/Dialect/Torch/inline-global-slots-analysis.mlir new file mode 100644 index 000000000000..da73bec23c47 --- /dev/null +++ b/test/Dialect/Torch/inline-global-slots-analysis.mlir @@ -0,0 +1,94 @@ +// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s + +// Safety analysis aspect of the pass. + +// ----- + +// Test case: Public slots cannot be inlined. +// Test case: Set slots cannot be inlined. + +// CHECK: torch.global_slot @public : !torch.int +// CHECK: torch.global_slot "private" @set : !torch.int +torch.global_slot @public : !torch.int +torch.global_slot "private" @set : !torch.int + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: %[[C1:.*]] = torch.constant.int 1 +// CHECK: torch.initialize.global_slots [ +// CHECK: @public(%[[C1]] : !torch.int) +// CHECK: @set(%[[C1]] : !torch.int) +// CHECK: ] +// CHECK: } +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + torch.initialize.global_slots [ + @public(%0 : !torch.int) + @set(%0 : !torch.int) + ] +} + +func.func @forward() { + %0 = torch.constant.int 2 + torch.global_slot.set @set = %0 : !torch.int + return +} + +// ----- + +// Test case: Propagate safety transitively through ops without HasValueSemantics. +torch.global_slot "private" @tensor : !torch.tensor +torch.global_slot "private" @list : !torch.list + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: torch.initialize.global_slots [ +// CHECK-NEXT ] +torch.global_slot.module_initializer { + %0 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list + torch.initialize.global_slots [ + @tensor(%0 : !torch.tensor) + @list(%1 : !torch.list) + ] +} + +func.func @forward() { + %int0 = torch.constant.int 0 + %0 = torch.global_slot.get @list : !torch.list + %1 = torch.aten.__getitem__.t %0, %int0 : !torch.list, !torch.int -> !torch.tensor + %2 = torch.aten.mul.Tensor %1, %1 : !torch.tensor, !torch.tensor -> !torch.tensor + return +} + +// ----- + + +// Test case: An unsafe subobject (@tensor) blocks inlining of the containing object (@list). +// Note that we can check just the initializer -- if we inlined the slot, then +// we would have eliminated the slot from the initializer. +// Also, the initializer is verified to match the list of global slots in the +// module. So it is a nice one-stop-shop. + +torch.global_slot "private" @tensor : !torch.tensor +torch.global_slot "private" @list : !torch.list + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: torch.initialize.global_slots [ +// CHECK-NEXT: @tensor(%{{.*}} : !torch.tensor) +// CHECK-NEXT: @list(%{{.*}} : !torch.list) +// CHECK-NEXT: ] +torch.global_slot.module_initializer { + %0 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list + torch.initialize.global_slots [ + @tensor(%0 : !torch.tensor) + @list(%1 : !torch.list) + ] +} + +func.func @forward() { + %int0 = torch.constant.int 0 + %0 = torch.global_slot.get @list : !torch.list + %tensor = torch.global_slot.get @tensor : !torch.tensor + torch.aten.relu_ %tensor : !torch.tensor -> !torch.tensor + return +} diff --git a/test/Dialect/Torch/inline-global-slots-transform.mlir b/test/Dialect/Torch/inline-global-slots-transform.mlir new file mode 100644 index 000000000000..5bb42c6bebc0 --- /dev/null +++ b/test/Dialect/Torch/inline-global-slots-transform.mlir @@ -0,0 +1,81 @@ +// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s + +// Transform aspect of the pass. + +// Test case: Most basic case that can be inlined. + +// CHECK-NOT: @slot0 +torch.global_slot "private" @slot0 : !torch.int + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: torch.initialize.global_slots [ +// CHECK-NEXT ] +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + torch.initialize.global_slots [ + @slot0(%0 : !torch.int) + ] +} + +// CHECK-LABEL: func.func @forward() { +// CHECK: %[[C1:.*]] = torch.constant.int 1 +// CHECK: return +func.func @forward() { + %0 = torch.global_slot.get @slot0 : !torch.int + return +} + +// ----- + +// Test case: Shared objects in object graph shared between two initial values. + +torch.global_slot "private" @tensor : !torch.tensor +torch.global_slot "private" @list_of_tensor : !torch.list + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: torch.initialize.global_slots [ +// CHECK-NEXT ] +torch.global_slot.module_initializer { + %0 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list + torch.initialize.global_slots [ + @tensor(%0 : !torch.tensor) + @list_of_tensor(%1 : !torch.list) + ] +} + +// CHECK-LABEL: func.func @forward() { +// CHECK: %[[T0:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor) : !torch.tensor +// CHECK: %[[T1:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor) : !torch.tensor +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[T1]] : (!torch.tensor) -> !torch.list +// CHECK: return +func.func @forward() { + %0 = torch.global_slot.get @tensor : !torch.tensor + %1 = torch.global_slot.get @list_of_tensor : !torch.tensor + return +} + +// ----- + +// Test case: Adjusting static info. + +// CHECK-NOT: @tensor +torch.global_slot "private" @tensor : !torch.tensor + +// CHECK-LABEL: torch.global_slot.module_initializer { +// CHECK: torch.initialize.global_slots [ +// CHECK-NEXT ] +torch.global_slot.module_initializer { + %0 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor<[],f32> + torch.initialize.global_slots [ + @tensor(%0 : !torch.tensor<[],f32>) + ] +} + +// CHECK-LABEL: func.func @forward() { +// CHECK: %[[T:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor) : !torch.tensor<[],f32> +// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.tensor<[],f32> to !torch.tensor +func.func @forward() { + %0 = torch.global_slot.get @tensor : !torch.tensor + return +} diff --git a/test/Dialect/Torch/inline-global-slots.mlir b/test/Dialect/Torch/inline-global-slots.mlir deleted file mode 100644 index 0d86a814ceff..000000000000 --- a/test/Dialect/Torch/inline-global-slots.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s - -// CHECK-NOT: @readonly -torch.global_slot "private" @readonly : !torch.tensor { - %0 = torch.tensor.literal(dense<0.0> : tensor<1xf32>) : !torch.tensor - torch.global_slot.init %0 : !torch.tensor -} -// CHECK-LABEL: torch.global_slot @public -torch.global_slot @public : !torch.tensor { - %0 = torch.tensor.literal(dense<0.0> : tensor<2xf32>) : !torch.tensor - torch.global_slot.init %0 : !torch.tensor -} -// CHECK-LABEL: torch.global_slot "private" @mutated -torch.global_slot "private" @mutated : !torch.tensor { - %0 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor - torch.global_slot.init %0 : !torch.tensor -} - -// CHECK-LABEL: func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) { -func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) { - // Inlined. - // CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor - %0 = torch.global_slot.get @readonly : !torch.tensor - - // Not inlined: potentially mutated by externals. - // CHECK: %[[PUBLIC:.*]] = torch.global_slot.get @public : !torch.tensor - %1 = torch.global_slot.get @public : !torch.tensor - - // Not inlined: potentially mutated internally. - // CHECK: torch.global_slot.set @mutated = %[[READONLY]] : !torch.tensor - // CHECK: %[[MUTATED:.*]] = torch.global_slot.get @mutated : !torch.tensor - torch.global_slot.set @mutated = %0 : !torch.tensor - %2 = torch.global_slot.get @mutated : !torch.tensor - - // CHECK: return %[[READONLY]], %[[PUBLIC]], %[[MUTATED]] : !torch.tensor, !torch.tensor, !torch.tensor - return %0, %1, %2 : !torch.tensor, !torch.tensor, !torch.tensor -} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index d960713f216c..84e7d63dae43 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -179,3 +179,89 @@ func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1 %1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32> return %1 : !torch.vtensor<[1],f32> } + +// ----- + +// There must be only one module initialize. + +torch.global_slot.module_initializer { + torch.initialize.global_slots [ + ] +} + +// expected-error @+1 {{there must be only one global slot initializer}} +torch.global_slot.module_initializer { + torch.initialize.global_slots [ + ] +} + +// ----- + +// Initialized slot missing, or or non-existent slots initialized. + +// expected-note @+1 {{missing global slot initializer for @slot0}} +torch.global_slot @slot0 : !torch.int +// expected-note @+1 {{missing global slot initializer for @slot1}} +torch.global_slot @slot1 : !torch.int + +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + %1 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + %2 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor<[],unk> + // expected-error @below {{must have one initializer for each global slot in the module}} + // expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot0}} + // expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot1}} + torch.initialize.global_slots [ + @nonexistent_slot0(%0 : !torch.int) + @nonexistent_slot1(%0 : !torch.int) + ] +} + +// ----- + +// Duplicate initialization of global slot. + +torch.global_slot @slot0 : !torch.int + +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + // expected-error @+1 {{duplicate initialization of global slot: @slot0}} + torch.initialize.global_slots [ + @slot0(%0 : !torch.int) + @slot0(%0 : !torch.int) + ] +} + +// ----- + +// Subtyping checks. + +torch.global_slot @tensor : !torch.tensor +torch.global_slot @initialized_with_refined : !torch.tensor +torch.global_slot @error_initialized_with_derefined : !torch.tensor<[],unk> + +torch.global_slot.module_initializer { + %1 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + %2 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor<[],unk> + // expected-error @below {{initial value for global slot @error_initialized_with_derefined has type '!torch.tensor' which is not within the bound '!torch.tensor<[],unk>'}} + torch.initialize.global_slots [ + @tensor(%1 : !torch.tensor) + @initialized_with_refined(%2 : !torch.tensor<[],unk>) + @error_initialized_with_derefined(%1 : !torch.tensor) + ] +} + +// ----- + +// Restricted set of ops in the module initializer. + +torch.global_slot @tensor : !torch.tensor + +torch.global_slot.module_initializer { + %0 = torch.tensor.literal(dense<0.0> : tensor) : !torch.tensor + // expected-error @+1 {{'torch.aten.mul.Tensor' op is not allowed in a module initializer}} + %1 = torch.aten.mul.Tensor %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor + torch.initialize.global_slots [ + @tensor(%1 : !torch.tensor) + ] +} diff --git a/test/Dialect/Torch/lower-to-backend-contract-error.mlir b/test/Dialect/Torch/lower-to-backend-contract-error.mlir new file mode 100644 index 000000000000..824f3ae23467 --- /dev/null +++ b/test/Dialect/Torch/lower-to-backend-contract-error.mlir @@ -0,0 +1,61 @@ +// RUN: torch-mlir-opt -torch-lower-to-backend-contract -split-input-file -verify-diagnostics %s + +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + // expected-error @+2 {{unsupported by backend contract: module initializers}} + // expected-note @+1 {{this is likely due to}} + torch.initialize.global_slots [ + @slot0(%0 : !torch.int) + ] +} +torch.global_slot @slot0 : !torch.int + + +// ----- + +// expected-error @+2 {{unsupported by backend contract: non-value tensor type}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.tensor) { + return +} + +// ----- + +// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.vtensor<*,f32>) { + return +} + +// ----- + +// expected-error @+2 {{unsupported by backend contract: tensor with unknown dtype}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.vtensor<[],unk>) { + return +} + +// ----- + +// expected-error @+1 {{unsupported by backend contract: type '!torch.any'}} +func.func @f(%arg0: !torch.any) { + return +} + +// ----- + +// Test case: checking of op results. +// TODO: In theory we could diagnose every single value, but for now we bail out on the first one. + +func.func @f(%arg0: !torch.bool, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[7],f32>) -> !torch.vtensor<*,f32> { + // expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}} + // expected-note @+1 {{this is likely due to}} + %0 = torch.prim.If %arg0 -> (!torch.vtensor<*,f32>) { + %1 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32> + torch.prim.If.yield %1 : !torch.vtensor<*,f32> + } else { + %2 = torch.tensor_static_info_cast %arg2 : !torch.vtensor<[7],f32> to !torch.vtensor<*,f32> + torch.prim.If.yield %2 : !torch.vtensor<*,f32> + } + return %0 : !torch.vtensor<*,f32> +} diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index 0cb97d1bd6d1..ad810ec97ccb 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { ^bb2: return %arg0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @return_multiple_copies_of_tensor( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) { +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor +// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor +// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> +func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) { + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + return %1, %1 : !torch.tensor, !torch.tensor +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 38ce507fe785..b8ee124fb0e3 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -240,3 +240,35 @@ func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { return %result2 : !torch.vtensor<*,unk> } + +// ----- + +// Check that we don't crash on this input. + +// CHECK-LABEL: func.func @forward +func.func @forward() -> !torch.vtensor { + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct : () -> !torch.list + // CHECK: torch.aten.tensor + %1 = torch.aten.tensor %0, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- + +// Check that we don't crash on this input. +// TODO: This appears to result in aten.mul.Tensor not being visited. +// We should investigate why that happens. + +// CHECK-LABEL: func.func @forward +func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { + %0 = torch.prim.If %arg0 -> (!torch.tensor) { + torch.prim.If.yield %arg1 : !torch.tensor + } else { + torch.prim.If.yield %arg1 : !torch.tensor + } + %1 = torch.copy.to_vtensor %0 : !torch.vtensor + %2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor + return +} diff --git a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir new file mode 100644 index 000000000000..e281ac732b91 --- /dev/null +++ b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt -pass-pipeline='torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax}' -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.square +func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: torch.aten.square + %0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.argmax +func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + // CHECK: torch.aten.argmax + %0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +} diff --git a/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir b/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir deleted file mode 100644 index 7ae6b1e19070..000000000000 --- a/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics - -// ----- - -func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> { - // @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}} - %neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor - return %arg : !torch.vtensor<[2],f32> -} diff --git a/test/Dialect/TorchConversion/canonicalize.mlir b/test/Dialect/TorchConversion/canonicalize.mlir index bf6a9b9d3c91..1af8b5ad1420 100644 --- a/test/Dialect/TorchConversion/canonicalize.mlir +++ b/test/Dialect/TorchConversion/canonicalize.mlir @@ -37,3 +37,41 @@ func.func @torch_c.to_i64$from_i64() -> !torch.int { %1 = torch_c.from_i64 %0 return %1 : !torch.int } + +// CHECK-LABEL: func.func @torch_c.from_f64() -> !torch.float { +// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00 +// CHECK: return %[[FLOAT5]] : !torch.float +func.func @torch_c.from_f64() -> !torch.float { + %c5_f64 = arith.constant 5.000000e+00 : f64 + %0 = torch_c.from_f64 %c5_f64 + return %0 : !torch.float +} + +// CHECK-LABEL: func.func @torch_c.to_f64() -> f64 { +// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64 +// CHECK: return %[[C5_f64]] : f64 +func.func @torch_c.to_f64() -> f64 { + %float5 = torch.constant.float 5.000000e+00 + %0 = torch_c.to_f64 %float5 + return %0 : f64 +} + +// CHECK-LABEL: func.func @torch_c.from_f64$to_f64() -> f64 { +// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64 +// CHECK: return %[[C5_f64]] : f64 +func.func @torch_c.from_f64$to_f64() -> f64 { + %c5_f64 = arith.constant 5.000000e+00 : f64 + %0 = torch_c.from_f64 %c5_f64 + %1 = torch_c.to_f64 %0 + return %1 : f64 +} + +// CHECK-LABEL: func.func @torch_c.to_f64$from_f64() -> !torch.float { +// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00 +// CHECK: return %[[FLOAT5]] : !torch.float +func.func @torch_c.to_f64$from_f64() -> !torch.float { + %float5 = torch.constant.float 5.000000e+00 + %0 = torch_c.to_f64 %float5 + %1 = torch_c.from_f64 %0 + return %1 : !torch.float +} diff --git a/test/Dialect/TorchConversion/ops.mlir b/test/Dialect/TorchConversion/ops.mlir index e10aede0421e..c099697ca74a 100644 --- a/test/Dialect/TorchConversion/ops.mlir +++ b/test/Dialect/TorchConversion/ops.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s +// RUN: torch-mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @builtin_tensor_interop( func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) { @@ -14,3 +14,35 @@ func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, % %4 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xi8> return } + +// ----- + +func.func @to_builtin_tensor_invalid_size(%arg0: !torch.vtensor<[3,?],si8>) { + // expected-error @+1 {{operand and result must have the same size and dtype}} + %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,?],si8> -> tensor + return +} + +// ----- + +func.func @to_builtin_tensor_invalid_dtype(%arg0: !torch.vtensor<*,si8>) { + // expected-error @+1 {{operand and result must have the same size and dtype}} + %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<*,si8> -> tensor<*xi64> + return +} + +// ----- + +func.func @from_builtin_tensor_invalid_size(%arg0: tensor<3x?xi8>) { + // expected-error @+1 {{operand and result must have the same size and dtype}} + %1 = torch_c.from_builtin_tensor %arg0 : tensor<3x?xi8> -> !torch.vtensor<[?,?],si8> + return +} + +// ----- + +func.func @from_builtin_tensor_invalid_dtype(%arg0: tensor<*xi8>) { + // expected-error @+1 {{operand and result must have the same size and dtype}} + %1 = torch_c.from_builtin_tensor %arg0 : tensor<*xi8> -> !torch.vtensor<*,si64> + return +} diff --git a/test/Dialect/TorchConversion/verify-invariants-before-backend-lowering.mlir b/test/Dialect/TorchConversion/verify-invariants-before-backend-lowering.mlir deleted file mode 100644 index 1c3e466a2aeb..000000000000 --- a/test/Dialect/TorchConversion/verify-invariants-before-backend-lowering.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-invariants-before-backend-lowering - -// ----- - -func.func @unknown_rank(%arg0: !torch.vtensor<[],f32>) { - // expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}} - // expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}} - %0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<*,f32> - return -} - -// ----- - -func.func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) { - // expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}} - // expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}} - %0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],unk> - return -} - -// ----- - -func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) { - // expected-error@+2 {{unsupported by backend lowering: `torch.operator` op}} - // expected-note@+1 {{this is likely due to a missing op that needs to be generated by torch_ods_gen.py}} - torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32> - return -} diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 20f8b2729449..94c93a05c898 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -15,6 +15,7 @@ config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.python_executable = sys.executable config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ +config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@ import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/test/python/importer/jit_ir/ivalue_import/quantization.py b/test/python/importer/jit_ir/ivalue_import/quantization.py index e5c732208ae2..422e6bb70526 100644 --- a/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -7,6 +7,7 @@ import torch from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +# UNSUPPORTED: system-darwin # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() diff --git a/test/python/importer/jit_ir/ivalue_import/tensors.py b/test/python/importer/jit_ir/ivalue_import/tensors.py index 02a82f06eab9..314c245f58be 100644 --- a/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -22,6 +22,7 @@ def __init__(self): self.ones_f64 = torch.ones(1, dtype=torch.float64) self.ones_bool = torch.ones(1, dtype=torch.bool) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) + self.ones_f16 = torch.ones(1, dtype=torch.half) self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8) self.arange = torch.nn.Parameter(torch.arange(3.0)) @@ -34,6 +35,7 @@ def __init__(self): # CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64> # CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense : tensor<1xi1>) : !torch.tensor<[1],i1> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> +# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16> # CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8> # CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 # CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 @@ -49,6 +51,7 @@ def __init__(self): # CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64> # CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> +# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16> # CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8> # CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8> # CHECK: } diff --git a/tools/torchscript_e2e_test.sh b/tools/e2e_test.sh similarity index 82% rename from tools/torchscript_e2e_test.sh rename to tools/e2e_test.sh index 152075e63531..cddd1543bbb7 100755 --- a/tools/torchscript_e2e_test.sh +++ b/tools/e2e_test.sh @@ -9,4 +9,4 @@ cd "$src_dir" export PYTHONPATH=${PYTHONPATH-} source .env -python -m e2e_testing.torchscript.main "$@" +python -m e2e_testing.main "$@" diff --git a/utils/bazel/.bazelrc b/utils/bazel/.bazelrc index 046b9a2702de..15d329dccc5c 100644 --- a/utils/bazel/.bazelrc +++ b/utils/bazel/.bazelrc @@ -4,8 +4,8 @@ build --action_env=CC=clang build --action_env=CXX=clang++ -build --cxxopt=-std=c++14 -build --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++17 +build --host_cxxopt=-std=c++17 build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 build --cxxopt=-U__GXX_ABI_VERSION build --cxxopt=-D__GXX_ABI_VERSION=1011 diff --git a/utils/bazel/WORKSPACE.bazel b/utils/bazel/WORKSPACE.bazel index 312f8319decc..bc4fa5336ca0 100644 --- a/utils/bazel/WORKSPACE.bazel +++ b/utils/bazel/WORKSPACE.bazel @@ -3,15 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + http_archive( name = "bazel_skylib", + sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", ], - sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c", ) -load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") @@ -26,21 +26,30 @@ new_local_repository( load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") llvm_configure( - name = "llvm-project", - repo_mapping = { - "@python_runtime": "@local_config_python", - }, + name = "llvm-project", + repo_mapping = { + "@python_runtime": "@local_config_python", + }, + targets = [ + "X86", + ], ) + llvm_disable_optional_support_deps() +local_repository( + name = "mlir-hlo", + path = "../../externals/mlir-hlo/", +) + new_local_repository( name = "torch-mlir-raw", build_file_content = "# empty", - path = "../../" + path = "../../", ) load("@torch-mlir-raw//utils/bazel:configure.bzl", "torch_mlir_configure") torch_mlir_configure( - name = "torch-mlir" + name = "torch-mlir", ) diff --git a/utils/bazel/configure.bzl b/utils/bazel/configure.bzl index 7ad082f83e68..564f4a9543fe 100644 --- a/utils/bazel/configure.bzl +++ b/utils/bazel/configure.bzl @@ -52,5 +52,5 @@ def _torch_mlir_configure_impl(repository_ctx): torch_mlir_configure = repository_rule( implementation = _torch_mlir_configure_impl, local = True, - configure = True + configure = True, ) diff --git a/utils/bazel/docker/Dockerfile b/utils/bazel/docker/Dockerfile new file mode 100644 index 000000000000..d7e1221ed920 --- /dev/null +++ b/utils/bazel/docker/Dockerfile @@ -0,0 +1,34 @@ +ARG BASE_IMG=ubuntu:18.04 +FROM ${BASE_IMG} as dev-base + +ARG ARCH="x86_64" +ARG BAZEL_VERSION=5.2.0 + +# Install basic packages +RUN apt-get update && \ + apt-get install -y \ + clang-10 \ + curl \ + git \ + python3-pip \ + python3.8 \ + python3.8-dev \ + wget \ + unzip + +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 10 +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 10 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-10 10 +RUN update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-10 10 + +# Install bazel +RUN wget -q https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-linux-${ARCH} -O /usr/bin/bazel \ + && chmod a+x /usr/bin/bazel + +COPY requirements.txt /opt/app/requirements.txt +WORKDIR /opt/app +RUN python -m pip install --upgrade pip +RUN python -m pip install --ignore-installed -r requirements.txt + +WORKDIR /opt/src/torch-mlir diff --git a/utils/bazel/docker/run_bazel_build.sh b/utils/bazel/docker/run_bazel_build.sh new file mode 100755 index 000000000000..c0705206231e --- /dev/null +++ b/utils/bazel/docker/run_bazel_build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +cd "$(pwd)/utils/bazel" && bazel build @torch-mlir//... diff --git a/utils/bazel/docker/run_docker.sh b/utils/bazel/docker/run_docker.sh new file mode 100755 index 000000000000..35c4e67b767f --- /dev/null +++ b/utils/bazel/docker/run_docker.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +docker build -f utils/bazel/docker/Dockerfile \ + -t torch-mlir:dev \ + . + +docker run -it \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:dev diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 5a991552cc04..d103b25d113b 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -2,11 +2,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( default_visibility = [ - "//visibility:public", + "//visibility:public", ], ) @@ -17,16 +17,16 @@ td_library( "include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td", "include/torch-mlir/Dialect/Torch/IR/TorchBase.td", "include/torch-mlir/Dialect/Torch/IR/TorchOps.td", - "include/torch-mlir/Dialect/Torch/IR/TorchTypes.td" + "include/torch-mlir/Dialect/Torch/IR/TorchTypes.td", ], includes = ["include"], deps = [ - "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:CastInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles" - ] + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], ) gentbl_cc_library( @@ -39,28 +39,28 @@ gentbl_cc_library( ), ( ["-gen-op-defs"], - "include/torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" + "include/torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc", ), ( [ "-gen-dialect-decls", "-dialect=torch", ], - "include/torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc" + "include/torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc", ), ( [ "-gen-dialect-defs", "-dialect=torch", ], - "include/torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc" + "include/torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/torch-mlir/Dialect/Torch/IR/TorchOps.td", deps = [ - ":MLIRTorchOpsIncGenTdFiles" - ] + ":MLIRTorchOpsIncGenTdFiles", + ], ) gentbl_cc_library( @@ -69,41 +69,41 @@ gentbl_cc_library( tbl_outs = [ ( ["-gen-typedef-decls"], - "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc" + "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc", ), ( ["-gen-typedef-defs"], - "include/torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc" - ) + "include/torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/torch-mlir/Dialect/Torch/IR/TorchTypes.td", deps = [ - ":MLIRTorchOpsIncGenTdFiles" - ] + ":MLIRTorchOpsIncGenTdFiles", + ], ) cc_library( name = "TorchMLIRTorchDialectUtils", srcs = [ + "lib/Dialect/Torch/Utils/TorchUpstream.cpp", "lib/Dialect/Torch/Utils/Utils.cpp", - "lib/Dialect/Torch/Utils/TorchUpstream.cpp" ], - strip_include_prefix = "include", hdrs = [ - "include/torch-mlir/Dialect/Torch/Utils/Utils.h", - "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", "include/torch-mlir/Dialect/Torch/IR/TorchOps.h", "include/torch-mlir/Dialect/Torch/IR/TorchTraits.h", - "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h" + "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", + "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", + "include/torch-mlir/Dialect/Torch/Utils/Utils.h", ], + strip_include_prefix = "include", deps = [ ":MLIRTorchOpsIncGen", ":MLIRTorchTypesIncGen", - "@llvm-project//mlir:IR", "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", - ] + ], ) cc_library( @@ -111,10 +111,10 @@ cc_library( srcs = [ "lib/Dialect/Torch/IR/TorchDialect.cpp", "lib/Dialect/Torch/IR/TorchOps.cpp", - "lib/Dialect/Torch/IR/TorchTypes.cpp", "lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp", + "lib/Dialect/Torch/IR/TorchTypes.cpp", "lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp", - "lib/Dialect/Torch/IR/UtilsForODSGenerated.h" + "lib/Dialect/Torch/IR/UtilsForODSGenerated.h", ], hdrs = glob([ "include/torch-mlir/Dialect/Torch/IR/*.h", @@ -124,24 +124,24 @@ cc_library( ":MLIRTorchOpsIncGen", ":MLIRTorchTypesIncGen", ":TorchMLIRTorchDialectUtils", - "@llvm-project//mlir:IR", "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:FuncDialect" - ] + ], ) # Torch Dialect/Transforms td_library( name = "TorchMLIRTorchPassesTdFiles", srcs = [ - "include/torch-mlir/Dialect/Torch/Transforms/Passes.td" + "include/torch-mlir/Dialect/Torch/Transforms/Passes.td", ], includes = ["include"], deps = [ "@llvm-project//mlir:OpBaseTdFiles", - ] + ], ) gentbl_cc_library( @@ -151,27 +151,28 @@ gentbl_cc_library( ( ["-gen-pass-decls"], "include/torch-mlir/Dialect/Torch/Transforms/Passes.h.inc", - ) + ), ], - td_file = "include/torch-mlir/Dialect/Torch/Transforms/Passes.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/torch-mlir/Dialect/Torch/Transforms/Passes.td", deps = [ ":TorchMLIRTorchPassesTdFiles", "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) - - cc_library( name = "TorchMLIRTorchPasses", srcs = [ "lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp", "lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp", "lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp", + "lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp", "lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp", "lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp", + "lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp", "lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp", + "lib/Dialect/Torch/Transforms/PassDetail.h", "lib/Dialect/Torch/Transforms/Passes.cpp", "lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp", "lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp", @@ -180,8 +181,6 @@ cc_library( "lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp", "lib/Dialect/Torch/Transforms/ShapeLibrary.cpp", "lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp", - "lib/Dialect/Torch/Transforms/PassDetail.h", - "lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp", ], hdrs = [ "include/torch-mlir/Dialect/Torch/Transforms/Passes.h", @@ -190,12 +189,12 @@ cc_library( deps = [ ":TorchMLIRTorchDialect", ":TorchMLIRTorchPassesIncGen", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:IR", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:Parser" - ] + ], ) # TorchConversion diaelct @@ -209,12 +208,12 @@ td_library( deps = [ ":MLIRTorchOpsIncGenTdFiles", "@llvm-project//mlir:AttrTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:CastInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles" - ] + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], ) gentbl_cc_library( @@ -227,28 +226,28 @@ gentbl_cc_library( ), ( ["-gen-op-defs"], - "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" + "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc", ), ( [ "-gen-dialect-decls", "-dialect=torch_c", ], - "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc" + "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc", ), ( [ "-gen-dialect-defs", "-dialect=torch_c", ], - "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.cpp.inc" + "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td", deps = [ - ":MLIRTorchConversionOpsTdFiles" - ] + ":MLIRTorchConversionOpsTdFiles", + ], ) cc_library( @@ -266,17 +265,17 @@ cc_library( ":MLIRTorchConversionOpsIncGen", ":TorchMLIRTorchDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface" - ] + "@llvm-project//mlir:InferTypeOpInterface", + ], ) # Conversion td_library( name = "TorchMLIRConversionPassesTdFiles", - includes = ["include"], srcs = [ - "include/torch-mlir/Conversion/Passes.td" - ] + "include/torch-mlir/Conversion/Passes.td", + ], + includes = ["include"], ) gentbl_cc_library( @@ -284,28 +283,33 @@ gentbl_cc_library( strip_include_prefix = "include", tbl_outs = [ ( - ["-gen-pass-decls"], + [ + "-gen-pass-decls", + "-DTORCH_MLIR_ENABLE_MHLO", + "-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32", + ], "include/torch-mlir/Conversion/Passes.h.inc", - ) + ), ], - td_file = "include/torch-mlir/Conversion/Passes.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/torch-mlir/Conversion/Passes.td", deps = [ ":TorchMLIRConversionPassesTdFiles", "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) # TorchConversion transforms td_library( name = "TorchMLIRTorchConversionPassesTdFiles", srcs = [ - "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td" + "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td", ], deps = [ "@llvm-project//mlir:OpBaseTdFiles", - ] + ], ) + gentbl_cc_library( name = "TorchMLIRTorchConversionPassesIncGen", strip_include_prefix = "include", @@ -313,42 +317,44 @@ gentbl_cc_library( ( ["-gen-pass-decls"], "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", - ) + ), ], - td_file = "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td", deps = [ ":TorchMLIRTorchConversionPassesTdFiles", "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) cc_library( name = "TorchMLIRConversionUtils", srcs = [ - "lib/Conversion/Utils/Utils.cpp" + "lib/Conversion/Utils/Utils.cpp", ], hdrs = [ - "include/torch-mlir/Conversion/Utils/Utils.h" + "include/torch-mlir/Conversion/Utils/Utils.h", ], strip_include_prefix = "include", deps = [ ":TorchMLIRTorchDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Transforms", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:LinalgDialect" - ] + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:Transforms", + ], ) cc_library( name = "TorchMLIRTorchToLinalg", srcs = [ + "lib/Conversion/PassDetail.h", "lib/Conversion/TorchToLinalg/DataMovement.cpp", "lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp", "lib/Conversion/TorchToLinalg/Linear.cpp", "lib/Conversion/TorchToLinalg/Pooling.cpp", + "lib/Conversion/TorchToLinalg/PopulatePatterns.h", "lib/Conversion/TorchToLinalg/Random.cpp", "lib/Conversion/TorchToLinalg/Reduction.cpp", "lib/Conversion/TorchToLinalg/TensorConstructors.cpp", @@ -357,203 +363,230 @@ cc_library( "lib/Conversion/TorchToLinalg/Uncategorized.cpp", "lib/Conversion/TorchToLinalg/Utils.cpp", "lib/Conversion/TorchToLinalg/Utils.h", - "lib/Conversion/TorchToLinalg/PopulatePatterns.h", - "lib/Conversion/PassDetail.h", ], hdrs = [ - "include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" + "include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h", ], strip_include_prefix = "include", deps = [ + ":TorchMLIRConversionPassesIncGen", ":TorchMLIRConversionUtils", ":TorchMLIRTorchBackendTypeConversion", - ":TorchMLIRTorchDialect", - ":TorchMLIRConversionPassesIncGen", ":TorchMLIRTorchConversionDialect", - "@llvm-project//mlir:Pass", + ":TorchMLIRTorchDialect", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TransformUtils" - ] + "@llvm-project//mlir:TransformUtils", + ], ) cc_library( name = "TorchMLIRTorchToSCF", srcs = [ - "lib/Conversion/TorchToSCF/TorchToSCF.cpp", "lib/Conversion/PassDetail.h", + "lib/Conversion/TorchToSCF/TorchToSCF.cpp", ], hdrs = [ - "include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" + "include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h", ], strip_include_prefix = "include", deps = [ + ":TorchMLIRConversionPassesIncGen", ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", - ":TorchMLIRConversionPassesIncGen", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TransformUtils" - ] + "@llvm-project//mlir:TransformUtils", + ], ) cc_library( - name = "TorchMLIRTorchToStd", + name = "TorchMLIRTorchToArith", srcs = [ - "lib/Conversion/TorchToStd/TorchToStd.cpp", - "lib/Conversion/PassDetail.h" + "lib/Conversion/PassDetail.h", + "lib/Conversion/TorchToArith/TorchToArith.cpp", ], hdrs = [ - "include/torch-mlir/Conversion/TorchToStd/TorchToStd.h" + "include/torch-mlir/Conversion/TorchToArith/TorchToArith.h", ], strip_include_prefix = "include", deps = [ - ":TorchMLIRTorchBackendTypeConversion", - ":TorchMLIRTorchConversionDialect", ":TorchMLIRConversionPassesIncGen", ":TorchMLIRConversionUtils", - "@llvm-project//mlir:Dialect" - ] + ":TorchMLIRTorchBackendTypeConversion", + ":TorchMLIRTorchConversionDialect", + "@llvm-project//mlir:Dialect", + ], ) cc_library( name = "TorchMLIRTorchToTMTensor", srcs = [ - "lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp", "lib/Conversion/PassDetail.h", + "lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp", ], hdrs = [ - "include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" + "include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h", ], strip_include_prefix = "include", deps = [ + ":TorchMLIRConversionPassesIncGen", + ":TorchMLIRConversionUtils", + ":TorchMLIRTMTensorDialect", ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", + "@llvm-project//mlir:LinalgDialect", + ], +) + +cc_library( + name = "TorchMLIRTorchToMhlo", + srcs = [ + "lib/Conversion/PassDetail.h", + "lib/Conversion/TorchToMhlo/Basic.cpp", + "lib/Conversion/TorchToMhlo/Gather.cpp", + "lib/Conversion/TorchToMhlo/Linear.cpp", + "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp", + "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h", + "lib/Conversion/TorchToMhlo/Pooling.cpp", + "lib/Conversion/TorchToMhlo/PopulatePatterns.h", + "lib/Conversion/TorchToMhlo/Reduction.cpp", + "lib/Conversion/TorchToMhlo/TorchToMhlo.cpp", + "lib/Conversion/TorchToMhlo/ViewLike.cpp", + ], + hdrs = [ + "include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h", + ], + copts = ['-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32'], + strip_include_prefix = "include", + deps = [ ":TorchMLIRConversionPassesIncGen", - ":TorchMLIRTMTensorDialect", ":TorchMLIRConversionUtils", - "@llvm-project//mlir:LinalgDialect" - ] + ":TorchMLIRTorchBackendTypeConversion", + ":TorchMLIRTorchConversionDialect", + "@llvm-project//mlir:Dialect", + "@mlir-hlo//:mlir_hlo", + ], ) cc_library( name = "TorchMLIRConversionPasses", srcs = [ - "lib/Conversion/Passes.cpp" + "lib/Conversion/Passes.cpp", ], hdrs = [ - "include/torch-mlir/Conversion/Passes.h" + "include/torch-mlir/Conversion/Passes.h", ], strip_include_prefix = "include", deps = [ + ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", - ":TorchMLIRTorchToStd", + ":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTosa", - ":TorchMLIRTorchToTMTensor" - ] + ], ) - cc_library( name = "TorchMLIRTorchConversionPasses", srcs = [ - "lib/Dialect/TorchConversion/Transforms/Passes.cpp", "lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp", - "lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp", + "lib/Dialect/TorchConversion/Transforms/PassDetail.h", + "lib/Dialect/TorchConversion/Transforms/Passes.cpp", "lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp", "lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp", - "lib/Dialect/TorchConversion/Transforms/PassDetail.h" ], hdrs = [ "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h", ], strip_include_prefix = "include", deps = [ - ":TorchMLIRTorchConversionPassesIncGen", ":TorchMLIRTorchBackendTypeConversion", + ":TorchMLIRTorchConversionDialect", + ":TorchMLIRTorchConversionPassesIncGen", ":TorchMLIRTorchDialect", ":TorchMLIRTorchPasses", - ":TorchMLIRTorchConversionDialect", + ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", - ":TorchMLIRTorchToStd", - ":TorchMLIRTorchToTosa", ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTosa", + "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:ConversionPasses", - ] + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TosaDialect", + ], ) - cc_library( name = "TorchMLIRTorchToTosa", srcs = [ + "lib/Conversion/PassDetail.h", "lib/Conversion/TorchToTosa/TorchToTosa.cpp", "lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp", "lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp", - "lib/Conversion/PassDetail.h", - ] , + ], hdrs = [ "include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h", "include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h", - "include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" + "include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h", ], strip_include_prefix = "include", deps = [ + ":TorchMLIRConversionPassesIncGen", ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", - ":TorchMLIRConversionPassesIncGen", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:TosaDialect" - ] + "@llvm-project//mlir:TosaDialect", + ], ) # Dialects.TorchConversion cc_library( name = "TorchMLIRTorchBackendTypeConversion", srcs = [ - "lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp" + "lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp", ], hdrs = [ - "include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + "include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h", ], strip_include_prefix = "include", deps = [ ":TorchMLIRTorchConversionDialect", - "@llvm-project//mlir:FuncTransforms" - ] + "@llvm-project//mlir:FuncTransforms", + ], ) # External dialects td_library( name = "TorchMLIRTMTensorOpsTdFiles", srcs = [ - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td", "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td", + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td", + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td", "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td", - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td" ], includes = ["externals/llvm-external-projects/torch-mlir-dialects/include"], deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles" - ] + ], ) gentbl_cc_library( @@ -574,14 +607,14 @@ gentbl_cc_library( ), ( ["-gen-type-interface-defs"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorTypeInterfaces.cpp.inc", - ) + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorTypeInterfaces.cpp.inc", + ), ], - td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td", deps = [ ":TorchMLIRTMTensorOpsTdFiles", - ] + ], ) gentbl_cc_library( @@ -590,18 +623,18 @@ gentbl_cc_library( tbl_outs = [ ( ["-gen-op-interface-decls"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc", ), ( ["-gen-op-interface-defs"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc" - ) + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc", + ), ], - td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td", deps = [ ":TorchMLIRTMTensorOpsTdFiles", - ] + ], ) gentbl_cc_library( @@ -610,64 +643,64 @@ gentbl_cc_library( tbl_outs = [ ( ["-gen-op-decls"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc", ), ( ["-gen-op-defs"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc", ), ( ["-gen-typedef-decls"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorTypes.h.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorTypes.h.inc", ), ( [ "-gen-dialect-decls", - "-dialect=tm_tensor" + "-dialect=tm_tensor", ], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc", ), ( [ "-gen-dialect-defs", - "-dialect=tm_tensor" + "-dialect=tm_tensor", ], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc", ), ], - td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td", deps = [ - ":TorchMLIRTMTensorOpsTdFiles" - ] + ":TorchMLIRTMTensorOpsTdFiles", + ], ) cc_library( name = "TorchMLIRTMTensorDialect", srcs = [ + "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp", "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorDialect.cpp", "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorInterfaces.cpp", "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp", - "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp" ], hdrs = [ - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h", "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h", "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h", - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h", + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h", ], strip_include_prefix = "externals/llvm-external-projects/torch-mlir-dialects/include", deps = [ + ":TorchMLIRTMTensorInterfacesIncGen", ":TorchMLIRTMTensorOpsIncGen", ":TorchMLIRTMTensorScalarLoopOpInterfaceIncGen", - ":TorchMLIRTMTensorInterfacesIncGen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:ViewLikeInterface", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:LinalgDialect", - ] + "@llvm-project//mlir:ViewLikeInterface", + ], ) td_library( @@ -677,8 +710,8 @@ td_library( ], deps = [ "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:PassBaseTdFiles" - ] + "@llvm-project//mlir:PassBaseTdFiles", + ], ) gentbl_cc_library( @@ -687,64 +720,64 @@ gentbl_cc_library( tbl_outs = [ ( ["-gen-pass-decls"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc", ), ( ["-gen-pass-capi-header"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.cpi.inc" + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.cpi.inc", ), ( ["-gen-pass-capi-impl"], - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.cpi.cpp.inc" - ) + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.cpi.cpp.inc", + ), ], - td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td", deps = [ - ":TorchMLIRTMTensorTransformsPassesTdFiles" - ] + ":TorchMLIRTMTensorTransformsPassesTdFiles", + ], ) cc_library( name = "TorchMLIRTMTensorPasses", - strip_include_prefix = "externals/llvm-external-projects/torch-mlir-dialects/include", srcs = [ "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp", "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp", "externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Passes.cpp", ], hdrs = [ + "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h", "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h", - "externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" ], + strip_include_prefix = "externals/llvm-external-projects/torch-mlir-dialects/include", deps = [ - ":TorchMLIRTMTensorTransformsPassesIncGen", ":TorchMLIRTMTensorDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:Transforms", + ":TorchMLIRTMTensorTransformsPassesIncGen", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:LinalgTransforms" - ] + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], ) # RefBackend filegroup( name = "TorchMLIRRefBackendPassesDetails", srcs = [ - "lib/RefBackend/PassDetail.h" - ] + "lib/RefBackend/PassDetail.h", + ], ) td_library( name = "TorchMLIRRefBackendPassTdFiles", srcs = [ - "include/torch-mlir/RefBackend/Passes.td" + "include/torch-mlir/RefBackend/Passes.td", ], deps = [ "@llvm-project//mlir:OpBaseTdFiles", - ] + ], ) gentbl_cc_library( @@ -754,73 +787,73 @@ gentbl_cc_library( ( ["-gen-pass-decls"], "include/torch-mlir/RefBackend/Passes.h.inc", - ) + ), ], - td_file = "include/torch-mlir/RefBackend/Passes.td", tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/torch-mlir/RefBackend/Passes.td", deps = [ ":TorchMLIRRefBackendPassTdFiles", "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) cc_library( name = "TorchMLIRRefBackendPass", srcs = [ - "lib/RefBackend/RefBackend.cpp" + "lib/RefBackend/RefBackend.cpp", ] + [":TorchMLIRRefBackendPassesDetails"], hdrs = [ - "include/torch-mlir/RefBackend/Passes.h" + "include/torch-mlir/RefBackend/Passes.h", ], strip_include_prefix = "include", deps = [ ":TorchMLIRRefBackendPassIncGen", - ":TorchMLIRTorchConversionDialect", ":TorchMLIRTorchBackendTypeConversion", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:MemRefDialect", + ":TorchMLIRTorchConversionDialect", "@llvm-project//mlir:ArithmeticTransforms", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathTransforms", - ] + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + ], ) cc_library( name = "TorchMLIRInitAll", srcs = [ - "lib/InitAll.cpp" + "lib/InitAll.cpp", ], hdrs = [ - "include/torch-mlir/InitAll.h" + "include/torch-mlir/InitAll.h", ], strip_include_prefix = "include", deps = [ - ":TorchMLIRTorchPasses", - ":TorchMLIRTorchConversionDialect", - ":TorchMLIRTorchDialect", - ":TorchMLIRTorchConversionPasses", - ":TorchMLIRTMTensorDialect", - ":TorchMLIRTMTensorPasses", ":TorchMLIRConversionPasses", ":TorchMLIRRefBackendPass", + ":TorchMLIRTMTensorDialect", + ":TorchMLIRTMTensorPasses", + ":TorchMLIRTorchConversionDialect", + ":TorchMLIRTorchConversionPasses", + ":TorchMLIRTorchDialect", + ":TorchMLIRTorchPasses", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR" - ] + "@llvm-project//mlir:IR", + ], ) # tools cc_binary( name = "torch-mlir-opt", srcs = [ - "tools/torch-mlir-opt/torch-mlir-opt.cpp" + "tools/torch-mlir-opt/torch-mlir-opt.cpp", ], deps = [ ":TorchMLIRInitAll", ":TorchMLIRTorchDialect", ":TorchMLIRTorchPasses", "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib" - ] + "@llvm-project//mlir:MlirOptLib", + ], )