diff --git a/build_tools/bazel/deep_copy.bzl b/build_tools/bazel/deep_copy.bzl index a7c1c29f6211..32ce6e6ea464 100644 --- a/build_tools/bazel/deep_copy.bzl +++ b/build_tools/bazel/deep_copy.bzl @@ -43,6 +43,7 @@ def _deep_copy_recursion_depth_1(x): def deep_copy(x): """Returns a copy of the argument, making a deep copy if it is a container. + Args: x: (object) value to copy. If it is a container with nested containers as elements, the maximum nesting depth is restricted to three (e.g., diff --git a/build_tools/bazel/enforce_glob.bzl b/build_tools/bazel/enforce_glob.bzl new file mode 100644 index 000000000000..1ee7c62284ba --- /dev/null +++ b/build_tools/bazel/enforce_glob.bzl @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A utility to enforce that a list matches a glob expression. + +We use this primarily to enable the error-checking capabilities of globs of test +files in IREE while still allowing our Bazel to CMake conversion to not create +CMake globs (which are discouraged for collecting source files, see +https://cmake.org/cmake/help/latest/command/file.html#glob) and not be dependent +on any information outside of the BUILD file. +""" + +def enforce_glob(files, **kwargs): + """A utility to enforce that a list matches a glob expression. + + Note that the comparison is done in an order-independent fashion. + + Args: + files: a list that is expected to contain the same files as the + specified glob expression. + **kwargs: keyword arguments forwarded to the glob. + + Returns: + files. The input argument unchanged + """ + glob_result = native.glob(**kwargs) + + # glob returns a sorted list. + if sorted(files) != glob_result: + glob_result_dict = {k: None for k in glob_result} + result_dict = {k: None for k in files} + missing = [k for k in glob_result if k not in files] + extra = [k for k in files if k not in glob_result] + expected_formatted = "\n".join(['"{}",'.format(file) for file in glob_result]) + fail(("Error in enforce_glob." + + "\nExpected {}." + + "\nGot {}." + + "\nMissing {}." + + "\nExtra {}" + + "\nPaste this into the first enforce_glob argument:" + + "\n{}").format( + glob_result, + files, + missing, + extra, + expected_formatted, + )) + return files diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index b6d32c23eb78..6a5e852697dd 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py @@ -251,6 +251,9 @@ def filegroup(self, name, **kwargs): def sh_binary(self, name, **kwargs): self._convert_unimplemented_function("sh_binary", name) + def enforce_glob(self, files, **kwargs): + return files + def glob(self, include, exclude=None, exclude_directories=1): if exclude_directories != 1: self._convert_unimplemented_function("glob", "with exclude_directories") diff --git a/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile b/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile index 8ef012a21356..7251f8ba64d4 100644 --- a/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile +++ b/build_tools/docker/cmake-bazel-frontends-nvidia/Dockerfile @@ -26,4 +26,6 @@ FROM gcr.io/iree-oss/cmake-bazel-frontends-vulkan@sha256:3656d9c3a08770f8371ad3a6f777b979c954656dae1cdbdfadfccdd7ab713f87 AS final RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-460=460.32.03-0ubuntu0.18.04.1 + && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libnvidia-gl-460=460.39-0ubuntu0.18.04.1 \ + libnvidia-compute-460=460.39-0ubuntu0.18.04.1 diff --git a/build_tools/docker/cmake-python-nvidia/Dockerfile b/build_tools/docker/cmake-python-nvidia/Dockerfile index ef2803bca003..488e0de86859 100644 --- a/build_tools/docker/cmake-python-nvidia/Dockerfile +++ b/build_tools/docker/cmake-python-nvidia/Dockerfile @@ -28,4 +28,6 @@ FROM gcr.io/iree-oss/cmake-python-vulkan@sha256:6722f69c6300749f6bd4b141fc653244990381a6b0111f9c361061adcd65c07c AS final RUN apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-460=460.32.03-0ubuntu0.18.04.1 + && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libnvidia-gl-460=460.39-0ubuntu0.18.04.1 \ + libnvidia-compute-460=460.39-0ubuntu0.18.04.1 diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt index d119c1817b32..6267a6feafd9 100644 --- a/build_tools/docker/prod_digests.txt +++ b/build_tools/docker/prod_digests.txt @@ -9,8 +9,8 @@ gcr.io/iree-oss/vulkan@sha256:5812ee64806a7f3df0739ccf0930c27cabce346901488eceb1 gcr.io/iree-oss/rbe-toolchain@sha256:d69c260b98a97ad430d34c4591fb2399e00888750f5d47ede00c1e6f3e774e5a gcr.io/iree-oss/cmake-python-vulkan@sha256:6722f69c6300749f6bd4b141fc653244990381a6b0111f9c361061adcd65c07c gcr.io/iree-oss/cmake-python-swiftshader@sha256:0be2b0c735a038365e7cad31f6b440805dd4e231e166a114ef22914a5469cbc8 -gcr.io/iree-oss/cmake-python-nvidia@sha256:2d823e6fc528d0da5296ae35c0aad8008306381e3ac923679c4052adedd029d7 +gcr.io/iree-oss/cmake-python-nvidia@sha256:0c931cac303791af85c5a717b418997cac9f3319717f59f7a70ac777edfa7b33 gcr.io/iree-oss/cmake-bazel-frontends@sha256:2a9c65cea8b061696c8217b78e9718ceb804b9245bb3a661ff8a56f832aeeb0a gcr.io/iree-oss/cmake-bazel-frontends-vulkan@sha256:3656d9c3a08770f8371ad3a6f777b979c954656dae1cdbdfadfccdd7ab713f87 -gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:e71d77c2e8b886e99ffce2f50c73c5ae5e316bde23b54f0e6090bb04ce2d1bdf +gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:bc1502af0679d301feb27477c2476006282b5f956cecb650b2716c7a2876e722 gcr.io/iree-oss/cmake-bazel-frontends-swiftshader@sha256:1466d3658f872a0675b3ac605c2e577def880c7334738d677d5ef2f1c1291646 diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh index bd98aab86c00..440ea340446b 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build_kokoro.sh @@ -38,7 +38,7 @@ docker_setup docker run "${DOCKER_RUN_ARGS[@]?}" \ --gpus all \ - gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:e71d77c2e8b886e99ffce2f50c73c5ae5e316bde23b54f0e6090bb04ce2d1bdf \ + gcr.io/iree-oss/cmake-bazel-frontends-nvidia@sha256:bc1502af0679d301feb27477c2476006282b5f956cecb650b2716c7a2876e722 \ build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh # Kokoro will rsync this entire directory back to the executor orchestrating the diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh index 0b20fd6d1140..46513ed81c20 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh @@ -38,7 +38,7 @@ docker_setup docker run "${DOCKER_RUN_ARGS[@]?}" \ --gpus all \ - gcr.io/iree-oss/cmake-python-nvidia@sha256:2d823e6fc528d0da5296ae35c0aad8008306381e3ac923679c4052adedd029d7 \ + gcr.io/iree-oss/cmake-python-nvidia@sha256:0c931cac303791af85c5a717b418997cac9f3319717f59f7a70ac777edfa7b33 \ build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build.sh # Kokoro will rsync this entire directory back to the executor orchestrating the diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD index 1ed07b949005..60d2030d023b 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_to_mhlo.mlir", + "lower_global_tensors.mlir", + "lower_global_tensors_complex.mlir", + "lower_global_tensors_invalid.mlir", + "propagate_resource_casts.mlir", + "strip_metadata.mlir", + "verify_fully_converted.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD index e9acbed21709..18de285136e7 100644 --- a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_metadata.mlir", + "strip_metadata.mlir", + "verify_fully_converted.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-opt-tflite", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD index 78398351c8f6..7f2f19fb1936 100644 --- a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "add.mlir", + "multi_add.mlir", + ], + include = ["*.mlir"], + ), data = glob(["*.tflite"]) + [ "//iree_tf_compiler:iree-import-tflite", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD index a2b8e05dad1d..da1fa3f9119e 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/test/BUILD @@ -15,6 +15,7 @@ # Tests for lowering MLIR in various dialects to IREE interpreter bytecode. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "tf_strings_to_strings.mlir", + "tf_to_tf_strings.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD index 24e3f79e5a8f..abf320c12d91 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/ir/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//build_tools/bazel:tblgen.bzl", "gentbl") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "base.td", + "ops.td", + ], + include = ["*.td"], + ), ) gentbl( diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD index 6b37940aeb03..305e762ef240 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_tf_tensorlist_to_tensorlist.mlir", + "convert_tf_to_tf_tensorlist.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD index c0745fc4c2b8..80b2c552de44 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//build_tools/bazel:tblgen.bzl", "gentbl") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ exports_files(["tf_tensorlist_base.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "tf_tensorlist_base.td", + "tf_tensorlist_ops.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD index 6b37940aeb03..b700c9f6e652 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/ir/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@iree//iree:lit_test.bzl", "iree_lit_test_suite") +load("@iree//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree_tf_compiler:iree-tf-opt", "@iree//iree/tools:IreeFileCheck", diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD index b780b112a316..95904a789942 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/BUILD +++ b/iree/compiler/Bindings/TFLite/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "materialize_shape_support.mlir", + "wrap_entry_points.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt index 96f36fe84237..365af3dbd2db 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Bindings/TFLite/Transforms/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "materialize_shape_support.mlir" + "wrap_entry_points.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/Common/LaunchConfig.cpp b/iree/compiler/Conversion/Common/LaunchConfig.cpp index 88b99f533d2f..246257494f15 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.cpp +++ b/iree/compiler/Conversion/Common/LaunchConfig.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -38,6 +39,7 @@ namespace iree_compiler { /// Name of the StrAttr that can be used to get the key to access the tile size /// information. static const char kLaunchInfoKey[] = "launch_info_key"; +static const char kRootOpKey[] = "is_root_op"; static Optional getKey(Operation *op) { StringAttr attr = op->getAttrOfType(kLaunchInfoKey); @@ -82,8 +84,7 @@ ArrayRef LaunchConfig::getTileSizes(Operation *op, Operation *LaunchConfig::getRootOperation(ArrayRef ops) { for (auto op : ops) { - auto key = getKey(op); - if (key && key.getValue() == rootOperationKey) return op; + if (op->getAttrOfType(kRootOpKey)) return op; } return nullptr; } @@ -114,8 +115,7 @@ void LaunchConfig::setNumSubgroups(ArrayRef vNumSubgroups) { } void LaunchConfig::setRootOperation(Operation *op) { - Optional key = getKey(op); - if (key) rootOperationKey = *key; + op->setAttr(kRootOpKey, UnitAttr::get(op->getContext())); } void LaunchConfig::setSameConfig(Operation *source, Operation *target) { diff --git a/iree/compiler/Conversion/Common/LaunchConfig.h b/iree/compiler/Conversion/Common/LaunchConfig.h index 7f6f3366c788..1367f6054f05 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.h +++ b/iree/compiler/Conversion/Common/LaunchConfig.h @@ -127,10 +127,6 @@ class LaunchConfig { /// these attributes. llvm::StringMap tileSizes; - /// Key used for tagging the root operation. The launch config does not track - /// the root operation itself, but rather the key used for the root operation. - StringRef rootOperationKey = ""; - /// Workgroup size to use. std::array workgroupSize = {1, 1, 1}; diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp index 0a61ea2c5d49..3f1f3b0e453b 100644 --- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp +++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp @@ -407,7 +407,7 @@ static SmallVector extractFromI64ArrayAttr(Attribute attr) { } LogicalResult convertInterfaceLoadTensorOp( - OpBuilder &b, IREE::Flow::DispatchInputLoadOp loadOp, + OpBuilder &b, IREE::Flow::DispatchTensorLoadOp loadOp, BlockAndValueMapping &bvm) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loadOp); @@ -449,7 +449,7 @@ static Operation *getInsertionPointForReplacementStoreOp( /// LinalgOp, create the subview operation that can be used by the op itself to /// store the result into directly. This avoids an extra allocation + copies. LogicalResult preProcessInterfaceStoreTensorOp( - OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp, + OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp, BlockAndValueMapping &bvm) { // Find the insertion point for the subview. SmallVector operandsOfSubviewOp; @@ -491,7 +491,7 @@ LogicalResult preProcessLinalgOps(OpBuilder &b, linalg::LinalgOp op, } LogicalResult convertInterfaceStoreTensorOp( - OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp, + OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp, BlockAndValueMapping &bvm) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(storeOp); @@ -570,7 +570,7 @@ void LinalgBufferizePass::runOnFunction() { transferShapeOpsToMemref(b, op.getResult(), baseBuffer.getResult(), bvm); }); if (funcOp - .walk([&](IREE::Flow::DispatchOutputStoreOp op) -> WalkResult { + .walk([&](IREE::Flow::DispatchTensorStoreOp op) -> WalkResult { return preProcessInterfaceStoreTensorOp(b, op, bvm); }) .wasInterrupted()) { @@ -596,12 +596,12 @@ void LinalgBufferizePass::runOnFunction() { auto conversionDispatch = [&](Operation *op) -> WalkResult { return TypeSwitch(op) - .Case( - [&](IREE::Flow::DispatchInputLoadOp loadOp) { + .Case( + [&](IREE::Flow::DispatchTensorLoadOp loadOp) { return convertInterfaceLoadTensorOp(b, loadOp, bvm); }) - .Case( - [&](IREE::Flow::DispatchOutputStoreOp storeOp) { + .Case( + [&](IREE::Flow::DispatchTensorStoreOp storeOp) { return convertInterfaceStoreTensorOp(b, storeOp, bvm); }) .Case([&](linalg::LinalgOp linalgOp) { diff --git a/iree/compiler/Conversion/Common/test/BUILD b/iree/compiler/Conversion/Common/test/BUILD index 1e3b7bba341b..0dbfe82569f8 100644 --- a/iree/compiler/Conversion/Common/test/BUILD +++ b/iree/compiler/Conversion/Common/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "linalg_bufferize.mlir", + "linalg_rewrite_destructive_updates.mlir", + "remove_dead_allocs.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/Common/test/CMakeLists.txt b/iree/compiler/Conversion/Common/test/CMakeLists.txt index 899242936781..5684d46df19c 100644 --- a/iree/compiler/Conversion/Common/test/CMakeLists.txt +++ b/iree/compiler/Conversion/Common/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "linalg_bufferize.mlir" + "linalg_rewrite_destructive_updates.mlir" + "remove_dead_allocs.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir index 7257acbd8f2b..e5f547306003 100644 --- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir +++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir @@ -6,20 +6,20 @@ func @tile_from_tensor_load() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> - %8 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %9, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %9, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -56,25 +56,25 @@ func @tile_from_pointwise_lhs() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> %shape = linalg.init_tensor [1, 3] : tensor<1x3xf32> %8 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<1x3xf32>) outs(%shape : tensor<1x3xf32>) { ^bb0(%arg2: f32, %s: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<1x3xf32> - %9 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %9 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %10 = linalg.matmul ins(%8, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -115,17 +115,17 @@ func @tile_from_pointwise_outs() { %c4 = constant 4 : index %c1 = constant 1 : index %c3 = constant 3 : index - %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.workgroup.id[0] : index %5 = hal.interface.workgroup.id[1] : index scf.for %arg0 = %5 to %c2 step %c2 { scf.for %arg1 = %4 to %c4 step %c4 { - %6 = flow.dispatch.input.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x3xf32> - %7 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<3x1xf32> - %8 = flow.dispatch.input.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input -> tensor<1x1xf32> + %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x3xf32> + %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<1x1xf32> %shape = linalg.init_tensor [1, 1] : tensor<1x1xf32> %9 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<1x1xf32>) outs(%shape : tensor<1x1xf32>) { @@ -133,7 +133,7 @@ func @tile_from_pointwise_outs() { linalg.yield %arg2 : f32 } -> tensor<1x1xf32> %10 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32> - flow.dispatch.output.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.output + flow.dispatch.tensor.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor } } return @@ -167,10 +167,10 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @bufferize_dynamic() { %c0 = constant 0 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %4 = hal.interface.load.constant offset = 0 : index %5 = hal.interface.load.constant offset = 1 : index %6 = hal.interface.load.constant offset = 2 : index @@ -180,13 +180,13 @@ func @bufferize_dynamic() { %10 = hal.interface.load.constant offset = 6 : index %11 = hal.interface.load.constant offset = 7 : index %12 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?]> - %13 = flow.dispatch.tie_shape %0, %12 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %13 = flow.dispatch.tie_shape %0, %12 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %14 = shapex.make_ranked_shape %6, %7 : (index, index) -> !shapex.ranked_shape<[?,?]> - %15 = flow.dispatch.tie_shape %1, %14 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %15 = flow.dispatch.tie_shape %1, %14 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %16 = shapex.make_ranked_shape %8, %9 : (index, index) -> !shapex.ranked_shape<[?,?]> - %17 = flow.dispatch.tie_shape %2, %16 : (!flow.dispatch.input, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.input + %17 = flow.dispatch.tie_shape %2, %16 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %18 = shapex.make_ranked_shape %10, %11 : (index, index) -> !shapex.ranked_shape<[?,?]> - %19 = flow.dispatch.tie_shape %3, %18 : (!flow.dispatch.output, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.output + %19 = flow.dispatch.tie_shape %3, %18 : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor %workgroup_size_x = hal.interface.workgroup.size[0] : index %workgroup_size_y = hal.interface.workgroup.size[1] : index %workgroup_id_x = hal.interface.workgroup.id[0] : index @@ -200,14 +200,14 @@ func @bufferize_dynamic() { %23 = muli %workgroup_size_x, %workgroup_count_x : index scf.for %arg1 = %22 to %7 step %23 { %24 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%4, %workgroup_size_y] - %25 = flow.dispatch.input.load %13, offsets = [%arg0, %c0], sizes = [%24, %5], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %25 = flow.dispatch.tensor.load %13, offsets = [%arg0, %c0], sizes = [%24, %5], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%7, %workgroup_size_x] - %27 = flow.dispatch.input.load %15, offsets = [%c0, %arg1], sizes = [%6, %26], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %27 = flow.dispatch.tensor.load %15, offsets = [%c0, %arg1], sizes = [%6, %26], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %28 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %8] %29 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %9] - %30 = flow.dispatch.input.load %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : !flow.dispatch.input -> tensor + %30 = flow.dispatch.tensor.load %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor %31 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%25, %27 : tensor, tensor) outs(%30 : tensor) -> tensor - flow.dispatch.output.store %31, %19, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %31, %19, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : tensor -> !flow.dispatch.tensor } } return @@ -271,13 +271,13 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { // %c0 = constant 0 : index // %c2 = constant 2 : index // %c1 = constant 1 : index -// %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<2x3xf32> -// %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<3x4xf32> -// %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input<2x4xf32> -// %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<2x4xf32> -// %4 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> -// %5 = flow.dispatch.input.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> -// %6 = flow.dispatch.input.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> +// %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor +// %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor +// %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor +// %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor +// %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x3xf32> +// %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> +// %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x1xf32> // %7 = vector.transfer_read %4[%c0, %c0], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> // %8 = vector.transfer_read %4[%c0, %c1], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> // %9 = vector.transfer_read %4[%c0, %c2], %cst {masked = [false, false]} : tensor<2x3xf32>, vector<1x1xf32> @@ -297,7 +297,7 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { // %23 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %12, %15, %22 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32> // %24 = vector.transfer_write %20, %6[%c0, %c0] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> // %25 = vector.transfer_write %23, %24[%c1, %c0] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> -// flow.dispatch.output.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.output<2x4xf32> +// flow.dispatch.tensor.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor // return // } // } @@ -310,11 +310,11 @@ func @reshape_simple() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> - flow.dispatch.output.store %3, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %3, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -338,9 +338,9 @@ func @reshape_fused_source() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> %4 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %5 = linalg.generic { @@ -351,7 +351,7 @@ func @reshape_fused_source() { %6 = addi %arg0, %arg0 : i32 linalg.yield %6 : i32 } -> tensor<3x4xi32> - flow.dispatch.output.store %5, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %5, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -378,10 +378,10 @@ func @reshape_fused_source_and_copyout() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<12xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x4xi32> - %2 = hal.interface.binding.subspan @legacy_io::@ret1[%c0] : !flow.dispatch.output<3x4xi32> - %3 = flow.dispatch.input.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.input<12xi32> -> tensor<12xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret1[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor -> tensor<12xi32> %4 = linalg.tensor_reshape %3 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32> %5 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %6 = linalg.generic { @@ -392,8 +392,8 @@ func @reshape_fused_source_and_copyout() { %7 = addi %arg0, %arg0 : i32 linalg.yield %7 : i32 } -> tensor<3x4xi32> - flow.dispatch.output.store %6, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> - flow.dispatch.output.store %4, %2, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.output<3x4xi32> + flow.dispatch.tensor.store %6, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor + flow.dispatch.tensor.store %4, %2, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -425,9 +425,9 @@ func @reshape_fused_target() { %c3 = constant 3 : index %c4 = constant 4 : index %c12 = constant 12 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<3x4xi32> - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<12xi32> - %2 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : !flow.dispatch.input<3x4xi32> -> tensor<3x4xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x4xi32> %3 = linalg.init_tensor [3, 4] : tensor<3x4xi32> %4 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -438,7 +438,7 @@ func @reshape_fused_target() { linalg.yield %5 : i32 } -> tensor<3x4xi32> %5 = linalg.tensor_reshape %4 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<3x4xi32> into tensor<12xi32> - flow.dispatch.output.store %5, %1, offsets = [%c0], sizes = [%c12], strides = [%c1] : tensor<12xi32> -> !flow.dispatch.output<12xi32> + flow.dispatch.tensor.store %5, %1, offsets = [%c0], sizes = [%c12], strides = [%c1] : tensor<12xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -467,10 +467,10 @@ func @dot_general_lowering() { %c0 = constant 0 : index %c2 = constant 2 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<1x1x2xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<2x3xf32> - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<1x3xf32> - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input<1x1x2xf32> -> tensor<1x1x2xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor<1x1x2xf32> %4 = linalg.tensor_reshape %3 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : tensor<1x1x2xf32> into tensor<1x2xf32> %workgroup_size_x = hal.interface.workgroup.size[0] : index %workgroup_size_y = hal.interface.workgroup.size[1] : index @@ -487,11 +487,11 @@ func @dot_general_lowering() { %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_y] %10 = subtensor %4[%arg0, 0] [%9, 2] [1, 1] : tensor<1x2xf32> to tensor %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_x] - %12 = flow.dispatch.input.load %1, offsets = [%c0, %arg1], sizes = [%c2, %11], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x?xf32> + %12 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c2, %11], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x?xf32> %13 = linalg.init_tensor [%9, %11] : tensor %14 = linalg.fill(%13, %cst) : tensor, f32 -> tensor %15 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%10, %12 : tensor, tensor<2x?xf32>) outs(%14 : tensor) -> tensor - flow.dispatch.output.store %15, %2, offsets = [%arg0, %arg1], sizes = [%9, %11], strides = [%c1, %c1] : tensor -> !flow.dispatch.output<1x3xf32> + flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%9, %11], strides = [%c1, %c1] : tensor -> !flow.dispatch.tensor } } return @@ -520,11 +520,11 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @slice_whole_stride_dispatch_0() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %2 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %2 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor %3 = subtensor %2[1, 0] [1, 4] [1, 1] : tensor to tensor<1x4xi32> - flow.dispatch.output.store %3, %1 : tensor<1x4xi32> -> !flow.dispatch.output + flow.dispatch.tensor.store %3, %1 : tensor<1x4xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -542,15 +542,15 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @subtensor_insert() { %c0 = constant 0 : index %c1 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor - %4 = flow.dispatch.input.load %1 : !flow.dispatch.input -> tensor + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor + %4 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor -> tensor %5 = dim %3, %c0 : tensor %6 = dim %3, %c1 : tensor %7 = subtensor_insert %3 into %4[3, 4] [%5, %6] [1, 1] : tensor into tensor - flow.dispatch.output.store %7, %2 : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %7, %2 : tensor -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { @@ -574,13 +574,13 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { func @tensor_extract() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input - %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<3x9xi32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %2 = linalg.init_tensor [3, 9] : tensor<3x9xi32> - %3 = flow.dispatch.input.load %0 : !flow.dispatch.input -> tensor + %3 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor %4 = tensor.extract %3[] : tensor - %5 = linalg.fill(%2, %4) : tensor<3x9xi32>, i32 -> tensor<3x9xi32> - flow.dispatch.output.store %5, %1 : tensor<3x9xi32> -> !flow.dispatch.output<3x9xi32> + %5 = linalg.fill(%2, %4) : tensor<3x9xi32>, i32 -> tensor<3x9xi32> + flow.dispatch.tensor.store %5, %1 : tensor<3x9xi32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/HLOToHLO/test/BUILD b/iree/compiler/Conversion/HLOToHLO/test/BUILD index 1e3b7bba341b..ab0ba6be34c1 100644 --- a/iree/compiler/Conversion/HLOToHLO/test/BUILD +++ b/iree/compiler/Conversion/HLOToHLO/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conv1x12dot.mlir", + "f32Tof16.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt index 420c06116ccb..50035da4b3e0 100644 --- a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt +++ b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conv1x12dot.mlir" + "f32Tof16.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/HLOToLinalg/test/BUILD b/iree/compiler/Conversion/HLOToLinalg/test/BUILD index 1e3b7bba341b..4b1bc253c185 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/BUILD +++ b/iree/compiler/Conversion/HLOToLinalg/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,28 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_ops.mlir", + "concatenate.mlir", + "decompose_hlo_clamp.mlir", + "depthwise_conv.mlir", + "dot.mlir", + "dynamic_shape.mlir", + "exp.mlir", + "fusion.mlir", + "linalg_tensor_to_buffer.mlir", + "pad_tensor.mlir", + "pad_tensor_to_tensor.mlir", + "pipeline_test.mlir", + "reduce.mlir", + "reduce_window.mlir", + "subtensor.mlir", + "subtensor_insert.mlir", + "torch_index_select.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt index 927972fd2126..2f24788e795a 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt +++ b/iree/compiler/Conversion/HLOToLinalg/test/CMakeLists.txt @@ -10,12 +10,27 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_ops.mlir" + "concatenate.mlir" + "decompose_hlo_clamp.mlir" + "depthwise_conv.mlir" + "dot.mlir" + "dynamic_shape.mlir" + "exp.mlir" + "fusion.mlir" + "linalg_tensor_to_buffer.mlir" + "pad_tensor.mlir" + "pad_tensor_to_tensor.mlir" + "pipeline_test.mlir" + "reduce.mlir" + "reduce_window.mlir" + "subtensor.mlir" + "subtensor_insert.mlir" + "torch_index_select.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LLVMToLLVM/BUILD b/iree/compiler/Conversion/LLVMToLLVM/BUILD deleted file mode 100644 index 165664750db2..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "LLVMToLLVM", - srcs = [ - "FastExpConversion.cpp", - ], - hdrs = [ - "Passes.h", - ], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt deleted file mode 100644 index 583a8fbb3f90..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Conversion/LLVMToLLVM/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_cc_library( - NAME - LLVMToLLVM - HDRS - "Passes.h" - SRCS - "FastExpConversion.cpp" - DEPS - LLVMSupport - MLIRIR - MLIRLLVMIR - MLIRPass - MLIRTransforms - PUBLIC -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp deleted file mode 100644 index df8451793afd..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace { - -// Fast polynomial approximation of exp(x) using its reduced range exp(y) -// where y is in the range [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) -// = x - k * ln(2), exp(x) = exp(y) * 2^k. exp(y) is computed with 4th degree -// polyomial: exp(y) = c0 + c1 * y + c2 * y^2 + c3 * y^3 + c4 * y^4 -struct FastExpConversionPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LLVM::ExpOp op, - PatternRewriter &rewriter) const override { - constexpr float ln2Const = 0.693147181f; - constexpr float ln2InvConst = 1.44269504f; - - // Least squares polynomial fit computed : - // cValues = np.polyfit(np.linspace(0, math.log(2), 10000), np.exp(x), 4) - constexpr float cValues[5] = {0.05924867f, 0.15514645f, 0.50308552f, - 0.99968939f, 1.00000721531f}; - auto loc = op.getLoc(); - Value x = op.getOperand(); - - auto floatType = Float32Type::get(rewriter.getContext()); - auto i32Type = IntegerType::get(rewriter.getContext(), 32); - - Value ln2 = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(ln2Const)); - Value ln2Inv = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(ln2InvConst)); - - // Compute reduced range input y = x - floor(x / ln(2)) * ln(2) - Value xL2Inv = rewriter.create(loc, floatType, x, ln2Inv); - Value kF32 = rewriter.create(loc, floatType, xL2Inv); - Value kLn2 = rewriter.create(loc, floatType, kF32, ln2); - Value y = rewriter.create(loc, floatType, x, kLn2); - - SmallVector PConst(5); - for (int i = 0; i < 5; ++i) { - PConst[i] = rewriter.create( - loc, floatType, rewriter.getF32FloatAttr(cValues[i])); - } - // Evaluate exp(y) = sum(c[i] * y**i, i) - Value expY = rewriter.create(loc, floatType, y, PConst[0]); - expY = rewriter.create(loc, floatType, expY, PConst[1]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[2]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[3]); - expY = rewriter.create(loc, floatType, expY, y); - expY = rewriter.create(loc, floatType, expY, PConst[4]); - - // Compute exp2(k) with integer bitshift: - // exp2(k) = f32_bitcast((127 + k) << 23) - Value fPBias = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(127)); - Value k = rewriter.create(loc, i32Type, kF32); - Value kPlusfPBias = rewriter.create(loc, i32Type, k, fPBias); - Value shiftConst = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(23)); - Value twoPowkI = - rewriter.create(loc, i32Type, kPlusfPBias, shiftConst); - Value twoPowk = rewriter.create(loc, floatType, twoPowkI); - expY = rewriter.create(loc, floatType, expY, twoPowk); - rewriter.replaceOp(op, {expY}); - // TODO(ataei): Handle overflow and underflow cases (e.g |k| > 128). - return success(); - } -}; - -struct FastExpConversionPass - : public PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override; -}; - -} // namespace - -void populateFastExpConversionPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { - patterns.insert(context); -} - -void FastExpConversionPass::runOnOperation() { - auto moduleOp = getOperation(); - auto context = moduleOp.getContext(); - OwningRewritePatternList patterns; - populateFastExpConversionPatterns(patterns, context); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); -} - -std::unique_ptr> -createFastExpApproximationConversionPass() { - return std::make_unique(); -} - -static PassRegistration> pass( - "iree-codegen-linalg-to-llvm-fast-exp-conversion-pass", - "Convert llvm.intr.exp into its fast polynomial approximation version", - [] { return std::make_unique(); }); - -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Conversion/LLVMToLLVM/Passes.h b/iree/compiler/Conversion/LLVMToLLVM/Passes.h deleted file mode 100644 index e40fe90842b8..000000000000 --- a/iree/compiler/Conversion/LLVMToLLVM/Passes.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ -#define IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { - -// Creates a pass to rewrite llvm.intr.exp using its reduced range polynomial -// approximation. -std::unique_ptr> -createFastExpApproximationConversionPass(); - -} // namespace iree_compiler -} // namespace mlir -#endif // IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_ diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD index ce8cd2b22d31..cbbeb4588e18 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/BUILD +++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD @@ -40,7 +40,6 @@ cc_library( "//iree/compiler/Conversion/Common", "//iree/compiler/Conversion/HLOToHLO", "//iree/compiler/Conversion/HLOToLinalg", - "//iree/compiler/Conversion/LLVMToLLVM", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/IREE/IR", diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt index 32bb64d19482..cc3cccd89083 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt @@ -51,7 +51,6 @@ iree_cc_library( iree::compiler::Conversion::Common iree::compiler::Conversion::HLOToHLO iree::compiler::Conversion::HLOToLinalg - iree::compiler::Conversion::LLVMToLLVM iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::IREE::IR diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index 140a7e320547..37e1bc0a695d 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp @@ -642,6 +642,13 @@ void ConvertToLLVMPass::runOnOperation() { std::move(vectorToLoopsPatterns)); } + // math dialect elementry functions -> polynomial form. + { + OwningRewritePatternList mathPatterns; + populateMathPolynomialApproximationPatterns(mathPatterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(mathPatterns)); + } + auto module = getOperation(); LLVMTypeConverter converter(&getContext()); diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index c080d5190ddf..2fa6794b5dbc 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp @@ -17,7 +17,6 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/Common/Passes.h" #include "iree/compiler/Conversion/HLOToHLO/Passes.h" -#include "iree/compiler/Conversion/LLVMToLLVM/Passes.h" #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h" #include "iree/compiler/Dialect/Shape/Transforms/Passes.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" @@ -40,12 +39,6 @@ static llvm::cl::opt convImg2ColConversion( "linag.matmul"), llvm::cl::init(false)); -static llvm::cl::opt fastExpConversion( - "iree-codegen-linalg-to-llvm-fast-exp", - llvm::cl::desc("If true convert llvm.intr.exp into its range reduced " - "polynomial approximation."), - llvm::cl::init(false)); - void addLinalgToLLVMPasses(OpPassManager &passManager) { // Distribute linalg op among a 3d grid of parallel threads. Tile each // workgroup thread memory then vectorize the linalg op. @@ -85,11 +78,6 @@ void addLinalgToLLVMPasses(OpPassManager &passManager) { nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addPass(createCSEPass()); - - // Approximate llvm.intr.exp with a 4-th order ploynmial in range[0, ln2]. - if (fastExpConversion) { - nestedModulePM.addPass(createFastExpApproximationConversionPass()); - } } void buildLLVMTransformPassPipeline(OpPassManager &passManager) { diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD index 1e3b7bba341b..7df28f0a92f5 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD +++ b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,20 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conv_img2col.mlir", + "hal_interface_bindings.mlir", + "hal_interface_constants.mlir", + "hal_interface_workgroup_info.mlir", + "linalg_vectorize.mlir", + "materialize_launch_configuration.mlir", + "matmul_vectorization.mlir", + "plan_conv_loop_order.mlir", + "tile_and_distribute.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt index 725edceb656d..661c4ab205c5 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt @@ -10,12 +10,19 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conv_img2col.mlir" + "hal_interface_bindings.mlir" + "hal_interface_constants.mlir" + "hal_interface_workgroup_info.mlir" + "linalg_vectorize.mlir" + "materialize_launch_configuration.mlir" + "matmul_vectorization.mlir" + "plan_conv_loop_order.mlir" + "tile_and_distribute.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir index 0cccf12665ad..8512b0c5de88 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir @@ -4,9 +4,9 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index -// CHECK: %[[I0:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> -// CHECK: %[[I1:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> -// CHECK: %[[I2:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> +// CHECK: %[[I0:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<2x3xf32> +// CHECK: %[[I1:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<3x1xf32> +// CHECK: %[[I2:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor -> tensor<2x1xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C0]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C1]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> // CHECK: %[[V2:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C2]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32> @@ -26,21 +26,21 @@ // CHECK: %[[D5:.*]] = vector.contract {{.*}} %[[V5]], %[[V8]], %[[D4]] : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32> // CHECK: %[[W0:.*]] = vector.transfer_write %[[D2]], %[[I2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> // CHECK: %[[W1:.*]] = vector.transfer_write %[[D5]], %[[W0]][%[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32> -// CHECK: flow.dispatch.output.store %[[W1]] +// CHECK: flow.dispatch.tensor.store %[[W1]] func @tensor_dispatch_0() { %c0 = constant 0 : index %c3 = constant 3 : index %c1 = constant 1 : index %c2 = constant 1 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<2x3xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<3x4xf32> - %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input<2x4xf32> - %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<2x4xf32> - %4 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32> - %5 = flow.dispatch.input.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32> - %6 = flow.dispatch.input.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor + %3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor + %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x3xf32> + %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<3x1xf32> + %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor -> tensor<2x1xf32> %7 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%4, %5 : tensor<2x3xf32>, tensor<3x1xf32>) outs(%6 : tensor<2x1xf32>) -> tensor<2x1xf32> - flow.dispatch.output.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.output<2x4xf32> + flow.dispatch.tensor.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor return } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir index e2759c52151e..5c6018f3ef22 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir @@ -9,8 +9,8 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_tensors attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_tensors() { %c0 = constant 0 : index @@ -97,8 +97,8 @@ hal.executable @add attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @add attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @add() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir index 20979c8e0b86..5260df30969c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir @@ -10,8 +10,8 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_128x128x128 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<128x128xf32>, !flow.dispatch.input<128x128xf32>, - !flow.dispatch.output<128x128xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) { linalg.matmul ins(%arg0, %arg1 : memref<128x128xf32>, memref<128x128xf32>) outs(%arg2 : memref<128x128xf32>) @@ -91,8 +91,8 @@ hal.executable @dynamic_matmul_i8_i8_i32 attributes {sym_visibility = "private"} hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_i8_i8_i32_128x128x128 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<128x128xi8>, !flow.dispatch.input<128x128xi8>, - !flow.dispatch.output<128x128xi32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @matmul_i8_i8_i32_128x128x128(%arg0 : memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { linalg.matmul_i8_i8_i32 ins(%arg0, %arg1 : memref<128x128xi8>, memref<128x128xi8>) outs(%arg2 : memref<128x128xi32>) diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir index 4868d059698f..ed84b770f19a 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir @@ -10,8 +10,8 @@ // hal.executable.target @llvm_aot, filter="dylib*" { // hal.executable.entry_point @dynamic_matmul attributes { // interface = @legacy_io, ordinal = 0 : i32, -// signature = (!flow.dispatch.input, !flow.dispatch.input, -// !flow.dispatch.output) -> ()} +// signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, +// !flow.dispatch.tensor) -> ()} // module { // func @dynamic_matmul(%lhs: memref, %rhs: memref, %result: memref) { // linalg.matmul ins(%lhs, %rhs : memref, memref) outs(%result : memref) @@ -58,8 +58,8 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @static_matmul attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<16x4xf32>, !flow.dispatch.input<4x8xf32>, - !flow.dispatch.output<16x8xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @static_matmul(%lhs: memref<16x4xf32>, %rhs: memref<4x8xf32>, %result: memref<16x8xf32>) { linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>) @@ -84,4 +84,5 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { // CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]] // CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]> // CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]> -// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) +// CHECK: linalg.matmul +//CHECK-SAME: ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD index 675ed9130dc0..a8cd50f554f3 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD +++ b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_to_nvvm.mlir", + "pipeline_test.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt index 3779e3cabe7d..2d850ad956ec 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_to_nvvm.mlir" + "pipeline_test.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir index 170dee689660..9d6a7fd3a645 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir +++ b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir @@ -1,6 +1,6 @@ // RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-nvvm-pipeline))" %s | IreeFileCheck %s -// Verify that a simple element wise op gets lowered succefully all the way to +// Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect. hal.executable @simpleMath_ex_dispatch_0 { @@ -9,22 +9,22 @@ hal.executable @simpleMath_ex_dispatch_0 { hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" } hal.executable.target @cuda, filter="cuda" { - hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<16xf32>, !flow.dispatch.input<16xf32>, !flow.dispatch.output<16xf32>) -> ()} + hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @add_dispatch_0() { %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<16xf32> - %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<16xf32> - %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<16xf32> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor %3 = linalg.init_tensor [16] : tensor<16xf32> - %4 = flow.dispatch.input.load %0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %5 = flow.dispatch.input.load %1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> + %4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor -> tensor<16xf32> + %5 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor -> tensor<16xf32> %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors %7 = addf %arg0, %arg1 : f32 linalg.yield %7 : f32 } -> tensor<16xf32> - flow.dispatch.output.store %6, %2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> + flow.dispatch.tensor.store %6, %2 : tensor<16xf32> -> !flow.dispatch.tensor return } hal.interface @legacy_io attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD index b224b665d763..2aab244cd670 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD @@ -34,11 +34,13 @@ cc_library( cc_library( name = "LinalgToSPIRV", srcs = [ + "ConcretizeTileAmongWorkgroupsPass.cpp", "ConvertToGPUPass.cpp", "ConvertToSPIRVPass.cpp", "CooperativeMatrixAnalysis.cpp", "FoldGPUProcessorIDUses.cpp", "KernelDispatchUtils.cpp", + "LinalgTileAndDistributePass.cpp", "LinalgTileAndFusePass.cpp", "MatMulVectorizationTest.cpp", "Passes.cpp", @@ -61,6 +63,7 @@ cc_library( "//iree/compiler/Conversion/HLOToHLO", "//iree/compiler/Conversion/HLOToLinalg", "//iree/compiler/Conversion/LinalgToVector", + "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/IREE/IR", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt index e51da222e2eb..7046ccca4ae4 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt @@ -32,11 +32,13 @@ iree_cc_library( "Passes.h" "Utils.h" SRCS + "ConcretizeTileAmongWorkgroupsPass.cpp" "ConvertToGPUPass.cpp" "ConvertToSPIRVPass.cpp" "CooperativeMatrixAnalysis.cpp" "FoldGPUProcessorIDUses.cpp" "KernelDispatchUtils.cpp" + "LinalgTileAndDistributePass.cpp" "LinalgTileAndFusePass.cpp" "MatMulVectorizationTest.cpp" "Passes.cpp" @@ -74,6 +76,7 @@ iree_cc_library( iree::compiler::Conversion::HLOToHLO iree::compiler::Conversion::HLOToLinalg iree::compiler::Conversion::LinalgToVector + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::IREE::IR diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp new file mode 100644 index 000000000000..5a68092a010f --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -0,0 +1,568 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- ConcretizeTileAmongWorkgroupsPass.cpp ------------------------------===// +// +// This pass concretizes hal.interface.workgroup ops by replacing them with +// constant values from the chosen tiling and distribution scheme. +// +// During dispatch region formation in IREE Flow transformations, ops are tiled +// and distributed in an abstract way by using symbolic hal.interface.workgroup +// ops. That is because the same source region is compiled towards different +// target backends and each target backend could use different tiling and +// distribution schemes. However, after HAL interface materialization, the +// hal.executable.target is just meant for one target backend. We need to +// concretize the tiling and distribution in order to inject static information +// for further compilation. +// +// This pass performs the conretization in two modes: +// +// 1) Partically static: where have a concrete tiling and distirbution sheme +// *but not* a full static original problem size (e.g., due to dynamic +// shapes). Under such circumstances, we can only replace ops like +// hal.interface.workgroup.size ops and still need to compute the number +// of workgroups using symbolic values. +// 2) Fully static: where we have a concrete tiling and distribution scheme +// *and* the full static original problem size. Under such circumstances, +// we can fully deduce the number of workgroups to dispatch and replace +// hal.interface.workgroup.count ops with constant values too. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/Common/LaunchConfig.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-spirv-concretize-tile-among-workgroups" + +namespace mlir { +namespace iree_compiler { + +namespace { + +constexpr unsigned kWorkgroupDimCount = 3; + +int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } + +static size_t getNumOuterParallelDims(linalg::LinalgOp op) { + ArrayRef iterators = op.iterator_types().getValue(); + auto parallels = iterators.take_while( + [](Attribute attr) { return linalg::isParallelIteratorType(attr); }); + return parallels.size(); +} + +/// Returns the root Linalg op that dictates tiling and distribution policy. +linalg::LinalgOp getRootLinalgOp(FuncOp funcOp) { + SmallVector linalgOps; + SmallVector tiledLoops; + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) return {}; + + SPIRVCodegenOptions options; + options.enableVectorization = true; + options.usingLinalgOnTensors = true; + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = initGPULaunchConfig( + funcOp.getContext(), dependenceGraph, options, linalgOps); + if (!launchConfigOpt) return {}; + + LaunchConfig &launchConfig = *launchConfigOpt; + Operation *rootOp = + launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( + linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); }))); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig.finalize(funcOp); + + return rootOp; +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// original input/output types for all tiles. +/// +/// Note: After the abstract tiling and distribution in Flow dispatch region +/// creation, the anchoring root op is already in a loop nest and works on a +/// tile. The full type for all tiles in the IR is not explicit anymore after +/// HAL interface is materialized. So go through the IR use chain to figure it +/// out. Otherwise we need to make even more assumptions in the following. +// TODO(antiagainst): This is quite fragile. We need a better way to pass the +// information down from the upper layer, which readily has it. Probably via +// linalg.tile op. +LogicalResult getInputOutputTypesForAllTiles( + linalg::LinalgOp rootOp, SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes) { + for (Value inputBuffer : rootOp.getInputBuffers()) { + auto subviewOp = inputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + inputTypes.push_back(subviewOp.getViewSource().getType()); + } + + for (Value outputBuffer : rootOp.getOutputBuffers()) { + auto subviewOp = outputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + outputTypes.push_back(subviewOp.getViewSource().getType()); + } + + return success(); +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// tile sizes for distributing to workgroups and the workgroups size for the +/// generated kernel. +/// +/// TODO(antiagainst): This pass can be shared between CPU and GPU. But the +/// following query scopes it to GPU for now. +llvm::Optional< + std::pair, llvm::SmallVector>> +getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, + ArrayRef outputTypes) { + // Build necesary structures to query the tile sizes for distributing to + // workgroups. + linalg::Aliases aliases; + SmallVector linalgOps; + auto ops = rootOp->getBlock()->getOps(); + linalgOps.assign(ops.begin(), ops.end()); + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + SPIRVCodegenOptions options; + + // NOTE: Launch configuration expects the original input/output type to decide + // the configuration. But we have already tiled the Linalg ops here. Use an + // attribute to send it over for now. + const char inputTypeAttrName[] = "iree.codegen.original_input_types"; + const char outputTypeAttrName[] = "iree.codegen.original_output_types"; + if (!inputTypes.empty()) { + rootOp->setAttr(inputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(inputTypes)); + } + if (!outputTypes.empty()) { + rootOp->setAttr(outputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(outputTypes)); + } + + Optional launchConfig = initGPULaunchConfig( + rootOp->getContext(), dependenceGraph, options, linalgOps); + if (!launchConfig) { + rootOp->emitError("unable to find launch configuration"); + return llvm::None; + } + + ArrayRef tileSize = launchConfig->getTileSizes(rootOp, 0); + ArrayRef workgroupSize = launchConfig->getWorkgroupSize(); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig->finalize(rootOp->getParentOfType()); + + return std::make_pair(llvm::to_vector<4>(tileSize), + llvm::to_vector<4>(workgroupSize)); +} + +/// Replaces hal.interface.workgroup.size op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupSizeOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupSizeOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(std::move(workloadSize)), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex < kWorkgroupDimCount && tileSize[dimIndex] != 0) { + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexAttr(tileSize[dimIndex])); + return success(); + } + + return failure(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Replaces hal.interface.workgroup.count op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupCountOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupCountOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex >= kWorkgroupDimCount) return failure(); + + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize || dimTile == 0) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexAttr(count)); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +// Canonicalizes away a trip-one scf.for loop by inlining its body and removing +// the loop. +// +// This pattern is needed because in Flow abstract tiling and distribution we +// will create scf.for loops that distribute workload cyclically. After +// concretizing hal.interface.workgroup.* ops, these scf.for loops still remain, +// and they will be of the form: +// +// %lb = mul %workgroup_id_{x|y|z}, %cst_tile_size_{x|y|z} +// scf.for %iv = %lb to %cst_wokload_size_{x|y|z} +// step %cst_workload_size_{x|y|z} { ... } +// +// Such scf.for loops can be inlined if %lb is smaller than upper bound. +class RemoveTripOneLoop final : public OpRewritePattern { + public: + RemoveTripOneLoop(MLIRContext *context, SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + // Get constant upper bound and step values. + IntegerAttr ub, step; + if (!matchPattern(op.upperBound(), m_Constant(&ub)) || + !matchPattern(op.step(), m_Constant(&step))) { + return failure(); + } + + // Require that they are the same. + if (ub != step) return failure(); + + // Now make sure the lower bound is smaller than upper bound. The lower + // bound should be multiplying the workgroup ID with some constant. + + auto mulOp = op.lowerBound().getDefiningOp(); + if (!mulOp || mulOp.mapOperands().size() != 2) return failure(); + + AffineExpr lhs, rhs; + bindSymbols(op.getContext(), lhs, rhs); + auto mulMap = AffineMap::get(0, 2, lhs * rhs); + if (mulOp.getAffineMap() != mulMap) return failure(); + + auto mulLhs = mulOp.mapOperands().front(); + auto mulRhs = mulOp.mapOperands().back(); + + auto idOp = mulLhs.getDefiningOp(); + IntegerAttr multipler; + if (!idOp || !matchPattern(mulRhs, m_Constant(&multipler))) + return failure(); + + // We just need to make sure the max value of the workgroup ID multipled by + // the multipler is smaller than the upper bound to guarantee one trip. + unsigned dimIndex = idOp.dimension().getZExtValue(); + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + assert(count > 0 && "expected at least one tile!"); + + // ID should be in range [0, count). + if ((count - 1) * multipler.getInt() >= ub.getInt()) { + // Dead loop. It can actually be removed entirely. But we aren't expecting + // it to happen here. Do not canonicalize for such case. + return failure(); + } + + SmallVector blockArgs; + blockArgs.reserve(op.getNumIterOperands() + 1); + blockArgs.push_back(op.lowerBound()); + llvm::append_range(blockArgs, op.getIterOperands()); + + Block *block = &op.getLoopBody().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Concretizes hal.interface.workgroup.* ops with constants from the chosen +/// tiling sheme when possible and perform loop canonicalization afterwards. +class ConcretizeTileAmongWorkgroupsPass + : public PassWrapper> { + public: + ConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) + : options(options) {} + ConcretizeTileAmongWorkgroupsPass( + const ConcretizeTileAmongWorkgroupsPass &that) + : options(that.options) { + inlineTripOneLoops = that.inlineTripOneLoops; + } + + void runOnOperation() override { + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!funcOp.isPublic()) continue; + if (failed(runOnFunction(funcOp))) return signalPassFailure(); + } + } + + private: + LogicalResult runOnFunction(FuncOp funcOp) { + MLIRContext &context = getContext(); + + // 1. Get the root op first. We need it to figure out the original problem + // size, which then affects the tiling and distribution policy. + + linalg::LinalgOp rootOp = getRootLinalgOp(funcOp); + if (!rootOp) { + LLVM_DEBUG(llvm::dbgs() << "unable to find root Linalg op\n"); + // It can happen for ops that are not abstractly tiled during dispatch + // region formation. So don't trigger pass failure. + return success(); + } + LLVM_DEBUG(llvm::dbgs() << "Root op: " << rootOp << "\n"); + + size_t numTilableDims = getNumOuterParallelDims(rootOp); + + // 2. Figure out the original problem size. + + SmallVector inputTypes, outputTypes; + SmallVector workloadSize; + if (succeeded( + getInputOutputTypesForAllTiles(rootOp, inputTypes, outputTypes))) { + if (outputTypes.size() != 1) { + return rootOp.emitError("only support ops with one result right now"); + } + + // Flow/HAL processor id/size/count ops' indices follow the reverse order + // of the shape dimensions. + workloadSize = llvm::to_vector<4>(llvm::reverse( + outputTypes.front().cast().getShape().take_front( + numTilableDims))); + } else { + // This can happen for dynamic shapes. + LLVM_DEBUG(llvm::dbgs() + << "unable to find input/output type for all tiles"); + + inputTypes.clear(); + outputTypes.clear(); + + workloadSize.assign(numTilableDims, ShapedType::kDynamicSize); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried workload size: "; + llvm::interleaveComma(workloadSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 3. Query the scheme for tiling among workgroups. + + SmallVector tileSize; + SmallVector workgroupSize; + + // Try to use configuration from the command-line first for testing. + tileSize.assign(options.tileSizes.begin(), options.tileSizes.end()); + tileSize.resize(numTilableDims, 0); + workgroupSize.assign(options.workgroupSize.begin(), + options.workgroupSize.end()); + if (tileSize.empty() || workgroupSize.empty()) { + auto sizes = getTileSizeAndWorkgroupSize(rootOp, inputTypes, outputTypes); + if (sizes) { + // The tile sizes are specified against the original dimension order of + // the workload shape. But Flow/HAL processor id/size/count ops' are + // created using the reverse order. + tileSize = sizes->first; + tileSize.resize(numTilableDims); + tileSize = llvm::to_vector<4>(llvm::reverse(tileSize)); + workgroupSize = sizes->second; + } else { + return funcOp.emitError("failed to query tile size and workgroup size"); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried tile size: "; + llvm::interleaveComma(tileSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 4. Replace hal.interface.workgroup symbolic ops with constant values. + + { + OwningRewritePatternList patterns; + patterns.insert( + &context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + LLVM_DEBUG({ + llvm::dbgs() + << "--- After concretizing hal.interface.workgroup ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + // 5. Set the entry point region for computing the number of workgroups + // to dispatch. The region has symbolic arguments representing the workload. + // So two modes here (see comments at the begining of this file). + + { + SmallVector numWorkgroups; + for (auto pair : llvm::zip(workloadSize, tileSize)) { + auto workload = std::get<0>(pair); + auto tile = std::get<1>(pair); + if (workload == ShapedType::kDynamicSize || tile == 0) { + numWorkgroups.push_back(ShapedType::kDynamicSize); + } else { + numWorkgroups.push_back(ceilDiv(workload, tile)); + } + } + + numWorkgroups.resize(kWorkgroupDimCount, 1); + + // If all dimensions are known constant, then we can set the number of + // workgroups directly. Otherwise, we need to generate the IR for + // computing it using symbolic values. + if (llvm::none_of(numWorkgroups, [](int64_t dim) { + return dim == ShapedType::kDynamicSize; + })) { + OpBuilder builder(&context); + WorkgroupCountRegionBuilder regionBuilder = + [&](OpBuilder &builder, Location loc, std::array) { + std::array returnValues; + for (unsigned i = 0; i < kWorkgroupDimCount; ++i) { + returnValues[i] = + builder.create(loc, numWorkgroups[i]); + } + return returnValues; + }; + if (failed( + defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) { + return funcOp.emitError( + "failed to set entry point region for number of workgroups"); + } + } else { + if (failed(materializeStaticLaunchInformation(funcOp, tileSize))) { + return funcOp.emitOpError( + "failed to materialize static launch information"); + } + } + } + + if (failed(updateWorkGroupSize(funcOp, workgroupSize))) { + return funcOp.emitOpError("failed to set workgroup size on function"); + } + + // 6. Canonicalization and clean up. + + if (inlineTripOneLoops) { + OwningRewritePatternList patterns; + patterns.insert(&context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + return success(); + } + + private: + SPIRVCodegenOptions options; + + // TODO(#5034): Investigate whether there is a better way to prove tileability + // and canonicalize affine.min ops, without matching against the specific + // pattern involving loops. + Option inlineTripOneLoops{ + *this, "inline-trip-one-loops", + llvm::cl::desc( + "Inline a loop's body if it can be proven to just have one trip"), + llvm::cl::init(true)}; +}; + +} // namespace + +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-spirv-concretize-tile-among-workgroups", + "Replace hal.interface.workgroup.* ops with constant values from chosen " + "tiling and distribution scheme", + [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp index d43044f52e40..323832d9e707 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp @@ -20,6 +20,7 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -89,26 +90,34 @@ int dimensionToIndex(StringRef dimension) { return StringSwitch(dimension).Case("x", 0).Case("y", 1).Case("z", 2); } -/// Gets the block processor ID's upper bound. This queries the workgroup count -/// function. -Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { - auto funcOp = blockIDOp->getParentOfType(); +IREE::HAL::ReturnOp getEntryPointReturnOp(Operation *op) { + auto funcOp = op->getParentOfType(); auto targetOp = funcOp.getOperation()->getParentOfType(); - IREE::HAL::ExecutableEntryPointOp entryPointOp = nullptr; + + IREE::HAL::ExecutableEntryPointOp entryPointOp; for (auto op : targetOp.getOps()) { if (op.sym_name() == funcOp.getName()) { entryPointOp = op; break; } } - if (!entryPointOp) return llvm::None; + if (!entryPointOp) return {}; Operation *terminator = entryPointOp.getBlock()->getTerminator(); auto retOp = dyn_cast(terminator); - if (!retOp || retOp.getNumOperands() != 3) return llvm::None; + if (!retOp || retOp.getNumOperands() != 3) return {}; + LLVM_DEBUG(llvm::dbgs() << "workgroup count function return op: " << retOp << "\n"); + return retOp; +} + +/// Gets the block processor ID's upper bound. This queries the workgroup count +/// function. +Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; int index = dimensionToIndex(blockIDOp.dimension()); IntegerAttr attr; @@ -118,6 +127,19 @@ Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { return attr.getInt(); } +Optional getProcessorIDUpperBound( + IREE::HAL::InterfaceWorkgroupIDOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; + + int index = blockIDOp.dimensionAttr().getInt(); + IntegerAttr attr; + if (!matchPattern(retOp.getOperand(index), m_Constant(&attr))) + return llvm::None; + + return attr.getInt(); +} + /// Gets the thread processor ID's upper bound. This queries the SPIR-V entry /// point ABI. Optional getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) { @@ -165,6 +187,9 @@ struct FoldAffineMinOverProcessorID : OpRewritePattern { Optional ub; if (auto blockIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(blockIDOp); + } else if (auto blockIDOp = + dyn_cast(symbolOp)) { + ub = getProcessorIDUpperBound(blockIDOp); } else if (auto threadIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(threadIDOp); } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index 0ef687798f0e..097a0f8f1983 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -121,8 +121,23 @@ static LogicalResult getMaliSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); // Get a vector of best tile size ordered from best to worst. @@ -292,8 +307,21 @@ static LogicalResult getTargetSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); // If the shape size is unknonw fall back to none vectorized path. if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); @@ -375,14 +403,34 @@ template static LogicalResult getMaliSpecificConfig(ConvOpTy op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(1).getType().template cast(); - auto outputType = op.getOutputBufferTypes()[0].template cast(); + Operation *operation = op.getOperation(); + if (!isa(operation)) return failure(); + + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_output_types")) { + auto inputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "conv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "conv output types: " << outputType << "\n"); + } else { + inputType = op.getInputs().front().getType().template cast(); + outputType = op.getOutputBufferTypes()[0].template cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); - // Only support NHWC conv. - if (!isa(op.getOperation())) { - return failure(); - } bool isInputTilable = inputType.getDimSize(3) % 4 == 0 || inputType.getDimSize(3) < 4; @@ -478,8 +526,29 @@ GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp) static LogicalResult getMaliSpecificConfig( linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(0).getType().cast(); - auto outputType = op.getOutputBufferTypes()[0].cast(); + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = + op->getAttrOfType("iree.codegen.original_output_types")) { + auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "dwconv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "dwconv output types: " << outputType << "\n"); + } else { + inputType = op.getInput(0).getType().cast(); + outputType = op.getOutputBufferTypes()[0].cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp new file mode 100644 index 000000000000..3c6d5e35e2bc --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp @@ -0,0 +1,155 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- LinalgTileAndDistributePass.cpp ------------------------------------===// +// +// This pass tiles and distributes linalg operations among multiple workgroups. +// +// NOTE: Deprecated. This pass is used for the first-level tiling in the Linalg +// on buffers path, which is expected to go away soon. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" +#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" +#include "iree/compiler/Conversion/Common/Attributes.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-linalg-to-spirv-tile-and-distribute" + +namespace mlir { +namespace iree_compiler { +namespace { + +/// Returns the distribution options for operations when targeting workgroups. +linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { + linalg::LinalgLoopDistributionOptions options; + + options.procInfo = [](OpBuilder &builder, Location loc, + ArrayRef parallelLoopRanges) { + return getGPUProcessorIdsAndCounts( + builder, loc, parallelLoopRanges.size()); + }; + options.distributionMethod.assign( + 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); + + return options; +} + +class LinalgTileAndDistributePass + : public PassWrapper> { + public: + LinalgTileAndDistributePass(const SPIRVCodegenOptions &options) + : options(options) {} + LinalgTileAndDistributePass(const LinalgTileAndDistributePass &that) + : options(that.options) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!isEntryPoint(funcOp)) continue; + + SmallVector linalgOps; + SmallVector tiledLoops; + + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { + return signalPassFailure(); + } + + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = + initGPULaunchConfig(context, dependenceGraph, options, linalgOps); + if (!launchConfigOpt) { + funcOp.emitError("unable to find launch configuration"); + return signalPassFailure(); + } + LaunchConfig &launchConfig = *launchConfigOpt; + + LLVM_DEBUG({ + llvm::dbgs() + << "\n--- IREE Linalg tile and distribute configuration ---\n"; + llvm::dbgs() << "@func " << funcOp.getName() + << ": # workgroup sizes: ["; + interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); + llvm::dbgs() << "]\n"; + for (auto op : linalgOps) { + llvm::dbgs() << "\t" << op.getOperation()->getName() << " : "; + TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op); + llvm::dbgs() << "{"; + std::string sep = ""; + for (auto &level : enumerate(tileSizes)) { + llvm::dbgs() << sep << level.index() << " : ["; + sep = ", "; + interleaveComma(level.value(), llvm::dbgs()); + llvm::dbgs() << "]"; + } + llvm::dbgs() << "}\n"; + } + }); + + TileAndFuseOptions tileAndFuseOptions = { + getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; + if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, + launchConfig, + tileAndFuseOptions)) || + failed( + updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { + return signalPassFailure(); + } + } + } + + private: + SPIRVCodegenOptions options; +}; + +} // namespace + +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-codegen-spirv-linalg-tile-and-distribute", + "Tile and distribute Linalg operations on buffers", [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp index 3aae4847324f..36d6ee2262c8 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp @@ -14,7 +14,8 @@ //===- LinalgTileAndFusePass.cpp - Tile and fuse Linalg on Buffers --------===// // -// Implements a pass to tile and fuse linalg operations on buffers. +// This pass tiles and vectorizes Linalg ops on buffers within in a single +// workgroup. // //===----------------------------------------------------------------------===// @@ -32,6 +33,8 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -72,33 +75,6 @@ static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker( markers, Identifier::get(replaceMarker, context)); } -/// Returns the distribution options for operations when targeting workgroups. -static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { - linalg::LinalgLoopDistributionOptions options; - - options.procInfo = [](OpBuilder &builder, Location loc, - ArrayRef parallelLoopRanges) { - return getGPUProcessorIdsAndCounts( - builder, loc, parallelLoopRanges.size()); - }; - options.distributionMethod.assign( - 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); - - return options; -} - -/// Applies canonicalization over index calculation inside the given `funcOp`. -static void applyIndexCalculationCanonicalization(FuncOp funcOp) { - MLIRContext *context = funcOp.getContext(); - OwningRewritePatternList canonicalizationPatterns; - DimOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - AddIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SubIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SignedDivIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns)); -} - //===----------------------------------------------------------------------===// // Main pass //===----------------------------------------------------------------------===// @@ -429,8 +405,6 @@ void LinalgTileAndFusePass::runOnOperation() { IREE::HAL::ExecutableTargetOp targetOp = getOperation(); ModuleOp module = targetOp.getInnerModule(); - LLVM_DEBUG( - llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";); for (FuncOp funcOp : module.getOps()) { if (!isEntryPoint(funcOp)) continue; @@ -452,6 +426,7 @@ void LinalgTileAndFusePass::runOnOperation() { LaunchConfig &launchConfig = *launchConfigOpt; LLVM_DEBUG({ + llvm::dbgs() << "\n--- IREE Linalg tile and fuse configuration ---\n"; llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: ["; interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); llvm::dbgs() << "]\n"; @@ -470,55 +445,6 @@ void LinalgTileAndFusePass::runOnOperation() { } }); - if (!options.usingLinalgOnTensors) { - TileAndFuseOptions tileAndFuseOptions = { - getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; - if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, - launchConfig, - tileAndFuseOptions)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - return signalPassFailure(); - } - } else { - // Find the root operation for the dispatch region and get the tile sizes. - Operation *rootOperation = - launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( - linalgOps, - [](linalg::LinalgOp op) { return op.getOperation(); }))); - if (!rootOperation) { - launchConfig.finalize(funcOp); - return; - } - - ArrayRef rootOperationTileSizes = - launchConfig.getTileSizes(rootOperation, 0); - if (rootOperationTileSizes.empty()) { - launchConfig.finalize(funcOp); - return; - } - - // Only use the tile sizes for parallel loops of the root operation. - rootOperationTileSizes = rootOperationTileSizes.take_front( - getNumOuterParallelLoops(rootOperation)); - - SmallVector workloadPerWorkgroup = - llvm::to_vector<4>(llvm::reverse(rootOperationTileSizes)); - if (failed(materializeStaticLaunchInformation(funcOp, - workloadPerWorkgroup)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - funcOp.emitOpError("failed to materialize static launch information"); - return signalPassFailure(); - } - } - - LLVM_DEBUG({ - llvm::dbgs() << "--- After first level of tiling and distribution ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - if (options.useWorkgroupMemory) { // The promotion patterns are put separate from the tiling patterns to // make sure that the allocated scratchspace memory is constant sizes diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp index 0ea3a74f1285..0dfe6cde7228 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp @@ -83,9 +83,15 @@ static void addLinalgToSPIRVPasses(OpPassManager &pm, // - The Linalg op is kept untouched. // //===--------------------------------------------------------------------===// - if (!options.usingLinalgOnTensors) { + if (options.usingLinalgOnTensors) { + // flow.dispatch.workgroups performed abstract tiling and distribution. Make + // them concrete now since we know the target and settings now. + pm.addPass(createConcretizeTileAmongWorkgroupsPass(options)); + } else { pm.addPass(createSplitDispatchFunctionPass()); + pm.addPass(createTileAndDistributeAmongWorkgroupsPass(options)); } + pm.addPass(createLinalgTileAndFusePass(options)); if (options.vectorizeMemref) { pm.nest().addNestedPass( diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h index 1c29576e9433..bce9f2efd7c0 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h @@ -81,6 +81,16 @@ createFoldProcessorIDUsesPass(); std::unique_ptr> createMaterializeEntryPointsPass(); +/// Creates a pass to concretize hal.interface.workgroup.* ops with concrete +/// tiling and distribution scheme. +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + +/// Tiles and distributes Linalg operations on buffers among multiple +/// workgroups. +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + //===----------------------------------------------------------------------===// // Pipelines //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp index 65febab28f3e..671512a8b5fa 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp @@ -126,6 +126,7 @@ class MemRefUsageAnalysis { void analyzeFunc(FuncOp funcOp); void analyzeAlloc(AllocOp allocOp); void analyzePlaceholder(IREE::PlaceholderOp placeholderOp); + void analyzeInterfaceBinding(IREE::HAL::InterfaceBindingSubspanOp bindingOp); llvm::DenseMap vectorization_size; llvm::DenseSet transferOps; }; @@ -136,6 +137,8 @@ MemRefUsageAnalysis::MemRefUsageAnalysis(mlir::Operation *op) { if (auto alloc = dyn_cast(op)) analyzeAlloc(alloc); if (auto placeholder = dyn_cast(op)) analyzePlaceholder(placeholder); + if (auto bindingOp = dyn_cast(op)) + analyzeInterfaceBinding(bindingOp); }); } @@ -159,6 +162,15 @@ void MemRefUsageAnalysis::analyzePlaceholder( } } +void MemRefUsageAnalysis::analyzeInterfaceBinding( + IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + SmallVector vectorUses; + if (unsigned vectorSize = isMemRefAndVectorizable(bindingOp, vectorUses)) { + vectorization_size.insert(std::make_pair(bindingOp, vectorSize)); + transferOps.insert(vectorUses.begin(), vectorUses.end()); + } +} + void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) { SmallVector vectorUses; if (unsigned vectorSize = isMemRefAndVectorizable(allocOp, vectorUses)) { @@ -363,6 +375,26 @@ class ProcessPlaceHolder final } }; +class ProcessInterfaceBinding final + : public MemRefConversionPattern { + public: + using MemRefConversionPattern< + IREE::HAL::InterfaceBindingSubspanOp>::MemRefConversionPattern; + + LogicalResult matchAndRewrite( + IREE::HAL::InterfaceBindingSubspanOp bindingOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memrefType = bindingOp.getType().dyn_cast(); + if (!memrefType) return failure(); + auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.getResult()); + if (!vecMemRef) return failure(); + rewriter.replaceOpWithNewOp( + bindingOp, *vecMemRef, bindingOp.binding(), bindingOp.byte_offset(), + bindingOp.byte_length()); + return success(); + } +}; + class VectorizeMemRefPass final : public PassWrapper> { void runOnOperation() override; @@ -410,8 +442,8 @@ void VectorizeMemRefPass::runOnOperation() { OwningRewritePatternList patterns; patterns.insert(context, - *memrefUsageAnalysis); + ProcessAlloc, ProcessPlaceHolder, ProcessInterfaceBinding>( + context, *memrefUsageAnalysis); ConversionTarget target(*context); target.addDynamicallyLegalOp([&](FuncOp op) { @@ -426,6 +458,10 @@ void VectorizeMemRefPass::runOnOperation() { [&](IREE::PlaceholderOp placeholder) { return !memrefUsageAnalysis->vectorizeMemRef(placeholder); }); + target.addDynamicallyLegalOp( + [&](IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + return !memrefUsageAnalysis->vectorizeMemRef(bindingOp); + }); target.markUnknownOpDynamicallyLegal([&](Operation *op) { if (isa(op)) return !memrefUsageAnalysis->transferConvert(op); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD index 1e3b7bba341b..761d17431f4a 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,32 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "batch_matmul_vectorization.mlir", + "concretize_tile_among_workgroups.mlir", + "convert_to_gpu.mlir", + "convert_to_spirv.mlir", + "dead_alloc.mlir", + "fold-gpu-procid-uses.mlir", + "forop_canonicalization.mlir", + "linalg_tile_and_fuse.mlir", + "materialize_launch_configuration.mlir", + "materialize_launch_configuration2.mlir", + "matmul_fused_vectorization.mlir", + "matmul_vectorization.mlir", + "matmul_vectorization_licm.mlir", + "memref_vecrotization.mlir", + "pipeline_test.mlir", + "pipeline_test_cooperative_mat.mlir", + "split_dispatch_function.mlir", + "tile_and_vectorize_conv.mlir", + "tile_and_vectorize_matmul.mlir", + "vector_to_gpu.mlir", + "workgroup_memory_promotion.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt index 8dbfbf23eff0..9c7c5001e05f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt @@ -10,12 +10,31 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "batch_matmul_vectorization.mlir" + "concretize_tile_among_workgroups.mlir" + "convert_to_gpu.mlir" + "convert_to_spirv.mlir" + "dead_alloc.mlir" + "fold-gpu-procid-uses.mlir" + "forop_canonicalization.mlir" + "linalg_tile_and_fuse.mlir" + "materialize_launch_configuration.mlir" + "materialize_launch_configuration2.mlir" + "matmul_fused_vectorization.mlir" + "matmul_vectorization.mlir" + "matmul_vectorization_licm.mlir" + "memref_vecrotization.mlir" + "pipeline_test.mlir" + "pipeline_test_cooperative_mat.mlir" + "split_dispatch_function.mlir" + "tile_and_vectorize_conv.mlir" + "tile_and_vectorize_matmul.mlir" + "vector_to_gpu.mlir" + "workgroup_memory_promotion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir index 508f72e6ad12..dd66e0b4c8cd 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -9,8 +9,8 @@ hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private" hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @batch_matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv2d_static_shape() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully static shaped dispatch region, we can: +// 1) Generate static constant workgroup counts, +// 2) Replace hal.interface.workgroup.{size|count} ops with constants, +// 3) Canonicalize loops and subview ops. + +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (9, d0 * -2 + 225)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + 32)> +// CHECK: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (4, -d0 + 112)> + +// CHECK: hal.executable.entry_point @conv2d_static_shape +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[C28_0:.+]] = constant 28 : index +// CHECK: %[[C28_1:.+]] = constant 28 : index +// CHECK: hal.return %[[C2]], %[[C28_0]], %[[C28_1]] : index, index, index + +// CHECK: func @conv2d_static_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index + +// CHECK: %[[Z_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Z]], %c4] +// CHECK: %[[Y_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[X_MUL_16:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] + +// CHECK: %[[Z_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Z_MUL_4]]) +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP1]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Y_MUL_4]]) +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP1]](%[[Y_MUL_4]])[%c4] + +// CHECK: %[[INPUT:.+]] = subview %{{.+}}[0, %[[Z_OFFSET]], %[[Y_OFFSET]], 0] [1, %[[Z_SIZE]], %[[Y_SIZE]], 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}> + +// CHECK: %[[X_SIZE:.+]] = affine.min #[[MAP2]](%[[X_MUL_16]])[%c16] + +// CHECK: %[[FILTER:.+]] = subview %{{.+}}[0, 0, 0, %[[X_MUL_16]]] [3, 3, 16, %[[X_SIZE]]] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}> + +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP3]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP3]](%[[Y_MUL_4]])[%c4] +// CHECK: %[[OUTPUT:.+]] = subview %{{.+}}[0, %[[Z_MUL_4]], %[[Y_MUL_4]], %[[X_MUL_16]]] [1, %[[Z_SIZE]], %[[Y_SIZE]], %[[X_SIZE]]] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}> + +// CHECK: linalg.fill(%[[OUTPUT]], %{{.+}}) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, is_root_op, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[FILTER]] : memref<1x?x?x16xf32, {{.+}}>, memref<3x3x16x?xf32, {{.+}}>) outs(%[[OUTPUT]] : memref<1x?x?x?xf32, {{.+}}>) + +// ----- + +hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_dynamic_shape attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_dynamic_shape() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %0 = hal.interface.load.constant offset = 0 : index + %1 = hal.interface.load.constant offset = 1 : index + %2 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref + %3 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref + %4 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref + %5 = hal.interface.load.constant offset = 2 : index + %6 = hal.interface.load.constant offset = 3 : index + %7 = hal.interface.load.constant offset = 4 : index + %8 = hal.interface.load.constant offset = 5 : index + %9 = hal.interface.load.constant offset = 6 : index + %10 = hal.interface.load.constant offset = 7 : index + %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]> + %12 = shapex.tie_shape %2, %11 : memref, !shapex.ranked_shape<[?,?]> + %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]> + %14 = shapex.tie_shape %3, %13 : memref, !shapex.ranked_shape<[?,?]> + %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]> + %16 = shapex.tie_shape %4, %15 : memref, !shapex.ranked_shape<[?,?]> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %17 to %5 step %18 { + %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %19 to %8 step %20 { + %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y] + %22 = subview %12[%arg0, 0] [%21, %6] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x] + %24 = subview %14[0, %arg1] [%7, %23] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y] + %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x] + %27 = subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + linalg.fill(%27, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * s1 + s0 + d1)>>, f32 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref (d0 * s1 + s0 + d1)>>, memref (d0 * s1 + s0 + d1)>>) outs(%27 : memref (d0 * s1 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully dynamic shaped dispatch region, we can: +// 1) Generate symbolic workgroup counts, +// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants. + +// CHECK: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (4, -d0 + s0)> +// CHECK: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (16, -d0 + s0)> + +// CHECK: hal.executable.entry_point @matmul_dynamic_shape +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index): +// CHECK: %c1 = constant 1 : index +// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]] +// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]] +// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1 + +// CHECK: func @matmul_dynamic_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index +// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index +// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index +// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index +// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index +// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index + +// CHECK: %[[Y_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_Y]], %c4] +// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]] +// CHECK: %[[X_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] +// CHECK: %[[X_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_X]], %c16] +// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]], %c4] +// CHECK: %[[A_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref to memref +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]], %c16] +// CHECK: %[[B_TILE:.+]] = subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]], %c4] +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]], %c16] +// CHECK: %[[C_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: linalg.fill(%[[C_TILE]], %cst) {__internal_linalg_transform__ = "workgroup"} : memref, f32 +// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op} ins(%[[A_TILE]], %[[B_TILE]] : memref, memref) outs(%[[C_TILE]] : memref) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir index 15e1e9281c89..141435794077 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-convert-to-gpu))" -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.target(iree-codegen-convert-to-gpu))' -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Enable this test when linalg on tensors becomes default. // #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> @@ -11,8 +11,8 @@ // hal.executable.target @vulkan, filter="vulkan*" { // hal.executable.entry_point @parallel_4D attributes { // interface = @legacy_io, ordinal = 0 : i32, -// signature = (!flow.dispatch.input, !flow.dispatch.input, -// !flow.dispatch.output) -> ()} +// signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, +// !flow.dispatch.tensor) -> ()} // module attributes { // spv.target_env = // #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output<40xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @reduce_sum() { %arg0 = iree.placeholder for "interace buffer" @@ -296,8 +296,8 @@ hal.executable @matmul attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @matmul attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -427,8 +427,8 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @conv_no_padding attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -603,7 +603,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vulkan, filter="vulkan*" { - hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<2x16x16x6xf32>, !flow.dispatch.input<1x3x4x2xf32>, !flow.dispatch.output<2x14x13x5xf32>) -> ()} { + hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} { ^bb0(%arg0: index, %arg1: index, %arg2: index): // no predecessors %c4 = constant 4 : index %c1 = constant 1 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir index fe9bbabefc44..8adf600f5a96 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir @@ -34,6 +34,40 @@ hal.executable @fold_block_id attributes {sym_visibility = "private"} { // ----- +hal.executable @fold_interface_workgroup_id attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + } + hal.executable.target @vulkan, filter="vulkan*" { + hal.executable.entry_point @fold_interface_workgroup_id attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = () -> ()} { + ^bb0(%arg0 : index, %arg1 : index, %arg2 : index): + %x = constant 112: index + %y = constant 42: index + %z = constant 1: index + hal.return %x, %y, %z: index, index, index + } + module { + func @fold_interface_workgroup_id() -> (index, index, index) { + %0 = hal.interface.workgroup.id[0] : index + %1 = hal.interface.workgroup.id[1] : index + %2 = hal.interface.workgroup.id[2] : index + %3 = affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%0] + %4 = affine.min affine_map<()[s0] -> (8, s0 * -1 + s0 * -1 + s0 * -1 + 131)>()[%2] + %5 = affine.min affine_map<()[s0] -> (11, s0 + 15)>()[%3] + return %3, %4, %5: index, index, index + } + } + } +} +// CHECK-LABEL: func @fold_interface_workgroup_id() +// CHECK-DAG: %[[C3:.+]] = constant 3 +// CHECK-DAG: %[[C8:.+]] = constant 8 +// CHECK-DAG: %[[C11:.+]] = constant 11 +// CHECK-DAG: return %[[C3]], %[[C8]], %[[C11]] + +// ----- + hal.executable @fold_thread_id attributes {sym_visibility = "private"} { hal.interface @legacy_io { } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir index d8da2427ec4b..1c0e06490836 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse))' -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @conv_no_padding attributes {sym_visibility = "private"} { @@ -10,8 +10,8 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<3x4x6x14xf32>, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<1x3x4x2xf32>, - !flow.dispatch.output<2x14x13x5xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<3x3x16x32xf32>, - !flow.dispatch.output<1x112x112x32xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x112x112x32xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x225x225x16xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x16x32xf32> - linalg.fill(%0, %cst) : memref<1x112x112x32xf32>, f32 - linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%1, %2: memref<1x225x225x16xf32>, memref<3x3x16x32xf32>) - outs (%0: memref<1x112x112x32xf32>) - return - } - - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK-COUNT-4: vector.transfer_write - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) - -// CHECK-COUNT-16: vector.contract - -// CHECK-COUNT-3: scf.yield - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_write - -// ----- - -hal.executable @depthwise_conv2d_2452x2423_valid_stride_2 attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @depthwise_conv2d_2452x2423_valid_stride_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<2x4x5x2xf32>, tensor<2x4x2x3xf32>) -> tensor<2x2x1x6xf32>} - module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { - func @depthwise_conv2d_2452x2423_valid_stride_2() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x2x1x2x3xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4x5x2xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<2x4x2x3xf32> - linalg.fill(%0, %cst) : memref<2x2x1x2x3xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<2x4x5x2xf32>, memref<2x4x2x3xf32>) outs(%0 : memref<2x2x1x2x3xf32>) - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv2d_2452x2423_valid_stride_2() -// CHECK: linalg.fill -// CHECK: linalg.generic -// CHECK-NOT: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf - -// ----- - -hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @depthwise_conv_tiled_and_vectorized attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x113x113x96xf32>, !flow.dispatch.input<3x3x96xf32>, - !flow.dispatch.output<1x56x56x96xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @depthwise_conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x56x56x96xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x113x113x96xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x96xf32> - linalg.fill(%0, %cst) : memref<1x56x56x96xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%0 : memref<1x56x56x96xf32>) - return - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK: vector.transfer_write - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) - - -// CHECK: vector.fma - -// CHECK-COUNT-2: scf.yield - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 8151a358e43f..1e2054461613 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s +// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -9,8 +9,8 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_tensors attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @matmul_tensors() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir index 30b13fab2dee..f988e5ade311 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir @@ -9,8 +9,8 @@ hal.executable @add attributes {sym_visibility = "private"} { hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @add attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @add() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir index 88251f44d56a..19e148b12277 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -9,8 +9,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} - module attributes { - spv.target_env = - #spv.target_env<#spv.vce, - ARM:IntegratedGPU, - {max_compute_shared_memory_size = 32768 : i32, - max_compute_workgroup_invocations = 512 : i32, - max_compute_workgroup_size = dense<512> : vector<3xi32>, - subgroup_size = 16 : i32}>} { - func @matmul_static_shape_f16() - attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} { - %arg0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16> - %arg1 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16> - %ret0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16> - %cst = constant 0.000000e+00 : f16 - linalg.fill(%ret0, %cst) : memref<4096x4096xf16>, f16 - linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>) - outs(%ret0 : memref<4096x4096xf16>) - return - } - func private @matmul_static_shape__num_workgroups__ - (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>, - !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index) - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @matmul_static_shape_f16 -// CHECK-COUNT-16: vector.transfer_write -// CHECK-COUNT-16: vector.transfer_read -// CHECK: %[[FOR_RES:.+]]:16 = scf.for -// CHECK-COUNT-16: vector.transfer_read -// CHECK-COUNT-64: vector.contract -// CHECK: scf.yield -// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] -// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir index 625b2b5ee41c..9f3ee9cc8132 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir @@ -1,5 +1,5 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -10,8 +10,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce> + // CHECK: hal.interface.binding.subspan @legacy_io::@ret0[%c0] + // CHECK-SAME: memref<4096x1024xvector<4xf32>> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf32> + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf32> + %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<32x8xf32> + vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf32>, memref<4096x4096xf32> + return +} + +hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +} diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir index aa408c18e898..95fa2a6f163f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir @@ -9,8 +9,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce %4 = alloc() : memref<1024x256xf32> linalg.fill(%4, %cst) : memref<1024x256xf32>, f32 - linalg.matmul ins(%1, %2 : memref<1024x512xf32>, memref<512x256xf32>) + linalg.matmul ins(%1, %2 : memref<1024x512xf32>, memref<512x256xf32>) outs(%4 : memref<1024x256xf32>) - linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%4, %3 : memref<1024x256xf32>, memref<1024x256xf32>) + iterator_types = ["parallel", "parallel"]} + ins(%4, %3 : memref<1024x256xf32>, memref<1024x256xf32>) outs(%0 : memref<1024x256xf32>) { ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors %5 = addf %arg0, %arg1 : f32 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir index 2e3382f34090..d928cc6e3828 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir @@ -10,8 +10,8 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.input<3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv1d_ops // CHECK: linalg.fill @@ -55,8 +55,8 @@ hal.executable @kernel_fusable_fill_conv2d_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv2d_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv2d_ops // CHECK: linalg.fill @@ -102,8 +102,8 @@ hal.executable @kernel_fusable_fill_conv3d_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv3d_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_conv3d_ops // CHECK: linalg.fill @@ -149,8 +149,8 @@ hal.executable @kernel_fusable_fill_matmul_ops attributes {sym_visiblity = "priv hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<512x?xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_matmul_ops // CHECK: linalg.fill @@ -196,8 +196,8 @@ hal.executable @kernel_fusable_pooling attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_pooling attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_pooling() // CHECK: linalg.fill @@ -236,8 +236,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 // CHECK: hal.executable.entry_point @kernel_dispatch_1 // CHECK: module attributes {hal.entry_point_schedule = [@kernel_dispatch_0, @kernel_dispatch_1]} @@ -303,8 +303,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 // CHECK: hal.executable.entry_point @kernel_dispatch_1 // CHECK: hal.executable.entry_point @kernel_dispatch_2 @@ -385,8 +385,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x3x3x512xf32>, !flow.dispatch.input<3x3x512x1xf32>, - !flow.dispatch.output<1x1x1x512xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} // CHECK-NOT: hal.entry_point_schedule module { // CHECK-LABEL: @kernel() @@ -427,8 +427,8 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<3x512x1xf32>, - !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}} func @kernel() { @@ -465,7 +465,7 @@ hal.executable @subview_interleaved attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @subview_interleaved attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<18x12xf32>, !flow.dispatch.output<12x4xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @subview_interleaved() { %cst = constant 0.000000e+00 : f32 @@ -516,8 +516,8 @@ hal.executable @reshape_interleaved attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @reshape_interleaved attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<2x4xf32>, !flow.dispatch.output<1x2x4xf32>, - !flow.dispatch.output<2x4xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module { func @reshape_interleaved() { %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x4xf32> @@ -576,8 +576,8 @@ hal.executable @predict_ex_dispatch_0 attributes {sym_visiblity = "private"} { hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @predict_ex_dispatch_0 attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x512x1xf32>, !flow.dispatch.input<4x8x16xf32>, - !flow.dispatch.output<4x8x16xf32>, !flow.dispatch.output<4x8x16xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @predict_ex_dispatch_0() { %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x512x1xf32> @@ -637,8 +637,8 @@ hal.executable @kernel_fusable_fill_matmul_generic_ops attributes {sym_visiblity hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_generic_ops attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input<512x?xf32>, - !flow.dispatch.input, !flow.dispatch.output) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { // CHECK: func @kernel_fusable_fill_matmul_generic_ops // CHECK: linalg.fill diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir new file mode 100644 index 000000000000..18c360113a14 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir @@ -0,0 +1,176 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @conv_static_shape_f32() + +// For linalg.fill +// CHECK-COUNT-4: vector.transfer_write + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) + +// CHECK-COUNT-16: vector.contract + +// CHECK-COUNT-3: scf.yield + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_write + +// ----- + +hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @depthwise_conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c96 = constant 96 : index + %c56 = constant 56 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x113x113x96xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x1x96xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x56x56x96xf32> + %3 = linalg.reshape %1 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] : memref<3x3x1x96xf32> into memref<864xf32> + %4 = linalg.reshape %3 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : memref<864xf32> into memref<3x3x96xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %5 to %c56 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %7 to %c56 step %8 { + %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %9 to %c96 step %10 { + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg0)[%workgroup_size_z] + %13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %14 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg1)[%workgroup_size_y] + %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x] + %16 = subview %0[0, %11, %13, %arg2] [1, %12, %14, %15] [1, 1, 1, 1] : memref<1x113x113x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>> + %17 = subview %4[0, 0, %arg2] [3, 3, %15] [1, 1, 1] : memref<3x3x96xf32> to memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>> + %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg0)[%workgroup_size_z] + %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg1)[%workgroup_size_y] + %20 = subview %2[0, %arg0, %arg1, %arg2] [1, %18, %19, %15] [1, 1, 1, 1] : memref<1x56x56x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>> + linalg.fill(%20, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>, f32 + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "workgroup", strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @depthwise_conv_static_shape_f32() + +// For linalg.fill +// CHECK: vector.transfer_write + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) + + +// CHECK: vector.fma + +// CHECK-COUNT-2: scf.yield + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir new file mode 100644 index 000000000000..3f17d95e8809 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir @@ -0,0 +1,61 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} { + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_static_shape_f16 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_static_shape_f16() { + %cst = constant 0.000000e+00 : f16 + %c0 = constant 0 : index + %c4096 = constant 4096 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf16> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<4096x4096xf16> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf16> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %3 to %c4096 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %5 to %c4096 step %6 { + %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y] + %8 = subview %0[%arg0, 0] [%7, 4096] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x] + %10 = subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %11 = subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf16> to memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> + linalg.fill(%10, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * 4096 + s0 + d1)>>, f16 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %11 : memref (d0 * 4096 + s0 + d1)>>, memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%10 : memref (d0 * 4096 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @matmul_static_shape_f16 +// CHECK-COUNT-16: vector.transfer_write +// CHECK-COUNT-16: vector.transfer_read +// CHECK: %[[FOR_RES:.+]]:16 = scf.for +// CHECK-COUNT-16: vector.transfer_read +// CHECK-COUNT-64: vector.contract +// CHECK: scf.yield +// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] +// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir index 5da02d9d228c..fe0fb852028e 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @matmul_tile attributes {sym_visibility = "private"} { @@ -10,8 +10,8 @@ hal.executable @matmul_tile attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_tile attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<25x50xf32>, !flow.dispatch.input<50x75xf32>, - !flow.dispatch.output<25x75xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, @@ -68,8 +68,8 @@ hal.executable @conv_no_padding_tile attributes {sym_visibility = "private"} { hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding_tile attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<3x4x6x14xf32>, !flow.dispatch.input<2x16x16x6xf32>, - !flow.dispatch.output<2x13x11x14xf32>) -> ()} + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, + !flow.dispatch.tensor) -> ()} module attributes { spv.target_env = #spv.target_env<#spv.vce, diff --git a/iree/compiler/Conversion/LinalgToVector/test/BUILD b/iree/compiler/Conversion/LinalgToVector/test/BUILD index e124cf534051..2997186606bf 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/BUILD +++ b/iree/compiler/Conversion/LinalgToVector/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "vectorize_linalg_conv.mlir", + "vectorize_linalg_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt index 05e4f22cc2ec..b958f27656d3 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "vectorize_linalg_conv.mlir" + "vectorize_linalg_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h index 9e0f52459371..cb1c9eb71028 100644 --- a/iree/compiler/Conversion/init_conversions.h +++ b/iree/compiler/Conversion/init_conversions.h @@ -65,6 +65,7 @@ inline void registerLinalgToSPIRVPasses() { // LinalgToSPIRV createConvertToGPUPass(SPIRVCodegenOptions()); createFoldProcessorIDUsesPass(); + createTileAndDistributeAmongWorkgroupsPass(SPIRVCodegenOptions()); createLinalgTileAndFusePass(SPIRVCodegenOptions()); createSplitDispatchFunctionPass(); createVectorToGPUPass(); diff --git a/iree/compiler/Dialect/Flow/Analysis/test/BUILD b/iree/compiler/Dialect/Flow/Analysis/test/BUILD index b780b112a316..8f34c74559e2 100644 --- a/iree/compiler/Dialect/Flow/Analysis/test/BUILD +++ b/iree/compiler/Dialect/Flow/Analysis/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["dispatchability.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt index c10e0a2033ed..26aafe0ab62f 100644 --- a/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Analysis/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dispatchability.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp index 974fad9854a0..1a988b4ea0fd 100644 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp +++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp @@ -53,7 +53,7 @@ struct DynamicUpdateSliceOpLowering rewriter.getIndexType()); })); rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.update(), op.operand(), startIndices); + op, op.operand(), startIndices, op.update()); return success(); } }; diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD deleted file mode 100644 index b780b112a316..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:lit_test.bzl", "iree_lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = glob(["*.mlir"]), - data = [ - "//iree/tools:IreeFileCheck", - "//iree/tools:iree-opt", - ], -) diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt deleted file mode 100644 index 609b1785dc6e..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Dialect/Flow/Conversion/HLOToFlow/test/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -iree_lit_test_suite( - NAME - lit - SRCS - "${_GLOB_X_MLIR}" - DATA - iree::tools::IreeFileCheck - iree::tools::iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD deleted file mode 100644 index b780b112a316..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:lit_test.bzl", "iree_lit_test_suite") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = glob(["*.mlir"]), - data = [ - "//iree/tools:IreeFileCheck", - "//iree/tools:iree-opt", - ], -) diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt deleted file mode 100644 index 637de90aa834..000000000000 --- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Dialect/Flow/Conversion/StandardToFlow/test/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -iree_lit_test_suite( - NAME - lit - SRCS - "${_GLOB_X_MLIR}" - DATA - iree::tools::IreeFileCheck - iree::tools::iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD index 825e6b37e283..9668dd208784 100644 --- a/iree/compiler/Dialect/Flow/IR/BUILD +++ b/iree/compiler/Dialect/Flow/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,14 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "FlowBase.td", + "FlowInterfaces.td", + "FlowOps.td", + ], + include = ["*.td"], + ), ) cc_library( @@ -31,8 +39,8 @@ cc_library( srcs = [ "FlowDialect.cpp", "FlowEnums.cpp.inc", + "FlowInterfaces.cpp.inc", "FlowOpFolders.cpp", - "FlowOpInterface.cpp.inc", "FlowOpUtils.cpp", "FlowOps.cpp", "FlowOps.cpp.inc", @@ -41,7 +49,7 @@ cc_library( hdrs = [ "FlowDialect.h", "FlowEnums.h.inc", - "FlowOpInterface.h.inc", + "FlowInterfaces.h.inc", "FlowOpUtils.h", "FlowOps.h", "FlowOps.h.inc", @@ -49,7 +57,7 @@ cc_library( ], deps = [ ":FlowEnumsGen", - ":FlowOpInterfaceGen", + ":FlowInterfacesGen", ":FlowOpsGen", "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Dialect/Shape/IR", @@ -84,13 +92,13 @@ gentbl( ) gentbl( - name = "FlowOpInterfaceGen", + name = "FlowInterfacesGen", tbl_outs = [ - ("-gen-op-interface-decls", "FlowOpInterface.h.inc"), - ("-gen-op-interface-defs", "FlowOpInterface.cpp.inc"), + ("-gen-op-interface-decls", "FlowInterfaces.h.inc"), + ("-gen-op-interface-defs", "FlowInterfaces.cpp.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "FlowBase.td", + td_file = "FlowInterfaces.td", td_srcs = [ ":td_files", "//iree/compiler/Dialect/IREE/IR:td_files", diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt index 1cff61a9f2d4..3f7e8b68608a 100644 --- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt @@ -10,14 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR HDRS "FlowDialect.h" "FlowEnums.h.inc" - "FlowOpInterface.h.inc" + "FlowInterfaces.h.inc" "FlowOpUtils.h" "FlowOps.h" "FlowOps.h.inc" @@ -25,8 +24,8 @@ iree_cc_library( SRCS "FlowDialect.cpp" "FlowEnums.cpp.inc" + "FlowInterfaces.cpp.inc" "FlowOpFolders.cpp" - "FlowOpInterface.cpp.inc" "FlowOpUtils.cpp" "FlowOps.cpp" "FlowOps.cpp.inc" @@ -58,12 +57,12 @@ iree_tablegen_library( iree_tablegen_library( NAME - FlowOpInterfaceGen + FlowInterfacesGen TD_FILE - "FlowBase.td" + "FlowInterfaces.td" OUTS - -gen-op-interface-decls FlowOpInterface.h.inc - -gen-op-interface-defs FlowOpInterface.cpp.inc + -gen-op-interface-decls FlowInterfaces.h.inc + -gen-op-interface-defs FlowInterfaces.cpp.inc ) iree_tablegen_library( diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td index 50a9bc4a2230..c2a1f515329d 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowBase.td +++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td @@ -79,37 +79,6 @@ class FLOW_Op traits = []> : let printer = [{ return print$cppClass(p, *this); }]; } -def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { - let description = [{ - Interface for ops that can be used within a stream. - - Some ops can exist both within a stream and outside of a stream. This allows - optimizations to place certain ops such that they are performed in a - synchronous (outside of a stream) or asynchronous (inside of a stream) - fashion. - - The goal of the stream forming process is to move as many operations that - can be used within a stream into one and only using non-streamed ops as a - last resort. Ops that are isStreamOnly may force the creation of single-op - command buffers and synchronous dispatches. - }]; - - let methods = [ - InterfaceMethod< - [{Returns true if the op is transfer operation (as defined by the HAL).}], - "bool", "isTransfer", (ins) - >, - InterfaceMethod< - [{Returns true if the op *can* be used within a stream.}], - "bool", "isUsableInStream", (ins) - >, - InterfaceMethod< - [{Returns true if the op *must* be used within a stream.}], - "bool", "isStreamOnly", (ins) - >, - ]; -} - //===----------------------------------------------------------------------===// // Flow dialect types //===----------------------------------------------------------------------===// @@ -117,6 +86,7 @@ def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { def FLOW_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>; def FLOW_Dim : TypeAlias; +def FLOW_ShapeDynamicDims : Variadic; def FLOW_Tensor : TypeAlias; @@ -131,29 +101,58 @@ def FLOW_Workload : AnyTypeOf<[Index]> { }]; } -def FLOW_DispatchInput : DialectType< +def FLOW_TensorAccessCanReadPred : CPred<[{ + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadOnly || + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadWrite +}]>; + +def FLOW_TensorAccessCanWritePred : CPred<[{ + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::WriteOnly || + $_self.cast().getAccess() == + IREE::Flow::TensorAccess::ReadWrite +}]>; + +def FLOW_DispatchTensor : DialectType< FLOW_Dialect, - CPred<"$_self.isa()">, - "dispatch.input"> { + CPred<"$_self.isa()">, + "dispatch.tensor"> { let description = [{ - A placeholder for a dispatch region input operand. This can be used to query - the metadata about the input (such as its shape) as well as load from the - backing tensor representation. + A placeholder for a dispatch region input/output operand. This can be used + to query the metadata about the tensor (such as its shape) as well as both + load and store from the backing tensor representation. }]; } -def FLOW_DispatchOutput : DialectType< +def FLOW_ReadableDispatchTensor : DialectType< FLOW_Dialect, - CPred<"$_self.isa()">, - "dispatch.output"> { + And<[ + CPred<"$_self.isa()">, + FLOW_TensorAccessCanReadPred, + ]>, + "dispatch.tensor"> { let description = [{ - A placeholder for a dispatch region output result. This can be used to - query the metadata about the output (such as its shape) as well as store - into the backing tensor representation. + A placeholder for a dispatch region input operand. This can be used + to query the metadata about the tensor (such as its shape) as well as load + from the backing tensor representation. }]; } -def FLOW_DispatchIO : AnyTypeOf<[FLOW_DispatchInput, FLOW_DispatchOutput]>; +def FLOW_WritableDispatchTensor : DialectType< + FLOW_Dialect, + And<[ + CPred<"$_self.isa()">, + FLOW_TensorAccessCanWritePred, + ]>, + "dispatch.tensor"> { + let description = [{ + A placeholder for a dispatch region output operand. This can be used + to query the metadata about the tensor (such as its shape) as well as store + to the backing tensor representation. + }]; +} // Use no padding and clamp the window to the valid area, possibly stopping // early prior to having covered all data. diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp index 6923e5a573f8..be5f5969d4c8 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp @@ -30,7 +30,7 @@ namespace iree_compiler { namespace IREE { namespace Flow { -#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.cpp.inc" +#include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.cpp.inc" namespace { @@ -48,7 +48,7 @@ struct FlowFolderInterface : public DialectFoldInterface { FlowDialect::FlowDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { addInterfaces(); - addTypes(); + addTypes(); #define GET_OP_LIST addOperations< @@ -70,10 +70,8 @@ Operation *FlowDialect::materializeConstant(OpBuilder &builder, Attribute value, Type FlowDialect::parseType(DialectAsmParser &parser) const { llvm::StringRef spec = parser.getFullSymbolSpec(); - if (succeeded(parser.parseOptionalKeyword("dispatch.input"))) { - return DispatchInputType::parse(parser); - } else if (succeeded(parser.parseOptionalKeyword("dispatch.output"))) { - return DispatchOutputType::parse(parser); + if (succeeded(parser.parseOptionalKeyword("dispatch.tensor"))) { + return DispatchTensorType::parse(parser); } parser.emitError(parser.getCurrentLocation()) << "unknown Flow type: " << spec; @@ -81,10 +79,8 @@ Type FlowDialect::parseType(DialectAsmParser &parser) const { } void FlowDialect::printType(Type type, DialectAsmPrinter &p) const { - if (auto inputType = type.dyn_cast()) { + if (auto inputType = type.dyn_cast()) { IREE::Flow::printType(inputType, p); - } else if (auto outputType = type.dyn_cast()) { - IREE::Flow::printType(outputType, p); } else { llvm_unreachable("unknown Flow type"); } diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.h b/iree/compiler/Dialect/Flow/IR/FlowDialect.h index 4e916c953ae7..9622041f413d 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowDialect.h +++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.h @@ -24,7 +24,7 @@ namespace iree_compiler { namespace IREE { namespace Flow { -#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.h.inc" +#include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.h.inc" class FlowDialect : public Dialect { public: diff --git a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td new file mode 100644 index 000000000000..dba5aeafd140 --- /dev/null +++ b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td @@ -0,0 +1,116 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_FLOW_INTERFACES +#define IREE_DIALECT_FLOW_INTERFACES + +include "iree/compiler/Dialect/IREE/IR/IREEBase.td" + +//===----------------------------------------------------------------------===// +// IREE::Flow::ClosureOpInterface +//===----------------------------------------------------------------------===// + +def FLOW_ClosureOpInterface : OpInterface<"ClosureOpInterface"> { + let description = [{ + Interface for ops that follow the Flow dialect closure semantics (explicit + captures, dynamic-shape awareness, and normal operand/result SSA behavior). + + Implementing this interface enables optimizations that perform manipulation + across the closure capture boundary (outside of the op <-> regions within + the op). + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the body region of the closure (may have multiple blocks). + }], + /*retTy=*/"Region &", + /*methodName=*/"getClosureBodyRegion", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return this->getOperation()->getRegion(0); + }] + >, + InterfaceMethod< + /*desc=*/[{Returns all closure operand values.}], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getClosureOperands", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{Returns all closure result values.}], + /*retTy=*/"Operation::result_range", + /*methodName=*/"getClosureResults", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if the given operation can exist in the closure. + Not all operations that a closure can contain are guaranteed to be folded + into the closure, such as when the operation may have side-effects. + }], + /*retTy=*/"bool", + /*methodName=*/"canClosureContainOp", + /*args=*/(ins "Operation *":$op) + >, + InterfaceMethod< + /*desc=*/[{ + Clones the op while removing specified operands and results. + The body of the op will be transferred to the new op and the entry block + will have its arguments removed. + + The returned op will be free standing. Callers must insert it into a block + where desired (most often just replacing the current op). + }], + /*retTy=*/"ClosureOpInterface", + /*methodName=*/"cloneReplacementExcludingOperandsAndResults", + /*args=*/(ins "ArrayRef":$excludedOperandIndices, + "ArrayRef":$excludedResultIndices) + >, + ]; +} + +//===----------------------------------------------------------------------===// +// IREE::Flow::StreamableOpInterface +//===----------------------------------------------------------------------===// + +def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> { + let description = [{ + Interface for ops that can be used within a stream. + + Some ops can exist both within a stream and outside of a stream. This allows + optimizations to place certain ops such that they are performed in a + synchronous (outside of a stream) or asynchronous (inside of a stream) + fashion. + + The goal of the stream forming process is to move as many operations that + can be used within a stream into one and only using non-streamed ops as a + last resort. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns true if the op is transfer operation (as defined by the HAL). + }], + /*retTy=*/"bool", + /*methodName=*/"isTransfer", + /*args=*/(ins) + >, + ]; +} + +#endif // IREE_DIALECT_FLOW_INTERFACES diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index fde4c282af81..acb5c45ff20c 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -18,12 +18,15 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/IREE/IR/IREETypes.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" @@ -40,34 +43,150 @@ namespace iree_compiler { namespace IREE { namespace Flow { +//===----------------------------------------------------------------------===// +// Folding utilities +//===----------------------------------------------------------------------===// + +// Returns a new set of dynamic dimensions for a shape carrying op when a type +// is being changed. This attempts to reuse the existing dimension values if +// they are available and will drop/insert new ones as required. +static SmallVector refreshDimsOnTypeChange( + Operation *op, Type oldType, Type newType, ValueRange oldDims, + PatternRewriter &rewriter) { + if (oldType == newType) return llvm::to_vector<4>(oldDims); + + // Build an expanded list of all the dims - constants will be nullptr. + // This lets us map back the new types without worrying about whether some + // subset become static or dynamic. + auto oldShapedType = oldType.cast(); + SmallVector allOldDims(oldShapedType.getRank()); + for (unsigned i = 0; i < oldShapedType.getRank(); ++i) { + if (oldShapedType.isDynamicDim(i)) { + allOldDims[i] = oldDims.front(); + oldDims = oldDims.drop_front(); + } + } + + auto newShapedType = newType.cast(); + SmallVector newDims; + for (unsigned i = 0; i < newShapedType.getRank(); ++i) { + if (newShapedType.isDynamicDim(i)) { + auto oldValue = allOldDims[i]; + if (oldValue) { + // Old value valid; reuse. + newDims.push_back(oldValue); + } else { + // Dimension has changed to be dynamic; insert a constant to use. + // This sometimes happens during folding of casts and usually is cleaned + // up pretty quickly. + newDims.push_back(rewriter.createOrFold( + op->getLoc(), oldShapedType.getDimSize(i))); + } + } + } + return newDims; +} + //===----------------------------------------------------------------------===// // Streams //===----------------------------------------------------------------------===// namespace { -// Optimizes stream fragment arguments by: -// - Removing any that are not used in the body -// - Deduping arguments that refer to the same Value -struct DceStreamFragment : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Returns true if the given |value| is used again after |updateOp| consumes it. +static bool hasUsersInStreamAfterUpdate(Value value, Operation *updateOp) { + for (auto user : value.getUsers()) { + if (user == updateOp) continue; + if (user->isBeforeInBlock(updateOp)) continue; + return true; + } + return false; +} + +/// Inserts clones into the stream as required by tied results. +/// This is required to preserve the immutable tensor semantics required by the +/// SSA use-def chain. +/// +/// Example: +/// %0 = flow.dispatch +/// // %0 will be updated in-place and renamed %1: +/// %1 = flow.dispatch %0 -> %0 +/// // The original value of %0 (aka %1) is required but is not valid! +/// %2 = flow.dispatch %0 +/// -> +/// %0 = flow.dispatch +/// // Capture the value of %0 before it is modified: +/// %clone = flow.tensor.clone %0 +/// // Update %0 in-place and rename to %1, safe as %0 now has one use: +/// %1 = flow.dispatch %0 -> %0 +/// // Use the cloned %0 value: +/// %2 = flow.dispatch %clone +struct InsertImmutabilityPreservingStreamClones + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExStreamFragmentOp op, PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/0); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); + bool didClone = + insertTiedClones(cast(op.getOperation()), rewriter); + for (auto &block : op.getClosureBodyRegion()) { + for (auto &innerOp : block) { + if (auto tiedOp = dyn_cast(innerOp)) { + didClone |= insertTiedClones(tiedOp, rewriter); + } + } } - return success(); + return didClone ? success() : failure(); + } + + bool insertTiedClones(TiedOpInterface tiedOp, + PatternRewriter &rewriter) const { + bool didClone = false; + for (unsigned resultIndex = 0; resultIndex < tiedOp->getNumResults(); + ++resultIndex) { + auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(resultIndex); + if (!tiedOperandIndex.hasValue()) continue; + auto tiedOperand = tiedOp->getOperand(tiedOperandIndex.getValue()); + if (hasUsersInStreamAfterUpdate(tiedOperand, tiedOp)) { + rewriter.setInsertionPointAfterValue(tiedOperand); + auto clonedOperand = rewriter.createOrFold( + tiedOperand.getLoc(), tiedOperand); + SmallPtrSet excludedOps; + excludedOps.insert(tiedOp.getOperation()); + excludedOps.insert(clonedOperand.getDefiningOp()); + tiedOperand.replaceAllUsesExcept(clonedOperand, excludedOps); + didClone = true; + } + } + return didClone; + } +}; + +/// Ties the results of streams to their operands when the stream operations are +/// tied throughout the entire body. +struct TieStreamResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExStreamFragmentOp op, + PatternRewriter &rewriter) const override { + assert(op.getRegion().getBlocks().size() == 1 && + "only one stream block supported"); + bool didModify = false; + op.walk([&](IREE::Flow::ReturnOp returnOp) { + for (auto result : llvm::enumerate(returnOp.getOperands())) { + if (op.getTiedResultOperandIndex(result.index()).hasValue()) { + continue; // Already tied. + } + auto baseValue = + IREE::TiedOpInterface::findTiedBaseValue(result.value()); + if (auto blockArg = baseValue.dyn_cast()) { + unsigned operandIndex = blockArg.getArgNumber(); + op.setTiedResultOperandIndex(result.index(), operandIndex); + didModify = true; + } + } + }); + return didModify ? success() : failure(); } }; @@ -75,7 +194,9 @@ struct DceStreamFragment : public OpRewritePattern { void ExStreamFragmentOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); + results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -223,88 +344,32 @@ void VariableStoreIndirectOp::getCanonicalizationPatterns( // Dispatch ops //===----------------------------------------------------------------------===// -namespace { - -struct DceDispatchRegion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DispatchRegionOp op, - PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/1); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); - } - return success(); - } -}; - -} // namespace - void DispatchRegionOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } -namespace { - -struct DceDispatchWorkgroups : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DispatchWorkgroupsOp op, - PatternRewriter &rewriter) const override { - if (op.body().empty()) return failure(); - ClosureOpDce dce(op, op.body().front(), - /*variadicOffset=*/op.workgroup_count().size()); - if (!dce.needsOptimization()) return failure(); - - bool newOperation = dce.needsNewOperation(); - if (!newOperation) { - rewriter.startRootUpdate(op); - dce.optimize(rewriter); - rewriter.finalizeRootUpdate(op); - } else { - dce.optimize(rewriter, /*eraseOriginal=*/false); - rewriter.eraseOp(op); - } - return success(); - } -}; - -} // namespace - void DispatchWorkgroupsOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - // TODO(benvanik): add DCE support; it's tricky here because the ClosureOpDce - // assumes 1:1 operands/results between the outer op and inner region, but we - // don't do that there. - // results.insert(context); + results.insert>(context); } //===----------------------------------------------------------------------===// -// flow.dispatch.input.load +// flow.dispatch.tensor.load //===----------------------------------------------------------------------===// namespace { // Some linalg patterns, due to being upstream, tend to introduce `dim` ops. // These generally fold with upstream patterns when tensors are involved, but -// when DispatchInputLoadOp's are involved (with dispatch tensor types), +// when DispatchTensorLoadOp's are involved (with dispatch tensor types), // then this starts to break down, which causes the `dim` ops to survive // arbitrarily late into the pipeline. Often, they keep alive -// DispatchInputLoadOp's that would otherwise be dead! +// DispatchTensorLoadOp's that would otherwise be dead! // // To fix this, we convert the `std.dim` ops to `flow.dispatch.shape` ops. // ``` -// dim(flow.dispatch.input.load(%x), %const) +// dim(flow.dispatch.tensor.load(%x), %const) // -> // shapex.ranked_dim(flow.dispatch.shape(%x), %const) // `` @@ -314,15 +379,14 @@ struct ConvertDimOfDispatchInputLoadToDispatchShape LogicalResult matchAndRewrite(DimOp op, PatternRewriter &rewriter) const override { - auto dispatchInputLoad = - op.memrefOrTensor().getDefiningOp(); - if (!dispatchInputLoad) return failure(); + auto loadOp = op.memrefOrTensor().getDefiningOp(); + if (!loadOp) return failure(); Optional constantIndex = op.getConstantIndex(); if (!constantIndex.hasValue()) return failure(); - auto rankedShape = rewriter.create( - op.getLoc(), dispatchInputLoad.source()); + auto rankedShape = + rewriter.create(op.getLoc(), loadOp.source()); rewriter.replaceOpWithNewOp(op, rankedShape, *constantIndex); return success(); @@ -331,7 +395,7 @@ struct ConvertDimOfDispatchInputLoadToDispatchShape } // namespace -void DispatchInputLoadOp::getCanonicalizationPatterns( +void DispatchTensorLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } @@ -444,14 +508,6 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { // No-op. return source(); } - - // Skip intermediate reshapes. - if (auto definingOp = - dyn_cast_or_null(source().getDefiningOp())) { - setOperand(definingOp.getOperand()); - return result(); - } - return {}; } @@ -504,10 +560,21 @@ OpFoldResult TensorSplatOp::fold(ArrayRef operands) { OpFoldResult TensorCloneOp::fold(ArrayRef operands) { if (operands[0]) { + // Constants always fold. return operands[0]; } - // TODO(benvanik): fold if clone device placements differ. - return operand(); + + // TODO(benvanik): elide clones when safe to do so. Right now clone is + // load-bearing to work around our lack of cross-stream scheduling. Clones are + // inserted to avoid mutating function arguments and any logic we perform here + // (without *also* checking all the conditions that may insert a clone) will + // just fight. + // + // Once the clones are not load-bearing we can remove them in all the normal + // cases (one user, no intervening uses between clone and consumers of + // operands, etc). + + return {}; } // Slices tensor from start to (start + length) exclusively at dim. @@ -606,12 +673,15 @@ static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target, } OpFoldResult TensorUpdateOp::fold(ArrayRef operands) { - auto indices = operands.drop_front(2); + auto targetIndex = getODSOperandIndexAndLength(0).first; + auto startIndices = getODSOperandIndexAndLength(2); + auto updateIndex = getODSOperandIndexAndLength(3).first; + auto indices = operands.slice(startIndices.first, startIndices.second); bool allIndicesConstant = llvm::count(indices, nullptr) == 0; - if (operands[0] && operands[1] && allIndicesConstant) { + if (operands[updateIndex] && operands[targetIndex] && allIndicesConstant) { // Fully constant arguments so we can perform the update here. - return tensorUpdate(operands[0].cast(), - operands[1].cast(), indices); + return tensorUpdate(operands[updateIndex].cast(), + operands[targetIndex].cast(), indices); } else { // Replace the entire tensor when the sizes match. auto updateType = update().getType().cast(); @@ -624,6 +694,43 @@ OpFoldResult TensorUpdateOp::fold(ArrayRef operands) { return {}; } +namespace { + +// When the target tensor is a result of a tensor.cast operation, the op needs +// to be updated to use the source of the cast as the target tensor. +struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorUpdateOp updateOp, + PatternRewriter &rewriter) const override { + auto targetCastOp = updateOp.target().getDefiningOp(); + auto updateCastOp = updateOp.update().getDefiningOp(); + if (!targetCastOp && !updateCastOp) return failure(); + auto target = (targetCastOp ? targetCastOp.source() : updateOp.target()); + auto update = (updateCastOp ? updateCastOp.source() : updateOp.update()); + auto newOp = rewriter.create( + updateOp.getLoc(), target.getType(), target, + refreshDimsOnTypeChange(updateOp, updateOp.target().getType(), + target.getType(), updateOp.target_dims(), + rewriter), + updateOp.start_indices(), update, + refreshDimsOnTypeChange(updateOp, updateOp.update().getType(), + update.getType(), updateOp.update_dims(), + rewriter), + updateOp.tied_operandsAttr()); + rewriter.replaceOpWithNewOp( + updateOp, updateOp.getResult().getType(), newOp.getResult()); + return success(); + } +}; + +} // namespace + +void TensorUpdateOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + } // namespace Flow } // namespace IREE } // namespace iree_compiler diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp index eecec9acebe6..d188f7bbc3f2 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp @@ -14,131 +14,297 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace iree_compiler { namespace IREE { namespace Flow { -Operation *cloneWithNewResultTypes(Operation *op, TypeRange newResultTypes) { - OperationState state(op->getLoc(), op->getName()); - state.addOperands(op->getOperands()); - state.addTypes(newResultTypes); - state.addSuccessors(op->getSuccessors()); - state.addAttributes(op->getAttrs()); - for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { - state.addRegion(); +//------------------------------------------------------------------------------ +// Closure optimization +//------------------------------------------------------------------------------ + +void excludeClosureOperandsAndResults( + SmallVector &operandValues, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + ArrayRef excludedResultIndices) { + SmallVector oldOperandValues = operandValues; + operandValues.clear(); + for (auto it : llvm::enumerate(oldOperandValues)) { + if (!llvm::count(excludedOperandIndices, it.index())) { + operandValues.push_back(it.value()); + } } - Operation *newOp = Operation::create(state); - for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { - newOp->getRegion(i).takeBody(op->getRegion(i)); + SmallVector oldResultTypes = resultTypes; + resultTypes.clear(); + for (auto it : llvm::enumerate(oldResultTypes)) { + if (!llvm::count(excludedResultIndices, it.index())) { + resultTypes.push_back(it.value()); + } } - return newOp; } -//------------------------------------------------------------------------------ -// ClosureOpDce -//------------------------------------------------------------------------------ +void excludeClosureOperandsAndResults( + SmallVector &operandValues, SmallVector &operandDims, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, SmallVector &resultDims, + ArrayRef excludedResultIndices) { + SmallVector oldOperandValues = operandValues; + SmallVector oldOperandDims = operandDims; + operandValues.clear(); + operandDims.clear(); + auto remainingOperandDims = llvm::makeArrayRef(oldOperandDims); + for (auto it : llvm::enumerate(oldOperandValues)) { + unsigned numDynamicDims = 0; + auto type = it.value().getType(); + if (auto shapedType = type.dyn_cast()) { + numDynamicDims = shapedType.getNumDynamicDims(); + } + if (!llvm::count(excludedOperandIndices, it.index())) { + operandValues.push_back(it.value()); + for (auto dim : remainingOperandDims.take_front(numDynamicDims)) { + operandDims.push_back(dim); + } + } + remainingOperandDims = remainingOperandDims.drop_front(numDynamicDims); + } + + SmallVector oldResultTypes = resultTypes; + SmallVector oldResultDims = resultDims; + resultTypes.clear(); + resultDims.clear(); + auto remainingResultDims = llvm::makeArrayRef(oldResultDims); + for (auto it : llvm::enumerate(oldResultTypes)) { + unsigned numDynamicDims = 0; + auto type = it.value(); + if (auto shapedType = type.dyn_cast()) { + numDynamicDims = shapedType.getNumDynamicDims(); + } + if (!llvm::count(excludedResultIndices, it.index())) { + resultTypes.push_back(type); + for (auto dim : remainingResultDims.take_front(numDynamicDims)) { + resultDims.push_back(dim); + } + } + remainingResultDims = remainingResultDims.drop_front(numDynamicDims); + } +} -ClosureOpDce::ClosureOpDce(Operation *closureOp, Block &entryBlock, - unsigned variadicOffset) - : closureOp(closureOp), - entryBlock(entryBlock), - variadicOffset(variadicOffset), - blockArgReplacements(entryBlock.getNumArguments()) { - assert(closureOp->getNumOperands() == - entryBlock.getNumArguments() + variadicOffset); +void eraseRegionResults(Region ®ion, + ArrayRef excludedResultIndices) { + region.walk([&](IREE::Flow::ReturnOp terminator) { + llvm::SmallVector newReturns; + for (auto it : llvm::enumerate(terminator.getOperands())) { + if (!llvm::count(excludedResultIndices, it.index())) { + newReturns.push_back(it.value()); + } + } + terminator.getOperation()->setOperands(newReturns); + }); +} + +// Returns true if |constantOp| represents a (logically) small constant value. +// +// "Small" is relative and there's a risk that we'll bloat the closures by +// duplicating a bunch of constants however what we are able to save by not +// doing that usually wins. Think of the number of bytes used on instructions to +// allocate/place/copy, setup function call arguments, compute the address, +// dereference the memory, etc vs. a constant immediate value of 16 bytes - +// afterall, there are single x64 instructions that approach 15 bytes :) +// +// This is also still at a fairly high level (flow dialect): once the closures +// are expanded out in lower dialects things like CSE have a chance to once +// again get at the constants and dedupe them if they survive. +static bool isConstantSmall(ConstantOp constantOp) { + // We could tune this/take it as a configuration setting. + // The current value is chosen based on what is known to be reasonable to + // inline into command buffers way down in the HAL, which is not great but at + // least better than either allocating independent buffers for 4 byte + // constants or inlining megabytes. + static constexpr int kMaxInlinedConstantBytes = 256; + + auto constantValueAttr = constantOp.getValue(); + auto constantType = constantOp.getType(); + if (constantValueAttr.isa()) { + // Splats are always small and can often have special handling when we + // know they are a splat - which is why it's so important we inline them + // here so we know when they are used that's the case. + return true; + } else if (auto denseAttr = constantValueAttr.dyn_cast()) { + // Smallish constants are worth moving inside. + auto shapedType = constantType.cast(); + uint64_t estimatedByteLength = + (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8; + return denseAttr.isSplat() || + estimatedByteLength <= kMaxInlinedConstantBytes; + } else if (constantType.isIntOrIndexOrFloat()) { + // Primitives can always go in. + return true; + } + + return false; +} + +// Returns true if the given value should be inlined into the closure region. +// This is non-recursive and only holds for this value. Recursively cloning +// trees is hard and it'd be better to model that differently such as by having +// a wrapper region for immutable blobs that can be inlined that this then +// returns true for. +static bool shouldInlineIntoClosure(Value value) { + auto definingOp = value.getDefiningOp(); + if (auto constantOp = dyn_cast(definingOp)) { + // Constants are perfect! + return isConstantSmall(constantOp); + } else if (auto variableLoadOp = + dyn_cast(definingOp)) { + // If the variable is immutable then we can inline the reference to it. + auto variableOp = + SymbolTable::lookupNearestSymbolFrom( + definingOp, variableLoadOp.variable()); + return !variableOp.is_mutable(); + } + return false; +} + +// Inlines operands of the closure into the entry block as appropriate. +// The closure operands and block arguments will remain untouched but all uses +// will be replaced with the newly cloned values. +// +// Note that if multiple operands reference the same value it will get cloned +// multiple times. That's fine, as anything we can inline here is something we +// should also be able to CSE and that happens later on anyway. +static void inlineClosureOperands(ClosureOpInterface &closureOp, + Block &entryBlock) { + auto builder = OpBuilder::atBlockBegin(&entryBlock); + for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { + auto outerValue = opArg.value(); + auto *sourceOp = outerValue.getDefiningOp(); + if (!sourceOp) continue; // can't clone block arguments into closures + if (closureOp.canClosureContainOp(sourceOp) && + shouldInlineIntoClosure(outerValue)) { + // Clone the op (with regions). + auto *clonedOp = builder.clone(*sourceOp); + + // Ensure we are using the right result in the case of ops with multiple + // results. If we only end up using a single result then canonicalization + // should take care of removing the unneeded ones. + int resultIndex = + std::distance(sourceOp->result_begin(), + std::find(sourceOp->result_begin(), + sourceOp->result_end(), outerValue)); + auto newValue = clonedOp->getResult(resultIndex); + + // Replace all of the uses inside of the closure. + auto innerValue = entryBlock.getArgument(opArg.index()); + innerValue.replaceAllUsesWith(newValue); + } + } +} + +bool optimizeClosureLikeOp(ClosureOpInterface &closureOp, + PatternRewriter *rewriter) { + // NOTE: the block is transferred to the new op; we can update it in place. + Block &entryBlock = closureOp.getClosureBodyRegion().front(); + + // Find constants/metadata/etc that we can clone into the closure. + // By doing this first we potentially create some dead operands that we can + // then elide below. When we do inline things the operands will be changed + // such that the following work is guaranteed to happen and thus our op will + // be rebuilt. + inlineClosureOperands(closureOp, entryBlock); // Build data structure for unused operand elision. - for (auto it : llvm::enumerate(entryBlock.getArguments())) { - BlockArgument blockArg = it.value(); - Value opArg = closureOp->getOperand(it.index() + variadicOffset); - if (blockArg.getUses().empty()) { + SmallVector elidedOperands; + llvm::SmallMapVector argToBlockMap; + SmallVector, 8> blockArgReplacements( + entryBlock.getNumArguments()); + for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { + auto blockArg = entryBlock.getArgument(opArg.index()); + if (blockArg.use_empty()) { // Not used - Drop. - needsOperandElision = true; - blockArgReplacements[it.index()] = BlockArgument(); + elidedOperands.push_back(opArg.index()); + blockArgReplacements[opArg.index()] = BlockArgument(); continue; } - auto existingIt = argToBlockMap.find(opArg); + auto existingIt = argToBlockMap.find(opArg.value()); if (existingIt == argToBlockMap.end()) { // Not found - Record for deduping. - argToBlockMap.insert(std::make_pair(opArg, blockArg)); + argToBlockMap.insert(std::make_pair(opArg.value(), blockArg)); } else { // Found - Replace. - needsOperandElision = true; - blockArgReplacements[it.index()] = existingIt->second; + elidedOperands.push_back(opArg.index()); + blockArgReplacements[opArg.index()] = existingIt->second; } } // Check for unused results. - for (auto result : closureOp->getResults()) { - if (result.getUses().empty()) { - needsResultElision = true; - break; + SmallVector preservedResults; + SmallVector elidedResults; + for (auto result : llvm::enumerate(closureOp.getClosureResults())) { + if (result.value().use_empty()) { + elidedResults.push_back(result.index()); + } else { + preservedResults.push_back(result.value()); } } -} -void ClosureOpDce::elideUnusedOperands(OpBuilder &builder) { - llvm::SmallVector newOperands( - closureOp->operand_begin(), closureOp->operand_begin() + variadicOffset); + if (elidedOperands.empty() && elidedResults.empty()) { + // No optimization required. + return false; + } + + if (elidedResults.size() == closureOp.getClosureResults().size()) { + // The op is completely unused - delete it. + if (rewriter) { + rewriter->eraseOp(closureOp); + } else { + closureOp.erase(); + } + closureOp = {}; + return true; + } + + // Replace duplicate block arguments. unsigned blockArgIndex = 0; - for (auto it : llvm::enumerate(blockArgReplacements)) { - llvm::Optional replacement = it.value(); - Value currentOpArg = closureOp->getOperand(it.index() + variadicOffset); + for (auto replacement : blockArgReplacements) { if (!replacement) { // No change. - newOperands.push_back(currentOpArg); blockArgIndex++; - continue; } else if (!replacement.getValue()) { - // Drop. - entryBlock.eraseArgument(blockArgIndex); - continue; + // Dropped. } else { - // Replace. - BlockArgument currentBlockArg = entryBlock.getArgument(blockArgIndex); - currentBlockArg.replaceAllUsesWith(*replacement); - entryBlock.eraseArgument(blockArgIndex); + // Replaced. + entryBlock.getArgument(blockArgIndex).replaceAllUsesWith(*replacement); } } - closureOp->setOperands(newOperands); -} - -void ClosureOpDce::elideUnusedResults(OpBuilder &builder, bool eraseOriginal) { - // Determine the result signature transform needed. - llvm::SmallVector resultIndexMap; - llvm::SmallVector newResultTypes; - for (auto it : llvm::enumerate(closureOp->getResults())) { - if (!it.value().getUses().empty()) { - newResultTypes.push_back(it.value().getType()); - resultIndexMap.push_back(it.index()); - } + // Clone the op with the elidable operands and results removed. + OpBuilder builder(closureOp); + auto newOp = closureOp.cloneReplacementExcludingOperandsAndResults( + elidedOperands, elidedResults); + if (rewriter) { + rewriter->insert(newOp); + } else { + builder.insert(newOp); } - // Re-allocate the op. - builder.setInsertionPoint(closureOp); - Operation *newOp = - builder.insert(cloneWithNewResultTypes(closureOp, newResultTypes)); - - // Remap all returns. - llvm::SmallVector newReturns(resultIndexMap.size()); - newOp->walk([&](IREE::Flow::ReturnOp terminator) { - for (unsigned i = 0, e = resultIndexMap.size(); i < e; ++i) { - newReturns[i] = terminator.getOperand(resultIndexMap[i]); - } - terminator.getOperation()->setOperands(newReturns); - }); + // Replace original uses of the closure results. + for (auto oldNewResult : + llvm::zip(preservedResults, newOp.getClosureResults())) { + std::get<0>(oldNewResult).replaceAllUsesWith(std::get<1>(oldNewResult)); + } - // Replace original uses. - for (unsigned i = 0, e = resultIndexMap.size(); i < e; ++i) { - closureOp->getResult(resultIndexMap[i]) - .replaceAllUsesWith(newOp->getResult(i)); + // Erase the original op. + if (rewriter) { + rewriter->eraseOp(closureOp); + } else { + closureOp.erase(); } - if (eraseOriginal) closureOp->erase(); + closureOp = newOp; + return true; } } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h index 38b44d0be86a..2e2da8ae1dd3 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h +++ b/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h @@ -12,56 +12,69 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace iree_compiler { namespace IREE { namespace Flow { -// Clones an operation with new result types. -// The original operation will be erased and a new operation constructed -// in its place. -Operation *cloneWithNewResultTypes(Operation *op, TypeRange newResultTypes); +//------------------------------------------------------------------------------ +// Closure optimization +//------------------------------------------------------------------------------ -// Utility class to optimize a "closure" op, which maintains a variadic -// list of operands corresponding to entry block arguments. -class ClosureOpDce { - public: - ClosureOpDce(Operation *closureOp, Block &entryBlock, - unsigned variadicOffset); +// Modifies in-place the operand results vectors for a closure operation. +// |excludedOperandIndices| and |excludedResultIndices| are sets containing the +// operands and results in the lists to remove. +void excludeClosureOperandsAndResults(SmallVector &operandValues, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + ArrayRef excludedResultIndices); +void excludeClosureOperandsAndResults(SmallVector &operandValues, + SmallVector &operandDims, + ArrayRef excludedOperandIndices, + SmallVector &resultTypes, + SmallVector &resultDims, + ArrayRef excludedResultIndices); - bool needsOptimization() { return needsOperandElision || needsResultElision; } +// Erases the given result indices from terminators in the given region. +void eraseRegionResults(Region ®ion, + ArrayRef excludedResultIndices); - // Whether the operation needs to be replaced. - bool needsNewOperation() { return needsResultElision; } +// Optimizes closure |closureOp| to remove duplicate operands and unused +// results. The op may be mutated, destroyed, or replaced with a new one. If an +// optional |rewriter| is provided then it will be notified of the operations +// performed on the op. Returns true if the op was optimized. +bool optimizeClosureLikeOp(ClosureOpInterface &closureOp, + PatternRewriter *rewriter = nullptr); +template +inline bool optimizeClosureOp(T &op, PatternRewriter *rewriter = nullptr) { + auto closureOp = cast(op.getOperation()); + bool didOptimize = optimizeClosureLikeOp(closureOp, rewriter); + op = dyn_cast_or_null(closureOp.getOperation()); + return didOptimize; +} - // Performs the optimization. If the optional eraseOriginal=false and - // needsNewOperation(), then the original will not be erased, leaving that - // to the caller (which is needed in some pattern rewriting scenarios). - // TODO(laurenzo): Fix OpBuilder upstream so that this eraseOriginal - // workaround is not required to write a safe rewriter pattern that uses this - // utility. - Operation *optimize(OpBuilder &builder, bool eraseOriginal = true) { - if (needsResultElision) elideUnusedResults(builder, eraseOriginal); - if (needsOperandElision) elideUnusedOperands(builder); - return closureOp; - } - - private: - void elideUnusedOperands(OpBuilder &builder); - void elideUnusedResults(OpBuilder &builder, bool eraseOriginal); +// A pattern that optimizes the given region-containing op T (CSE, DCE, etc). +// Duplicate operands will be combined and unused operands and results will be +// removed. +// +// T must implement the IREE::Flow::ClosureOpInterface. +template +struct ClosureOptimizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - Operation *closureOp; - Block &entryBlock; - unsigned variadicOffset; - llvm::SmallVector, 8> blockArgReplacements; - llvm::SmallMapVector argToBlockMap; - bool needsOperandElision = false; - bool needsResultElision = false; + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + auto closureOp = cast(op.getOperation()); + return optimizeClosureLikeOp(closureOp, &rewriter) ? success() : failure(); + } }; } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index 3bca4d5a86d7..34aeeba4512b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -16,9 +16,10 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -37,6 +38,10 @@ namespace iree_compiler { namespace IREE { namespace Flow { +//===----------------------------------------------------------------------===// +// Op utilities used within the Flow dialect +//===----------------------------------------------------------------------===// + // Returns true if the given |accessType| is compatible with the |variableType|. // For example, this will return true if the variable type is a tensor // and the access is tensor<4xf32>. @@ -44,6 +49,270 @@ static bool isVariableTypeCompatible(Type variableType, Type accessType) { return succeeded(mlir::verifyCompatibleShape(variableType, accessType)); } +// Verifies that |dynamicDims| contains the appropriate number of dims for all +// of the dynamic dimensions in |values|. +static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values, + ValueRange dynamicDims) { + unsigned requiredCount = 0; + for (auto value : values) { + if (auto shapedType = value.getType().dyn_cast()) { + requiredCount += shapedType.getNumDynamicDims(); + } + } + if (dynamicDims.size() != requiredCount) { + return op->emitOpError() + << "value set has " << requiredCount + << " dynamic dimensions but only " << dynamicDims.size() + << " dimension values are attached"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// +// type{%dim0, %dim1} +// %arg0 + +static ParseResult parseTiedResult( + OpAsmParser &parser, Type &resultType, + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { + if (failed(parser.parseType(resultType))) return failure(); + if (auto shapedType = resultType.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + resultDims.append(dynamicDims); + } + } + tiedOperands = parser.getBuilder().getIndexArrayAttr({0}); + return success(); +} + +static void printTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, + ValueRange resultDims, ArrayAttr tiedOperands) { + p.printType(resultType); + if (auto shapedType = resultType.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (resultDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + resultDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + } + } +} + +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// +// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4) + +static ParseResult parseShapedOperandList( + OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &dims) { + do { + Type type; + if (failed(parser.parseType(type))) return failure(); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + dims.append(dynamicDims); + } + } + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + return success(); +} + +// Ties the |tiedResult| parsed operand back to a previously parsed operand. +// The type and any dynamic dimensions of the operand will be used for the +// result values and the operand index will be appended to |tiedOperandIndices|. +static ParseResult tieOperand( + OpAsmParser::OperandType tiedResult, OpAsmParser &parser, + ArrayRef operands, TypeRange operandTypes, + ArrayRef operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims, + SmallVectorImpl &tiedOperandIndices) { + int64_t operandIndex = TiedOpInterface::kUntiedIndex; + for (int64_t i = 0; i < operands.size(); ++i) { + if (operands[i].name == tiedResult.name) { + operandIndex = i; + break; + } + } + if (operandIndex == TiedOpInterface::kUntiedIndex) { + return parser.emitError(tiedResult.location, + "tied operand not found for result reference ") + << tiedResult.name; + } + + auto resultType = operandTypes[operandIndex]; + resultTypes.push_back(resultType); + tiedOperandIndices.push_back(operandIndex); + + auto shapedType = resultType.dyn_cast(); + if (shapedType) { + unsigned dimsIndex = 0; + for (unsigned i = 0; i < operandIndex; ++i) { + if (auto shapedType = operandTypes[i].dyn_cast()) { + dimsIndex += shapedType.getNumDynamicDims(); + } + } + resultDims.append(llvm::to_vector<4>( + operandDims.slice(dimsIndex, shapedType.getNumDynamicDims()))); + } + + return success(); +} + +static ParseResult parseShapedResultList( + OpAsmParser &parser, ArrayRef operands, + TypeRange operandTypes, ArrayRef operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { + SmallVector tiedOperandIndices; + do { + OpAsmParser::OperandType tiedResult; + auto res = parser.parseOptionalOperand(tiedResult); + if (res.hasValue() && succeeded(res.getValue())) { + if (failed(tieOperand(tiedResult, parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperandIndices))) { + return failure(); + } + } else { + Type type; + if (failed(parser.parseType(type))) return failure(); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + SmallVector dynamicDims; + if (failed(parser.parseLBrace()) || + failed(parser.parseOperandList(dynamicDims, + shapedType.getNumDynamicDims(), + OpAsmParser::Delimiter::None)) || + failed(parser.parseRBrace())) { + return failure(); + } + resultDims.append(dynamicDims); + } + } + resultTypes.push_back(type); + tiedOperandIndices.push_back(TiedOpInterface::kUntiedIndex); + } + } while (succeeded(parser.parseOptionalComma())); + if (!tiedOperandIndices.empty()) { + tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); + } + return success(); +} + +static ParseResult parseShapedFunctionType( + OpAsmParser &parser, ArrayRef operands, + SmallVectorImpl &operandTypes, + SmallVectorImpl &operandDims, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultDims, + ArrayAttr &tiedOperands) { + if (failed(parser.parseLParen())) return failure(); + if (failed(parser.parseOptionalRParen())) { + if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) || + failed(parser.parseRParen())) { + return failure(); + } + } + if (failed(parser.parseArrow())) return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parseShapedResultList(parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperands)) || + failed(parser.parseRParen())) { + return failure(); + } + } else { + if (failed(parseShapedResultList(parser, operands, operandTypes, + operandDims, resultTypes, resultDims, + tiedOperands))) { + return failure(); + } + } + return success(); +} + +static void printShapedFunctionType(OpAsmPrinter &p, Operation *op, + ValueRange operands, TypeRange operandTypes, + OperandRange operandDims, + TypeRange resultTypes, + OperandRange resultDims, + ArrayAttr tiedOperands) { + p << "("; + llvm::interleaveComma(operandTypes, p, [&](Type type) { + p.printType(type); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (operandDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + operandDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + operandDims = operandDims.drop_front(shapedType.getNumDynamicDims()); + } + } + }); + p << ") -> "; + if (resultTypes.size() != 1) p << "("; + auto tiedOp = cast(op); + for (unsigned i = 0; i < resultTypes.size(); ++i) { + auto tiedOperand = tiedOp.getTiedResultOperandIndex(i); + if (tiedOperand.hasValue()) { + p.printOperand(op->getOperand(tiedOperand.getValue())); + } else { + auto type = resultTypes[i]; + p.printType(type); + if (auto shapedType = type.dyn_cast()) { + if (!shapedType.hasStaticShape()) { + if (resultDims.empty()) { + p << "{<>}"; + return; + } + p << "{"; + llvm::interleaveComma( + resultDims.take_front(shapedType.getNumDynamicDims()), p, + [&](Value value) { p.printOperand(value); }); + p << "}"; + resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); + } + } + } + if (i < resultTypes.size() - 1) p << ", "; + } + if (resultTypes.size() != 1) p << ")"; +} + //===----------------------------------------------------------------------===// // flow.variable //===----------------------------------------------------------------------===// @@ -329,10 +598,24 @@ DispatchRegionOp::formFromAnchorOp(Value workload, Operation *anchorOp, return std::make_pair(drOp, newAnchorOp); } -void DispatchRegionOp::dceOperandsAndResults(DispatchRegionOp &op) { - OpBuilder builder(op.getContext()); - ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/1); - op = llvm::cast(dce.optimize(builder)); +// Clones an operation with new result types. +// The original operation will be erased and a new operation constructed +// in its place. +static Operation *cloneWithNewResultTypes(Operation *op, + TypeRange newResultTypes) { + OperationState state(op->getLoc(), op->getName()); + state.addOperands(op->getOperands()); + state.addTypes(newResultTypes); + state.addSuccessors(op->getSuccessors()); + state.addAttributes(op->getAttrs()); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + state.addRegion(); + } + Operation *newOp = Operation::create(state); + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + newOp->getRegion(i).takeBody(op->getRegion(i)); + } + return newOp; } ResultRange DispatchRegionOp::appendResults(DispatchRegionOp &self, @@ -489,10 +772,9 @@ void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { // Print the result types, if any. if (op.getNumResults() > 0) { - p << " -> "; - if (op.getNumResults() > 1) p << "("; + p << " -> ("; interleaveComma(op.getResultTypes(), p); - if (op.getNumResults() > 1) p << ")"; + p << ")"; } p.printRegion(op.body(), /*printEntryBlockArgs=*/false); @@ -500,22 +782,81 @@ void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { /*elidedAttrs=*/{}); } +Operation::operand_range DispatchRegionOp::getClosureOperands() { + return args(); +} + +Operation::result_range DispatchRegionOp::getClosureResults() { + return results(); +} + +// TODO(#4897): allow non-splat constants - current paths can't handle them. +static bool canDispatchRegionContainOpIssue4897(Operation *op) { + if (auto constantOp = dyn_cast(op)) { + auto constantValueAttr = constantOp.getValue(); + auto constantType = constantOp.getType(); + if (constantValueAttr.isa()) { + return true; + } else if (auto denseAttr = + constantValueAttr.dyn_cast()) { + return denseAttr.isSplat(); + } else if (constantType.isIntOrIndexOrFloat()) { + return true; + } + } + return false; +} + +bool DispatchRegionOp::canClosureContainOp(Operation *op) { + return canDispatchRegionContainOpIssue4897(op); +} + +ClosureOpInterface +DispatchRegionOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newOperandsValues = llvm::to_vector<4>(args()); + excludeClosureOperandsAndResults(newOperandsValues, excludedOperandIndices, + newResultTypes, excludedResultIndices); + auto newOp = OpBuilder(getContext()) + .create(getLoc(), newResultTypes, + workload(), newOperandsValues, + getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + eraseRegionResults(newBody, excludedResultIndices); + newBody.front().eraseArguments(excludedOperandIndices); + return newOp; +} + //===----------------------------------------------------------------------===// // flow.dispatch.workgroups //===----------------------------------------------------------------------===// void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, ValueRange workgroupCount, - TypeRange resultTypes, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(workgroupCount); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); - state.addAttribute( - "operand_segment_sizes", - builder.getI32VectorAttr({static_cast(workgroupCount.size()), - static_cast(operands.size())})); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(workgroupCount.size()), + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); auto *body = state.addRegion(); assert(body->begin() == body->end()); @@ -523,17 +864,36 @@ void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state, OpBuilder::InsertionGuard g(builder); builder.createBlock(body); // createBlock implicitly moves IP, RAII away... } - for (auto operand : operands) { - Type type = operand.getType(); + + llvm::BitVector operandAliases(llvm::size(operands), false); + llvm::BitVector resultAliases(llvm::size(resultTypes), false); + for (unsigned resultIndex = 0; resultIndex < tiedOperands.size(); + ++resultIndex) { + int64_t tiedOperandIndex = tiedOperands[resultIndex]; + if (tiedOperandIndex != TiedOpInterface::kUntiedIndex) { + operandAliases[tiedOperandIndex] = true; + resultAliases[resultIndex] = true; + } + } + + for (auto operand : llvm::enumerate(operands)) { + Type type = operand.value().getType(); if (auto tensorType = type.dyn_cast()) { - type = DispatchInputType::get(tensorType); + type = DispatchTensorType::get(operandAliases[operand.index()] + ? TensorAccess::ReadWrite + : TensorAccess::ReadOnly, + tensorType); } body->addArgument(type); } - for (auto resultType : resultTypes) { - Type type = resultType; + for (auto resultType : llvm::enumerate(resultTypes)) { + if (resultAliases[resultType.index()]) { + // Already handled by an aliased operand. + continue; + } + Type type = resultType.value(); if (auto tensorType = type.dyn_cast()) { - type = DispatchOutputType::get(tensorType); + type = DispatchTensorType::get(TensorAccess::WriteOnly, tensorType); } body->addArgument(type); } @@ -544,8 +904,6 @@ static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser, TypeRange operandTypes, TypeRange resultTypes, Region &body) { - auto loc = parser.getCurrentLocation(); - SmallVector regionArgs; SmallVector regionArgTypes; if (failed(parser.parseLParen())) { @@ -565,12 +923,6 @@ static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser, return failure(); } } - - if (regionArgs.size() != operandTypes.size() + resultTypes.size()) { - return parser.emitError(loc, - "region operand list required required to match " - "count of dispatch op operands + results"); - } return parser.parseRegion(body, regionArgs, regionArgTypes, /*enableNameShadowing=*/true); } @@ -581,7 +933,7 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op, p << "("; interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p << arg; - p << " : "; + p << ": "; p << arg.getType(); }); p << ")"; @@ -589,93 +941,76 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op, /*printBlockTerminators=*/true); } -// TODO(benvanik): remove after https://bugs.llvm.org/show_bug.cgi?id=48478 -// The parser/printer are modified autogenerated values to work around the bug. - -static ::mlir::ParseResult parseDispatchWorkgroupsOp( - ::mlir::OpAsmParser &parser, ::mlir::OperationState *result) { - ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> - workgroup_countOperands; - ::llvm::SMLoc workgroup_countOperandsLoc; - (void)workgroup_countOperandsLoc; - ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> operandsOperands; - ::llvm::SMLoc operandsOperandsLoc; - (void)operandsOperandsLoc; - ::llvm::ArrayRef<::mlir::Type> operandsTypes; - ::llvm::ArrayRef<::mlir::Type> resultsTypes; - std::unique_ptr<::mlir::Region> bodyRegion = - std::make_unique<::mlir::Region>(); - if (parser.parseLSquare()) return ::mlir::failure(); - - workgroup_countOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(workgroup_countOperands)) - return ::mlir::failure(); - if (parser.parseRSquare()) return ::mlir::failure(); - if (parser.parseLParen()) return ::mlir::failure(); - - operandsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(operandsOperands)) return ::mlir::failure(); - if (parser.parseRParen()) return ::mlir::failure(); - if (parser.parseColon()) return ::mlir::failure(); - - ::mlir::FunctionType operands__results_functionType; - if (parser.parseType(operands__results_functionType)) - return ::mlir::failure(); - operandsTypes = operands__results_functionType.getInputs(); - resultsTypes = operands__results_functionType.getResults(); - if (parser.parseOptionalAttrDictWithKeyword(result->attributes)) - return ::mlir::failure(); - if (parser.parseEqual()) return ::mlir::failure(); - { - if (parseDispatchWorkgroupBody(parser, operandsTypes, resultsTypes, - *bodyRegion)) - return ::mlir::failure(); - } - ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType(); - result->addTypes(resultsTypes); - if (parser.resolveOperands(workgroup_countOperands, odsBuildableType0, - result->operands)) - return ::mlir::failure(); - if (parser.resolveOperands(operandsOperands, operandsTypes, - operandsOperandsLoc, result->operands)) - return ::mlir::failure(); - result->addRegion(std::move(bodyRegion)); - result->addAttribute( - "operand_segment_sizes", - parser.getBuilder().getI32VectorAttr( - {static_cast(workgroup_countOperands.size()), - static_cast(operandsOperands.size())})); - return ::mlir::success(); -} - -static void printDispatchWorkgroupsOp(::mlir::OpAsmPrinter &p, - DispatchWorkgroupsOp &op) { - p << "flow.dispatch.workgroups"; - p << "["; - p << op.workgroup_count(); - p << "]"; - p << ' ' << "("; - p << op.operands(); - p << ")"; - p << ' ' << ":"; - p << ' '; - p.printFunctionalType(op.operands().getTypes(), op.results().getTypes()); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{ - "operand_segment_sizes", - }); - p << ' ' << "="; - p << ' '; - printDispatchWorkgroupBody(p, op, op.operands().getTypes(), - op.results().getTypes(), op.body()); -} - static LogicalResult verifyDispatchWorkgroupsOp(DispatchWorkgroupsOp op) { if (op.workgroup_count().empty()) { return op.emitOpError() << "at least one workgroup dimension is required"; } + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } return success(); } +Value DispatchWorkgroupsOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} + +Value DispatchWorkgroupsOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); +} + +Operation::operand_range DispatchWorkgroupsOp::getClosureOperands() { + return operands(); +} + +Operation::result_range DispatchWorkgroupsOp::getClosureResults() { + return results(); +} + +bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) { + return canDispatchRegionContainOpIssue4897(op); +} + +ClosureOpInterface +DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newResultDims = llvm::to_vector<4>(result_dims()); + SmallVector newOperandsValues = llvm::to_vector<4>(operands()); + SmallVector newOperandDims = llvm::to_vector<4>(operand_dims()); + excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, + excludedOperandIndices, newResultTypes, + newResultDims, excludedResultIndices); + SmallVector newTiedOperandIndices = + llvm::to_vector<4>(getTiedResultOperandIndices()); + excludeTiedOperandAndResultIndices( + excludedOperandIndices, excludedResultIndices, newTiedOperandIndices); + auto newOp = OpBuilder(getContext()) + .create( + getLoc(), workgroup_count(), newResultTypes, + newResultDims, newOperandsValues, newOperandDims, + newTiedOperandIndices, getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + newBody.front().eraseArguments(excludedOperandIndices); + unsigned baseResultIndex = newBody.front().getNumArguments(); + newBody.front().eraseArguments(llvm::to_vector<4>(llvm::map_range( + excludedResultIndices, + [&](unsigned index) { return baseResultIndex + index; }))); + return newOp; +} + +std::pair +DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() { + return getODSOperandIndexAndLength(1); +} + //===----------------------------------------------------------------------===// // flow.dispatch.workgroup.* //===----------------------------------------------------------------------===// @@ -842,7 +1177,9 @@ static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) { void DispatchOp::build(OpBuilder &builder, OperationState &state, DispatchEntryOp entryPoint, ValueRange workgroupCount, - TypeRange results, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { StringRef executableOpSymName = entryPoint->getParentOp() @@ -854,13 +1191,22 @@ void DispatchOp::build(OpBuilder &builder, OperationState &state, {builder.getSymbolRefAttr(entryPoint)})); state.addOperands(workgroupCount); - state.addTypes(results); + state.addTypes(resultTypes); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); - state.addAttribute( - "operand_segment_sizes", - builder.getI32VectorAttr({static_cast(workgroupCount.size()), - static_cast(operands.size())})); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(workgroupCount.size()), + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); } StringRef DispatchOp::executable() { return entry_point().getRootReference(); } @@ -874,121 +1220,297 @@ static LogicalResult verifyDispatchOp(DispatchOp op) { if (op.workgroup_count().empty()) { return op.emitOpError() << "at least one workgroup dimension is required"; } + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } + return success(); +} + +Value DispatchOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} + +Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); +} + +std::pair DispatchOp::getTiedOperandsIndexAndLength() { + return getODSOperandIndexAndLength(1); +} + +//===----------------------------------------------------------------------===// +// flow.tensor.* +//===----------------------------------------------------------------------===// + +Value TensorReshapeOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorReshapeOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +Value TensorLoadOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorLoadOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return {}; +} + +Value TensorStoreOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); +} + +Value TensorStoreOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), target_dims(), + builder); +} + +Value TensorSplatOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return {}; +} + +Value TensorSplatOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +Value TensorCloneOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), operand(), operand_dims(), + builder); +} + +Value TensorCloneOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), operand_dims(), + builder); +} + +Value TensorSliceOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(), + builder); +} + +Value TensorSliceOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(), + builder); +} + +//===----------------------------------------------------------------------===// +// flow.tensor.update +//===----------------------------------------------------------------------===// + +void TensorUpdateOp::build(OpBuilder &builder, OperationState &state, + Value target, ValueRange startIndices, + Value update) { + auto targetDims = + Shape::buildOrFindDynamicDimsForValue(state.location, target, builder); + auto updateDims = + Shape::buildOrFindDynamicDimsForValue(state.location, update, builder); + build(builder, state, target.getType(), target, targetDims, startIndices, + update, updateDims, builder.getIndexArrayAttr({0})); +} + +static LogicalResult verifyTensorUpdateOp(TensorUpdateOp op) { + if (failed(verifyOpDynamicDims(op, {op.update()}, op.update_dims())) || + failed(verifyOpDynamicDims(op, {op.target()}, op.target_dims()))) { + return failure(); + } return success(); } +Value TensorUpdateOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + switch (idx) { + case 0: + return Shape::buildRankedShapeForValue(getLoc(), update(), update_dims(), + builder); + case 2: + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); + default: + llvm_unreachable("unshaped operand"); + } +} + +Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) { + return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(), + builder); +} + +Value TensorUpdateOp::getTiedResult(unsigned resultIndex) { + return IREE::TiedOpInterface::findTiedBaseValue(target()); +} + +::llvm::Optional TensorUpdateOp::getTiedResultOperandIndex( + unsigned resultIndex) { + return 0; // target +} + +SmallVector TensorUpdateOp::getTiedResultOperandIndices() { + return {0}; // target +} + //===----------------------------------------------------------------------===// // flow.ex.stream.fragment //===----------------------------------------------------------------------===// void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state, - ArrayRef resultTypes, ValueRange operands, + TypeRange resultTypes, ValueRange resultDims, + ValueRange operands, ValueRange operandDims, + ArrayRef tiedOperands, ArrayRef attributes) { state.addTypes(resultTypes); state.addOperands(operands); + state.addOperands(operandDims); + state.addOperands(resultDims); state.addAttributes(attributes); + state.attributes.erase(TiedOpInterface::getStorageAttrName()); + state.addAttribute(TiedOpInterface::getStorageAttrName(), + builder.getIndexArrayAttr(tiedOperands)); + state.attributes.erase("operand_segment_sizes"); + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr({ + static_cast(operands.size()), + static_cast(operandDims.size()), + static_cast(resultDims.size()), + })); state.addRegion(); } -ParseResult parseExStreamFragmentOp(OpAsmParser &parser, - OperationState *result) { +static LogicalResult verifyExStreamFragmentOp(ExStreamFragmentOp op) { + if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) || + failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) { + return failure(); + } + return success(); +} + +static ParseResult parseStreamFragmentBody(OpAsmParser &parser, + TypeRange operandTypes, + TypeRange resultTypes, + ArrayAttr tiedOperands, + Region &body) { + auto loc = parser.getCurrentLocation(); + SmallVector regionArgs; SmallVector regionArgTypes; if (failed(parser.parseLParen())) { return failure(); } if (failed(parser.parseOptionalRParen())) { - SmallVector regionOperands; - auto argsLoc = parser.getCurrentLocation(); do { // Reserve entries in the lists. regionArgs.emplace_back(); - regionOperands.emplace_back(); regionArgTypes.emplace_back(); if (failed(parser.parseRegionArgument(regionArgs.back())) || - failed(parser.parseEqual()) || - failed(parser.parseOperand(regionOperands.back())) || failed(parser.parseColonType(regionArgTypes.back()))) { return failure(); } } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen()) || - failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, - result->operands))) { + if (failed(parser.parseRParen())) { return failure(); } } - // Parse (optional) results. - if (failed(parser.parseOptionalArrowTypeList(result->types))) { - return failure(); + SmallVector regionResultTypes; + if (failed(parser.parseArrowTypeList(regionResultTypes))) return failure(); + + if (regionArgs.size() != operandTypes.size()) { + return parser.emitError(loc, "region operand list mismatch"); + } + if (regionResultTypes.size() != resultTypes.size()) { + return parser.emitError(loc, "region result list mismatch"); } - // Parse region body. - Region *body = result->addRegion(); - if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || - failed(parser.parseOptionalAttrDict(result->attributes))) { - return failure(); + return parser.parseRegion(body, regionArgs, regionArgTypes, + /*enableNameShadowing=*/true); +} + +static void printStreamFragmentBody(OpAsmPrinter &p, Operation *op, + TypeRange operandTypes, + TypeRange resultTypes, + ArrayAttr tiedOperands, Region &body) { + p << "("; + llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { + p << arg; + p << ": "; + p << arg.getType(); + }); + p << ") -> "; + if (resultTypes.size() != 1) p << "("; + for (unsigned i = 0; i < resultTypes.size(); ++i) { + p.printType(resultTypes[i]); + if (i < resultTypes.size() - 1) p << ", "; } - return success(); + if (resultTypes.size() != 1) p << ")"; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); } -void printExStreamFragmentOp(OpAsmPrinter &p, ExStreamFragmentOp op) { - p << op.getOperationName(); +Value ExStreamFragmentOp::buildOperandRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(), + operand_dims(), builder); +} - // Print the data argument remapping. - p << "("; - interleaveComma(llvm::zip(op.body().getArguments(), op.args()), p, - [&](std::tuple it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - p << " : "; - p << std::get<1>(it).getType(); - }); - p << ")"; +Value ExStreamFragmentOp::buildResultRankedShape(unsigned idx, + OpBuilder &builder) { + return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(), + result_dims(), builder); +} - // Print the result types, if any. - if (op.getNumResults() > 0) { - p << " -> "; - if (op.getNumResults() > 1) p << "("; - interleaveComma(op.getResultTypes(), p); - if (op.getNumResults() > 1) p << ")"; - } +Operation::operand_range ExStreamFragmentOp::getClosureOperands() { + return operands(); +} - p.printRegion(op.body(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{}); +Operation::result_range ExStreamFragmentOp::getClosureResults() { + return results(); } -//===----------------------------------------------------------------------===// -// flow.tensor.update -//===----------------------------------------------------------------------===// +bool ExStreamFragmentOp::canClosureContainOp(Operation *op) { + // NOTE: we widen support on new stream ops only - the legacy path isn't worth + // upgrading to support more. + if (auto constantOp = dyn_cast(op)) { + return constantOp.getType().isIntOrIndexOrFloat(); + } + return false; +} -namespace { -// When the target tensor is a result of a tensor.cast operation, the op needs -// to be updated to use the source of the cast as the target tensor. -struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorUpdateOp updateOp, - PatternRewriter &rewriter) const override { - auto targetCastOp = updateOp.target().getDefiningOp(); - auto updateCastOp = updateOp.update().getDefiningOp(); - if (!targetCastOp && !updateCastOp) return failure(); - auto target = (targetCastOp ? targetCastOp.source() : updateOp.target()); - auto update = (updateCastOp ? updateCastOp.source() : updateOp.update()); - auto newOp = rewriter.create( - updateOp.getLoc(), target.getType(), update, target, - updateOp.start_indices()); - rewriter.replaceOpWithNewOp( - updateOp, updateOp.getResult().getType(), newOp.getResult()); - return success(); - } -}; -} // namespace - -void TensorUpdateOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); +ClosureOpInterface +ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices) { + SmallVector newResultTypes = llvm::to_vector<4>(getResultTypes()); + SmallVector newResultDims = llvm::to_vector<4>(result_dims()); + SmallVector newOperandsValues = llvm::to_vector<4>(operands()); + SmallVector newOperandDims = llvm::to_vector<4>(operand_dims()); + excludeClosureOperandsAndResults(newOperandsValues, newOperandDims, + excludedOperandIndices, newResultTypes, + newResultDims, excludedResultIndices); + SmallVector newTiedOperandIndices = + llvm::to_vector<4>(getTiedResultOperandIndices()); + excludeTiedOperandAndResultIndices( + excludedOperandIndices, excludedResultIndices, newTiedOperandIndices); + auto newOp = OpBuilder(getContext()) + .create( + getLoc(), newResultTypes, newResultDims, + newOperandsValues, newOperandDims, newTiedOperandIndices, + getOperation()->getAttrs()); + auto &newBody = newOp.getClosureBodyRegion(); + newBody.takeBody(getClosureBodyRegion()); + eraseRegionResults(newBody, excludedResultIndices); + newBody.front().eraseArguments(excludedOperandIndices); + return newOp; } } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h index 0fda24fd1844..fae825805725 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.h +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h @@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/IREE/IR/IREETraits.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td index 947c7b1e37c7..a0f43bb9471b 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -16,6 +16,9 @@ #define IREE_DIALECT_FLOW_OPS include "iree/compiler/Dialect/Flow/IR/FlowBase.td" +include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td" +include "iree/compiler/Dialect/IREE/IR/IREEInterfaces.td" +include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -166,6 +169,7 @@ def FLOW_VariableStoreIndirectOp : FLOW_Op<"variable.store.indirect"> { def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ IsolatedFromAbove, + DeclareOpInterfaceMethods, ]> { let summary = [{partitioned region representing a dispatched workload}]; let description = [{ @@ -202,11 +206,6 @@ def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ formFromAnchorOp(Value workload, Operation *anchorOp, OpBuilder &builder); - /// Performs an in-place DCE optimization on unused operands and results. - /// Note that this may or may not re-allocate the op. If so, the reference - /// will be updated. - static void dceOperandsAndResults(DispatchRegionOp &op); - // Appends results to the dispatch region. This will re-allocate the // DispatchRegionOp itself but preserve the contained body block. // Returns a ResultRange for the new dispatch region op's results @@ -241,7 +240,12 @@ def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ IsolatedFromAbove, AttrSizedOperandSegments, - SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp"> + SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; let description = [{ @@ -281,7 +285,10 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ let arguments = (ins Variadic:$workgroup_count, - Variadic:$operands + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -289,23 +296,27 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [ let regions = (region AnyRegion:$body); - // TODO(benvanik): use after https://bugs.llvm.org/show_bug.cgi?id=48478 - // let assemblyFormat = [{ - // `[` $workgroup_count `]` - // `(` $operands `)` `:` - // functional-type($operands, $results) - // attr-dict-with-keyword - // `=` - // custom(type_ref($operands), - // type_ref($results), - // $body) - // }]; + let assemblyFormat = [{ + `[` $workgroup_count `]` `` + `(` $operands `)` `:` + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) + attr-dict-with-keyword + `=` `\n` ` ` ` ` ` ` + custom(ref(type($operands)), + ref(type($results)), + $body) + }]; let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "ValueRange":$workgroupCount, - "TypeRange":$resultTypes, "ValueRange":$operands, + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -464,7 +475,7 @@ def FLOW_DispatchShapeOp : FLOW_PureOp<"dispatch.shape", [ }]; let arguments = (ins - FLOW_DispatchIO:$source + FLOW_DispatchTensor:$source ); let results = (outs Shape_RankedShape:$result @@ -484,11 +495,11 @@ def FLOW_DispatchTieShapeOp : FLOW_PureOp<"dispatch.tie_shape"> { }]; let arguments = (ins - FLOW_DispatchIO:$operand, + FLOW_DispatchTensor:$operand, Shape_RankedShape:$shape ); let results = (outs - FLOW_DispatchIO:$result + FLOW_DispatchTensor:$result ); // TODO(benvanik): figure out a way to make this look like shapex.tie_shape. @@ -500,7 +511,7 @@ def FLOW_DispatchTieShapeOp : FLOW_PureOp<"dispatch.tie_shape"> { let hasCanonicalizer = 1; } -def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ +def FLOW_DispatchTensorLoadOp : FLOW_PureOp<"dispatch.tensor.load", [ SameVariadicOperandSize, ]> { let summary = [{loads a tensor from a dispatch input placeholder}]; @@ -511,7 +522,7 @@ def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ }]; let arguments = (ins - FLOW_DispatchInput:$source, + FLOW_ReadableDispatchTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides @@ -531,7 +542,7 @@ def FLOW_DispatchInputLoadOp : FLOW_PureOp<"dispatch.input.load", [ let hasCanonicalizer = 1; } -def FLOW_DispatchOutputStoreOp : FLOW_Op<"dispatch.output.store", [ +def FLOW_DispatchTensorStoreOp : FLOW_Op<"dispatch.tensor.store", [ SameVariadicOperandSize, ]> { let summary = [{stores a tensor into a dispatch output placeholder}]; @@ -543,7 +554,7 @@ def FLOW_DispatchOutputStoreOp : FLOW_Op<"dispatch.output.store", [ let arguments = (ins AnyRankedTensor:$value, - FLOW_DispatchOutput:$target, + FLOW_WritableDispatchTensor:$target, Variadic:$offsets, Variadic:$sizes, Variadic:$strides @@ -652,6 +663,10 @@ def FLOW_DispatchEntryOp : FLOW_Op<"dispatch.entry", [ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ AttrSizedOperandSegments, FLOW_StreamableOp, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{a dispatch of workgroups across an n-dimension grid}]; let description = [{ @@ -663,7 +678,10 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ let arguments = (ins Variadic:$workgroup_count, SymbolRefAttr:$entry_point, - Variadic:$operands + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -673,7 +691,9 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ let builders = [ OpBuilder<(ins "DispatchEntryOp":$entryPoint, "ValueRange":$workgroupCount, - "TypeRange":$resultTypes, CArg<"ValueRange", "{}">:$operands, + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; @@ -683,14 +703,15 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ // StreamableOpInterface: bool isTransfer() { return false; } - bool isUsableInStream() { return true; } - bool isStreamOnly() { return true; } }]; let assemblyFormat = [{ - $entry_point `[` $workgroup_count `]` + $entry_point `[` $workgroup_count `]` `` `(` $operands `)` attr-dict `:` - functional-type($operands, $results) + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) }]; let verifier = [{ return verifyDispatchOp(*this); }]; @@ -703,6 +724,8 @@ def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [ def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [ FLOW_StreamableOp, AllElementTypesMatch<["source", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{reshapes a tensor}]; let description = [{ @@ -710,21 +733,36 @@ def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [ }]; let arguments = (ins - FLOW_Tensor:$source - // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes. + FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$source `:` type($source) `->` type($result) attr-dict"; + let assemblyFormat = [{ + $source `:` + type($source) (`{` $source_dims^ `}`)? `->` + type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins + "Type":$result_type, "Value":$source, "ValueRange":$target_dims), + [{ + build($_builder, $_state, + result_type, + source, + Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder), + target_dims); + }]>, + ]; let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a shape manipulation. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize away if resulting ops don't care. @@ -735,6 +773,8 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ TypesMatchWith<"value type matches element type of target operand", "source", "result", "$_self.cast().getElementType()">, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{loads a value from a tensor element}]; let description = [{ @@ -743,6 +783,7 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ let arguments = (ins FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, Variadic:$indices ); let results = (outs @@ -750,9 +791,23 @@ def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load", [ ); let assemblyFormat = [{ - $source (`[` $indices^ `]`)? `:` type($source) attr-dict-with-keyword + $source (`[` $indices^ `]`)? `:` + type($source) (`{` $source_dims^ `}`)? + attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Type":$result_type, "Value":$source, CArg<"ValueRange", "{}">:$indices), + [{ + build($_builder, $_state, + result_type, + source, + Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder), + indices); + }]>, + ]; + // TODO(benvanik): canonicalize to slice+load if dims are known. let hasFolder = 1; } @@ -762,6 +817,8 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ TypesMatchWith<"value type matches element type of target operand", "target", "value", "$_self.cast().getElementType()">, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{stores a value into a tensor element}]; let description = [{ @@ -771,6 +828,7 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ let arguments = (ins AnyTypeOf<[FLOW_PrimitiveType, AnyVector]>:$value, FLOW_Tensor:$target, + FLOW_ShapeDynamicDims:$target_dims, Variadic:$indices ); let results = (outs @@ -778,10 +836,24 @@ def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ ); let assemblyFormat = [{ - $value `,` $target (`[` $indices^ `]`)? `:` type($target) + $value `,` $target (`[` $indices^ `]`)? `:` + type($target) (`{` $target_dims^ `}`)? attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$value, "Value":$target, CArg<"ValueRange", "{}">:$indices), + [{ + build($_builder, $_state, + target.getType(), + value, + target, + Shape::buildOrFindDynamicDimsForValue($_state.location, target, $_builder), + indices); + }]>, + ]; + let hasFolder = 1; } @@ -790,6 +862,7 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ TypesMatchWith<"value type matches element type of result", "result", "value", "$_self.cast().getElementType()">, + DeclareOpInterfaceMethods, ]> { let summary = [{splats a value into a shaped tensor}]; let description = [{ @@ -797,21 +870,21 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ }]; let arguments = (ins - FLOW_PrimitiveType:$value - // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes. + FLOW_PrimitiveType:$value, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$value `:` type($result) attr-dict-with-keyword"; + let assemblyFormat = [{ + $value `:` type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword + }]; let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.fill. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize splat+slice to smaller splat. @@ -820,7 +893,8 @@ def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ FLOW_StreamableOp, - SameOperandsAndResultType, + AllTypesMatch<["operand", "result"]>, + DeclareOpInterfaceMethods, ]> { let summary = [{performs a full tensor clone operation}]; let description = [{ @@ -828,20 +902,31 @@ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ }]; let arguments = (ins - FLOW_Tensor:$operand + FLOW_Tensor:$operand, + FLOW_ShapeDynamicDims:$operand_dims ); let results = (outs FLOW_Tensor:$result ); - let assemblyFormat = "$operand `:` type($result) attr-dict"; + let assemblyFormat = [{ + $operand `:` type($result) (`{` $operand_dims^ `}`)? + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins "Value":$operand), + [{ + build($_builder, $_state, + operand.getType(), + operand, + Shape::buildOrFindDynamicDimsForValue($_state.location, operand, $_builder)); + }]>, + ]; let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.copy. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize away entirely in most cases. @@ -852,7 +937,8 @@ def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ FLOW_StreamableOp, AllRanksMatch<["source", "result"]>, AllElementTypesMatch<["source", "result"]>, - SameVariadicOperandSize, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, ]> { let summary = [{slices out a subregion of a tensor}]; let description = [{ @@ -861,25 +947,25 @@ def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ let arguments = (ins FLOW_Tensor:$source, + FLOW_ShapeDynamicDims:$source_dims, Variadic:$start_indices, - Variadic:$lengths - // TODO(benvanik): strides. + Variadic:$lengths, + FLOW_ShapeDynamicDims:$result_dims ); let results = (outs FLOW_Tensor:$result ); let assemblyFormat = [{ - $source `[` $start_indices `for` $lengths `]` `:` type($source) `->` - type($result) attr-dict + $source `[` $start_indices `for` $lengths `]` `:` + type($source) (`{` $source_dims^ `}`)? `->` + type($result) (`{` $result_dims^ `}`)? + attr-dict-with-keyword }]; let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.slice. - bool isStreamOnly() { return true; } }]; // TODO(benvanik): canonicalize multiple slices (traverse upward through ssa). @@ -891,6 +977,13 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ AllRanksMatch<["update", "target", "result"]>, AllTypesMatch<["target", "result"]>, AllElementTypesMatch<["update", "target", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{updates a tensor with the contents of another tensor}]; let description = [{ @@ -899,28 +992,38 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [ }]; let arguments = (ins - FLOW_Tensor:$update, FLOW_Tensor:$target, - Variadic:$start_indices + FLOW_ShapeDynamicDims:$target_dims, + Variadic:$start_indices, + FLOW_Tensor:$update, + FLOW_ShapeDynamicDims:$update_dims, + OptionalAttr:$tied_operands ); let results = (outs FLOW_Tensor:$result ); let assemblyFormat = [{ - $update `,` $target `[` $start_indices `]` `:` type($update) `->` - type($result) attr-dict + $update `,` $target `[` $start_indices `]` `:` + type($update) (`{` $update_dims^ `}`)? `->` + custom(type($result), $target_dims, $tied_operands) + attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$target, + "ValueRange":$start_indices, + "Value":$update)>, + ]; + let extraClassDeclaration = [{ // StreamableOpInterface: bool isTransfer() { return true; } - bool isUsableInStream() { return true; } - // TODO(benvanik): allow out of stream to act as a hal.buffer.copy. - bool isStreamOnly() { return true; } }]; - // TODO(benvanik): canonicalize contiguous updates/across slices. + let verifier = [{ return verifyTensorUpdateOp(*this); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -949,15 +1052,23 @@ def FLOW_TensorTraceOp : FLOW_Op<"tensor.trace", []> { // TODO(benvanik): replace with real segmented stream ops. def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ IsolatedFromAbove, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{experimental op for defining formed stream regions}]; let description = [{ Represents a region where all of the dispatches are meant to target the - same execution stream. This will be replaced with a segmented verison. + same execution stream. This will be replaced with a segmented version in the + future that stitches the stream segments together. }]; let arguments = (ins - Variadic:$args + Variadic:$operands, + FLOW_ShapeDynamicDims:$operand_dims, + FLOW_ShapeDynamicDims:$result_dims, + OptionalAttr:$tied_operands ); let results = (outs Variadic:$results @@ -965,12 +1076,31 @@ def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [ let regions = (region AnyRegion:$body); + let assemblyFormat = [{ + `(` $operands `)` `:` + custom(ref($operands), + type($operands), $operand_dims, + type($results), $result_dims, + $tied_operands) + attr-dict-with-keyword + `=` `\n` ` ` ` ` ` ` + custom(ref(type($operands)), + ref(type($results)), + ref($tied_operands), + $body) + }]; + let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "ArrayRef":$resultTypes, "ValueRange":$args, + OpBuilder<(ins + "TypeRange":$resultTypes, "ValueRange":$resultDims, + "ValueRange":$operands, "ValueRange":$operandDims, + "ArrayRef":$tiedOperands, CArg<"ArrayRef", "{}">:$attributes)>, ]; + let verifier = [{ return verifyExStreamFragmentOp(*this); }]; + let hasCanonicalizer = 1; } diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index b9639178c50e..c916a84b64d1 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -25,6 +25,25 @@ namespace Flow { // Object types //===----------------------------------------------------------------------===// +// static +DispatchTensorType DispatchTensorType::get(TensorAccess access, + ArrayRef shape, + Type elementType) { + return Base::get(elementType.getContext(), static_cast(access), + shape, elementType); +} + +// static +DispatchTensorType DispatchTensorType::get(TensorAccess access, + TensorType tensorType) { + return DispatchTensorType::get(access, tensorType.getShape(), + tensorType.getElementType()); +} + +TensorAccess DispatchTensorType::getAccess() const { + return static_cast(static_cast(impl)->access); +} + Type DispatchTensorType::getElementType() const { return static_cast(impl)->elementType; } @@ -79,8 +98,8 @@ bool DispatchTensorType::hasStaticShape(ArrayRef shape) const { } LogicalResult DispatchTensorType::verify( - function_ref emitError, ArrayRef shape, - Type elementType) { + function_ref emitError, uint32_t access, + ArrayRef shape, Type elementType) { if (!isValidElementType(elementType)) { return emitError() << "dispatch tensor elements must be int or float type"; } @@ -93,17 +112,38 @@ LogicalResult DispatchTensorType::verify( template static T parseShapedType(DialectAsmParser &parser) { + StringRef accessStr; SmallVector shape; Type elementType; - if (failed(parser.parseLess()) || + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&accessStr)) || + failed(parser.parseColon()) || failed(parser.parseDimensionList(shape, /*allowDynamic=*/true)) || failed(parser.parseType(elementType)) || failed(parser.parseGreater())) { return {}; } - return T::get(shape, elementType); + auto access = llvm::StringSwitch(accessStr) + .Case("readonly", TensorAccess::ReadOnly) + .Case("readwrite", TensorAccess::ReadWrite) + .Case("writeonly", TensorAccess::WriteOnly) + .Default(TensorAccess::ReadOnly); + return T::get(access, shape, elementType); } static void printShapedType(DispatchTensorType &type, DialectAsmPrinter &p) { + switch (type.getAccess()) { + case TensorAccess::ReadOnly: + p << "readonly"; + break; + case TensorAccess::ReadWrite: + p << "readwrite"; + break; + case TensorAccess::WriteOnly: + p << "writeonly"; + break; + default: + llvm_unreachable("unhandled access"); + } + p << ":"; for (int64_t dim : type.getShape()) { if (ShapedType::isDynamic(dim)) { p << '?'; @@ -116,61 +156,12 @@ static void printShapedType(DispatchTensorType &type, DialectAsmPrinter &p) { } // static -DispatchInputType DispatchInputType::get(ArrayRef shape, - Type elementType) { - return Base::get(elementType.getContext(), shape, elementType); -} - -// static -DispatchInputType DispatchInputType::getChecked(ArrayRef shape, - Type elementType, - Location location) { - return Base::getChecked(location, shape, elementType); -} - -// static -DispatchInputType DispatchInputType::get(TensorType tensorType) { - return DispatchInputType::get(tensorType.getShape(), - tensorType.getElementType()); -} - -// static -DispatchInputType DispatchInputType::parse(DialectAsmParser &parser) { - return parseShapedType(parser); -} - -void printType(DispatchInputType &type, DialectAsmPrinter &p) { - p << "dispatch.input<"; - printShapedType(type, p); - p << '>'; -} - -// static -DispatchOutputType DispatchOutputType::get(ArrayRef shape, - Type elementType) { - return Base::get(elementType.getContext(), shape, elementType); -} - -// static -DispatchOutputType DispatchOutputType::getChecked(ArrayRef shape, - Type elementType, - Location location) { - return Base::getChecked(location, shape, elementType); -} - -// static -DispatchOutputType DispatchOutputType::get(TensorType tensorType) { - return DispatchOutputType::get(tensorType.getShape(), - tensorType.getElementType()); -} - -// static -DispatchOutputType DispatchOutputType::parse(DialectAsmParser &parser) { - return parseShapedType(parser); +DispatchTensorType DispatchTensorType::parse(DialectAsmParser &parser) { + return parseShapedType(parser); } -void printType(DispatchOutputType &type, DialectAsmPrinter &p) { - p << "dispatch.output<"; +void printType(DispatchTensorType &type, DialectAsmPrinter &p) { + p << "dispatch.tensor<"; printShapedType(type, p); p << '>'; } diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/iree/compiler/Dialect/Flow/IR/FlowTypes.h index 59abc6f231b3..8a2fdf149715 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowTypes.h +++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.h @@ -43,15 +43,36 @@ namespace detail { struct DispatchTensorTypeStorage; } // namespace detail +enum class TensorAccess : uint32_t { + ReadOnly, + ReadWrite, + WriteOnly, +}; + // Blatantly ripped from ShapedType, because the closed type system means that // we can't extend it and reuse all of this. -class DispatchTensorType : public Type { +class DispatchTensorType + : public Type::TypeBase { public: using ImplType = detail::DispatchTensorTypeStorage; static constexpr int64_t kDynamicSize = -1; - using Type::Type; + using Base::Base; + + /// Get or create a new DispatchTensorType of the provided shape and + /// element type. Assumes the arguments define a well-formed + /// DispatchTensorType. + static DispatchTensorType get(TensorAccess access, ArrayRef shape, + Type elementType); + + static DispatchTensorType get(TensorAccess access, TensorType tensorType); + + static DispatchTensorType parse(DialectAsmParser &parser); + + /// Returns the allowed operations the tensor. + TensorAccess getAccess() const; /// Return the element type. Type getElementType() const; @@ -97,17 +118,15 @@ class DispatchTensorType : public Type { /// dimensions, given its `index` within the shape. unsigned getDynamicDimIndex(unsigned index) const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); - /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dSize) { return dSize == kDynamicSize; } - /// Verify the construction of a vector type. + /// Verify the construction of a tensor type. static LogicalResult verify(function_ref emitError, - ArrayRef shape, Type elementType); + uint32_t access, ArrayRef shape, + Type elementType); /// Returns true of the given type can be used as an element of a vector type. /// In particular, vectors can consist of integer or float primitives. @@ -124,100 +143,41 @@ class DispatchTensorType : public Type { } }; -class DispatchInputType - : public Type::TypeBase { - public: - using Base::Base; - - /// Get or create a new DispatchInputType of the provided shape and element - /// type. Assumes the arguments define a well-formed DispatchInputType. - static DispatchInputType get(ArrayRef shape, Type elementType); - - /// Get or create a new DispatchInputType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// DispatchInputType defined by the arguments would be ill-formed, emit - /// errors and return nullptr-wrapping type. - static DispatchInputType getChecked(ArrayRef shape, Type elementType, - Location location); - static DispatchInputType getChecked( - function_ref emitError, ArrayRef shape, - Type elementType) { - return Base::getChecked(emitError, elementType.getContext(), shape, - elementType); - } - - static DispatchInputType get(TensorType tensorType); - - static DispatchInputType parse(DialectAsmParser &parser); -}; - -void printType(DispatchInputType &type, DialectAsmPrinter &p); - -class DispatchOutputType - : public Type::TypeBase { - public: - using Base::Base; - - /// Get or create a new DispatchOutputType of the provided shape and element - /// type. Assumes the arguments define a well-formed DispatchOutputType. - static DispatchOutputType get(ArrayRef shape, Type elementType); - - /// Get or create a new DispatchOutputType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// DispatchOutputType defined by the arguments would be ill-formed, emit - /// errors and return nullptr-wrapping type. - static DispatchOutputType getChecked(ArrayRef shape, - Type elementType, Location location); - static DispatchOutputType getChecked( - function_ref emitError, ArrayRef shape, - Type elementType) { - return Base::getChecked(emitError, elementType.getContext(), shape, - elementType); - } - - static DispatchOutputType get(TensorType tensorType); - - static DispatchOutputType parse(DialectAsmParser &parser); -}; - -void printType(DispatchOutputType &type, DialectAsmPrinter &p); - -inline bool DispatchTensorType::classof(Type type) { - return type.isa(); -} +void printType(DispatchTensorType &type, DialectAsmPrinter &p); namespace detail { struct DispatchTensorTypeStorage : public TypeStorage { - DispatchTensorTypeStorage(unsigned shapeSize, Type elementTy, + DispatchTensorTypeStorage(uint32_t access, unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : shapeElements(shapeElements), + : access(access), + shapeElements(shapeElements), shapeSize(shapeSize), elementType(elementTy) {} /// The hash key used for uniquing. - using KeyTy = std::pair, Type>; + using KeyTy = std::tuple, Type>; bool operator==(const KeyTy &key) const { - return key == KeyTy(getShape(), elementType); + return key == KeyTy(access, getShape(), elementType); } /// Construction. static DispatchTensorTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { // Copy the shape into the bump pointer. - ArrayRef shape = allocator.copyInto(key.first); + ArrayRef shape = allocator.copyInto(std::get<1>(key)); // Initialize the memory using placement new. return new (allocator.allocate()) - DispatchTensorTypeStorage(shape.size(), key.second, shape.data()); + DispatchTensorTypeStorage(std::get<0>(key), shape.size(), + std::get<2>(key), shape.data()); } ArrayRef getShape() const { return ArrayRef(shapeElements, shapeSize); } + uint32_t access; const int64_t *shapeElements; unsigned shapeSize; Type elementType; diff --git a/iree/compiler/Dialect/Flow/IR/test/BUILD b/iree/compiler/Dialect/Flow/IR/test/BUILD index b780b112a316..b2f3b8d3239a 100644 --- a/iree/compiler/Dialect/Flow/IR/test/BUILD +++ b/iree/compiler/Dialect/Flow/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,24 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "dispatch_ops.mlir", + "dispatch_region_folding.mlir", + "dispatch_regions.mlir", + "dispatch_workgroups.mlir", + "dispatch_workgroups_folding.mlir", + "executable_ops.mlir", + "stream_folding.mlir", + "stream_ops.mlir", + "tensor_folding.mlir", + "tensor_ops.mlir", + "types.mlir", + "variable_folding.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt index 28f5faeabaa1..a1b98324a6bc 100644 --- a/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/IR/test/CMakeLists.txt @@ -10,12 +10,23 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dispatch_ops.mlir" + "dispatch_region_folding.mlir" + "dispatch_regions.mlir" + "dispatch_workgroups.mlir" + "dispatch_workgroups_folding.mlir" + "executable_ops.mlir" + "stream_folding.mlir" + "stream_ops.mlir" + "tensor_folding.mlir" + "tensor_ops.mlir" + "types.mlir" + "variable_folding.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir index 2325c63bfac5..7cb22d0673c3 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir @@ -1,10 +1,8 @@ -// Tests printing and parsing of dispatch ops. - // RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s flow.executable @ex0 { module { - func @dispatch_fn(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + func @dispatch_fn(%cst : index, %arg0 : tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } } @@ -15,7 +13,33 @@ flow.executable @ex0 { func @dispatch(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[CST:.+]] = constant %cst = constant 4 : index - // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex0::@dispatch_fn[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0) : (index, tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0) : (index, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } + +// ----- + +// CHECK-LABEL: @inplaceDispatch +func @inplaceDispatch(%arg0 : tensor<4xf32>, %arg1 : tensor<8xf32>) -> (tensor<4xf32>, tensor<8xf32>) { + // CHECK: %[[CST:.+]] = constant + %cst = constant 4 : index + // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4xf32>, tensor<8xf32>) -> (%arg0, %arg1) + %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4xf32>, tensor<8xf32>) -> (%arg0, %arg1) + return %0, %1 : tensor<4xf32>, tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: @inplaceDynamicDispatch +func @inplaceDynamicDispatch(%arg0 : tensor<4x?xf32>, %arg1 : tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) { + // CHECK-DAG: %[[CST:.+]] = constant 4 + %cst = constant 4 : index + // CHECK-DAG: %[[DIM0:.+]] = constant 100 + %dim0 = constant 100 : index + // CHECK-DAG: %[[DIM1:.+]] = constant 200 + %dim1 = constant 200 : index + // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0, %arg1) + %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) + return %0, %1 : tensor<4x?xf32>, tensor<8x?xf32> +} diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir index 8ab0c0f869b5..fb678b8894ed 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_region_folding.mlir @@ -5,7 +5,7 @@ func @dceOperandsAndResults(%arg0 : tensor) -> (tensor) { // CHECK: %[[WORKLOAD:.+]] = constant 5 %workload = constant 5 : index // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD]] : index] - // CHECK-SAME: (%[[CA1:.+]] = %arg0 : tensor) -> tensor + // CHECK-SAME: (%[[CA1:.+]] = %arg0 : tensor) -> (tensor) // CHECK: %[[DR0:.+]] = addf %[[CA1]], %[[CA1]] // CHECK: flow.return %[[DR0]] : tensor %ret0, %ret1 = flow.dispatch.region[%workload : index]( diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir index 8e7171ebf9ca..d06d20c4ddfb 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_regions.mlir @@ -37,7 +37,7 @@ func @multipleArgs(%arg0 : tensor, %arg1 : tensor) { // CHECK-LABEL: @singleResult func @singleResult(%arg0 : tensor) -> tensor { // CHECK-NEXT: %[[WORKLOAD:.+]] = "some.shape" - // CHECK-NEXT: %1 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor) -> tensor { + // CHECK-NEXT: %1 = flow.dispatch.region[%[[WORKLOAD]] : index](%arg1 = %arg0 : tensor) -> (tensor) { // CHECK-NEXT: flow.return %arg1 : tensor // CHECK-NEXT: } %workload = "some.shape"(%arg0) : (tensor) -> index diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir index a0ab78448d05..054292c0c503 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir @@ -6,19 +6,20 @@ func @complexWorkgroupsUsage( %arg0 : tensor, // CHECK-SAME: %[[ARG1:.+]]: index %arg1 : index) -> tensor<4x?xf32> { + %c128 = constant 128 : index // CHECK-DAG: %[[WORKGROUP_COUNT_X:.+]] = constant 100 %x = constant 100 : index // CHECK-DAG: %[[WORKGROUP_COUNT_Y:.+]] = constant 50 %y = constant 50 : index // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[ // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]] - // CHECK-SAME: ] (%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: : (tensor, index) -> tensor<4x?xf32> = ( - %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor, index) -> (tensor<4x?xf32>) = - // CHECK-SAME: %[[INNER_ARG0:.+]] : !flow.dispatch.input - // CHECK-SAME: %[[INNER_ARG1:.+]] : index - // CHECK-SAME: %[[INNER_RET0:.+]] : !flow.dispatch.output<4x?xf32> - (%arg0_capture : !flow.dispatch.input, %arg1_capture : index, %ret0 : !flow.dispatch.output<4x?xf32>) { + // CHECK-SAME: ](%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: : (tensor{%c128}, index) -> tensor<4x?xf32>{%c128} = + %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor{%c128}, index) -> tensor<4x?xf32>{%c128} = + // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor + // CHECK-SAME: %[[INNER_ARG1:.+]]: index + // CHECK-SAME: %[[INNER_RET0:.+]]: !flow.dispatch.tensor) { + (%arg0_capture: !flow.dispatch.tensor, %arg1_capture: index, %ret0: !flow.dispatch.tensor) { // Query symbolic workgroup info: @@ -40,8 +41,8 @@ func @complexWorkgroupsUsage( // Query shapes directly from IO (static dims will fold): - // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_ARG0]] : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> - %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> + // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_ARG0]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> + %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> // CHECK-DAG: %[[ARG0_DIM0:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][0] : !shapex.ranked_shape<[?,4]> -> index %arg0_dim0 = shapex.ranked_dim %arg0_shape[0] : !shapex.ranked_shape<[?,4]> -> index // CHECK-DAG: %[[ARG0_DIM1:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][1] : !shapex.ranked_shape<[?,4]> -> index @@ -49,8 +50,8 @@ func @complexWorkgroupsUsage( // CHECK-NEXT: "test.sink"(%[[ARG0_DIM0]], %[[ARG0_DIM1]]) "test.sink"(%arg0_dim0, %arg0_dim1) : (index, index) -> () - // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_RET0]] : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> - %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> + // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[INNER_RET0]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> + %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> // CHECK-DAG: %[[RET0_DIM0:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][0] : !shapex.ranked_shape<[4,?]> -> index %ret0_dim0 = shapex.ranked_dim %ret0_shape[0] : !shapex.ranked_shape<[4,?]> -> index // CHECK-DAG: %[[RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] : !shapex.ranked_shape<[4,?]> -> index @@ -60,8 +61,8 @@ func @complexWorkgroupsUsage( // Load tensors (optional offsets/sizes/strides): - // CHECK: %[[ARG0_VALUE:.+]] = flow.dispatch.input.load %[[INNER_ARG0]] : !flow.dispatch.input -> tensor - %arg0_value = flow.dispatch.input.load %arg0_capture : !flow.dispatch.input -> tensor + // CHECK: %[[ARG0_VALUE:.+]] = flow.dispatch.tensor.load %[[INNER_ARG0]] : !flow.dispatch.tensor -> tensor + %arg0_value = flow.dispatch.tensor.load %arg0_capture : !flow.dispatch.tensor -> tensor // CHECK-NEXT: %[[ARG0_SHAPE_INDIRECT:.+]] = shapex.get_ranked_shape %[[ARG0_VALUE]] : tensor -> !shapex.ranked_shape<[?,4]> %arg0_shape_indirect = shapex.get_ranked_shape %arg0_value : tensor -> !shapex.ranked_shape<[?,4]> @@ -72,8 +73,8 @@ func @complexWorkgroupsUsage( // Store tensors (optional offsets/sizes/strides): - // CHECK: flow.dispatch.output.store %[[RET0_VALUE]], %[[INNER_RET0]] : tensor<4x?xf32> -> !flow.dispatch.output<4x?xf32> - flow.dispatch.output.store %ret0_value, %ret0 : tensor<4x?xf32> -> !flow.dispatch.output<4x?xf32> + // CHECK: flow.dispatch.tensor.store %[[RET0_VALUE]], %[[INNER_RET0]] : tensor<4x?xf32> -> !flow.dispatch.tensor + flow.dispatch.tensor.store %ret0_value, %ret0 : tensor<4x?xf32> -> !flow.dispatch.tensor // CHECK-NEXT: flow.return flow.return @@ -81,3 +82,35 @@ func @complexWorkgroupsUsage( // CHECK: return %[[OUTER_RET0]] : tensor<4x?xf32> return %0 : tensor<4x?xf32> } + +// ----- + +// CHECK-LABEL: @inplaceDispatch +func @inplaceDispatch( + // CHECK-SAME: %[[ARG0:.+]]: tensor + %arg0: tensor, + // CHECK-SAME: %[[ARG1:.+]]: index + %arg1: index) -> tensor { + %c128 = constant 128 : index + // CHECK-DAG: %[[WORKGROUP_COUNT_X:.+]] = constant 100 + %x = constant 100 : index + // CHECK-DAG: %[[WORKGROUP_COUNT_Y:.+]] = constant 50 + %y = constant 50 : index + // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[ + // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]] + // CHECK-SAME: ](%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: : (tensor{%c128}, index) -> %arg0 = + %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor{%c128}, index) -> %arg0 = + // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor + // CHECK-SAME: %[[INNER_ARG1:.+]]: index) { + (%arg0_capture: !flow.dispatch.tensor, %arg1_capture: index) { + // CHECK: %[[VALUE:.+]] = flow.dispatch.tensor.load %[[INNER_ARG0]] : !flow.dispatch.tensor -> tensor + %t = flow.dispatch.tensor.load %arg0_capture : !flow.dispatch.tensor -> tensor + // CHECK: flow.dispatch.tensor.store %[[VALUE]], %[[INNER_ARG0]] : tensor -> !flow.dispatch.tensor + flow.dispatch.tensor.store %t, %arg0_capture : tensor -> !flow.dispatch.tensor + // CHECK-NEXT: flow.return + flow.return + } + // CHECK: return %[[OUTER_RET0]] : tensor + return %0 : tensor +} diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir index 697ea9eb64b2..7cedb3cee5de 100644 --- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir @@ -2,19 +2,20 @@ // CHECK-LABEL: @workgroupStaticShapeDims func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { + %c128 = constant 128 : index %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor) -> (tensor<4x?xf32>) = ( - // CHECK-SAME: = (%[[ARG0:.+]] : !flow.dispatch.input - %arg0_capture : !flow.dispatch.input, - // CHECK-SAME: %[[RET0:.+]] : !flow.dispatch.output<4x?xf32>) - %ret0 : !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> (tensor<4x?xf32>{%c128}) = ( + // CHECK-NEXT: (%[[ARG0:.+]]: !flow.dispatch.tensor, + %arg0_capture: !flow.dispatch.tensor, + // CHECK-SAME: %[[RET0:.+]]: !flow.dispatch.tensor) + %ret0: !flow.dispatch.tensor ) { // CHECK: %[[DIM_4:.+]] = constant 4 : index // CHECK: %[[ARG0_SHAPE:.+]] = flow.dispatch.shape %[[ARG0]] - %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.input -> !shapex.ranked_shape<[?,4]> + %arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,4]> // CHECK: %[[ARG0_DIM0:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][0] %arg0_dim0 = shapex.ranked_dim %arg0_shape[0] : !shapex.ranked_shape<[?,4]> -> index %arg0_dim1 = shapex.ranked_dim %arg0_shape[1] : !shapex.ranked_shape<[?,4]> -> index @@ -22,7 +23,7 @@ func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { "test.sink"(%arg0_dim0, %arg0_dim1) : (index, index) -> () // CHECK: %[[RET0_SHAPE:.+]] = flow.dispatch.shape %[[RET0]] - %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]> + %ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,?]> %ret0_dim0 = shapex.ranked_dim %ret0_shape[0] : !shapex.ranked_shape<[4,?]> -> index // CHECK: %[[RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] %ret0_dim1 = shapex.ranked_dim %ret0_shape[1] : !shapex.ranked_shape<[4,?]> -> index @@ -38,12 +39,13 @@ func @workgroupStaticShapeDims(%arg0 : tensor) -> tensor<4x?xf32> { // CHECK-LABEL: @workgroupRankFolding func @workgroupRankFolding(%arg0 : tensor) -> tensor<4x?xf32> { + %c128 = constant 128 : index %x = constant 100 : index %y = constant 50 : index // CHECK: flow.dispatch.workgroups - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor) -> (tensor<4x?xf32>) = ( - %arg0_capture : !flow.dispatch.input, - %ret0 : !flow.dispatch.output<4x?xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor{%c128}) -> (tensor<4x?xf32>{%c128}) = ( + %arg0_capture: !flow.dispatch.tensor, + %ret0: !flow.dispatch.tensor ) { // CHECK: %[[RANK:.+]] = constant 2 : index %workgroup_rank = flow.dispatch.workgroup.rank : index @@ -57,12 +59,12 @@ func @workgroupRankFolding(%arg0 : tensor) -> tensor<4x?xf32> { // ----- // CHECK-LABEL: @convertDimOfDispatchInputLoadToDispatchShape -// CHECK-SAME: %[[ARG:.*]]: !flow.dispatch.input) { -func @convertDimOfDispatchInputLoadToDispatchShape(%arg0: !flow.dispatch.input) { +// CHECK-SAME: %[[ARG:.*]]: !flow.dispatch.tensor) { +func @convertDimOfDispatchInputLoadToDispatchShape(%arg0: !flow.dispatch.tensor) { // CHECK-NEXT: %[[RANKED_SHAPE:.*]] = flow.dispatch.shape %[[ARG]] // CHECK-NEXT: %[[DIM:.*]] = shapex.ranked_dim %[[RANKED_SHAPE]][0] // CHECK-NEXT: "test.sink"(%[[DIM]]) : (index) -> () - %tensor = flow.dispatch.input.load %arg0 : !flow.dispatch.input -> tensor + %tensor = flow.dispatch.tensor.load %arg0 : !flow.dispatch.tensor -> tensor %c0 = constant 0 : index %dim = dim %tensor, %c0 : tensor "test.sink"(%dim) : (index) -> () diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir index 028a583ab78b..81ca6be3ebca 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir @@ -1,71 +1,107 @@ -// Tests folding and canonicalization of stream ops. - // RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s -flow.executable @dispatch_0 { - flow.dispatch.entry @rgn_dispatch_0 - module { - func @rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } -} - -// CHECK-LABEL: func @fragmentDedupCaptures -// CHECK-SAME: %[[A0:.+]]: tensor<4xf32> -func @fragmentDedupCaptures(%arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK: %[[CST:.+]] = constant 4 +// CHECK-LABEL: func @inlineConstant +func @inlineConstant() -> index { %cst = constant 4 : index - // Should dedup %cst in arg list. - // CHECK: flow.ex.stream.fragment(%arg1 = %[[CST]] : index, %arg2 = %[[A0]] : tensor<4xf32>) - %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %cst : index, %arg3 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // Both referreants of the constant should use the deduped arg. - // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] - // CHECK: flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] - %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg3) : (tensor<4xf32>) -> tensor<4xf32> - %2 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg2] (%1) : (tensor<4xf32>) -> tensor<4xf32> - flow.return %2, %2 : tensor<4xf32>, tensor<4xf32> + // CHECK: flow.ex.stream.fragment() + %0 = flow.ex.stream.fragment(%cst) : (index) -> index = + (%arg0: index) -> index { + // CHECK: %[[C:.+]] = constant 4 : index + // CHECK-NEXT: return %[[C]] + flow.return %arg0 : index } - return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> + return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedCapture -func @removeUnusedCapture() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index +// CHECK-SAME: (%[[ARG:.+]]: index) +func @removeUnusedCapture(%arg: index) -> index { %unused = constant 5 : index - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - %0 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %unused : index) -> (index) { - flow.return %arg0 : index + // CHECK: flow.ex.stream.fragment(%[[ARG]]) + %0 = flow.ex.stream.fragment(%arg, %unused) : (index, index) -> index = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: index) -> index { + (%arg0: index, %arg1: index) -> index { + // CHECK-NEXT: %[[T:.+]] = addi %[[INNER_ARG]], %[[INNER_ARG]] + %t = addi %arg0, %arg0 : index + // CHECK-NEXT: flow.return %[[T]] + flow.return %t : index } return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedDupCapture -func @removeUnusedDupCapture() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - %0 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %cst : index) -> (index) { - flow.return %arg1 : index +// CHECK-SAME: (%[[ARG:.+]]: index) +func @removeUnusedDupCapture(%arg: index) -> index { + // CHECK: flow.ex.stream.fragment(%[[ARG]]) + %0 = flow.ex.stream.fragment(%arg, %arg) : (index, index) -> index = + (%arg0: index, %arg1: index) -> index { + %t = addi %arg0, %arg0 : index + flow.return %t : index } return %0 : index } // ----- + // CHECK-LABEL: func @removeUnusedResult -func @removeUnusedResult() -> (index) { - // CHECK: %[[CST:.+]] = constant 4 - %cst = constant 4 : index - // Note that the unused result should also cascade to elide the newly - // unused operand. - // CHECK: flow.ex.stream.fragment(%arg0 = %[[CST]] : index) - // CHECK-SAME: -> index - %0:2 = flow.ex.stream.fragment(%arg0 = %cst : index, %arg1 = %cst : index) -> (index, index) { - flow.return %arg1, %arg0 : index, index +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) +func @removeUnusedResult(%arg0: index, %arg1: index) -> index { + // CHECK: flow.ex.stream.fragment(%[[ARG1]]) + %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : (index, index) -> (index, index) = + (%unused: index, %arg1: index) -> (index, index) { + %t = addi %arg1, %arg1 : index + flow.return %t, %unused : index, index } - return %0 : index + return %0#0 : index +} + +// ----- + +// CHECK-LABEL: func @removeUnusedDynamicResult +// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x?xf32>, %[[DIM0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: tensor<8x?xf32>, %[[DIM1:.+]]: index) +func @removeUnusedDynamicResult(%arg0: tensor<4x?xf32>, %dim0: index, + %arg1: tensor<8x?xf32>, %dim1: index) -> tensor<8x?xf32> { + // CHECK: flow.ex.stream.fragment(%[[ARG1]]) : + %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : + // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]] = + (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32> + (%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) { + // CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32> + flow.return %unused, %arg1 : tensor<4x?xf32>, tensor<8x?xf32> + } + return %0#1 : tensor<8x?xf32> +} + +// ----- + +// Testing inserted clones: a clone here is required as %stream_target is used +// after it is updated. + +// CHECK-LABEL: @dynamicUpdateSliceImmutability +func @dynamicUpdateSliceImmutability( + %target: tensor<2x4xi32>, %update: tensor<1x1xi32>) -> tensor<2x4xi32> { + // CHECK: %[[RET:.+]] = flow.ex.stream.fragment + %ret = flow.ex.stream.fragment(%target, %update) : + (tensor<2x4xi32>, tensor<1x1xi32>) -> tensor<2x4xi32> = + // CHECK-NEXT: (%[[TARGET:.+]]: tensor<2x4xi32>, %[[UPDATE:.+]]: tensor<1x1xi32>) + (%stream_target: tensor<2x4xi32>, %stream_update: tensor<1x1xi32>) -> tensor<2x4xi32> { + %start0 = constant 0 : index + %start1 = constant 1 : index + %workload = constant 8 : index + // CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32> + // CHECK-NEXT: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] + %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32> + // CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]]) + %t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + // CHECK-NEXT: flow.return %[[RETURN]] + flow.return %t1 : tensor<2x4xi32> + } + // CHECK: return %[[RET]] + return %ret : tensor<2x4xi32> } diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir index 457ea939c208..1dab1a9f4331 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir @@ -16,8 +16,10 @@ flow.executable @dispatch_0 { func @fragment(%arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK: %[[WORKLOAD:.+]] = constant %cst = constant 4 : index - // CHECK: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0:2 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK: %0:2 = flow.ex.stream.fragment(%[[WORKLOAD]], %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + // CHECK-NEXT: (%arg1: index, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = flow.ex.stream.fragment(%cst, %arg0) : (index, tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + (%arg1 : index, %arg2 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK-NEXT: flow.dispatch %1 = flow.dispatch @dispatch_0::@rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 502c4184d4aa..cd9cbe9518dd 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -9,6 +9,8 @@ func @reshapeNoOp(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @reshapeNoOpScalar func @reshapeNoOpScalar(%arg0 : tensor) -> tensor { // CHECK-NEXT: return %arg0 : tensor @@ -16,15 +18,6 @@ func @reshapeNoOpScalar(%arg0 : tensor) -> tensor { return %0 : tensor } -// CHECK-LABEL: @reshapeTransitive -func @reshapeTransitive(%arg0 : tensor<4x4xf32>) -> tensor<8x2xf32> { - %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<2x8xf32> - // CHECK-NEXT: %[[T:.+]] = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<8x2xf32> - %1 = flow.tensor.reshape %0 : tensor<2x8xf32> -> tensor<8x2xf32> - // CHECK-NEXT: return %[[T]] : tensor<8x2xf32> - return %1 : tensor<8x2xf32> -} - // ----- // CHECK-LABEL: @loadConst @@ -38,6 +31,8 @@ func @loadConst() -> i32 { return %2 : i32 } +// ----- + // CHECK-LABEL: @loadConstScalar func @loadConstScalar() -> i32 { %0 = constant dense<4> : tensor @@ -63,6 +58,8 @@ func @storeConst() -> tensor<2x2xi32> { return %1 : tensor<2x2xi32> } +// ----- + // CHECK-LABEL: @storeConstScalar func @storeConstScalar() -> tensor { %0 = constant dense<0> : tensor @@ -84,6 +81,8 @@ func @splatConst() -> tensor<4xi32> { return %1 : tensor<4xi32> } +// ----- + // CHECK-LABEL: @splatConstScalar func @splatConstScalar() -> tensor { %0 = constant 4 : i32 @@ -104,13 +103,6 @@ func @cloneConst() -> tensor<4xi32> { return %1 : tensor<4xi32> } -// CHECK-LABEL: @cloneDynamic -func @cloneDynamic(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = flow.tensor.clone %arg0 : tensor<4xi32> - // CHECK-NEXT: return %arg0 - return %0 : tensor<4xi32> -} - // ----- // CHECK-LABEL: @sliceConst0D @@ -122,6 +114,8 @@ func @sliceConst0D() -> tensor { return %1 : tensor } +// ----- + // CHECK-LABEL: @sliceConst1D func @sliceConst1D() -> tensor<1xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -133,6 +127,8 @@ func @sliceConst1D() -> tensor<1xi32> { return %1 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @sliceConst1DZeroLength func @sliceConst1DZeroLength() -> tensor<0xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -143,6 +139,8 @@ func @sliceConst1DZeroLength() -> tensor<0xi32> { return %1 : tensor<0xi32> } +// ----- + // CHECK-LABEL: @sliceConst2D func @sliceConst2D() -> tensor<1x2xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -157,6 +155,8 @@ func @sliceConst2D() -> tensor<1x2xi32> { return %1 : tensor<1x2xi32> } +// ----- + // CHECK-LABEL: @sliceConst2DZeroLength1 func @sliceConst2DZeroLength1() -> tensor<1x0xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -168,6 +168,8 @@ func @sliceConst2DZeroLength1() -> tensor<1x0xi32> { return %1 : tensor<1x0xi32> } +// ----- + // CHECK-LABEL: @sliceConst2DZeroLength01 func @sliceConst2DZeroLength01() -> tensor<0x0xi32> { %0 = constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32> @@ -178,6 +180,8 @@ func @sliceConst2DZeroLength01() -> tensor<0x0xi32> { return %1 : tensor<0x0xi32> } +// ----- + // CHECK-LABEL: @sliceConst3D func @sliceConst3D() -> tensor<1x2x3xi32> { %0 = constant dense<[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[9, 10, 11], [12, 13, 14], [15, 16, 17]]]> : tensor<2x3x3xi32> @@ -205,6 +209,8 @@ func @updateConst0D() -> tensor { return %2 : tensor } +// ----- + // CHECK-LABEL: @updateConst1D func @updateConst1D() -> tensor<1xi32> { %0 = constant dense<0> : tensor<1xi32> @@ -216,6 +222,8 @@ func @updateConst1D() -> tensor<1xi32> { return %2 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @updateConst1DUpdateZeroSize func @updateConst1DUpdateZeroSize() -> tensor<1xi32> { %0 = constant dense<> : tensor<0xi32> @@ -227,6 +235,8 @@ func @updateConst1DUpdateZeroSize() -> tensor<1xi32> { return %2 : tensor<1xi32> } +// ----- + // CHECK-LABEL: @updateConst2DUpdate1x1 func @updateConst2DUpdate1x1() -> tensor<3x4xi32> { %0 = constant dense<[[12]]> : tensor<1x1xi32> @@ -240,6 +250,8 @@ func @updateConst2DUpdate1x1() -> tensor<3x4xi32> { return %2 : tensor<3x4xi32> } +// ----- + // CHECK-LABEL: @updateConst2DUpdate2x2 func @updateConst2DUpdate2x2() -> tensor<3x4xi32> { %0 = constant dense<[[12, 13], [14, 15]]> : tensor<2x2xi32> @@ -253,6 +265,8 @@ func @updateConst2DUpdate2x2() -> tensor<3x4xi32> { return %2 : tensor<3x4xi32> } +// ----- + // CHECK-LABEL: @updateConst3DUpdate1x2x3 func @updateConst3DUpdate1x2x3() -> tensor<2x3x3xi32> { %0 = constant dense<[[[18, 19, 20], [21, 22, 23]]]> : tensor<1x2x3xi32> @@ -268,6 +282,8 @@ func @updateConst3DUpdate1x2x3() -> tensor<2x3x3xi32> { return %2 : tensor<2x3x3xi32> } +// ----- + // CHECK-LABEL: @updateConst3DUpdate2x3x2 func @updateConst3DUpdate2x3x2() -> tensor<2x3x3xi32> { %0 = constant dense<[[[18, 19], [20, 21], [22, 23]], [[24, 25], [26, 27], [28, 29]]]> : tensor<2x3x2xi32> @@ -283,6 +299,8 @@ func @updateConst3DUpdate2x3x2() -> tensor<2x3x3xi32> { return %2 : tensor<2x3x3xi32> } +// ----- + // CHECK-LABEL: @updateReplace func @updateReplace(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi32> { %c0 = constant 0 : index @@ -291,6 +309,8 @@ func @updateReplace(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi3 return %0 : tensor<4xi32> } +// ----- + // CHECK-LABEL: @propogateStaticShapeOfTarget func @propogateStaticShapeOfTarget(%arg0 : tensor, %arg1 : f32) -> tensor { %c21 = constant 21 : index @@ -305,11 +325,13 @@ func @propogateStaticShapeOfTarget(%arg0 : tensor, %arg1 : f32) -> tens } : tensor // CHECK: %[[UPDATED:.+]] = flow.tensor.update %{{.+}}, %[[TARGET]] // CHECK: %[[RESULT:.+]] = tensor.cast %[[UPDATED]] : tensor<21x42xf32> to tensor - %1 = flow.tensor.update %arg0, %0[%c2, %c4] : tensor -> tensor + %1 = flow.tensor.update %arg0, %0[%c2, %c4] : tensor{%c21, %c42} -> tensor{%c21, %c42} // CHECK: return %[[RESULT]] return %1 : tensor } +// ----- + // CHECK-LABEL: @propogateStaticShapeOfUpdate func @propogateStaticShapeOfUpdate(%arg0 : tensor, %arg1 : f32) -> tensor { %c21 = constant 21 : index @@ -323,7 +345,7 @@ func @propogateStaticShapeOfUpdate(%arg0 : tensor, %arg1 : f32) -> tens tensor.yield %arg1 : f32 } : tensor // CHECK: %[[RESULT:.+]] = flow.tensor.update %[[UPDATE]] - %1 = flow.tensor.update %0, %arg0[%c2, %c4] : tensor -> tensor + %1 = flow.tensor.update %0, %arg0[%c2, %c4] : tensor{%c21, %c42} -> tensor{%c21, %c42} // CHECK: return %[[RESULT]] return %1 : tensor } diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir index 1e45299ff38c..84fef0802390 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir @@ -16,6 +16,15 @@ func @tensorReshapeScalar(%arg0 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorReshapeDynamic +func @tensorReshapeDynamic(%arg0 : tensor) -> tensor { + %c4 = constant 4 : index + %c8 = constant 8 : index + // CHECK: %0 = flow.tensor.reshape %arg0 : tensor{%c4} -> tensor{%c8} + %0 = flow.tensor.reshape %arg0 : tensor{%c4} -> tensor{%c8} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorLoad @@ -32,6 +41,14 @@ func @tensorLoadScalar(%arg0 : tensor) -> f32 { return %0 : f32 } +// CHECK-LABEL: @tensorLoadDynamic +func @tensorLoadDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> f32 { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor{%c4} + %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor{%c4} + return %0 : f32 +} + // ----- // CHECK-LABEL: @tensorStore @@ -48,6 +65,14 @@ func @tensorStoreScalar(%arg0 : f32, %arg1 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorStoreDynamic +func @tensorStoreDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index, %arg3 : f32) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor{%c4} + %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorSplat @@ -64,6 +89,14 @@ func @tensorSplatScalar(%arg0 : f32) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorSplatDynamic +func @tensorSplatDynamic(%arg0 : f32) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.splat %arg0 : tensor{%c4} + %0 = flow.tensor.splat %arg0 : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorClone @@ -80,6 +113,14 @@ func @tensorCloneScalar(%arg0 : tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @tensorCloneDynamic +func @tensorCloneDynamic(%arg0 : tensor) -> tensor { + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.clone %arg0 : tensor{%c4} + %0 = flow.tensor.clone %arg0 : tensor{%c4} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorSlice @@ -89,6 +130,15 @@ func @tensorSlice(%arg0 : tensor<4x4xf32>, %arg1 : index, %arg2 : index) -> tens return %0 : tensor<2x2xf32> } +// CHECK-LABEL: @tensorSliceDynamic +func @tensorSliceDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> tensor { + %c2 = constant 2 : index + %c4 = constant 4 : index + // CHECK: %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor{%c4} -> tensor{%c2} + %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor{%c4} -> tensor{%c2} + return %0 : tensor +} + // ----- // CHECK-LABEL: @tensorUpdate @@ -97,3 +147,13 @@ func @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, %arg2 : ind %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32> return %0 : tensor<4x4xf32> } + +// CHECK-LABEL: @tensorUpdateDynamic +func @tensorUpdateDynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor{%c1, %c2} -> tensor{%c3} + %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor{%c1, %c2} -> tensor{%c3} + return %0 : tensor +} diff --git a/iree/compiler/Dialect/Flow/IR/test/types.mlir b/iree/compiler/Dialect/Flow/IR/test/types.mlir index 110a09c1fd77..f64620b82a6e 100644 --- a/iree/compiler/Dialect/Flow/IR/test/types.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/types.mlir @@ -2,26 +2,26 @@ // CHECK-LABEL: @dispatchTypes func @dispatchTypes( - // CHECK-SAME: %arg0: !flow.dispatch.input - %arg0 : !flow.dispatch.input, - // CHECK-SAME: %arg1: !flow.dispatch.input<4x4xf32> - %arg1 : !flow.dispatch.input<4x4xf32>, - // CHECK-SAME: %arg2: !flow.dispatch.input<1x2x3x4x5x6xf32> - %arg2 : !flow.dispatch.input<1x2x3x4x5x6xf32>, - // CHECK-SAME: %arg3: !flow.dispatch.input - %arg3 : !flow.dispatch.input, - // CHECK-SAME: %arg4: !flow.dispatch.input<1x?x3xf32> - %arg4 : !flow.dispatch.input<1x?x3xf32>, - // CHECK-SAME: %arg5: !flow.dispatch.output - %arg5 : !flow.dispatch.output, - // CHECK-SAME: %arg6: !flow.dispatch.output<4x4xf32> - %arg6 : !flow.dispatch.output<4x4xf32>, - // CHECK-SAME: %arg7: !flow.dispatch.output<1x2x3x4x5x6xf32> - %arg7 : !flow.dispatch.output<1x2x3x4x5x6xf32>, - // CHECK-SAME: %arg8: !flow.dispatch.output - %arg8 : !flow.dispatch.output, - // CHECK-SAME: %arg9: !flow.dispatch.output<1x?x3xf32> - %arg9 : !flow.dispatch.output<1x?x3xf32> + // CHECK-SAME: %arg0: !flow.dispatch.tensor + %arg0: !flow.dispatch.tensor, + // CHECK-SAME: %arg1: !flow.dispatch.tensor + %arg1: !flow.dispatch.tensor, + // CHECK-SAME: %arg2: !flow.dispatch.tensor + %arg2: !flow.dispatch.tensor, + // CHECK-SAME: %arg3: !flow.dispatch.tensor + %arg3: !flow.dispatch.tensor, + // CHECK-SAME: %arg4: !flow.dispatch.tensor + %arg4: !flow.dispatch.tensor, + // CHECK-SAME: %arg5: !flow.dispatch.tensor + %arg5: !flow.dispatch.tensor, + // CHECK-SAME: %arg6: !flow.dispatch.tensor + %arg6: !flow.dispatch.tensor, + // CHECK-SAME: %arg7: !flow.dispatch.tensor + %arg7: !flow.dispatch.tensor, + // CHECK-SAME: %arg8: !flow.dispatch.tensor + %arg8: !flow.dispatch.tensor, + // CHECK-SAME: %arg9: !flow.dispatch.tensor + %arg9: !flow.dispatch.tensor ) { return } diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index fe910da963bd..720ec1de2abf 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD @@ -46,7 +46,6 @@ cc_library( "OutlineLargeConstants.cpp", "Passes.cpp", "PrePostPartitioningConversion.cpp", - "RematerializeDispatchConstants.cpp", "StripAndSplatConstantVariables.cpp", ], hdrs = [ diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 61a31fa3c015..97c556d81ef1 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -41,7 +41,6 @@ iree_cc_library( "OutlineLargeConstants.cpp" "Passes.cpp" "PrePostPartitioningConversion.cpp" - "RematerializeDispatchConstants.cpp" "StripAndSplatConstantVariables.cpp" DEPS LLVMSupport diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp index 352ae67fadb1..711a0cd28977 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp @@ -69,7 +69,8 @@ class CreateBenchmarkFuncs auto dummyWorkload = blockBuilder.create(loc, 0); auto dispatchOp = blockBuilder.create( loc, dispatchEntryOp, ValueRange{dummyWorkload}, - funcType.getResults(), args); + funcType.getResults(), ValueRange{}, args, ValueRange{}, + ArrayRef{}); blockBuilder.create(loc, dispatchOp.getResults()); } } diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp index 0835445df6f5..57f3d4343069 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp @@ -221,7 +221,7 @@ static Value isADestructiveUpdatePattern(Value tensor, static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); - auto loadOp = op.source().getDefiningOp(); + auto loadOp = op.source().getDefiningOp(); if (!loadOp) { BlockArgument val = op.source().dyn_cast(); while (val) { @@ -230,7 +230,7 @@ static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { if (!forOp) return failure(); unsigned idx = val.getArgNumber() - 1; // accounting for IV arg. Value iterOperand = *(forOp.getIterOperands().begin() + idx); - loadOp = iterOperand.getDefiningOp(); + loadOp = iterOperand.getDefiningOp(); val = iterOperand.dyn_cast(); } } @@ -243,7 +243,7 @@ static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) { sizes.push_back(r.size); strides.push_back(r.stride); } - Value loaded = b.create( + Value loaded = b.create( op.getLoc(), op.getResult().getType(), loadOp.source(), offsets, sizes, strides); op.getResult().replaceAllUsesWith(loaded); @@ -277,7 +277,7 @@ static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b, sizes.push_back(r.size); strides.push_back(r.stride); } - b.create(op.getLoc(), op.source(), target, + b.create(op.getLoc(), op.source(), target, offsets, sizes, strides); return success(); } @@ -418,7 +418,7 @@ static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, Value v, // Reload the value produced inplace right after the inplace update. OpBuilder::InsertionGuard g(b); b.setInsertionPointAfter(outermostProducingOp); - Value newLoad = b.create( + Value newLoad = b.create( outermostProducingOp->getLoc(), v.getType(), target); // TODO(nicolasvasilache): this brutally replaces all uses by the result of // this load. In practice we may want more recompute and we may have lost @@ -435,8 +435,8 @@ static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, Value v, // consecutive ops". Probably better to wait until core alias analysis is // upstreamed. // TODO(nicolasvasilache): interfaces. -static bool hasInterleavedAliases(IREE::Flow::DispatchInputLoadOp loadOp, - IREE::Flow::DispatchOutputStoreOp storeOp) { +static bool hasInterleavedAliases(IREE::Flow::DispatchTensorLoadOp loadOp, + IREE::Flow::DispatchTensorStoreOp storeOp) { Block *bLoad = loadOp.getOperation()->getBlock(); Block *bStore = loadOp.getOperation()->getBlock(); if (!isa(bLoad->getParentOp()) || @@ -461,7 +461,7 @@ LogicalResult rewriteLinalgDestructiveUpdates( // For each tensor store op, look for destructive updates and replace the // destructive pattern by a custom inplace update pattern. bool fail = dispatchOp - .walk([&](IREE::Flow::DispatchOutputStoreOp op) { + .walk([&](IREE::Flow::DispatchTensorStoreOp op) { if (failed(rewriteDestructiveUpdateInPlace(b, op.value(), op.target()))) { return WalkResult::interrupt(); @@ -472,8 +472,8 @@ LogicalResult rewriteLinalgDestructiveUpdates( if (fail) return failure(); // For each tensor store op, redundant load/store optimization. - dispatchOp.walk([&](IREE::Flow::DispatchOutputStoreOp storeOp) { - auto loadOp = dyn_cast_or_null( + dispatchOp.walk([&](IREE::Flow::DispatchTensorStoreOp storeOp) { + auto loadOp = dyn_cast_or_null( storeOp.value().getDefiningOp()); // Bail if there exists an interleaved aliasing. if (!loadOp || hasInterleavedAliases(loadOp, storeOp)) { diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h index 106d3c981df3..7c747a074122 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h @@ -52,27 +52,30 @@ namespace Flow { // // The following destructive update patterns are rewritten. // -// Coming from an `Flow::DispatchInputLoadOp` +// Coming from an `Flow::DispatchTensorLoadOp` // ========================================== // ``` -// %0 = flow.dispatch.input.load %a : !flow.dispatch.input<...> -> tensor<...> +// %0 = flow.dispatch.tensor.load %a : !flow.dispatch.tensor -> +// tensor<...> // ... // %1 = destructive_update(%0) // ... -// use_of(%1) // e.g. flow.dispatch.output.store %1, %b : -// // tensor<...> -> !flow.dispatch.output<...> +// use_of(%1) // e.g. flow.dispatch.tensor.store %1, %b : +// // tensor<...> -> !flow.dispatch.tensor // ``` // is rewritten into: // ``` -// %0 = flow.dispatch.input.load %a : !flow.dispatch.input<...> -> tensor<...> +// %0 = flow.dispatch.tensor.load %a : !flow.dispatch.tensor -> +// tensor<...> // ... -// inplace_update(%0, %out) //e.g. flow.dispatch.output.store %subtensor, %b, +// inplace_update(%0, %out) //e.g. flow.dispatch.tensor.store %subtensor, %b, // // offsets = ..., sizes = ..., strides = ... : -// // tensor<...> -> !flow.dispatch.output<...> +// // tensor<...> -> +// !flow.dispatch.tensor // %2 = flow.dispatch.output.load %b // ... -// use_of(%2) // e.g. flow.dispatch.output.store %2, %b : -// // tensor<...> -> !flow.dispatch.output<...> +// use_of(%2) // e.g. flow.dispatch.tensor.store %2, %b : +// // tensor<...> -> !flow.dispatch.tensor // ``` // // This is a typical pattern that appears after tiling Linalg ops on tensors @@ -84,8 +87,8 @@ namespace Flow { // ``` // %2 = flow.dispatch.output.load %b // ... -// flow.dispatch.output.store %2, %b : -// tensor<...> -> !flow.dispatch.output<...> +// flow.dispatch.tensor.store %2, %b : +// tensor<...> -> !flow.dispatch.tensor // ``` // is elided. LogicalResult rewriteLinalgDestructiveUpdates( diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 1a2fbccec1f3..8f8a428f0742 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -16,11 +16,14 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" #include "iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -177,7 +180,9 @@ static SmallVector convertToWorkload(OpBuilder &b, Location loc, /// linalg.init_tensor operations. static bool isRootOp(Operation *op) { - return isa(op); + return isa(op); } static bool isAlwaysClonedIntoDispatchOp(Operation *op) { @@ -219,7 +224,10 @@ static std::pair buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc, ArrayRef count, Operation *op) { auto dispatchOp = rewriter.create( - loc, count, op->getResultTypes(), ValueRange{}); + loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{}, + /*operands=*/ValueRange{}, + /*operand_dims=*/ValueRange{}, + /*tied_operands=*/ArrayRef{}); Region ®ion = dispatchOp.body(); Block *block = ®ion.front(); Operation *clonedOp; @@ -230,7 +238,7 @@ buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc, for (auto it : llvm::zip(clonedOp->getResults(), dispatchOp.body().getArguments().take_back( clonedOp->getNumResults()))) { - rewriter.create( + rewriter.create( loc, std::get<0>(it), std::get<1>(it), llvm::None, llvm::None, llvm::None); } @@ -380,18 +388,18 @@ static LogicalResult legalizeDispatchWorkgroupOperands( // Replace valuesDefinedAbove by new BB args (including the op's operands). for (Value operand : valuesDefinedAbove) { if (auto rt = operand.getType().dyn_cast()) { - block.addArgument(IREE::Flow::DispatchInputType::get( - rt.getShape(), rt.getElementType())); + block.addArgument(IREE::Flow::DispatchTensorType::get( + TensorAccess::ReadOnly, rt.getShape(), rt.getElementType())); } else { block.addArgument(operand.getType()); } Value bbArg = block.getArguments().back(); Value repl = bbArg; - if (bbArg.getType().isa()) { - repl = b.create(loc, operand.getType(), - bbArg); - } else if (bbArg.getType().isa()) { + if (bbArg.getType().isa()) { + repl = b.create(loc, operand.getType(), + bbArg); + } else if (bbArg.getType().isa()) { // TODO(nicolasvasilache): do something useful. continue; } @@ -401,7 +409,7 @@ static LogicalResult legalizeDispatchWorkgroupOperands( // The only existing arguments are for the outputs. Just need to add a new // argument for the outputs and remap the value to use the new argument. for (auto ba : block.getArguments().take_front(numOldBBArgs)) { - assert(ba.getType().isa()); + assert(ba.getType().isa()); map.map(ba, block.addArgument(ba.getType())); } @@ -446,8 +454,22 @@ static LogicalResult legalizeDispatchWorkgroupOperands( block.eraseArguments( llvm::to_vector<4>(llvm::seq(0, numOldBBArgs))); + // Gather the dynamic dimensions for all operands. + SmallVector operandDynamicDims; + OpBuilder builder(dispatchOp); + for (Value operand : valuesDefinedAbove) { + if (auto rt = operand.getType().dyn_cast()) { + for (unsigned i = 0; i < rt.getRank(); ++i) { + if (!rt.isDynamicDim(i)) continue; + auto dim = builder.createOrFold(dispatchOp.getLoc(), operand, i); + operandDynamicDims.push_back(dim); + } + } + } + // Set the values captured from above as the new operands. dispatchOp.operandsMutable().assign(llvm::to_vector<4>(valuesDefinedAbove)); + dispatchOp.operand_dimsMutable().assign(operandDynamicDims); return success(); } @@ -507,10 +529,25 @@ struct TileAndDistributeOnTensorsPattern SmallVector count = llvm::to_vector<4>( llvm::map_range(linalgOp.createLoopRanges(rewriter, loc), [](Range r) { return r.size; })); + // NOTE: Special treatment for convolution, which have more than 3 parallel + // dimensions. We want to ignore the batch dimension and tile along the + // next three. + // TODO(#5048): figure out a better way to avoid this special case. + if (isa(op)) { + count.erase(count.begin()); + } count.resize(getNumTilableLoops(op)); auto workload = convertToWorkload(rewriter, loc, count); - // Note: DispatchOutputStoreOp generated by the + // Capture dynamic result dimensions. + SmallVector resultDynamicDims; + for (auto result : linalgOp.outputs()) { + resultDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + linalgOp.getLoc(), result, rewriter)); + } + + // Note: DispatchTensorStoreOp generated by the // `buildOperandLessFlowDispatchWorkgroupOp` is an abstraction jump that // consumes the SSA value produced by `clonedOp` but it does not comply with // the semantics of DispatchWorkgroupsOp which explicitly states: "behavior @@ -518,10 +555,11 @@ struct TileAndDistributeOnTensorsPattern // output tensors". Similarly to sequentialized SPMD loops, the semantics // is valid assuming a sequential ordering of execution. After destructive // update rewrites, the abstraction gap disappears. - auto en = - buildOperandLessFlowDispatchWorkgroupOp(rewriter, loc, workload, op); - linalg::LinalgOp clonedLinalgOp = cast(en.second); + auto en = buildOperandLessFlowDispatchWorkgroupOp(rewriter, loc, workload, + linalgOp); IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first; + linalg::LinalgOp clonedLinalgOp = cast(en.second); + dispatchOp.result_dimsMutable().assign(resultDynamicDims); // Scoped within DispatchWorkgroupOp. OpBuilder::InsertionGuard g(rewriter); @@ -546,7 +584,7 @@ struct TileAndDistributeOnTensorsPattern tiledLinalgOp.op.getOperation()->removeAttr(kRootOpAttr); rewriter.replaceOpWithIf( - op, dispatchOp.getOperation()->getResults(), + op, dispatchOp.getResults(), [&](OpOperand &operand) { return !isa(operand.getOwner()); }); return success(); } @@ -554,7 +592,7 @@ struct TileAndDistributeOnTensorsPattern /// The workload is computed based on the problem size. For a given operation, /// return the problem size. -static Optional> getProblemSize(PatternRewriter &rewriter, +static Optional> getResultShape(PatternRewriter &rewriter, Operation *op) { Location loc = op->getLoc(); auto getShapeOfShapedTypeVal = [&](Value v) -> SmallVector { @@ -599,11 +637,13 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { })) { return failure(); } + // The workgroup count is based on the result shape. if (op->getNumResults() != 1) return failure(); - Optional> countOpt = getProblemSize(rewriter, op); - if (!countOpt) return failure(); - SmallVector count = *countOpt; + Optional> resultShapeOpt = + getResultShape(rewriter, op); + if (!resultShapeOpt) return failure(); + SmallVector resultShape = *resultShapeOpt; // TODO(ravishankarm): For now the Flow -> HAL conversion only handles // workload count of 3, though it should be generalized. For now making sure @@ -611,6 +651,7 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { // workloads for all higher dimensions greater than or equal to // kNumMaxParallelDims. Location loc = op->getLoc(); + SmallVector count = resultShape; if (count.size() > kNumMaxParallelDims) { unsigned numSymbols = 0; AffineExpr expr = rewriter.getAffineSymbolExpr(numSymbols++); @@ -625,9 +666,22 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern { ArrayRef(count).take_back(kNumMaxParallelDims)); } auto workload = convertToWorkload(rewriter, loc, count); + + // Capture dynamic result dimensions. + assert(op->getNumResults() == 1 && "currently assuming a single result"); + auto resultType = op->getResult(0).getType().cast(); + SmallVector resultDynamicDims; + for (unsigned i = 0; i < resultType.getRank(); ++i) { + if (resultType.isDynamicDim(i)) { + resultDynamicDims.push_back(resultShape[i]); + } + } + auto en = buildOperandLessFlowDispatchWorkgroupOp(rewriter, op->getLoc(), workload, op); IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first; + dispatchOp.result_dimsMutable().assign(resultDynamicDims); + rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(), [&](OpOperand &operand) { Operation *user = operand.getOwner(); @@ -729,10 +783,9 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { auto numParallelDims = parallelLoopRanges.size(); + SmallVector procInfo(numParallelDims); - for (size_t dim = 0; - dim < std::min(numParallelDims, kNumMaxParallelDims); - ++dim) { + for (size_t dim = 0; dim < numParallelDims; ++dim) { procInfo[numParallelDims - dim - 1] = { buildFlowWorkgroupInfoOp(builder, dim), @@ -746,21 +799,34 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { auto tileSizeFn = [&](OpBuilder &builder, Operation *op) -> SmallVector { + auto numParallelDims = getNumOuterParallelLoops(cast(op)); auto numTiledLoops = getNumTilableLoops(cast(op)); - SmallVector useTileSizes(numTiledLoops); + + // Default to zero to skip tiling. + auto zero = builder.create(op->getLoc(), 0); + SmallVector useTileSizes(numParallelDims, zero); + if (!clLinalgOnTensorsTileSizes.empty()) { SmallVector tileSizes(clLinalgOnTensorsTileSizes.begin(), clLinalgOnTensorsTileSizes.end()); - useTileSizes.resize(std::min(tileSizes.size(), numTiledLoops)); + useTileSizes.resize(std::min(tileSizes.size(), numParallelDims)); return llvm::to_vector<4>(llvm::map_range( ArrayRef(tileSizes).take_front( - std::min(tileSizes.size(), numTiledLoops)), + std::min(tileSizes.size(), numParallelDims)), [&](int64_t t) -> Value { return builder.create(op->getLoc(), t); })); } + + // NOTE: Special treatment for convolution, which have more than 3 + // parallel dimensions. We want to ignore the batch dimension and tile + // along the next three. That means setting the first position to zero. + // TODO(#5048): figure out a better way to avoid this special case. + bool isConvOp = isa(op); + for (size_t dim = 0; dim < numTiledLoops; ++dim) { - useTileSizes[numTiledLoops - dim - 1] = + useTileSizes[(isConvOp ? numParallelDims : numTiledLoops) - dim - 1] = buildFlowWorkgroupInfoOp(builder, dim); } return useTileSizes; @@ -817,6 +883,19 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { return signalPassFailure(); } + // Run necessary canonicalization patterns before destructive updates. + { + OwningRewritePatternList patterns; + // This is needed because tiling and distribution may create + // subtensor_insert ops whose source operands come from tensor.cast ops. + // Those tensor.cast ops cast tensors into a more dynamic shape, in order + // to guarantee type match during transformation. Later in destructive + // update subtensor_insert ops will be turned into flow dispatch output + // store ops. + SubTensorInsertOp::getCanonicalizationPatterns(patterns, context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + // Rewrite destructive updates and ensure no remaining store remains to the // full output. if (funcOp diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp index 4a86c0d92c9d..85fbbdc06e0a 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/Utils/TypeConversion.h" @@ -43,8 +44,8 @@ namespace Flow { // Returns true if the given op can be used within a stream. static bool isStreamableOp(Operation *op) { - if (auto streamableOp = dyn_cast(op)) { - return streamableOp.isUsableInStream(); + if (isa(op)) { + return true; } if (llvm::isa(op)) { return true; @@ -61,21 +62,6 @@ static inline bool usefulStreamWork(ArrayRef currentStreamOps) { return llvm::any_of(currentStreamOps, usefulStreamOp); } -// Expand any compound types to primitive types in the stream fragment. -static void expandFragmentToPrimitiveTypes(ExStreamFragmentOp fragmentOp) { - auto loc = fragmentOp.getLoc(); - Block *entryBlock = &fragmentOp.body().front(); - auto &typeExpander = Shape::getShapeToPrimitiveTypeExpander(); - OpBuilder expandBuilder(fragmentOp.getContext()); - (void)typeExpander.expandBlockSignature(loc, entryBlock, expandBuilder); - SmallVector origFragmentArgs(fragmentOp.args()); - SmallVector newFragmentArgs; - expandBuilder.setInsertionPoint(fragmentOp); - (void)typeExpander.expandSourceValuesToTarget(loc, origFragmentArgs, - newFragmentArgs, expandBuilder); - fragmentOp.getOperation()->setOperands(newFragmentArgs); -} - // Temporary hack to get the experimental stream ops constructed. In the future // this will run an analysis to identify compatible dispatches across the entire // function CFG, create the streams, and then thread the streams through the CFG @@ -130,10 +116,11 @@ class FormStreamsPass : public PassWrapper { llvm::SmallSetVector streamOpSet{streamOps.begin(), streamOps.end()}; SmallVector fragmentOperands; + SmallVector fragmentOperandDims; SmallVector fragmentResults; + SmallVector fragmentResultDims; SmallVector fragmentResultTypes; - SmallVector tieShapeOps; - SmallVector outsideTieShapeOperands; + SmallVector fragmentTiedOperands; for (auto *op : streamOps) { for (auto operand : op->getOperands()) { if (std::find(fragmentOperands.begin(), fragmentOperands.end(), @@ -141,16 +128,10 @@ class FormStreamsPass : public PassWrapper { if (!operand.getDefiningOp() || !streamOpSet.count(operand.getDefiningOp())) { fragmentOperands.push_back(operand); - - auto operandDefiningOp = operand.getDefiningOp(); - if (operandDefiningOp && - llvm::isa(operandDefiningOp)) { - tieShapeOps.push_back(operand.getDefiningOp()); - auto definingOp = - dyn_cast(operand.getDefiningOp()); - for (auto arg : definingOp.getOperands()) { - outsideTieShapeOperands.push_back(arg); - } + if (operand.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + fragmentLoc, operand, blockBuilder); + fragmentOperandDims.append(dynamicDims); } } } @@ -167,32 +148,28 @@ class FormStreamsPass : public PassWrapper { if (!onlyStreamUses) { fragmentResults.push_back(result); fragmentResultTypes.push_back(result.getType()); + if (result.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + fragmentLoc, result, blockBuilder); + fragmentResultDims.append(dynamicDims); + } } } } - // TODO(Tao Peng): pass args(operand and shape) which need by outside - // tie_shape into fragment body, and ignore the tie_shape arg passed into - // the fragment, it will not be used, and will be deleted by canonicalizer - // later. - outsideTieShapeOperands.append(fragmentOperands.begin(), - fragmentOperands.end()); - fragmentOperands = outsideTieShapeOperands; - // Create the fragment and clone in all of the ops. auto fragmentOp = blockBuilder.create( - fragmentLoc, fragmentResultTypes, fragmentOperands); + fragmentLoc, fragmentResultTypes, fragmentResultDims, fragmentOperands, + fragmentOperandDims, fragmentTiedOperands); auto *entryBlock = new Block(); fragmentOp.body().getBlocks().push_back(entryBlock); - entryBlock->addArguments(llvm::to_vector<8>(fragmentOp.getOperandTypes())); + entryBlock->addArguments(TypeRange(fragmentOp.operands())); BlockAndValueMapping mapping; - for (auto arg : entryBlock->getArguments()) { + for (unsigned i = 0; i < fragmentOperands.size(); ++i) { + auto arg = entryBlock->getArgument(i); mapping.map(fragmentOperands[arg.getArgNumber()], arg); } OpBuilder fragmentBuilder = OpBuilder::atBlockEnd(entryBlock); - for (auto *op : tieShapeOps) { - fragmentBuilder.clone(*op, mapping); - } for (auto *op : streamOps) { fragmentBuilder.clone(*op, mapping); } @@ -201,20 +178,18 @@ class FormStreamsPass : public PassWrapper { llvm::to_vector<8>(llvm::map_range(fragmentResults, [&](Value value) { return mapping.lookup(value); }))); - for (auto resultOldNew : - llvm::zip(fragmentResults, fragmentOp.getResults())) { + for (auto resultOldNew : llvm::zip(fragmentResults, fragmentOp.results())) { auto oldValue = std::get<0>(resultOldNew); auto newValue = std::get<1>(resultOldNew); oldValue.replaceAllUsesWith(newValue); } // Erase the ops from the block now that we've cloned them. + // Note the backwards order as the ops may have dependencies on each other + // and we have to erase the consumers before the producers. for (auto *op : llvm::reverse(streamOps)) { op->erase(); } - - // Expand any shape types to corresponding primitives. - expandFragmentToPrimitiveTypes(fragmentOp); } }; diff --git a/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp b/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp index 62e9449b4c95..7e2aca509c5d 100644 --- a/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/HoistUnstreamableOps.cpp @@ -29,8 +29,8 @@ namespace IREE { namespace Flow { static bool isStreamableOp(Operation *op) { - if (auto streamableOp = dyn_cast(op)) { - return streamableOp.isUsableInStream(); + if (isa(op)) { + return true; } if (llvm::isa(op)) { return true; diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp index dadac5d560a5..3d6e70144348 100644 --- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp @@ -248,7 +248,6 @@ LogicalResult fuseInputs(DispatchRegion &dispatchRegion, // makes this simple check safe. // The dispatch region must be optimized to remove unused arguments // resulting from this fusion. - DispatchRegionOp::dceOperandsAndResults(dispatchRegion.op); if (nextOp->use_empty()) { nextOp->erase(); } @@ -386,8 +385,11 @@ LogicalResult processBlock(Block &block, OpDispatchPolicy &policy) { if (failed(fuseInputs(*dispatchRegion, policy))) return failure(); // Ensure all unused operands and results are dce'd. - DispatchRegionOp::dceOperandsAndResults(dispatchRegion->op); - hoistDispatchRegionMetadataOps(*dispatchRegion, policy); + // Note that this may delete the op itself if it is unused. + optimizeClosureOp(dispatchRegion->op); + if (dispatchRegion->op) { + hoistDispatchRegionMetadataOps(*dispatchRegion, policy); + } } return success(); } diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index ddda1be83c9a..81ec219d4073 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -17,6 +17,7 @@ #include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h" +#include "iree/compiler/Dialect/Shape/IR/Builders.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "iree/compiler/Dialect/Shape/Utils/TypeConversion.h" @@ -76,10 +77,26 @@ LogicalResult convertToDispatchOp(DispatchRegionOp regionOp, getTensorTypeArgs(newArgs)); } + SmallVector operandDynamicDims; + for (auto operand : regionOp.args()) { + if (operand.getType().isa()) { + operandDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), operand, builder)); + } + } + SmallVector resultDynamicDims; + for (auto result : regionOp.results()) { + if (result.getType().isa()) { + resultDynamicDims.append(Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), result, builder)); + } + } + // Create the dispatch op to the executable function. auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, ValueRange{regionOp.workload()}, - outlinedFuncOp.getType().getResults(), newArgs); + outlinedFuncOp.getType().getResults(), resultDynamicDims, newArgs, + operandDynamicDims, ArrayRef{}); if (traceDispatchTensors) { std::string str = "Output for " + std::string(outlinedFuncOp.getName()); diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp index 4b704c6788d6..da32fcbcea90 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions2.cpp @@ -37,28 +37,6 @@ namespace IREE { namespace Flow { namespace { -// Expands dynamic dimensions of a shaped type to their individual values. -// Will walk the shape IR hierarchy to resolve the dimensions as possible. -static void insertDynamicShapeDimOperands(DispatchWorkgroupsOp regionOp, - Value value, - SmallVectorImpl &newOperands, - OpBuilder &builder) { - auto shapedType = value.getType().cast(); - if (shapedType.hasStaticShape()) return; - - // NOTE: this may insert new shape values at |builder|, which is prior to our - // dispatch operation. All new values that are built can only depend on SSA - // values that are defined prior to the region op. - auto shapeValue = Shape::buildOrFindRankedShapeForValue( - regionOp.getLoc(), value, builder.getIndexType(), builder); - for (int dim = 0, e = shapedType.getRank(); dim < e; ++dim) { - if (shapedType.isDynamicDim(dim)) { - newOperands.push_back(builder.create( - regionOp.getLoc(), shapeValue, dim)); - } - } -} - // Converts a dispatch region op into a dispatch op to the outlined region. static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp, ExecutableOp executableOp, @@ -67,25 +45,39 @@ static LogicalResult convertToDispatchOp(DispatchWorkgroupsOp regionOp, OpBuilder builder(regionOp); // Perform shape to primitive type expansion. + // NOTE: this may insert new shape values at |builder|, which is prior to + // our dispatch operation. All new values that are built can only depend + // on SSA values that are defined prior to the region op. SmallVector newOperands; + SmallVector operandDynamicDims; + SmallVector resultDynamicDims; for (auto operand : regionOp.operands()) { newOperands.push_back(operand); } for (auto operand : regionOp.operands()) { - if (operand.getType().isa()) { - insertDynamicShapeDimOperands(regionOp, operand, newOperands, builder); + if (operand.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), operand, builder); + operandDynamicDims.append(dynamicDims); + newOperands.append(dynamicDims); } } for (auto result : regionOp.results()) { - if (result.getType().isa()) { - insertDynamicShapeDimOperands(regionOp, result, newOperands, builder); + if (result.getType().isa()) { + auto dynamicDims = Shape::buildOrFindDynamicDimsForValue( + regionOp.getLoc(), result, builder); + resultDynamicDims.append(dynamicDims); + newOperands.append(dynamicDims); } } // Create the dispatch op to the executable function. + // Note that we copy the tied operand indices from the workgroups op - it + // lines up 1:1 with the dispatch once we've outlined things. auto dispatchOp = builder.create( regionOp.getLoc(), entryPointOp, regionOp.workgroup_count(), - regionOp.getResultTypes(), newOperands); + regionOp.getResultTypes(), resultDynamicDims, newOperands, + operandDynamicDims, regionOp.getTiedResultOperandIndices()); // Replace uses of the existing results with the new results. for (int i = 0; i < regionOp.getNumResults(); ++i) { @@ -112,13 +104,9 @@ static FuncOp createWorkgroupFunc(Location loc, StringRef functionName, SmallVector operandTypes; int64_t totalDynamicDims = 0; for (auto &operand : region.getArguments()) { - if (auto inputType = operand.getType().dyn_cast()) { - operandTypes.push_back(inputType); - totalDynamicDims += inputType.getNumDynamicDims(); - } else if (auto outputType = - operand.getType().dyn_cast()) { - operandTypes.push_back(outputType); - totalDynamicDims += outputType.getNumDynamicDims(); + if (auto tensorType = operand.getType().dyn_cast()) { + operandTypes.push_back(tensorType); + totalDynamicDims += tensorType.getNumDynamicDims(); } else { // Pass-through. operandTypes.push_back(operand.getType()); diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index f5a60db96162..0e56a132d9cc 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -195,8 +195,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(createCanonicalizerPass()); addHLOToLinalgOnTensorsPasses(passManager, clEnableLinalgOnTensorsDispatch); passManager.addNestedPass(createDispatchLinalgOnTensorsPass()); - passManager.addPass(createCanonicalizerPass()); - passManager.addPass(createCSEPass()); + passManager.addNestedPass(createCanonicalizerPass()); } // First perform module-level analysis that following passes will use to query @@ -211,12 +210,6 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass( IREE::Flow::createFoldCompatibleDispatchRegionsPass()); - // Note that as we are rematerializing things here it's critical we do not run - // the canonicalizer/CSE between now and when we outline - otherwise it'll - // undo all of our work! - passManager.addNestedPass( - IREE::Flow::createRematerializeDispatchConstantsPass()); - // Outline the dispatch regions into their own functions wrapped in // executables. This separates sequencer functions performing dispatches from // dispatchees. @@ -258,15 +251,16 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(createCanonicalizerPass()); passManager.addNestedPass(IREE::Flow::createFormStreamsPass()); - // Forming streams involves a fair amount of subgraph stitching, which can - // cause duplication. Run CSE to collapse. - passManager.addNestedPass(createCanonicalizerPass()); - passManager.addNestedPass(createCSEPass()); // Prior to leaving the pipeline we need to clean things up for following // layers. These transforms may be undone by subsequent CSE/folding passes. passManager.addPass(createOutlineLargeConstantsPass()); + // Forming streams involves a fair amount of subgraph stitching, which can + // cause duplication. Run CSE to collapse. + passManager.addNestedPass(createCanonicalizerPass()); + passManager.addNestedPass(createCSEPass()); + // Symbol DCE any remaining variables/functions that are now no longer // required. passManager.addPass(createSymbolDCEPass()); diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index e7101f7cdeba..0d2c28b1e8b4 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -124,10 +124,6 @@ std::unique_ptr> createIdentifyDispatchRegions2Pass(); std::unique_ptr> createFoldCompatibleDispatchRegionsPass(); -// Rematerializes small previously-CSE'd constants into dispatch regions. -std::unique_ptr> -createRematerializeDispatchConstantsPass(); - // Outlines dispatch regions into executables. std::unique_ptr> createOutlineDispatchRegionsPass(); std::unique_ptr> createOutlineDispatchRegions2Pass(); @@ -200,7 +196,6 @@ inline void registerFlowPasses() { createIdentifyDispatchRegionsPass(); createIdentifyDispatchRegions2Pass(); createFoldCompatibleDispatchRegionsPass(); - createRematerializeDispatchConstantsPass(); createOutlineDispatchRegionsPass(); createCreateBenchmarkFuncs(); createOutlineLargeConstantsPass(); diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp deleted file mode 100644 index 43a184240e8a..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" -#include "llvm/Support/Debug.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace Flow { - -namespace { - -// Returns true if the constant value is a splat constant and can be -// rematerialized in a dispatch region. -bool isSplatConstant(ConstantOp constantOp) { - if (constantOp.getValue().isa()) { - // Splats are always small and can be much better handled by broadcasting - // within the dispatch regions. - return true; - } else if (auto value = constantOp.getValue().dyn_cast()) { - return value.isSplat(); - } else if (constantOp.getType().isIntOrFloat()) { - return true; - } - - // Assume anything unshaped is small. This may not always be true in custom - // dialects but is in std for now. - return false; -} - -// Returns true if the dispatch region is allowed to have constants inside. -// Certain regions that may get replaced or turned into kernel imports shouldn't -// have the constants moved into them as they'll just get lost. -bool canDispatchRegionContainConstants(DispatchRegionOp dispatchRegionOp) { - for (auto &block : dispatchRegionOp.body()) { - for (auto &op : block) { - // TODO(b/144530470): replace with tablegen attributes/interfaces. - if (isa(&op) || isa(&op)) { - // These two generally result in a lot of generated code so we try to - // keep constants out such that can dedupe more. We may still want to - // allow some parameters in (shapes/etc). - return false; - } - } - } - return true; -} - -// Recursively clones the given |sourceOp| and returns the newly cloned op. -Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - // Note that we dedupe required operands in the case of multiple arguments - // coming from the same source operation. - SmallPtrSet operandOps; - for (auto operand : sourceOp->getOperands()) { - operandOps.insert(operand.getDefiningOp()); - } - for (auto *operandOp : operandOps) { - recursivelyCloneOp(operandOp, builder, mapping); - } - return builder.clone(*sourceOp, *mapping); -} - -// Clones the |sourceValue| op tree into |targetBlock|. -// |mapping| is used to lookup existing values that may be present in the block -// such as block arguments or already cloned ancestor ops. |mapping| will be -// updated as the tree is cloned. -Value cloneOpTreeIntoBlock(Value sourceValue, Block *targetBlock, - BlockAndValueMapping *mapping) { - // If the op has already been cloned we can just reuse that. - // This happens if multiple arguments reference the same trees. - if (auto existingValue = mapping->lookupOrNull(sourceValue)) { - return existingValue; - } - - OpBuilder builder = OpBuilder::atBlockEnd(targetBlock); - builder.setInsertionPointToStart(targetBlock); - auto *sourceOp = sourceValue.getDefiningOp(); - auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping); - - // Return only the result matching our source value (in the case of multiple - // results). - int resultIndex = std::distance( - sourceOp->result_begin(), - std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue)); - return clonedOp->getResult(resultIndex); -} - -// Modify the second operand of the SegmentSize attribute -// TODO(ataei): Remove this once we have flow.dispatch.workgroups only here. -template -void modifyOperandSegmentSizeAttr(DispatchOpType dispatchOp, int32_t argCount); - -template <> -void modifyOperandSegmentSizeAttr(DispatchRegionOp dispatchRegionOp, - int32_t argCount) {} - -template <> -void modifyOperandSegmentSizeAttr(DispatchWorkgroupsOp dispatchWorkgroupsOp, - int32_t argCount) { - dispatchWorkgroupsOp.getOperation()->setAttr( - DispatchWorkgroupsOp::getOperandSegmentSizeAttr(), - DenseIntElementsAttr::get( - VectorType::get( - 2, IntegerType::get(dispatchWorkgroupsOp.getContext(), 32)), - ArrayRef( - {static_cast( - dispatchWorkgroupsOp.workgroup_count().size()), - static_cast(dispatchWorkgroupsOp.operands().size() - - argCount)}))); -} - -// Inlines use of the given |value| from outside of a dispatch region to inside -// of it and removes the argument. Supports multiple arguments that reference -// |value| and will clone the entire value tree. -template -LogicalResult inlineDispatchRegionOperandsUsingValue(DispatchOpType dispatchOp, - ValueRange args, - Value value) { - // Find all args that are using this value. - SmallVector argIndices; - for (auto arg : llvm::enumerate(args)) { - if (arg.value() == value) { - argIndices.push_back(arg.index()); - } - } - if (argIndices.empty()) { - // Not used? Wasteful call! - return success(); - } - - // Clone the value (and the ops required to create it) into the entry block. - auto &entryBlock = dispatchOp.body().getBlocks().front(); - BlockAndValueMapping mapping; - auto clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping); - - // Replace all uses of the inner operand with the new value. - for (unsigned argIndex : argIndices) { - entryBlock.getArgument(argIndex).replaceAllUsesWith(clonedValue); - } - - // Remove the dispatch region args and the block args that have been - // replaced. - for (unsigned argIndex : llvm::reverse(argIndices)) { - dispatchOp.getOperation()->eraseOperand( - dispatchOp.mapArgOperandToOpOperand(argIndex)); - entryBlock.eraseArgument(argIndex); - } - modifyOperandSegmentSizeAttr(dispatchOp, argIndices.size()); - - return success(); -} - -// Rematerializes a constant inside of all dispatch regions that use it. -// Afterward the constant is only removed if there are no other uses within -// the non-dispatch block (such as by sequencer ops). -LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) { - Value constantValue = constantOp.getResult(); - SmallVector usingRegionOps; - for (auto *user : constantValue.getUsers()) { - if (auto dispatchRegionOp = dyn_cast(user)) { - // Ensure this isn't just the workload and is used as an arg. - if (std::find(dispatchRegionOp.args().begin(), - dispatchRegionOp.args().end(), - constantValue) != dispatchRegionOp.args().end()) { - if (canDispatchRegionContainConstants(dispatchRegionOp)) { - usingRegionOps.push_back(dispatchRegionOp); - } - } - } - } - for (auto &dispatchRegionOp : usingRegionOps) { - if (failed(inlineDispatchRegionOperandsUsingValue( - dispatchRegionOp, dispatchRegionOp.args(), constantValue))) { - return failure(); - } - } - return success(); -} - -LogicalResult rematerializeConstantInDispatchWorkgroupsRegions( - ConstantOp constantOp) { - Value constantValue = constantOp.getResult(); - for (auto *user : constantValue.getUsers()) { - if (auto dispatchWorkgroupsOp = dyn_cast(user)) { - if (failed(inlineDispatchRegionOperandsUsingValue( - dispatchWorkgroupsOp, dispatchWorkgroupsOp.operands(), - constantValue))) { - return failure(); - } - } - } - return success(); -} - -} // namespace - -// Finds constant arguments to dispatch regions that are too small to be worth -// putting into constant pools. This prevents things like a CSE'd scalar -// constant of 0.0 being passed by reference to a bunch of regions. Later -// backend-specific passes running on the dispatch regions may also be able to -// improve their constant propagation chances by having the full constant value -// available. -// -// Note that this currently only operates at the block level. Constants that are -// pushed across branches are assumed to have been rematerialized within blocks -// already, but if that isn't the case then this pass can be extended to do -// that. -class RematerializeDispatchConstantsPass - : public PassWrapper { - public: - void runOnFunction() override { - for (auto &block : getFunction()) { - SmallVector smallConstantOps; - for (auto constantOp : block.getOps()) { - if (isSplatConstant(constantOp)) { - smallConstantOps.push_back(constantOp); - } - } - // Note: we iterate in reverse so that the rematerialized constants appear - // in the same order they did originally (as insertion is at the top). - for (auto constantOp : llvm::reverse(smallConstantOps)) { - if (failed(rematerializeConstantInDispatchRegions(constantOp))) { - return signalPassFailure(); - } - if (failed( - rematerializeConstantInDispatchWorkgroupsRegions(constantOp))) { - return signalPassFailure(); - } - // Remove if there are no other uses within the block. - if (constantOp.use_empty()) { - constantOp.erase(); - } - } - } - } -}; - -std::unique_ptr> -createRematerializeDispatchConstantsPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "iree-flow-rematerialize-dispatch-constants", - "Rematerializes small previously-CSE'd constants into dispatch regions"); - -} // namespace Flow -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD index b780b112a316..3903093842b6 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,37 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "create_benchmark_funcs.mlir", + "deduplicate_executables.mlir", + "dispatch_linalg_on_tensors.mlir", + "dispatch_linalg_on_tensors_dynamic.mlir", + "expand_variable_dynamic_dims.mlir", + "fold_compatible_dispatch_regions.mlir", + "form_streams.mlir", + "hlo_to_hlo_preprocessing.mlir", + "hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir", + "hlo_to_hlo_preprocessing_extract_pad_from_conv.mlir", + "hoist_unstreamable_ops.mlir", + "identify_dispatch_regions.mlir", + "identify_dispatch_regions2_enable_matmul_fusion.mlir", + "identify_dispatch_regions2_hlo.mlir", + "identify_dispatch_regions2_linalg.mlir", + "identify_dispatch_regions2_shapes.mlir", + "identify_dispatch_regions2_std_fusion.mlir", + "inject_dispatch_tracing.mlir", + "legalize_input_types.mlir", + "materialize_and_merge_exported_reflection.mlir", + "materialize_exported_reflection.mlir", + "merge_exported_reflection.mlir", + "outline_dispatch_regions2.mlir", + "outline_large_constants.mlir", + "strip_and_splat_constant_variables.mlir", + "transformation.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index ddcc3cde95d4..1b9ea4dca2f7 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -10,12 +10,36 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "create_benchmark_funcs.mlir" + "deduplicate_executables.mlir" + "dispatch_linalg_on_tensors.mlir" + "dispatch_linalg_on_tensors_dynamic.mlir" + "expand_variable_dynamic_dims.mlir" + "fold_compatible_dispatch_regions.mlir" + "form_streams.mlir" + "hlo_to_hlo_preprocessing.mlir" + "hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir" + "hlo_to_hlo_preprocessing_extract_pad_from_conv.mlir" + "hoist_unstreamable_ops.mlir" + "identify_dispatch_regions.mlir" + "identify_dispatch_regions2_enable_matmul_fusion.mlir" + "identify_dispatch_regions2_hlo.mlir" + "identify_dispatch_regions2_linalg.mlir" + "identify_dispatch_regions2_shapes.mlir" + "identify_dispatch_regions2_std_fusion.mlir" + "inject_dispatch_tracing.mlir" + "legalize_input_types.mlir" + "materialize_and_merge_exported_reflection.mlir" + "materialize_exported_reflection.mlir" + "merge_exported_reflection.mlir" + "outline_dispatch_regions2.mlir" + "outline_large_constants.mlir" + "strip_and_splat_constant_variables.mlir" + "transformation.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir index 6d785ebbd693..804ea79a74f0 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/create_benchmark_funcs.mlir @@ -12,8 +12,8 @@ module { // CHECK: func @two_dispatch_ex_dispatch_0_entry // CHECK: %{{.+}} = flow.variable.load @[[IN0_0]] : tensor<5x3xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN0_1]] : tensor<3x5xf32> -// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> { -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> = +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32> // CHECK: flow.return %[[DISPATCH_RES]] : tensor<5x5xf32> // CHECK: return %[[RES]] : tensor<5x5xf32> // @@ -23,7 +23,7 @@ module { // CHECK: %{{.+}} = flow.variable.load @[[IN1_0]] : tensor<3x5xf32> // CHECK: %{{.+}} = flow.variable.load @[[IN1_1]] : tensor<5x5xf32> // CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32> -// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}] (%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}}](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32> // CHECK: flow.return %[[DISPATCH_RES]] : tensor<3x5xf32> // CHECK: return %[[RES]] : tensor<3x5xf32> // @@ -32,7 +32,7 @@ module { // CHECK: func @two_dispatch_dummy_args() // CHECK: %{{.+}} = flow.variable.load @[[MAIN_IN_0]] : tensor<5x3xf32> // CHECK: %{{.+}} = flow.variable.load @[[MAIN_IN_1]] : tensor<3x5xf32> -// CHECK: flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>, tensor<3x5xf32>) { +// CHECK: flow.ex.stream.fragment({{.+}}) -> (tensor<5x5xf32>, tensor<3x5xf32>) = // CHECK: %[[DISPATCH_RES1:.+]] = flow.dispatch // CHECK: %[[DISPATCH_RES2:.+]] = flow.dispatch // CHECK: flow.return %[[DISPATCH_RES1]], %[[DISPATCH_RES2]] : tensor<5x5xf32>, tensor<3x5xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir index 60d38175b83c..d5eaee6cf93d 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir @@ -13,7 +13,7 @@ flow.executable @single_executable_ex_0 { // CHECK-LABEL: func @single_executable func @single_executable(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @single_executable_ex_0::@single_executable_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -53,11 +53,11 @@ flow.executable @duplicate_executables_ex_2 { // CHECK-LABEL: func @duplicate_executables func @duplicate_executables(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = flow.dispatch @duplicate_executables_ex_0::@duplicate_executables_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %1 = flow.dispatch @duplicate_executables_ex_1::@duplicate_executables_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %2 = flow.dispatch @duplicate_executables_ex_2::@duplicate_executables_entry_2[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -87,9 +87,9 @@ flow.executable @same_ops_diff_operands_ex_1 { // CHECK-LABEL: func @same_ops_diff_operands func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4] (%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -129,13 +129,13 @@ flow.executable @multiple_entry_points_ex_1 { // CHECK-LABEL: func @multiple_entry_points func @multiple_entry_points(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %0 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %1 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %2 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %2 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %3 = flow.dispatch @multiple_entry_points_ex_0::@multiple_entry_points_0_entry_1[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> %3 = flow.dispatch @multiple_entry_points_ex_1::@multiple_entry_points_1_entry_1[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -165,9 +165,9 @@ flow.executable @different_types_int_ex { // CHECK-LABEL: func @different_types func @different_types(%arg0: tensor<4xf32>) -> tensor<4xi1> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> + // CHECK: %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1> %0 = flow.dispatch @different_types_float_ex::@different_types_float_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> - // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> + // CHECK: %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xi1> %1 = flow.dispatch @different_types_int_ex::@different_types_int_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -222,11 +222,11 @@ flow.executable @nested_ops_ex_2 { // CHECK-LABEL: func @nested_ops func @nested_ops(%arg0: tensor<1x4xi32>) -> tensor<1xi32> { %c4 = constant 4 : index - // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %0 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %1 = flow.dispatch @nested_ops_ex_0::@nested_ops_entry_0[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> - // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> + // CHECK: %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4] (%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index a4506b7c640d..4bbb5536f7b3 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -17,12 +17,12 @@ func @tensor() -> tensor<2x4xf32> { %C = iree.unfoldable_constant dense<1000.0> : tensor<2x4xf32> // %[[C2]] will be handled by a later RematerializeDispatchConstants - // CHECK: flow.dispatch.workgroups[%[[C4wg]], %[[C2wg]], %[[C1wg]]] (%[[outerA]], %[[outerB]], %[[outerC]]) : + // CHECK: flow.dispatch.workgroups[%[[C4wg]], %[[C2wg]], %[[C1wg]]](%[[outerA]], %[[outerB]], %[[outerC]]) : // CHECK-SAME: (tensor<2x3xf32>, tensor<3x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> = - // CHECK-SAME: (%[[A:[0-9a-z]*]] : !flow.dispatch.input<2x3xf32>, - // CHECK-SAME: %[[B:[0-9a-z]*]] : !flow.dispatch.input<3x4xf32>, - // CHECK-SAME: %[[C:[0-9a-z]*]] : !flow.dispatch.input<2x4xf32>, - // CHECK-SAME: %[[OUT:[0-9a-z]*]] : !flow.dispatch.output<2x4xf32>) { + // CHECK-NEXT: (%[[A:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[B:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[C:[0-9a-z]*]]: !flow.dispatch.tensor, + // CHECK-SAME: %[[OUT:[0-9a-z]*]]: !flow.dispatch.tensor) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index @@ -38,23 +38,23 @@ func @tensor() -> tensor<2x4xf32> { // CHECK-NEXT: scf.for %[[J:.*]] = %[[bix_scaled]] to %[[C4]] step %[[bdx_scaled]] { // Canonicalizations not yet powerful enough here. // CHECK-NEXT: %[[MIN_I:.*]] = affine.min{{.*}}(%[[I]]) - // CHECK-NEXT: %[[AA:.*]] = flow.dispatch.input.load %[[A]], + // CHECK-NEXT: %[[AA:.*]] = flow.dispatch.tensor.load %[[A]], // CHECK-SAME: offsets = [%[[I]], %[[C0]]], sizes = [%[[MIN_I]], %[[C3]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<2x3xf32> -> tensor + // CHECK-SAME: !flow.dispatch.tensor -> tensor // // Canonicalizations not yet powerful enough here. // CHECK-NEXT: %[[MIN_J:.*]] = affine.min{{.*}}(%[[J]]) - // CHECK-NEXT: %[[BB:.*]] = flow.dispatch.input.load %[[B]], + // CHECK-NEXT: %[[BB:.*]] = flow.dispatch.tensor.load %[[B]], // CHECK-SAME: offsets = [%[[C0]], %[[J]]], sizes = [%[[C3]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<3x4xf32> -> tensor<3x?xf32> - // CHECK-NEXT: %[[CC:.*]] = flow.dispatch.input.load %[[C]], + // CHECK-SAME: !flow.dispatch.tensor -> tensor<3x?xf32> + // CHECK-NEXT: %[[CC:.*]] = flow.dispatch.tensor.load %[[C]], // CHECK-SAME: offsets = [%[[I]], %[[J]]], sizes = [%[[MIN_I]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: !flow.dispatch.input<2x4xf32> -> tensor + // CHECK-SAME: !flow.dispatch.tensor -> tensor // CHECK-NEXT: %[[RES:.*]] = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[AA]], %[[BB]] : // CHECK-SAME: tensor, tensor<3x?xf32>) outs(%[[CC]] : tensor) -> tensor - // CHECK-NEXT: flow.dispatch.output.store %[[RES]], %[[OUT]], + // CHECK-NEXT: flow.dispatch.tensor.store %[[RES]], %[[OUT]], // CHECK-SAME: offsets = [%[[I]], %[[J]]], sizes = [%[[MIN_I]], %[[MIN_J]]], strides = [%[[C1]], %[[C1]]] : - // CHECK-SAME: tensor -> !flow.dispatch.output<2x4xf32> + // CHECK-SAME: tensor -> !flow.dispatch.tensor // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: flow.return @@ -63,6 +63,8 @@ func @tensor() -> tensor<2x4xf32> { return %E : tensor<2x4xf32> } +// ----- + // CHECK-LABEL: func @tensor2 func @tensor2(%A: tensor, %B: tensor, %C: tensor) -> tensor attributes {iree.module.export} @@ -89,6 +91,8 @@ func @tensor2(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- + // CHECK-LABEL: func @tensor3 func @tensor3(%A: tensor, %B: tensor, %C: tensor) -> tensor attributes {iree.module.export} @@ -115,6 +119,7 @@ func @tensor3(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- // CHECK-LABEL: func @tensor4 func @tensor4(%A: tensor, %B: tensor, %C: tensor) @@ -143,6 +148,8 @@ func @tensor4(%A: tensor, %B: tensor, %C: tensor) return %D: tensor } +// ----- + // CHECK: func @tensor5 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -155,13 +162,13 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[D0:.+]] = dim %[[ARG2]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = dim %[[ARG2]], %[[C1]] - // CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]] (%[[ARG2]]) - // CHECK-SAME: %[[ARG3:.+]] : !flow.dispatch.input - // CHECK-SAME: %[[ARG4:.+]] : !flow.dispatch.output - // CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG3]] + // CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]](%[[ARG2]]) + // CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor + // CHECK-SAME: %[[ARG4:.+]]: !flow.dispatch.tensor + // CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG3]] // CHECK: %[[STOREVAL:.+]] = linalg.generic // CHECK-SAME: outs(%[[LOAD]] : tensor) - // CHECK: flow.dispatch.output.store %[[STOREVAL]], %[[ARG4]] + // CHECK: flow.dispatch.tensor.store %[[STOREVAL]], %[[ARG4]] // linalg.generic is fused inside the dispatch region and becomes a noop but // there is still a use. @@ -187,3 +194,33 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK: return %[[D]], %[[origCC]] return %D, %CC: tensor, tensor } + +func @conv2d(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>) -> tensor<1x112x112x32xf32> { + %0 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %cst = constant 0.000000e+00 : f32 + %1 = linalg.fill(%0, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32> + %2 = linalg.conv_2d_input_nhwc_filter_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x16xf32>, tensor<3x3x16x32xf32>) + outs(%1 : tensor<1x112x112x32xf32>) + -> tensor<1x112x112x32xf32> + return %2 : tensor<1x112x112x32xf32> +} + +// CHECK-LABEL: func @conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf + +func @depthwise_conv2d(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> { + %cst = constant 0.000000e+00 : f32 + %1 = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> + %2 = linalg.fill(%1, %cst) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32> + %4 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%input, %filter : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) outs(%2 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> + return %4 : tensor<1x56x56x96xf32> +} + +// CHECK-LABEL: func @depthwise_conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir index 7b32d8cba903..e32e64940582 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir @@ -14,10 +14,10 @@ func @tensor(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK: flow.dispatch.workgroups // CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output +// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0] // CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1] @@ -33,16 +33,16 @@ func @tensor(%arg0 : tensor, %arg1 : tensor, // CHECK: %[[STEP_X:.+]] = affine.apply #[[MULMAP]]()[%[[WGCOUNT_X]], %[[WGSIZE_X]]] // CHECK: scf.for %[[ARG8:.+]] = %[[OFFSET_X]] // CHECK-SAME: to %{{.+}} step %[[STEP_X]] -// CHECK: %[[LHS:.+]] = flow.dispatch.input.load %[[ARG3]] +// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[ARG3]] // CHECK-SAME: offsets = [%[[ARG7]], %[[C0]]] -// CHECK: %[[RHS:.+]] = flow.dispatch.input.load %[[ARG4]] +// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[ARG4]] // CHECK-SAME: offsets = [%[[C0]], %[[ARG8]]] -// CHECK: %[[INIT:.+]] = flow.dispatch.input.load %[[ARG5]] +// CHECK: %[[INIT:.+]] = flow.dispatch.tensor.load %[[ARG5]] // CHECK-SAME: offsets = [%[[ARG7]], %[[ARG8]]] // CHECK: %[[RESULT:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]] // CHECK-SAME: offsets = [%[[ARG7]], %[[ARG8]]] // ----- @@ -74,19 +74,19 @@ func @generic_op(%A: tensor, %B: tensor) -> tensor { // CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: flow.dispatch.workgroups -// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]] (%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]]) -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output -// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.input.load %[[ARG2]] : !flow.dispatch.input +// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]]) +// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG2]] : !flow.dispatch.tensor // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]] -// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.input.load %[[ARG3]] : !flow.dispatch.input +// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.tensor.load %[[ARG3]] : !flow.dispatch.tensor // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: ins(%[[LOAD2]], %[[LOAD3]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]] // ----- @@ -111,22 +111,22 @@ func @fuse_fill_with_producer(%A : tensor, %B : tensor) -> ten // CHECK: %[[N:.+]] = dim %[[ARG1]], %[[C1]] // CHECK: flow.dispatch.workgroups[%[[N]], %[[M]], %[[C1]]] // CHECK-SAME: (%[[M]], %[[N]], %[[ARG0]], %[[ARG1]]) -// CHECK-SAME: (%[[ARG2:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 // CHECK: scf.for // CHECK: scf.for -// CHECK-DAG: %[[LHS_TILE:.+]] = flow.dispatch.input.load %[[ARG4]] -// CHECK-DAG: %[[RHS_TILE:.+]] = flow.dispatch.input.load %[[ARG5]] +// CHECK-DAG: %[[LHS_TILE:.+]] = flow.dispatch.tensor.load %[[ARG4]] +// CHECK-DAG: %[[RHS_TILE:.+]] = flow.dispatch.tensor.load %[[ARG5]] // CHECK-DAG: %[[INIT_TILE:.+]] = linalg.init_tensor // CHECK: %[[FILL_TILE:.+]] = linalg.fill(%[[INIT_TILE]], %[[ZERO]]) // CHECK: %[[RESULT_TILE:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL_TILE]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT_TILE]], %[[ARG6]] +// CHECK: flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG6]] // CHECK: flow.return // CHECK: } @@ -166,17 +166,17 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] : index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // CHECK: %[[ONE:.+]] = constant 1.0 -// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.input.load %[[ARG2]] +// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[ARG2]] // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: ins(%[[INPUT]] : tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG5]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG5]] // CHECK: flow.return // CHECK: } // CHECK: flow.dispatch.workgroups[%[[N]], %[[M]], %[[C1]]] @@ -190,23 +190,23 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor -// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] : !flow.dispatch.input -// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]] : !flow.dispatch.output) { +// NOCHECK-SAME: (%[[ARG2:[a-zA-Z0-9_]+]]: index +// NOCHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// NOCHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// NOCHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor) { // NOCHECK: %[[ZERO:.+]] = constant 0.0 // NOCHECK: scf.for // NOCHECK: scf.for -// NOCHECK-DAG: %[[LHS_TILE_2:.+]] = flow.dispatch.input.load %[[ARG6]] -// NOCHECK-DAG: %[[RHS_TILE_2:.+]] = flow.dispatch.input.load %[[ARG5]] +// NOCHECK-DAG: %[[LHS_TILE_2:.+]] = flow.dispatch.tensor.load %[[ARG6]] +// NOCHECK-DAG: %[[RHS_TILE_2:.+]] = flow.dispatch.tensor.load %[[ARG5]] // NOCHECK-DAG: %[[INIT_TILE_2:.+]] = linalg.init_tensor // NOCHECK: %[[FILL_TILE:.+]] = linalg.fill(%[[INIT_TILE]], %[[ZERO]]) // NOCHECK: %[[RESULT_TILE_2:.++]]] = linalg.matmul // NOCHECK-SAME: ins(%[[LHS_TILE_2]], %[[RHS_TILE_2]] : tensor, tensor) // NOCHECK: outs(%[[FILL_TILE_2]] : tensor) -// NOCHECK: flow.dispatch.output.store %[[RESULT_TILE_2]], %[[ARG7]] +// NOCHECK: flow.dispatch.tensor.store %[[RESULT_TILE_2]], %[[ARG7]] // NOCHECK: flow.return // NOCHECK: } @@ -228,22 +228,22 @@ func @dot_general_lower() attributes {iree.module.export} { } // CHECK-LABEL: func @dot_general_lower // CHECK: flow.dispatch.workgroups[%{{.+}}, %{{.+}}, %{{.+}}] -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] : !flow.dispatch.input<1x1x2xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] : !flow.dispatch.input<2x3xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] : !flow.dispatch.output<1x3xf32> +// CHECK-NEXT: %[[ARG0:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor // CHECK-DAG: %[[ZERO:.+]] = constant 0.0 -// CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG0]] +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG0]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] // CHECK: scf.for // CHECK: scf.for // CHECK-DAG: %[[LHS:.+]] = subtensor %[[RESHAPE]] -// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.input.load %[[ARG1]] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[ARG1]] // CHECK: %[[INIT:.+]] = linalg.init // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) // CHECK: %[[RESULT:.+]] = linalg.matmul // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor<2x?xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: flow.dispatch.output.store %[[RESULT]], %[[ARG2]] +// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG2]] // ----- @@ -262,12 +262,12 @@ func @reshapeop(%arg0: tensor) -> tensor // CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[WORKLOAD:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]] // CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]] (%[[ARG0]]) -// CHECK-SAME: %[[ARG1:.+]] : !flow.dispatch.input -// CHECK-SAME: %[[ARG2:.+]] : !flow.dispatch.output -// CHECK: %[[LOAD:.+]] = flow.dispatch.input.load %[[ARG1]] +// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]](%[[ARG0]]) +// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor +// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.tensor +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG1]] // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] [#[[MAP1]]] -// CHECK: flow.dispatch.output.store %[[RESHAPE]], %[[ARG2]] +// CHECK: flow.dispatch.tensor.store %[[RESHAPE]], %[[ARG2]] // ----- @@ -355,8 +355,8 @@ func @always_fuse_reshape // ----- -func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, - %arg3 : index, %arg4 : index, %arg5 : index ) -> tensor { +func @pad_test(%arg0: tensor, %arg1: tensor, %arg2: index, + %arg3: index, %arg4: index, %arg5: index) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index %0 = tensor.extract %arg1[] : tensor @@ -366,7 +366,7 @@ func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5] %5 = linalg.init_tensor [%3, %4] : tensor %6 = linalg.fill(%5, %0) : tensor, f32 -> tensor - %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor -> tensor + %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor{%1, %2} -> tensor{%3, %4} return %7 : tensor } @@ -390,6 +390,6 @@ func @pad_test(%arg0 : tensor, %arg1 : tensor, %arg2 : index, // CHECK-DAG: %[[VAL:.+]] = tensor.extract // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor // CHECK: %[[RETURN:.+]] = linalg.fill(%[[INIT]], %[[VAL]]) -// CHECK: flow.dispatch.output.store %[[RETURN]] +// CHECK: flow.dispatch.tensor.store %[[RETURN]] // CHECK-NEXT: flow.return // CHECK: flow.tensor.update %[[ARG0]], %[[RESULT]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir index 292b28052dd3..d7ec93eaf468 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir @@ -11,7 +11,7 @@ func @noFolding(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @noFolding // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %0 = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } @@ -38,7 +38,7 @@ func @elementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @elementwiseOps // CHECK: %[[WORKLOAD0:.+]] = constant 4 -// CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> // CHECK-NEXT: %2 = mhlo.subtract %1, %arg1 : tensor<4xf32> // CHECK-NEXT: %3 = mhlo.multiply %arg1, %2 : tensor<4xf32> @@ -69,17 +69,17 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: func @interleavedDot // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index -// CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.add %arg1, %arg1 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 16 : index -// CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index](%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index](%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 16 : index -// CHECK-NEXT: %[[R2:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index](%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[R2:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index](%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.multiply %arg1, %arg2 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir index 8796e2431222..630dd616ed6c 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir @@ -1,22 +1,18 @@ -// RUN: iree-opt -split-input-file -iree-flow-form-streams %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-flow-form-streams -cse -canonicalize %s | IreeFileCheck %s // CHECK-LABEL: func @outsideTieShape func @outsideTieShape(%arg0: tensor {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?]> {iree.reflection = {}}) -> (tensor {iree.reflection = {}}) attributes {iree.module.export} { - // CHECK: %[[WORKLOAD0:.+]] = constant 0 : index %c0 = constant 0 : index - // CHECK-NEXT: %0 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> - %2 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> - // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 1 : index - %c1 = constant 1 : index - // CHECK-NEXT: %1 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor, %arg3 = %1 : index, %arg4 = %[[WORKLOAD0]] : index, %arg5 = %0 : tensor) -> tensor { - // CHECK-NEXT: %3 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?]> - // CHECK-NEXT: %4 = shapex.tie_shape %arg2, %3 : tensor, !shapex.ranked_shape<[?]> - // CHECK-NEXT: %5 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%arg4] (%arg4, %4) : (index, tensor) -> tensor - // CHECK-NEXT: flow.return %5 : tensor + // CHECK-DAG: %[[DIM:.+]] = shapex.ranked_dim %arg1[0] + %dim = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index + // CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%[[DIM]], %arg0) : (index, tensor{%[[DIM]]}) -> tensor{%[[DIM]]} = + // CHECK-NEXT: (%[[INNER_DIM:.+]]: index, %[[CAPTURE:.+]]: tensor) -> tensor { + // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 0 : index + // CHECK-NEXT: %[[INNER_RET:.+]] = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%[[WORKLOAD0]]](%[[INNER_DIM]], %[[CAPTURE]]) : (index, tensor{%[[INNER_DIM]]}) -> tensor{%[[INNER_DIM]]} + // CHECK-NEXT: flow.return %[[INNER_RET]] : tensor // CHECK-NEXT: } - %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0] (%c0, %2) : (index, tensor) -> tensor - // CHECK-NEXT: return %2 : tensor + %15 = flow.dispatch @main_ex_dispatch_1::@main_ex_dispatch_1[%c0](%dim, %arg0) : (index, tensor{%dim}) -> (tensor{%dim}) + // CHECK-NEXT: return %[[RET]] : tensor return %15 : tensor } @@ -35,15 +31,16 @@ flow.executable @outerOps_ex_dispatch_0 { } // CHECK-LABEL: func @outerOps func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index - %cst = constant 4 : index - // CHECK-NEXT: %0 = addf %arg0, %arg0 : tensor<4xf32> + // CHECK: %0 = addf %arg0, %arg0 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %1 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + %cst = constant 4 : index + // CHECK-NEXT: %1 = flow.ex.stream.fragment(%0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%[[WORKLOAD]]](%[[INNER_ARG]]) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst] (%0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = flow.dispatch @outerOps_ex_dispatch_0::@outerOps_rgn_dispatch_0[%cst](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: %2 = addf %1, %1 : tensor<4xf32> %2 = addf %1, %1 : tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> @@ -54,15 +51,16 @@ func @outerOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @nondependentOuterOps( func @nondependentOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index // CHECK-NEXT: %[[ADD1:.+]] = addf %arg0, %arg0 : tensor<4xf32> %add1 = addf %arg0, %arg0 : tensor<4xf32> - // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %[[ADD1]] : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%arg1] (%arg2, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst] (%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%arg1] (%[[D1]], %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst] (%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[S:.+]] = flow.ex.stream.fragment(%arg0, %[[ADD1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1[%[[WORKLOAD]]](%arg1, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%cst](%arg0, %arg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2[%[[WORKLOAD]]](%[[D1]], %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%cst](%d1, %add1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %[[D2]] : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[ADD2:.+]] = addf %[[S]], %arg0 : tensor<4xf32> @@ -86,20 +84,23 @@ flow.executable @interleavedOuterOps_ex_dispatch_0 { } // CHECK-LABEL: func @interleavedOuterOps( func @interleavedOuterOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = addf %0, %0 : tensor<4xf32> %1 = addf %0, %0 : tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%1) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%[[WORKLOAD2]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst] (%1) : (tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @interleavedOuterOps_ex_dispatch_0::@interleavedOuterOps_rgn_dispatch_0[%cst](%1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -116,13 +117,14 @@ flow.executable @independentOps_ex_dispatch_0 { } // CHECK-LABEL: func @independentOps( func @independentOps(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%arg1] (%arg2) - %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%arg1] (%arg2) - %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0:2 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) + %0 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-DAG: = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%[[WORKLOAD]]](%arg1) + %1 = flow.dispatch @independentOps_ex_dispatch_0::@independentOps_rgn_dispatch_1[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %{{.+}}, %{{.+}} // CHECK-NEXT: } // CHECK-NEXT: return %{{.+}}, %{{.+}} @@ -166,17 +168,18 @@ flow.executable @interleavedDot_ex_dispatch_2 { } // CHECK-LABEL: func @interleavedDot( func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index %cst = constant 16 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> = + // CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index + // CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst] (%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst] (%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst] (%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %0 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_rgn_dispatch_0[%cst](%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> + %1 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_rgn_dispatch_1[%cst](%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_rgn_dispatch_2[%cst](%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: return %0 : tensor<4x4xf32> return %2 : tensor<4x4xf32> } @@ -207,20 +210,23 @@ flow.executable @caller_ex_dispatch_1 { } // CHECK-LABEL: func @caller( func @caller(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%[[WORKLOAD1]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @caller_ex_dispatch_0::@caller_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> %1 = call @callee(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>, %arg3 = %1 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%arg1] (%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %2 = flow.ex.stream.fragment(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 4 : index + // CHECK-NEXT: %3 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%[[WORKLOAD2]]](%arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %3 : tensor<4xf32> // CHECK-NEXT: } - %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst] (%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = flow.dispatch @caller_ex_dispatch_1::@caller_rgn_dispatch_1[%cst](%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> return %2 : tensor<4xf32> } @@ -235,62 +241,55 @@ flow.executable @callee_ex_dispatch_0 { } // CHECK-LABEL: func @callee( func @callee(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index %cst = constant 4 : index - // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = + // CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index + // CHECK-NEXT: %1 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } - %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @callee_ex_dispatch_0::@callee_rgn_dispatch_0[%cst](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> return %0 : tensor<4xf32> } +// ----- + // CHECK-LABEL: @simple_unary -// CHECK-SAME: %[[A0:[^:[:space:]]+]]: tensor -// CHECK-SAME: %[[A1:[^:[:space:]]+]]: !shapex.ranked_shape<[?,?]> +// CHECK-SAME: %[[A0:.+]]: tensor +// CHECK-SAME: %[[A1:.+]]: !shapex.ranked_shape<[?,?]> func @simple_unary(%arg0: tensor, %arg1: !shapex.ranked_shape<[?,?]>) -> (tensor, !shapex.ranked_shape<[?,?]>) { - %0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index - %1 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index - %2 = muli %0, %1 : index - %4 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index - %5 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index - %3 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?,?]> + // CHECK-DAG: %[[DIM0:.+]] = shapex.ranked_dim %arg1[0] + %dim0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,?]> -> index + // CHECK-DAG: %[[DIM1:.+]] = shapex.ranked_dim %arg1[1] + %dim1 = shapex.ranked_dim %arg1[1] : !shapex.ranked_shape<[?,?]> -> index + // CHECK: %[[SZ:.+]] = muli + %2 = muli %dim0, %dim1 : index // Verify that the fragment captures the tie_shapes and marshals the indices // in as loose index values (not as ranked_shape types). - // CHECK: %[[S:.+]] = flow.ex.stream.fragment - // CHECK-SAME: %[[STREAM_A0:[^:[:space:]]+]] = %[[A0]] : tensor, - // CHECK-SAME: %[[STREAM_A1:[^:[:space:]]+]] = %[[UNUSED0:[^:[:space:]]+]] : index, - // CHECK-SAME: %[[STREAM_A2:[^:[:space:]]+]] = %[[UNUSED1:[^:[:space:]]+]] : index, - // CHECK-SAME: { - // CHECK: %[[STREAM_RS0:.+]] = shapex.make_ranked_shape %[[STREAM_A1]], %[[STREAM_A2]] - // CHECK: %[[STREAM_R0:.+]] = shapex.tie_shape %[[STREAM_A0]], %[[STREAM_RS0]] - // CHECK: %[[STREAM_R1:.+]] = flow.dispatch @simple_unary_ex_dispatch_0 - // CHECK: %[[STREAM_R2:.+]] = shapex.tie_shape %[[STREAM_R1]], %[[STREAM_RS0]] - // CHECK: return %[[STREAM_R2]] + // CHECK: %[[S:.+]] = flow.ex.stream.fragment(%[[SZ]], %[[A0]], %[[DIM0]], %[[DIM1]]) : (index, tensor{%[[DIM0]], %[[DIM1]]}, index, index) -> tensor{%[[DIM0]], %[[DIM1]]} = + // CHECK: (%arg2: index, %arg3: tensor, %arg4: index, %arg5: index) -> tensor { + // CHECK: %[[STREAM_RET:.+]] = flow.dispatch @simple_unary_ex_dispatch_0{{.+}}[%arg2](%arg3, %arg4, %arg5) : (tensor{%arg4, %arg5}, index, index) -> tensor{%arg4, %arg5} + // CHECK: return %[[STREAM_RET]] // CHECK: } - %6 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2] (%3, %4, %5) : (tensor, index, index) -> tensor - %7 = shapex.tie_shape %6, %arg1 : tensor, !shapex.ranked_shape<[?,?]> - return %7, %arg1 : tensor, !shapex.ranked_shape<[?,?]> + %3 = flow.dispatch @simple_unary_ex_dispatch_0::@simple_unary_ex_dispatch_0[%2](%arg0, %dim0, %dim1) : (tensor{%dim0, %dim1}, index, index) -> tensor{%dim0, %dim1} + return %3, %arg1 : tensor, !shapex.ranked_shape<[?,?]> } - // ----- // CHECK-LABEL: @bad_input_ordering func @bad_input_ordering() -> (tensor, tensor) { - // CHECK: %[[W:.+]] = constant 1 : index - %workload = constant 1 : index // CHECK: %[[S:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK: flow.return - // CHECK-NEXT: } - %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload] () : () -> tensor - // CHECK: %[[C2:.+]] = constant 2 : i32 - %c2 = constant 2 : i32 + // CHECK: = constant 1 : index + // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + %workload = constant 1 : index + %0 = flow.dispatch @dispatch_1::@dispatch_1[%workload]() : () -> tensor + // CHECK: %[[C2:.+]] = constant 2 : i32 // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload] (%c2) : (i32) -> tensor + %c2 = constant 2 : i32 + %1 = flow.dispatch @dispatch_2::@dispatch_2[%workload](%c2) : (i32) -> tensor + // CHECK: flow.return return %0, %1 : tensor, tensor } @@ -298,24 +297,20 @@ func @bad_input_ordering() -> (tensor, tensor) { // CHECK-LABEL: @interstream_readback func @interstream_readback() -> (tensor, tensor, tensor<2xf32>) { - // CHECK: %[[W:.+]] = constant 1 : index %w = constant 1 : index // CHECK: %[[S1:.+]]:2 = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - // Could be returned in either order - // CHECK-NEXT: flow.return - // CHECK-NEXT: } - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] () : () -> tensor - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] () : () -> tensor + // CHECK: %[[W:.+]] = constant 1 : index + // CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 + // Could be returned in either order + // CHECK-NEXT: flow.return + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w]() : () -> tensor + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w]() : () -> tensor // CHECK: %[[READBACK:.+]] = flow.tensor.load %[[S1]] %readback = flow.tensor.load %d1 : tensor // CHECK: %[[S2:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { // CHECK-DAG: %[[D3:.+]] = flow.dispatch @dispatch_3::@dispatch_3 // CHECK: flow.return %[[D3]] - // CHECK-NEXT: } %d3 = flow.dispatch @dispatch_3::@dispatch_3[%w] (%readback) : (i32) -> tensor<2xf32> // CHECK: return %[[S1]]# // CHECK-SAME: %[[S1]]# @@ -325,31 +320,29 @@ func @interstream_readback() -> (tensor, tensor, tensor<2xf32>) { } // ----- + // CHECK-LABEL: @ordering func @ordering(%w : index) -> (tensor, tensor, tensor) { - // CHECK: %[[C1:.+]] = constant 1 %c1 = constant 1 : i32 // CHECK: %[[S1:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 - // CHECK-NEXT: flow.return %[[D1]] - // CHECK-NEXT: } - %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w] (%c1) : (i32) -> (tensor) + // CHECK: %[[C1:.+]] = constant 1 + // CHECK-DAG: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1 + // CHECK-NEXT: flow.return %[[D1]] + %d1 = flow.dispatch @dispatch_1::@dispatch_1[%w](%c1) : (i32) -> tensor // CHECK: %[[SE_USER:.+]] = iree.do_not_optimize(%[[S1]]) %side_effecting_user = iree.do_not_optimize(%d1) : tensor - // CHECK: %[[C2:.+]] = constant 2 %c2 = constant 2 : i32 // CHECK: %[[S2:.+]] = flow.ex.stream.fragment - // CHECK-SAME: { - // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 - // CHECK-NEXT: flow.return %[[D2]] - // CHECK-NEXT: } - %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w] (%c2) : (i32) -> (tensor) + // CHECK: %[[C2:.+]] = constant 2 + // CHECK-DAG: %[[D2:.+]] = flow.dispatch @dispatch_2::@dispatch_2 + // CHECK-NEXT: flow.return %[[D2]] + %d2 = flow.dispatch @dispatch_2::@dispatch_2[%w](%c2) : (i32) -> tensor // CHECK: return %[[S1]], %[[S2]], %[[SE_USER]] return %d1, %d2, %side_effecting_user : tensor, tensor, tensor } // ----- + // CHECK-LABEL: @metadata_only func @metadata_only(%t: tensor) -> (tensor, !shapex.ranked_shape<[?]>) { // CHECK-NOT: flow.ex.stream.fragment diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir index 1872945f6633..811e3e44d672 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/hoist_unstreamable_ops.mlir @@ -8,24 +8,24 @@ func @constants() { // CHECK-DAG: constant 4 : index // CHECK-DAG: constant 5 : index // CHECK-DAG: constant 6 : index - // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]] () : () -> tensor - // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]] () : () -> tensor + // CHECK: flow.dispatch @dispatch0::@dispatch0[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch1::@dispatch1[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch2::@dispatch2[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch3::@dispatch3[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch4::@dispatch4[%[[W]]]() : () -> tensor + // CHECK: flow.dispatch @dispatch5::@dispatch5[%[[W]]]() : () -> tensor %w = constant 1 : index - %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor) + %d0 = flow.dispatch @dispatch0::@dispatch0[%w]() : () -> tensor %c2 = constant 2 : index - %d1 = flow.dispatch @dispatch1::@dispatch1[%w] () : () -> (tensor) + %d1 = flow.dispatch @dispatch1::@dispatch1[%w]() : () -> tensor %c3 = constant 3 : index - %d2 = flow.dispatch @dispatch2::@dispatch2[%w] () : () -> (tensor) + %d2 = flow.dispatch @dispatch2::@dispatch2[%w]() : () -> tensor %c4 = constant 4 : index - %d3 = flow.dispatch @dispatch3::@dispatch3[%w] () : () -> (tensor) + %d3 = flow.dispatch @dispatch3::@dispatch3[%w]() : () -> tensor %c5 = constant 5 : index - %d4 = flow.dispatch @dispatch4::@dispatch4[%w] () : () -> (tensor) + %d4 = flow.dispatch @dispatch4::@dispatch4[%w]() : () -> tensor %c6 = constant 6 : index - %d5 = flow.dispatch @dispatch5::@dispatch5[%w] () : () -> (tensor) + %d5 = flow.dispatch @dispatch5::@dispatch5[%w]() : () -> tensor return } @@ -38,16 +38,12 @@ func @dynamic_tensor(%input: tensor, %shape: !shapex.ranked_shape<[?,?] // CHECK-DAG: %[[W:.+]] = constant 1 // CHECK-DAG: %[[DIM0:.+]] shapex.ranked_dim %[[SHAPE]][0] // CHECK-DAG: %[[DIM1:.+]] shapex.ranked_dim %[[SHAPE]][1] - // CHECK: %[[TIE0:.+]] = shapex.tie_shape %[[INPUT]], %[[SHAPE]] // CHECK: %[[D:.+]] = flow.dispatch - // CHECK: %[[TIE1:.+]] = shapex.tie_shape %[[D]], %[[SHAPE]] %w = constant 1 : index - %tie0 = shapex.tie_shape %input, %shape : tensor, !shapex.ranked_shape<[?,?]> %dim0 = shapex.ranked_dim %shape[0] : !shapex.ranked_shape<[?,?]> -> index %dim1 = shapex.ranked_dim %shape[1] : !shapex.ranked_shape<[?,?]> -> index - %d = flow.dispatch @dispatch::@dispatch[%w] (%tie0, %dim0, %dim1) : (tensor, index, index) -> tensor - %tie1 = shapex.tie_shape %d, %shape : tensor, !shapex.ranked_shape<[?,?]> - return %tie1, %shape : tensor, !shapex.ranked_shape<[?,?]> + %d = flow.dispatch @dispatch::@dispatch[%w](%input, %dim0, %dim1) : (tensor{%dim0, %dim1}, index, index) -> tensor{%dim0, %dim1} + return %d, %shape : tensor, !shapex.ranked_shape<[?,?]> } // ----- @@ -87,7 +83,7 @@ func @dependencies_with_dispatch() { %c2 = constant 2 : index %ct3 = constant dense<3> : tensor // CHECK: flow.dispatch - %d0 = flow.dispatch @dispatch0::@dispatch0[%w] () : () -> (tensor) + %d0 = flow.dispatch @dispatch0::@dispatch0[%w]() : () -> tensor // CHECK: addi %add0 = addi %d0, %ct3 : tensor return diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir index 933f3b181ae2..a54bb7fd7674 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions.mlir @@ -13,7 +13,7 @@ func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -29,7 +29,7 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> @@ -49,7 +49,7 @@ func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %0 = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = mhlo.subtract %1, %arg1 : tensor<4xf32> @@ -73,21 +73,21 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: %[[WORKLOAD2:.+]] = constant 16 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.add %arg1, %arg1 : tensor<4x4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD1]] : index] - // CHECK-SAME: (%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %[[R0]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %1 = "mhlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK: %[[R2:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD2]] : index] - // CHECK-SAME: (%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-SAME: (%arg1 = %[[R1]] : tensor<4x4xf32>, %arg2 = %arg0 : tensor<4x4xf32>) -> (tensor<4x4xf32>) { // CHECK-NEXT: %3 = mhlo.multiply %arg1, %arg2 : tensor<4x4xf32> %2 = mhlo.multiply %1, %arg0 : tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> @@ -103,7 +103,7 @@ func @caller(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = call @callee(%1) : (tensor<4xf32>) -> tensor<4xf32> @@ -120,7 +120,7 @@ func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.multiply %arg1, %arg1 : tensor<4xf32> %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -138,7 +138,7 @@ func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> tensor<4xf32> + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> (tensor<4xf32>) // CHECK-NEXT: = "mhlo.reduce"(%arg1, %arg2) %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1 : tensor, %arg2 : tensor): diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir index 5ed4cea47ed6..ae8744462327 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-consumer-only-fusion %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-consumer-only-fusion -canonicalize %s | IreeFileCheck %s func @simpleDotAddMul (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>, diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir index 3389a5aef4e2..eb1e8e468b4a 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_hlo.mlir @@ -5,7 +5,7 @@ func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4xf32> %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -55,7 +55,7 @@ func @callee(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[R0:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = mhlo.multiply %arg1, %arg1 : tensor<4xf32> %0 = mhlo.multiply %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> @@ -87,8 +87,9 @@ func @single_reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>, %arg2 = %[[INITIAL]] : tensor) -> tensor<4xf32> - // CHECK-NEXT: = "mhlo.reduce"(%arg1, %arg2) + // CHECK-SAME: (%arg1 = %arg0 : tensor<4x8xf32>) -> (tensor<4xf32>) + // CHECK-NEXT: %[[CST_0:.+]] = constant dense<0.0 + // CHECK-NEXT: = "mhlo.reduce"(%arg1, %[[CST_0]]) %1 = "mhlo.reduce"(%arg0, %0) ( { ^bb0(%arg1 : tensor, %arg2 : tensor): %2 = mhlo.add %arg1, %arg2 : tensor @@ -110,8 +111,10 @@ func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tens // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 4 : index // CHECK: %[[RESULT:.+]]:2 = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD0]] : index] - // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>, %arg4 = %[[INITIALA]] : tensor, %arg5 = %[[INITIALB]] : tensor) -> (tensor<4xf32>, tensor<4xf32>) - // CHECK-NEXT: = "mhlo.reduce"(%arg2, %arg3, %arg4, %arg5) + // CHECK-SAME: (%arg2 = %arg0 : tensor<4x8xf32>, %arg3 = %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) + // CHECK-NEXT: %[[CST_0:.+]] = constant dense<0.0 + // CHECK-NEXT: %[[CST_1:.+]] = constant dense<1.0 + // CHECK-NEXT: = "mhlo.reduce"(%arg2, %arg3, %[[CST_0]], %[[CST_1]]) %2, %3 = "mhlo.reduce"(%arg0, %arg1, %0, %1) ( { ^bb0(%arg0_lhs : tensor, %arg1_lhs : tensor, %arg0_rhs : tensor, %arg1_rhs : tensor): %4 = mhlo.add %arg0_lhs, %arg0_rhs : tensor @@ -127,7 +130,7 @@ func @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tens // ----- -// CHECK-LABEL: @clone_broadcas +// CHECK-LABEL: @clone_broadcast func @clone_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { %splatCst = constant dense<1.0> : tensor // CHECK: flow.dispatch.region @@ -140,7 +143,7 @@ func @clone_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor< // CHECK: mhlo.add %0 = "mhlo.broadcast"(%splatCst) {broadcast_sizes = dense<[4, 4]> : tensor<2xi64>} : (tensor) -> tensor<4x4xf32> %1 = "mhlo.add"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "mhlo.dot"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = "mhlo.dot"(%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> %3 = "mhlo.add"(%0, %2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %3: tensor<4x4xf32> } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir index ef4584275090..1290461ccfa8 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_linalg.mlir @@ -25,29 +25,27 @@ func @constant_capture(%arg0 : tensor<10x20xf32>) -> tensor<10x20xf32> { } // CHECK: func @constant_capture // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32> -// CHECK-DAG: %[[CST1:.+]] = constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST2:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> // CHECK-DAG: %[[CST3:.+]] = constant dense<[1.000000e+00, 2.000000e+00, // CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, // CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> // CHECK: %[[RESULT:.+]] = flow.dispatch.region[%{{.+}} : index]( // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST2]] -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST3]] -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] = %[[CST1]] -// CHECK-SAME: ) -> tensor<10x20xf32> { -// CHECK: %[[T0:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST3]] +// CHECK-SAME: ) -> (tensor<10x20xf32>) { +// CHECK-DAG: %[[CST1:.+]] = constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> +// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf32> // CHECK: %[[RETURN:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG3]] +// CHECK-SAME: ins(%[[ARG1]], %[[CST2]], %[[ARG2]] // CHECK-SAME: ) outs(%[[T0]] : tensor<10x20xf32>) { // CHECK-NEXT: ^{{[a-zA-Z0-9]+}}( +// CHECK-SAME: %[[ARG3:.[a-zA-Z0-9_]+]]: f32, +// CHECK-SAME: %[[ARG4:.[a-zA-Z0-9_]+]]: f32, // CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG7:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG8:.[a-zA-Z0-9_]+]]: f32) -// CHECK: %[[T0:.+]] = addf %[[ARG5]], %[[ARG4]] -// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG6]] -// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG7]] +// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32) +// CHECK: %[[T0:.+]] = addf %[[ARG3]], %[[CST1]] +// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG4]] +// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG5]] // CHECK: linalg.yield %[[T2]] // CHECK: } // CHECK: flow.return %[[RETURN]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir index fd7566cb537a..4212eaf00e76 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_shapes.mlir @@ -15,7 +15,7 @@ func @singleDispatchWithShapes(%arg0 : tensor, // make generic. // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[UNUSED_WORKLOAD:.+]] : index]( // CHECK-SAME: %[[CA2:.+]] = %[[A2]] : !shapex.ranked_shape<[?,4]>, - // CHECK-SAME: %[[CA0:.+]] = %[[TS0]] : tensor, + // CHECK-SAME: %[[CA0:.+]] = %{{.+}} : tensor, // CHECK-SAME: %[[CA1:.+]] = %[[A1]] : !shapex.ranked_shape<[?,4]>) // Dispatch region should contain captured tie_shapes. // CHECK: %[[R1:.+]] = shapex.tie_shape %[[CA0]], %[[CA1]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir index 36568aba73f7..b7219bd2ad3b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_std_fusion.mlir @@ -13,7 +13,7 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 // CHECK-NEXT: %[[R1:.+]] = flow.dispatch.region // CHECK-SAME: [%[[WORKLOAD]] : index] - // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { + // CHECK-SAME: (%arg1 = %arg0 : tensor<4xf32>) -> (tensor<4xf32>) { // CHECK-NEXT: %1 = addf %arg1, %arg1 : tensor<4xf32> %0 = addf %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir index df5a57fce9dd..855df02f3182 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/inject_dispatch_tracing.mlir @@ -5,8 +5,8 @@ func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK-NEXT: return %[[RET0]] return %0 : tensor<4xf32> @@ -20,13 +20,13 @@ func @multiDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> { %c4 = constant 4 : index // CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4] (%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> - %0 = flow.dispatch @ex::@entry0[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32> // CHECK: flow.tensor.trace {key = "ex::entry1 inputs"} %[[RET0]] : tensor<4xf32> - // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4] (%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32> - %1 = flow.dispatch @ex::@entry1[%c4] (%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4](%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32> + %1 = flow.dispatch @ex::@entry1[%c4](%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.tensor.trace {key = "ex::entry1 outputs"} %[[RET1]] : tensor<4xf32> // CHECK: return %[[RET1]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir index a58a4f7a96b2..81e0e639d0c2 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions2.mlir @@ -5,13 +5,13 @@ // CHECK-SAME: signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, // CHECK-SAME: workgroup_rank = 2 : index} // CHECK: func @staticShapeDispatch_dispatch_0( -// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.input<8x4xf32>, -// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.output<4x8xf32>) { -// CHECK-DAG: %[[ARG_VALUE:.+]] = flow.dispatch.input.load %[[ARG]] : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> -// CHECK-DAG: %[[ARG_SHAPE:.+]] = flow.dispatch.shape %[[ARG]] : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> -// CHECK-DAG: %[[RET_SHAPE:.+]] = flow.dispatch.shape %[[RET]] : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> +// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.tensor, +// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.tensor) { +// CHECK-DAG: %[[ARG_VALUE:.+]] = flow.dispatch.tensor.load %[[ARG]] : !flow.dispatch.tensor -> tensor<8x4xf32> +// CHECK-DAG: %[[ARG_SHAPE:.+]] = flow.dispatch.shape %[[ARG]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> +// CHECK-DAG: %[[RET_SHAPE:.+]] = flow.dispatch.shape %[[RET]] : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> // CHECK-NEXT: %[[RET_VALUE:.+]] = "test.sink"(%[[ARG_VALUE]], %[[ARG_SHAPE]], %[[RET_SHAPE]]) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32> -// CHECK-NEXT: flow.dispatch.output.store %[[RET_VALUE]], %[[RET]] : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> +// CHECK-NEXT: flow.dispatch.tensor.store %[[RET_VALUE]], %[[RET]] : tensor<4x8xf32> -> !flow.dispatch.tensor // CHECK-NEXT: return // CHECK-NEXT: } @@ -24,15 +24,15 @@ func @staticShapeDispatch(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: %[[RET:.+]] = flow.dispatch @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> - %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> + %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> tensor<4x8xf32> = ( + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> %ret_value = "test.sink"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor flow.return } // CHECK-NEXT: return %[[RET]] @@ -62,28 +62,28 @@ func @dispatchFnMuli(%arg0 : tensor<8x4xf32>) -> tensor<8x4xf32> { %y = constant 50 : index // CHECK: %[[RET0:.+]] = flow.dispatch @dispatchFnMuli_dispatch_0::@dispatchFnMuli_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> + // CHECK-SAME: ](%[[ARG0]]) : (tensor<8x4xf32>) -> tensor<4x8xf32> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> %ret_value = "test.sink1"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor flow.return } // CHECK: %[[RET1:.+]] = flow.dispatch @dispatchFnMuli_dispatch_1::@dispatchFnMuli_dispatch_1[ // CHECK-SAME: %[[Y]], %[[X]] - // CHECK-SAME: ] (%[[RET0]]) : (tensor<4x8xf32>) -> tensor<8x4xf32> + // CHECK-SAME: ](%[[RET0]]) : (tensor<4x8xf32>) -> tensor<8x4xf32> %1 = flow.dispatch.workgroups[%y, %x](%0) : (tensor<4x8xf32>) -> (tensor<8x4xf32>) = ( - %arg : !flow.dispatch.input<4x8xf32>, %ret : !flow.dispatch.output<8x4xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { - %arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<4x8xf32> -> tensor<8x4xf32> - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<4x8xf32> -> !shapex.ranked_shape<[4,8]> - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<8x4xf32> -> !shapex.ranked_shape<[8,4]> + %arg_value = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[4,8]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[8,4]> %ret_value = "test.sink2"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[4,8]>, !shapex.ranked_shape<[8,4]>) -> (tensor<8x4xf32>) - flow.dispatch.output.store %ret_value, %ret : tensor<8x4xf32> -> !flow.dispatch.output<8x4xf32> + flow.dispatch.tensor.store %ret_value, %ret : tensor<8x4xf32> -> !flow.dispatch.tensor flow.return } // CHECK-NEXT: return %[[RET1]] @@ -100,7 +100,7 @@ func @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { flow.return } @@ -115,7 +115,7 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %y = constant 50 : index // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0 %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( - %arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { flow.return } @@ -129,8 +129,8 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { // CHECK-SAME: signature = (tensor<7x?x24x?xf32>) -> tensor, // CHECK-SAME: workgroup_rank = 2 : index} // CHECK: func @dynamicShapeDispatch_dispatch_0( -// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.input<7x?x24x?xf32>, -// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.output, +// CHECK-SAME: %[[ARG:.+]]: !flow.dispatch.tensor, +// CHECK-SAME: %[[RET:.+]]: !flow.dispatch.tensor, // CHECK-SAME: %[[IN_ARG_DIM1:.+]]: index, %[[IN_ARG_DIM3:.+]]: index, %[[IN_RET_DIM0:.+]]: index, %[[IN_RET_DIM1:.+]]: index) { // CHECK: %[[IN_ARG_SHAPE:.+]] = shapex.make_ranked_shape %[[IN_ARG_DIM1]], %[[IN_ARG_DIM3]] @@ -147,10 +147,10 @@ func @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { // CHECK-DAG: %[[RET_DIM1:.+]] = shapex.ranked_dim %[[RET_SHAPE]][1] // CHECK-NEXT: "test.sink_shape_ret"(%[[RET_DIM0]], %[[RET_DIM1]]) -// CHECK: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG_SHAPED]] +// CHECK: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG_SHAPED]] // CHECK-NEXT: %[[ARG_TILE_SHAPE:.+]] = shapex.get_ranked_shape %[[ARG_TILE]] // CHECK-NEXT: %[[RET_TILE:.+]] = "test.tile_math"(%[[ARG_TILE]], %[[ARG_TILE_SHAPE]], %[[RET_SHAPE]]) -// CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET_SHAPED]] +// CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET_SHAPED]] // CHECK: return // CHECK-NEXT: } @@ -169,46 +169,40 @@ func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor // CHECK-DAG: %[[Y:.+]] = constant 512 %y = constant 512 : index // CHECK-NEXT: %[[ARG0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM1]], %[[ARG0_DIM3]] - %arg0_shape = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - // CHECK-NEXT: %[[ARG0_SHAPED:.+]] = shapex.tie_shape %[[ARG0]], %[[ARG0_SHAPE]] - %arg0_shaped = shapex.tie_shape %arg0, %arg0_shape : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> - // CHECK-NEXT: %[[RET0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM3]], %[[ARG0_DIM1]] - %ret0_shape = shapex.make_ranked_shape %dim3, %dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> // CHECK-DAG: %[[IN_ARG0_DIM1:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][1] // CHECK-DAG: %[[IN_ARG0_DIM3:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][3] + // CHECK-NEXT: %[[RET0_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG0_DIM3]], %[[ARG0_DIM1]] // CHECK-DAG: %[[IN_RET0_DIM0:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][0] // CHECK-DAG: %[[IN_RET0_DIM1:.+]] = shapex.ranked_dim %[[RET0_SHAPE]][1] // CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[ // CHECK-SAME: %[[X]], %[[Y]] - // CHECK-SAME: ] (%[[ARG0_SHAPED]], %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]) - // CHECK-SAME: : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor - %ret0 = flow.dispatch.workgroups[%x, %y](%arg0_shaped) : (tensor<7x?x24x?xf32>) -> tensor = ( - %arg : !flow.dispatch.input<7x?x24x?xf32>, %ret : !flow.dispatch.output + // CHECK-SAME: ](%arg0, %[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]], %[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]) + // CHECK-SAME: : (tensor<7x?x24x?xf32>{%[[IN_ARG0_DIM1]], %[[IN_ARG0_DIM3]]}, index, index, index, index) -> tensor{%[[IN_RET0_DIM0]], %[[IN_RET0_DIM1]]} + %ret0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim3, %dim1} = ( + %arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor ) { %workgroup_rank = flow.dispatch.workgroup.rank : index - %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> + %arg_shape = flow.dispatch.shape %arg : !flow.dispatch.tensor -> !shapex.ranked_shape<[7,?,24,?]> %arg_dim1 = shapex.ranked_dim %arg_shape[1] : !shapex.ranked_shape<[7,?,24,?]> -> index %arg_dim3 = shapex.ranked_dim %arg_shape[3] : !shapex.ranked_shape<[7,?,24,?]> -> index "test.sink_shape_arg"(%arg_dim1, %arg_dim3) : (index, index) -> () - %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output -> !shapex.ranked_shape<[?,?,1024]> + %ret_shape = flow.dispatch.shape %ret : !flow.dispatch.tensor -> !shapex.ranked_shape<[?,?,1024]> %ret_dim0 = shapex.ranked_dim %ret_shape[0] : !shapex.ranked_shape<[?,?,1024]> -> index %ret_dim1 = shapex.ranked_dim %ret_shape[1] : !shapex.ranked_shape<[?,?,1024]> -> index "test.sink_shape_ret"(%ret_dim0, %ret_dim1) : (index, index) -> () - %arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> + %arg_tile = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<7x?x24x?xf32> %arg_tile_shape = shapex.get_ranked_shape %arg_tile : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> %ret_tile = "test.tile_math"(%arg_tile, %arg_tile_shape, %ret_shape) : (tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> (tensor) - flow.dispatch.output.store %ret_tile, %ret : tensor -> !flow.dispatch.output + flow.dispatch.tensor.store %ret_tile, %ret : tensor -> !flow.dispatch.tensor flow.return } - // CHECK-NEXT: %[[RET0_SHAPED:.+]] = shapex.tie_shape %[[RET0]], %[[RET0_SHAPE]] - %ret0_shaped = shapex.tie_shape %ret0, %ret0_shape : tensor, !shapex.ranked_shape<[?,?,1024]> - // CHECK-NEXT: return %[[RET0_SHAPED]] - return %ret0_shaped : tensor + // CHECK-NEXT: return %[[RET0]] + return %ret0 : tensor } diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir deleted file mode 100644 index 60c843acac26..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions_ranked_dynamic.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-outline-dispatch-regions -canonicalize %s | IreeFileCheck %s -// NOTE: Most of the common cases for outlining are tested via -// transformation.mlir; however, this test performs some specific tests -// of corner cases that are easier to access at this level. - -// CHECK-LABEL: @dynamicRankedShapeModule -// Verify that the outlined function properly expands shape dims -// Note that all but the entry shape ties/ops are removed. -// CHECK: flow.executable @dynamicRankedShape_ex_dispatch_0 -// CHECK: func @dynamicRankedShape_ex_dispatch_0(%[[EXARG0:.+]]: tensor<7x?x24x?xf32>, %[[EXARG1:.+]]: index, %[[EXARG2:.+]]: index) -> tensor { -// CHECK-DAG: %[[EXSHAPE0:.+]] = shapex.make_ranked_shape %[[EXARG1]], %[[EXARG2]] -// CHECK-DAG: %[[EXT0:.+]] = shapex.tie_shape %[[EXARG0]], %[[EXSHAPE0]] -// CHECK-DAG: %[[EXT1:.+]] = "some_kind_of_sum"(%[[EXT0]]) -// CHECK-DAG: return %[[EXT1]] -// Verify that the generated flow.dispatch op properly inputs individual shape dims -// CHECK: func @dynamicRankedShape(%[[ARG0:.+]]: tensor<7x?x24x?xf32>) -// CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK-DAG: %[[C3:.*]] = constant 3 : index -// CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C3]] -// CHECK-DAG: %[[WORKLOAD0:.+]] = constant 1024 : index -// CHECK-DAG: %[[DISPATCH:.+]] = flow.dispatch @dynamicRankedShape_ex_dispatch_0::@dynamicRankedShape_ex_dispatch_0[%[[WORKLOAD0]]] (%[[ARG0]], %[[D1]], %[[D3]]) : (tensor<7x?x24x?xf32>, index, index) -// CHECK-DAG: return %[[DISPATCH]] -module @dynamicRankedShapeModule { -func @dynamicRankedShape(%arg0 : tensor<7x?x24x?xf32>) -> tensor { - %c1 = constant 1 : index - %c3 = constant 3 : index - %dim1 = dim %arg0, %c1 : tensor<7x?x24x?xf32> - %dim3 = dim %arg0, %c3 : tensor<7x?x24x?xf32> - %workload0 = constant 1024 : index - %shape0 = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - %1 = flow.dispatch.region[%workload0 : index](%arg1 = %arg0 : tensor<7x?x24x?xf32>, %arg2 = %shape0 : !shapex.ranked_shape<[7,?,24,?]>) -> tensor { - %2 = shapex.tie_shape %arg1, %arg2 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> - // Simulate a custom op that shuffles the input in a weird way. - %3 = "some_kind_of_sum"(%2) : (tensor<7x?x24x?xf32>) -> tensor - %4 = shapex.ranked_dim %arg2[1] : !shapex.ranked_shape<[7,?,24,?]> -> index - %5 = shapex.ranked_dim %arg2[3] : !shapex.ranked_shape<[7,?,24,?]> -> index - %6 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> - %7 = shapex.tie_shape %3, %6 : tensor, !shapex.ranked_shape<[?,?,1024]> - flow.return %7 : tensor - } - return %1 : tensor -} -} diff --git a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir b/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir deleted file mode 100644 index bfc6b7c898ac..000000000000 --- a/iree/compiler/Dialect/Flow/Transforms/test/rematerialize_dispatch_constants.mlir +++ /dev/null @@ -1,153 +0,0 @@ -// RUN: iree-opt -split-input-file -iree-flow-rematerialize-dispatch-constants %s | IreeFileCheck %s - -// CHECK-LABEL: func @rematerializeSmall -func @rematerializeSmall(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK: %[[WORKLOAD0:.+]] = constant 16 : index - %cst = constant 16 : index - %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[REMAT_SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> - // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SMALL]] : tensor<4x4xf32> - %3 = mhlo.add %arg1, %arg2 : tensor<4x4xf32> - flow.return %3 : tensor<4x4xf32> - } - return %0 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: func @rematerializeSplat -func @rematerializeSplat(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - %cst = constant 16 : index - %large = constant dense<8> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>) -> tensor<1025xi8> { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-NEXT: %[[REMAT_SPLAT:.+]] = constant dense<8> : tensor<1025xi8> - // CHECK-NEXT: %1 = mhlo.add %arg1, %[[REMAT_SPLAT]] : tensor<1025xi8> - %3 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - flow.return %3 : tensor<1025xi8> - } - return %0 : tensor<1025xi8> -} - -// ----- - -// CHECK-LABEL: func @noRematerializeLarge -func @noRematerializeLarge(%arg0 : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - // CHECK-DAG: %[[CST:.+]] = constant dense<{{.+}}> : tensor<1025xi8> - %cst = constant 16 : index - %large = constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,0]> : tensor<1025xi8> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %[[CST]] : tensor<1025xi8>) -> tensor<1025xi8> { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<1025xi8>, %arg2 = %large : tensor<1025xi8>) -> tensor<1025xi8> { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - %3 = mhlo.add %arg1, %arg2 : tensor<1025xi8> - flow.return %3 : tensor<1025xi8> - } - return %0 : tensor<1025xi8> -} - -// ----- - -// CHECK-LABEL: func @noRematerializeIntoDot -func @noRematerializeIntoDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-DAG: %[[WORKLOAD0:.+]] = constant 16 : index - // CHECK-DAG: %[[SMALL:.+]] = constant dense<1.230000e+00> : tensor<4x4xf32> - %cst = constant 16 : index - %small = constant dense<1.23> : tensor<4x4xf32> - // CHECK-NEXT: %[[R0:.+]] = flow.dispatch.region[%[[WORKLOAD0]] : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %[[SMALL]] : tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = flow.dispatch.region[%cst : index](%arg1 = %arg0 : tensor<4x4xf32>, %arg2 = %small : tensor<4x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %1 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %3 = "mhlo.dot"(%arg1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - flow.return %3 : tensor<4x4xf32> - } - return %0 : tensor<4x4xf32> -} - -// ----- - -func @constant_capture(%arg0: tensor<10x20xf32>) -> tensor<10x20xf32> { - %c200 = constant 200 : index - %cst = constant 1.000000e+00 : f32 - %cst_0 = constant dense<2.000000e+00> : tensor<10x20xf32> - %cst_1 = constant dense< - [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, - 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> - : tensor<10xf32> - %0 = flow.dispatch.region[%c200 : index] - (%arg1 = %arg0 : tensor<10x20xf32>, %arg2 = %cst_0 : tensor<10x20xf32>, - %arg3 = %cst_1 : tensor<10xf32>, %arg4 = %cst : f32) -> tensor<10x20xf32> { - %1 = linalg.init_tensor [10, 20] : tensor<10x20xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg2, %arg3 - : tensor<10x20xf32>, tensor<10x20xf32>, tensor<10xf32>) - outs(%1 : tensor<10x20xf32>) { - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32): // no predecessors - %3 = addf %arg5, %arg4 : f32 - %4 = mulf %3, %arg6 : f32 - %5 = addf %4, %arg7 : f32 - linalg.yield %5 : f32 - } -> tensor<10x20xf32> - flow.return %2 : tensor<10x20xf32> - } - return %0 : tensor<10x20xf32> -} - -// CHECK-LABEL: func @constant_capture -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<10x20xf32> -// CHECK: %[[CST:.+]] = constant dense<[1.000000e+00, 2.000000e+00, -// CHECK-SAME: 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, -// CHECK-SAME: 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01]> -// CHECK: flow.dispatch.region -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] = %[[ARG0]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST]] -// CHECK-DAG: %[[CST_0:.+]] = constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST_1:.+]] = constant dense<2.000000e+00> : tensor<10x20xf32> -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG1]], %[[CST_1]], %[[ARG2]] -// CHECK-SAME: ) { -// CHECK: ^{{[a-zA-Z0-9_]+}}( -// CHECK-SAME: %[[ARG3:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG4:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG5:.[a-zA-Z0-9_]+]]: f32, -// CHECK-SAME: %[[ARG6:.[a-zA-Z0-9_]+]]: f32) -// CHECK: %[[T0:.+]] = addf %[[ARG3]], %[[CST_0]] -// CHECK: %[[T1:.+]] = mulf %[[T0]], %[[ARG4]] -// CHECK: %[[T2:.+]] = addf %[[T1]], %[[ARG5]] -// CHECK: linalg.yield %[[T2]] -// CHECK: } -// CHECK: flow.return %[[RESULT]] - -// ----- - -func @rematerialize_dispatch_workgroups(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - %cst_0 = constant 0.0 : f32 - %c2 = constant 1 : index - %0 = flow.dispatch.workgroups[%c2, %c2, %c2] (%cst_0, %arg0, %arg1) : (f32, tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> = (%arg2 : f32, %arg3 : !flow.dispatch.input<8x8xf32>, %arg4 : !flow.dispatch.input<8x8xf32>, %arg5 : !flow.dispatch.output<8x8xf32>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c8 = constant 8 : index - %1 = linalg.init_tensor [8, 8] : tensor<8x8xf32> - %2 = linalg.fill(%1, %arg2) : tensor<8x8xf32>, f32 -> tensor<8x8xf32> - %3 = flow.dispatch.input.load %arg3, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : !flow.dispatch.input<8x8xf32> -> tensor<8x8xf32> - %4 = flow.dispatch.input.load %arg4, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : !flow.dispatch.input<8x8xf32> -> tensor<8x8xf32> - %5 = linalg.matmul ins(%3, %4 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%2 : tensor<8x8xf32>) -> tensor<8x8xf32> - flow.dispatch.output.store %5, %arg5, offsets = [%c0, %c0], sizes = [%c8, %c8], strides = [%c1, %c1] : tensor<8x8xf32> -> !flow.dispatch.output<8x8xf32> - flow.return - } - return %0: tensor<8x8xf32> -} - -// CHECK: func @rematerialize_dispatch_workgroups(%[[ARG1:.+]]: tensor<8x8xf32>, %[[ARG2:.+]]: tensor<8x8xf32>) -// CHECK: %[[CONST1:.+]] = constant 1 : index -// CHECK: flow.dispatch.workgroups[%[[CONST1]], %[[CONST1]], %[[CONST1]]] (%[[ARG1]], %[[ARG2]]) -// CHECK: %[[CONST0:.+]] = constant 0.000000e+00 : f32 -// CHECK: %[[INIT_TENSOR:.+]] = linalg.init_tensor [8, 8] : tensor<8x8xf32> -// CHECK: linalg.fill(%[[INIT_TENSOR]], %[[CONST0]]) : tensor<8x8xf32>, f32 -> tensor<8x8xf32> diff --git a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir index b13d93f4103c..5067a91cf1b1 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/transformation.mlir @@ -8,31 +8,6 @@ func @empty() { // ----- -func @simpleMath(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: flow.executable @simpleMath_ex_dispatch_0 attributes {sym_visibility = "private"} { -// CHECK-NEXT: flow.dispatch.entry @simpleMath_ex_dispatch_0 -// CHECK-NEXT: module { -// CHECK-NEXT: func @simpleMath_ex_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> -// CHECK-NEXT: return %0 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: func @simpleMath(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @simpleMath_ex_dispatch_0::@simpleMath_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> -// CHECK-NEXT: flow.return %1 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: return %0 : tensor<4xf32> -// CHECK-NEXT: } - -// ----- - func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { %0 = addf %arg0, %arg0 : tensor<4xf32> %1 = subf %0, %arg0 : tensor<4xf32> @@ -52,9 +27,10 @@ func @stdElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @stdElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @stdElementwiseOps_ex_dispatch_0::@stdElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -81,9 +57,10 @@ func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @hloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%arg1] (%arg2) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @hloElementwiseOps_ex_dispatch_0::@hloElementwiseOps_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -126,11 +103,12 @@ func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 16 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%arg1] (%1, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%arg1] (%2, %arg2) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32> = +// CHECK-NEXT: (%arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 16 : index +// CHECK-NEXT: %1 = flow.dispatch @interleavedDot_ex_dispatch_0::@interleavedDot_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %2 = flow.dispatch @interleavedDot_ex_dispatch_1::@interleavedDot_ex_dispatch_1[%[[WORKLOAD]]](%1, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %3 = flow.dispatch @interleavedDot_ex_dispatch_2::@interleavedDot_ex_dispatch_2[%[[WORKLOAD]]](%2, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: flow.return %3 : tensor<4x4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4x4xf32> @@ -163,9 +141,10 @@ func @reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %[[WORKLOAD0:.+]] = constant 4 : index -// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg1 = %[[WORKLOAD0]] : index, %arg2 = %arg0 : tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%arg1] (%arg2) : (tensor<4x8xf32>) -> tensor<4xf32> +// CHECK-NEXT: %0 = flow.ex.stream.fragment(%arg0) : (tensor<4x8xf32>) -> tensor<4xf32> = +// CHECK-NEXT: (%arg1: tensor<4x8xf32>) -> tensor<4xf32> { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 4 : index +// CHECK-NEXT: %1 = flow.dispatch @reduction_ex_dispatch_0::@reduction_ex_dispatch_0[%[[WORKLOAD]]](%arg1) : (tensor<4x8xf32>) -> tensor<4xf32> // CHECK-NEXT: flow.return %1 : tensor<4xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : tensor<4xf32> @@ -189,15 +168,17 @@ func @dynamicUpdateSlice(%operand : tensor<2x4xi32>, %update : tensor<1x1xi32>, // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: func @dynamicUpdateSlice(%arg0: tensor<2x4xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor, %arg3: tensor) -> tensor<2x4xi32> { -// CHECK-DAG: %[[WORKLOAD0:.+]] = constant 8 : index -// CHECK-DAG: %[[ARG2_LOAD:.+]] = flow.tensor.load %arg2 : tensor -// CHECK-DAG: %[[ARG2_INDEX:.+]] = index_cast %[[ARG2_LOAD]] : i32 to index -// CHECK-DAG: %[[ARG3_LOAD:.+]] = flow.tensor.load %arg3 : tensor -// CHECK-DAG: %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index -// CHECK-NEXT: %4 = flow.ex.stream.fragment(%arg4 = %arg1 : tensor<1x1xi32>, %arg5 = %arg0 : tensor<2x4xi32>, %arg6 = %[[ARG2_INDEX]] : index, %arg7 = %[[ARG3_INDEX]] : index, %arg8 = %[[WORKLOAD0]] : index) -> tensor<2x4xi32> { -// CHECK-NEXT: %5 = flow.tensor.update %arg4, %arg5[%arg6, %arg7] : tensor<1x1xi32> -> tensor<2x4xi32> -// CHECK-NEXT: %6 = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%arg8] (%arg5, %5) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK-NEXT: flow.return %6 : tensor<2x4xi32> +// CHECK-DAG: %[[ARG2_LOAD:.+]] = flow.tensor.load %arg2 : tensor +// CHECK-DAG: %[[ARG2_INDEX:.+]] = index_cast %[[ARG2_LOAD]] : i32 to index +// CHECK-DAG: %[[ARG3_LOAD:.+]] = flow.tensor.load %arg3 : tensor +// CHECK-DAG: %[[ARG3_INDEX:.+]] = index_cast %[[ARG3_LOAD]] : i32 to index +// CHECK-NEXT: %[[RET:.+]] = flow.ex.stream.fragment(%arg0, %[[ARG2_INDEX]], %[[ARG3_INDEX]], %arg1) : (tensor<2x4xi32>, index, index, tensor<1x1xi32>) -> tensor<2x4xi32> = +// CHECK-NEXT: (%arg4: tensor<2x4xi32>, %arg5: index, %arg6: index, %arg7: tensor<1x1xi32>) -> tensor<2x4xi32> { +// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 8 : index +// CHECK-NEXT: %[[ARG4_CLONE:.+]] = flow.tensor.clone %arg4 : tensor<2x4xi32> +// CHECK-NEXT: %[[T0:.+]] = flow.tensor.update %arg7, %arg4[%arg5, %arg6] : tensor<1x1xi32> -> tensor<2x4xi32> +// CHECK-NEXT: %[[T1:.+]] = flow.dispatch @dynamicUpdateSlice_ex_dispatch_0::@dynamicUpdateSlice_ex_dispatch_0[%[[WORKLOAD]]](%[[ARG4_CLONE]], %[[T0]]) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK-NEXT: flow.return %[[T1]] : tensor<2x4xi32> // CHECK-NEXT: } -// CHECK-NEXT: return %4 : tensor<2x4xi32> +// CHECK-NEXT: return %[[RET]] : tensor<2x4xi32> // CHECK-NEXT: } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index af4c7c38901f..22ee0c2de3b9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp @@ -66,6 +66,20 @@ static bool isNoOp(Operation *op) { return isa(op); } // identity. static bool isIdentityOp(Operation *op) { return isa(op); } +// HACK: until we are doing buffer allocation we need to pin tied buffers all +// the way up the stream from the outputs. +static void propagateTiedBuffer(BufferSet &bufferSet, Value streamValue, + BufferRange bufferRange) { + Value baseValue = streamValue; + while (auto definingOp = dyn_cast_or_null( + baseValue.getDefiningOp())) { + auto tiedValue = definingOp.getTiedResultOperand(baseValue); + if (!tiedValue) break; + baseValue = tiedValue; + bufferSet.rangeMap[baseValue] = bufferRange; + } +} + // Allocates a buffer for the given stream output value. // |streamValue| is the Value used within the stream region and // |externalValue| is the returned value from the stream region in the parent @@ -110,17 +124,36 @@ static Value allocateOutputBuffer(Value streamValue, Value externalValue, static void allocateOutputBuffers(IREE::Flow::ExStreamFragmentOp streamOp, BufferSet &bufferSet, ConversionPatternRewriter &rewriter) { + auto tiedStreamOp = cast(streamOp.getOperation()); + auto &entryBlock = streamOp.body().front(); + // Allocate output buffers and replace the original uses with the buffers. auto returnOp = cast(streamOp.body().front().back()); for (auto result : llvm::enumerate(streamOp.getResults())) { auto streamValue = returnOp.getOperand(result.index()); auto externalValue = result.value(); - auto buffer = allocateOutputBuffer(streamValue, externalValue, - bufferSet.allocator, rewriter); - auto bufferRange = BufferRange{buffer}; + + // Tied results reuse their operand buffer. + BufferRange bufferRange; + auto tiedOperandIndex = + tiedStreamOp.getTiedResultOperandIndex(result.index()); + if (tiedOperandIndex.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " -- REUSING TIED OPERAND(" + << tiedOperandIndex.getValue() << ") BUFFER FOR STREAM RESULT(" + << result.index() << "): " << streamOp << "\n"); + auto operand = entryBlock.getArgument(tiedOperandIndex.getValue()); + bufferRange = bufferSet.rangeMap[operand]; + } else { + auto buffer = allocateOutputBuffer(streamValue, externalValue, + bufferSet.allocator, rewriter); + bufferRange = BufferRange{buffer}; + } + assert(bufferRange.buffer); + bufferSet.outputBuffers.push_back(bufferRange.buffer); bufferSet.rangeMap[externalValue] = bufferRange; bufferSet.rangeMap[streamValue] = bufferRange; - bufferSet.outputBuffers.push_back(buffer); + propagateTiedBuffer(bufferSet, streamValue, bufferRange); } } @@ -172,8 +205,10 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, // Pull outputs that terminate on identities to operands. for (auto &op : llvm::reverse(streamOp.body().front())) { if (isIdentityOp(&op)) { - auto result = op.getResult(0); auto operand = op.getOperand(0); + auto result = op.getResult(0); + if (!operand.getType().isa()) continue; + if (!result.getType().isa()) continue; if (bufferSet.rangeMap[result].buffer && !bufferSet.rangeMap[operand].buffer) { LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY RESULT->OPERAND: " @@ -190,6 +225,8 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, if (isIdentityOp(&op)) { auto operand = op.getOperand(0); auto result = op.getResult(0); + if (!operand.getType().isa()) continue; + if (!result.getType().isa()) continue; if (bufferSet.rangeMap[operand].buffer && !bufferSet.rangeMap[result].buffer) { LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY OPERAND->RESULT: " @@ -217,19 +254,41 @@ static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, } for (auto &op : streamOp.body().front()) { if (isNoOp(&op) || isIdentityOp(&op)) continue; + auto tiedOp = dyn_cast(op); for (auto it : llvm::enumerate(op.getResults())) { auto result = it.value(); + if (!result.getType().isa()) continue; + // If the result is an output buffer we can just use that directly. if (bufferSet.rangeMap[result].buffer) { LLVM_DEBUG(llvm::dbgs() << " -- SKIP ALREADY SET BUFFER RESULT(" << it.index() << "): " << op << "\n"); continue; } + + // Tied results reuse their operand buffer. + if (tiedOp) { + auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(it.index()); + if (tiedOperandIndex.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " -- REUSING TIED OPERAND(" + << tiedOperandIndex.getValue() << ") BUFFER FOR RESULT(" + << it.index() << "): " << op << "\n"); + auto operand = op.getOperand(tiedOperandIndex.getValue()); + auto operandBufferRange = bufferSet.rangeMap[operand]; + assert(operandBufferRange.buffer); + bufferSet.rangeMap[result] = operandBufferRange; + continue; + } + } + LLVM_DEBUG(llvm::dbgs() << " -- ALLOCATE BUFFER FOR RESULT(" << it.index() << "): " << op << "\n"); auto buffer = allocateTransientBuffer(result, bufferSet.allocator, rewriter); - bufferSet.rangeMap[result] = BufferRange{buffer}; + auto bufferRange = BufferRange{buffer}; + bufferSet.rangeMap[result] = bufferRange; + propagateTiedBuffer(bufferSet, result, bufferRange); } } while (propagateIdentityBuffers()) { @@ -445,13 +504,44 @@ static LogicalResult recordDispatch(Value device, Value commandBuffer, return success(); } +static LogicalResult recordTensorClone(Value device, Value commandBuffer, + IREE::Flow::TensorCloneOp &cloneOp, + BufferSet &bufferSet, + ConversionPatternRewriter &rewriter) { + auto &operandBuffer = bufferSet.rangeMap[cloneOp.operand()]; + auto &resultBuffer = bufferSet.rangeMap[cloneOp.result()]; + + auto operand = IREE::HAL::TensorRewriteAdaptor::getChecked( + cloneOp.getLoc(), cloneOp.operand(), operandBuffer.buffer, rewriter); + auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( + cloneOp.getLoc(), cloneOp.result(), resultBuffer.buffer, rewriter); + if (!operand.hasValue() || !result.hasValue()) { + return cloneOp.emitOpError() + << "cannot create adaptors for tensor clone operands/results"; + } + + auto zeroOffset = + rewriter.createOrFold(cloneOp.getLoc(), 0); + auto byteLength = operand->getByteLength(); + if (!byteLength) return failure(); + + rewriter.create( + cloneOp.getLoc(), commandBuffer, operand->getBuffer(), zeroOffset, + result->getBuffer(), zeroOffset, byteLength); + + // Full barriers for now as we aren't scheduling things. + // TODO(benvanik): don't add at the end of the command buffer (we could + // also do a canonicalization step that removed trailing barriers). + recordFullExecutionBarrier(commandBuffer, cloneOp.getLoc(), rewriter); + return success(); +} + static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, IREE::Flow::TensorUpdateOp &updateOp, BufferSet &bufferSet, ConversionPatternRewriter &rewriter) { auto &updateBuffer = bufferSet.rangeMap[updateOp.update()]; auto &targetBuffer = bufferSet.rangeMap[updateOp.target()]; - auto &resultBuffer = bufferSet.rangeMap[updateOp.result()]; // TODO(benvanik): use something other than the BufferRange::buffer? // This may require us to subview the buffer first. @@ -459,9 +549,7 @@ static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, updateOp.getLoc(), updateOp.update(), updateBuffer.buffer, rewriter); auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( updateOp.getLoc(), updateOp.target(), targetBuffer.buffer, rewriter); - auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( - updateOp.getLoc(), updateOp.result(), resultBuffer.buffer, rewriter); - if (!update.hasValue() || !target.hasValue() || !result.hasValue()) { + if (!update.hasValue() || !target.hasValue()) { return updateOp.emitOpError() << "cannot create adaptors for tensor update operands/results"; } @@ -479,18 +567,10 @@ static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, target->computeRange(startIndices, *update->getShapeDims()); if (!targetRange) return failure(); - // TODO(benvanik): actual buffer allocation so we aren't doing this copy. - auto targetByteLength = target->getByteLength(); - if (!targetByteLength) return failure(); - - rewriter.create( - updateOp.getLoc(), commandBuffer, target->getBuffer(), zeroOffset, - result->getBuffer(), zeroOffset, targetByteLength); // TODO(benvanik): slice left/mid/right, but really just don't do this. - recordFullExecutionBarrier(commandBuffer, updateOp.getLoc(), rewriter); rewriter.create( updateOp.getLoc(), commandBuffer, update->getBuffer(), zeroOffset, - result->getBuffer(), targetRange->offset, targetRange->length); + target->getBuffer(), targetRange->offset, targetRange->length); // Full barriers for now as we aren't scheduling things. // TODO(benvanik): don't add at the end of the command buffer (we could @@ -509,6 +589,11 @@ static LogicalResult recordStreamCommands(Value device, Value commandBuffer, rewriter))) { return failure(); } + } else if (auto cloneOp = dyn_cast(op)) { + if (failed(recordTensorClone(device, commandBuffer, cloneOp, bufferSet, + rewriter))) { + return failure(); + } } else if (auto updateOp = dyn_cast(op)) { if (failed(recordTensorUpdate(device, commandBuffer, updateOp, bufferSet, rewriter))) { @@ -519,6 +604,10 @@ static LogicalResult recordStreamCommands(Value device, Value commandBuffer, } else if (isNoOp(&op) || isIdentityOp(&op)) { // No work to perform. For identity ops, all buffers have been pushed // to "real" ops. + } else if (isa(op)) { + // HACK: all this code is going away soon. + auto newOp = rewriter.clone(op); + op.replaceAllUsesWith(newOp); } else { return op.emitOpError() << "unexpected in stream"; } @@ -532,8 +621,11 @@ class ExStreamFragmentOpConversion using OpConversionPattern< IREE::Flow::ExStreamFragmentOp>::OpConversionPattern; LogicalResult matchAndRewrite( - IREE::Flow::ExStreamFragmentOp streamOp, llvm::ArrayRef operands, + IREE::Flow::ExStreamFragmentOp streamOp, ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { + IREE::Flow::ExStreamFragmentOp::Adaptor adaptor( + newOperands, streamOp->getAttrDictionary()); + // TODO(benvanik): choose buffer mode/category based on stream commands. auto mode = IREE::HAL::CommandBufferModeBitfield::OneShot; auto category = IREE::HAL::CommandCategoryBitfield::Dispatch | @@ -551,13 +643,13 @@ class ExStreamFragmentOpConversion // Remap non-tensor operands (such as workloads). auto &entryBlock = streamOp.body().front(); - for (int i = 0; i < operands.size(); ++i) { + for (int i = 0; i < adaptor.operands().size(); ++i) { if (streamOp.getOperand(i).getType().isa()) { bufferSet.rangeMap[entryBlock.getArgument(i)] = - BufferRange{operands[i]}; + BufferRange{adaptor.operands()[i]}; } else { rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), - operands[i]); + adaptor.operands()[i]); } } @@ -591,10 +683,10 @@ class ExStreamFragmentOpConversion // It's annoying, but we need to do this replacement at the very end as // otherwise we lose access to the original values (which we need for // shape information). - for (int i = 0; i < operands.size(); ++i) { - if (operands[i].getType().isa()) { + for (int i = 0; i < adaptor.operands().size(); ++i) { + if (adaptor.operands()[i].getType().isa()) { rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), - operands[i]); + adaptor.operands()[i]); } } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp index b59988f8d44a..591a1b19d8ce 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp @@ -90,7 +90,8 @@ class TensorLoadOpConversion LogicalResult matchAndRewrite( IREE::Flow::TensorLoadOp loadOp, llvm::ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { - IREE::Flow::TensorLoadOp::Adaptor operands(newOperands); + IREE::Flow::TensorLoadOp::Adaptor operands(newOperands, + loadOp->getAttrDictionary()); auto source = IREE::HAL::TensorRewriteAdaptor::getChecked( loadOp.getLoc(), loadOp.source(), operands.source(), rewriter); if (!source.hasValue()) { @@ -117,7 +118,8 @@ class TensorStoreOpConversion LogicalResult matchAndRewrite( IREE::Flow::TensorStoreOp storeOp, llvm::ArrayRef newOperands, ConversionPatternRewriter &rewriter) const override { - IREE::Flow::TensorStoreOp::Adaptor operands(newOperands); + IREE::Flow::TensorStoreOp::Adaptor operands(newOperands, + storeOp->getAttrDictionary()); auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( storeOp.getLoc(), storeOp.target(), operands.target(), rewriter); diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD index b780b112a316..e55486242ff6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_ops.mlir", + "stream_ops.mlir", + "tensor_ops.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt index 17b455554a40..ceb1dcf6a018 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" + "stream_ops.mlir" + "tensor_ops.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index abc7d9f8e8b8..54d17cd6a8c4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir @@ -24,16 +24,17 @@ func @multipleDispatches(%arg0: tensor<128xf32>) -> tensor<128xf32> { // CHECK: %[[TMP_BUF:.+]] = hal.allocator.allocate {{.+}}, "DeviceVisible|DeviceLocal", "Transfer|Dispatch" // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> { + %0 = flow.ex.stream.fragment(%cst, %arg0) : (index, tensor<128xf32>) -> tensor<128xf32> = + (%arg1: index, %arg2: tensor<128xf32>) -> tensor<128xf32> { // CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set = %c0, bindings = [%c0 = (%arg0, %c0, %c512), %c1 = (%[[TMP_BUF]], %c0, %c512)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz // CHECK: hal.command_buffer.execution_barrier - %1 = flow.dispatch @ex0::@entry0[%arg1] (%arg2) : (tensor<128xf32>) -> tensor<128xf32> + %1 = flow.dispatch @ex0::@entry0[%arg1](%arg2) : (tensor<128xf32>) -> tensor<128xf32> // CHECK: hal.command_buffer.push_descriptor_set // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz // CHECK: hal.command_buffer.execution_barrier - %2 = flow.dispatch @ex0::@entry0[%arg1] (%1) : (tensor<128xf32>) -> tensor<128xf32> + %2 = flow.dispatch @ex0::@entry0[%arg1](%1) : (tensor<128xf32>) -> tensor<128xf32> flow.return %2 : tensor<128xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -52,13 +53,13 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate // CHECK: %[[CMD:.+]] = hal.command_buffer.create // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %0 = flow.ex.stream.fragment(%arg2 = %arg0 : tensor<1x1x10xf32>, %arg3 = %arg1 : tensor<5x1x10xf32>, %arg4 = %c4 : index, %arg5 = %c1 : index) -> tensor<5x1x10xf32> { - // TODO(laurenzo): Update these checks to be more precise. The regexes can - // match too much, masking issues. + %0 = flow.ex.stream.fragment(%arg0, %arg1, %c4, %c1) : (tensor<1x1x10xf32>, tensor<5x1x10xf32>, index, index) -> tensor<5x1x10xf32> = + (%arg2: tensor<1x1x10xf32>, %arg3: tensor<5x1x10xf32>, %arg4: index, %arg5: index) -> tensor<5x1x10xf32> { // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %c0, %[[RET_BUF]], %c0, %c200 // CHECK: hal.command_buffer.execution_barrier + %clone = flow.tensor.clone %arg3 : tensor<5x1x10xf32> // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[UBUF]], %c0, %[[RET_BUF]], %c204, %c40 - %1 = flow.tensor.update %arg2, %arg3[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32> + %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32> flow.return %1 : tensor<5x1x10xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -95,16 +96,12 @@ func @dispatchWithShapeTies(%arg0: tensor, %bs : index) -> tensor, %arg3 = %bs : index) -> tensor { - %1 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,128]> - %2 = shapex.tie_shape %arg2, %1 : tensor, !shapex.ranked_shape<[?,128]> - %3 = flow.dispatch @ex0::@entry0[%arg1] (%2, %arg3) : (tensor, index) -> tensor - %4 = shapex.tie_shape %3, %1 : tensor, !shapex.ranked_shape<[?,128]> - %5 = flow.dispatch @ex0::@entry0[%arg1] (%4, %arg3) : (tensor, index) -> tensor - %6 = shapex.tie_shape %5, %1 : tensor, !shapex.ranked_shape<[?,128]> - %7 = flow.dispatch @ex0::@entry0[%arg1] (%6, %arg3) : (tensor, index) -> tensor - %8 = shapex.tie_shape %7, %1 : tensor, !shapex.ranked_shape<[?,128]> - flow.return %8 : tensor + %0 = flow.ex.stream.fragment(%cst, %arg0, %bs) : (index, tensor{%cst}, index) -> tensor{%cst} = + (%arg1: index, %arg2: tensor, %arg3: index) -> tensor { + %3 = flow.dispatch @ex0::@entry0[%arg1](%arg2, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + %5 = flow.dispatch @ex0::@entry0[%arg1](%3, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + %7 = flow.dispatch @ex0::@entry0[%arg1](%5, %arg3) : (tensor{%arg3}, index) -> tensor{%arg3} + flow.return %7 : tensor } return %0 : tensor } @@ -120,7 +117,7 @@ hal.executable @ex attributes {sym_visibility = "private"} { hal.executable.entry_point @entry attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<7x4x24xf32>, !flow.dispatch.output<4x7x1024xf32>) -> () + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () } module {} } @@ -132,14 +129,11 @@ func @static_tiled_dispatch(%arg0: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> { %c512 = constant 512 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %1 = flow.ex.stream.fragment( - %arg3 = %arg0 : tensor<7x4x24xf32>, - %arg6 = %c1024 : index, - %arg7 = %c512 : index - ) -> tensor<4x7x1024xf32> { + %1 = flow.ex.stream.fragment(%arg0, %c1024, %c512) : (tensor<7x4x24xf32>, index, index) -> tensor<4x7x1024xf32> = + (%arg3: tensor<7x4x24xf32>, %arg6: index, %arg7: index) -> tensor<4x7x1024xf32> { // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %c2688), %c1 = (%buffer, %c0, %c114688)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz - %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> + %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> flow.return %0 : tensor<4x7x1024xf32> } // CHECK: hal.command_buffer.end %[[CMD]] @@ -157,7 +151,7 @@ hal.executable @ex attributes {sym_visibility = "private"} { hal.executable.entry_point @entry attributes { interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output, index, index, index, index) -> () + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () } module {} } @@ -169,16 +163,8 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i %c512 = constant 512 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %2 = flow.ex.stream.fragment( - %arg3 = %arg0 : tensor<7x?x24x?xf32>, - %arg4 = %arg1 : index, - %arg5 = %arg2 : index, - %arg6 = %c1024 : index, - %arg7 = %c512 : index - ) -> tensor { - %3 = shapex.make_ranked_shape %arg4, %arg5 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> - %4 = shapex.make_ranked_shape %arg5, %arg4 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> - %5 = shapex.tie_shape %arg3, %3 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> + %2 = flow.ex.stream.fragment(%arg0, %arg1, %arg2, %c1024, %c512) : (tensor<7x?x24x?xf32>{%arg1, %arg2}, index, index, index, index) -> tensor{%arg2, %arg1} = + (%arg3: tensor<7x?x24x?xf32>, %arg4: index, %arg5: index, %arg6: index, %arg7: index) -> tensor { // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32 // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %9), %c1 = (%buffer, %c0, %12)] @@ -199,9 +185,8 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i // CHECK: hal.command_buffer.dispatch.symbol %[[CMD_INNER]], @ex::@tgt::@entry, workgroup_xyz = // CHECK-SAME: [%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]] - %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%5, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor - %7 = shapex.tie_shape %6, %4 : tensor, !shapex.ranked_shape<[?,?,1024]> - flow.return %7 : tensor + %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>{%arg4, %arg5}, index, index, index, index) -> tensor{%arg5, %arg4} + flow.return %6 : tensor } // CHECK: hal.command_buffer.end %[[CMD]] return %2 : tensor diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD index b780b112a316..f63519c37096 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["constant_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt index 910ac9537da0..063d8b0269a4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD index b780b112a316..94f1dc8bd15e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,20 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "allocator_ops.mlir", + "buffer_ops.mlir", + "buffer_view_ops.mlir", + "command_buffer_ops.mlir", + "constant_ops.mlir", + "control_flow_ops.mlir", + "device_ops.mlir", + "executable_ops.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt index 7844d24c0128..fa16efe8eca9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/CMakeLists.txt @@ -10,12 +10,19 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "allocator_ops.mlir" + "buffer_ops.mlir" + "buffer_view_ops.mlir" + "command_buffer_ops.mlir" + "constant_ops.mlir" + "control_flow_ops.mlir" + "device_ops.mlir" + "executable_ops.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD index b780b112a316..a9b6fa17b0c6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["shape_constants.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt index cee582d2149d..a6025156f003 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "shape_constants.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD index b780b112a316..51722bc38906 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["structural_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt index ed79fde93dc2..aec20652b290 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD index d759f0e3ea2e..37671708432b 100644 --- a/iree/compiler/Dialect/HAL/IR/BUILD +++ b/iree/compiler/Dialect/HAL/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["HALBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "HALBase.td", + "HALOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 06072e3697e2..d548abe427ff 100644 --- a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/HAL/IR/test/BUILD b/iree/compiler/Dialect/HAL/IR/test/BUILD index b780b112a316..a4e7c0ba2d58 100644 --- a/iree/compiler/Dialect/HAL/IR/test/BUILD +++ b/iree/compiler/Dialect/HAL/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,29 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "allocator_ops.mlir", + "attributes.mlir", + "buffer_folding.mlir", + "buffer_ops.mlir", + "buffer_view_folding.mlir", + "buffer_view_ops.mlir", + "command_buffer_folding.mlir", + "command_buffer_ops.mlir", + "constant_folding.mlir", + "constant_ops.mlir", + "descriptor_set_ops.mlir", + "device_ops.mlir", + "executable_ops.mlir", + "experimental_ops.mlir", + "interface_ops.mlir", + "semaphore_ops.mlir", + "variable_folding.mlir", + "variable_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt index 578bd7d49161..172247e7f333 100644 --- a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt @@ -10,12 +10,28 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "allocator_ops.mlir" + "attributes.mlir" + "buffer_folding.mlir" + "buffer_ops.mlir" + "buffer_view_folding.mlir" + "buffer_view_ops.mlir" + "command_buffer_folding.mlir" + "command_buffer_ops.mlir" + "constant_folding.mlir" + "constant_ops.mlir" + "descriptor_set_ops.mlir" + "device_ops.mlir" + "executable_ops.mlir" + "experimental_ops.mlir" + "interface_ops.mlir" + "semaphore_ops.mlir" + "variable_folding.mlir" + "variable_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD index 326145d04b6c..793f1d4e084f 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["smoketest.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt index c461923966d9..229736f77813 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir index ceb0560dd88d..26c079652628 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir @@ -1,33 +1,24 @@ // RUN: iree-opt -split-input-file -iree-hal-transformation-pipeline -iree-hal-target-backends=cuda %s | IreeFileCheck %s - #map = affine_map<(d0) -> (d0)> -module { - flow.executable @add_dispatch_0 attributes {sym_visibility = "private"} { - flow.dispatch.entry @add_dispatch_0 attributes {signature = (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>, workgroup_rank = 3 : index} - module { - func @add_dispatch_0(%arg0: !flow.dispatch.input<16xf32>, %arg1: !flow.dispatch.input<16xf32>, %arg2: !flow.dispatch.output<16xf32>) { - %0 = linalg.init_tensor [16] : tensor<16xf32> - %1 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %2 = flow.dispatch.input.load %arg1 : !flow.dispatch.input<16xf32> -> tensor<16xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %4 = addf %arg3, %arg4 : f32 - linalg.yield %4 : f32 - } -> tensor<16xf32> - flow.dispatch.output.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.output<16xf32> - return - } - } +flow.executable @add_dispatch_0 { + flow.dispatch.entry @add_dispatch_0 attributes { + signature = (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>, + workgroup_rank = 3 : index } - func @add(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> attributes {iree.module.export, iree.reflection = {f = "I13!B4!d16B4!d16R7!B4!d16", fv = "1"}} { - %c1 = constant 1 : index - %c16 = constant 16 : index - %0 = flow.ex.stream.fragment(%arg2 = %c16 : index, %arg3 = %c1 : index, %arg4 = %arg0 : tensor<16xf32>, %arg5 = %arg1 : tensor<16xf32>) -> tensor<16xf32> { - %1 = flow.dispatch @add_dispatch_0::@add_dispatch_0[%arg2, %arg3, %arg3] (%arg4, %arg5) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> - flow.return %1 : tensor<16xf32> + module { + func @add_dispatch_0(%arg0: !flow.dispatch.tensor, %arg1: !flow.dispatch.tensor, %arg2: !flow.dispatch.tensor) { + %0 = linalg.init_tensor [16] : tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0 : !flow.dispatch.tensor -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1 : !flow.dispatch.tensor -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %4 = addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.tensor.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.tensor + return } - return %0 : tensor<16xf32> } } diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD index bb9c0e0808d7..8d49ad95b14e 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "binary_op.mlir", + "matmul_op.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt index 766c6e249488..c43c1cdf01f5 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "binary_op.mlir" + "matmul_op.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD index bb9c0e0808d7..fa884d612e5a 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "i1_types.mlir", + "linking.mlir", + "smoketest.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt index 8aba5262972f..4a238ee67584 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "i1_types.mlir" + "linking.mlir" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir index df5ed384dd7f..16e8a4583354 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir @@ -5,8 +5,9 @@ func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> { %c4 = constant 4 : index // CHECK: %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 0, 1, 0]> : tensor<4xi8> %cst = constant dense<[true, false, true, false]> : tensor<4xi1> - %0 = flow.ex.stream.fragment(%arg1 = %c4 : index, %arg2 = %arg0 : tensor<4xi1>, %arg3 = %cst : tensor<4xi1>) -> tensor<4xi1> { - %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1] (%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = flow.ex.stream.fragment(%c4, %arg0, %cst) : (index, tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>) = + (%arg1: index, %arg2: tensor<4xi1>, %arg3: tensor<4xi1>) -> (tensor<4xi1>) { + %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1](%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>) flow.return %1 : tensor<4xi1> } return %0 : tensor<4xi1> diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD index bb9c0e0808d7..0e3700209bf6 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["smoketest.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt index 29cc1273de8f..c2c0cf180b0c 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 404969b11c06..69d4da0b2a61 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -70,29 +70,43 @@ static llvm::Optional declareInterfaceIO( // NOTE: we assume right now that all entry points have the same signature. // TODO(benvanik): replace when we have descriptor sets in the HAL IR. auto anyFuncOp = entryFuncOps.front(); - int binding = 0; + int nextBindingOrdinal = 0; int pushConstantCount = 0; - int argOrdinal = 0; - int retOrdinal = 0; for (auto inputType : llvm::enumerate(anyFuncOp.getType().getInputs())) { if (inputType.value().isa()) { + int bindingOrdinal = nextBindingOrdinal++; auto bindingName = "arg" + std::to_string(inputType.index()); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, - IREE::HAL::DescriptorType::StorageBuffer, - IREE::HAL::MemoryAccessBitfield::Read); - } else if (inputType.value().isa()) { - auto bindingName = "arg" + std::to_string(argOrdinal++); - interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::Read); - } else if (inputType.value().isa()) { - auto bindingName = "ret" + std::to_string(retOrdinal++); + } else if (auto tensorType = + inputType.value() + .dyn_cast()) { + StringRef prefix; + IREE::HAL::MemoryAccessBitfield memoryAccess = + IREE::HAL::MemoryAccessBitfield::None; + switch (tensorType.getAccess()) { + case IREE::Flow::TensorAccess::ReadOnly: + prefix = "ro"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::Read; + break; + case IREE::Flow::TensorAccess::ReadWrite: + prefix = "rw"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::Read | + IREE::HAL::MemoryAccessBitfield::Write; + break; + case IREE::Flow::TensorAccess::WriteOnly: + prefix = "wo"; + memoryAccess = IREE::HAL::MemoryAccessBitfield::DiscardWrite; + break; + } + int bindingOrdinal = nextBindingOrdinal++; + std::string bindingName = + std::string(prefix) + std::to_string(bindingOrdinal); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, - IREE::HAL::DescriptorType::StorageBuffer, - IREE::HAL::MemoryAccessBitfield::DiscardWrite); + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, + IREE::HAL::DescriptorType::StorageBuffer, memoryAccess); } else if (auto indexType = inputType.value().dyn_cast()) { ++pushConstantCount; } else if (auto integerType = inputType.value().dyn_cast()) { @@ -113,10 +127,11 @@ static llvm::Optional declareInterfaceIO( } } for (auto outputType : llvm::enumerate(anyFuncOp.getType().getResults())) { + int bindingOrdinal = nextBindingOrdinal++; auto bindingName = "ret" + std::to_string(outputType.index()); if (outputType.value().isa()) { interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/binding++, + interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::DiscardWrite); } else { diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 29f64a461a0f..a0dc69f0ce26 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -137,6 +137,9 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // if we serialized things. passManager.addPass(createSymbolDCEPass()); } + + // Final cleanup of IR; cleans up things left behind by CSE/DCE above. + passManager.addNestedPass(createCanonicalizerPass()); } void buildHALTransformPassPipeline(OpPassManager &passManager, diff --git a/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/iree/compiler/Dialect/HAL/Transforms/test/BUILD index b780b112a316..f902b2514044 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/BUILD +++ b/iree/compiler/Dialect/HAL/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,23 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "benchmark_batch_dispatches.mlir", + "cse_variable_loads.mlir", + "identify_constant_pools.mlir", + "inline_device_switches.mlir", + "materialize_constant_pool_buffers.mlir", + "materialize_interfaces.mlir", + "materialize_resource_caches.mlir", + "memoize_device_queries.mlir", + "pack_constant_pool_storage.mlir", + "propagate_constant_workgroup_info.mlir", + "public_abi_generation.mlir", + "resolve_entry_point_ordinals.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index b8831ee32130..676ac1c612a1 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -10,12 +10,22 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "benchmark_batch_dispatches.mlir" + "cse_variable_loads.mlir" + "identify_constant_pools.mlir" + "inline_device_switches.mlir" + "materialize_constant_pool_buffers.mlir" + "materialize_interfaces.mlir" + "materialize_resource_caches.mlir" + "memoize_device_queries.mlir" + "pack_constant_pool_storage.mlir" + "propagate_constant_workgroup_info.mlir" + "public_abi_generation.mlir" + "resolve_entry_point_ordinals.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir index 1f98f6e700a6..b90afde7bafc 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir @@ -112,15 +112,15 @@ flow.executable @shaped_dispatch { // CHECK-LABEL: hal.executable @static_tiled_dispatch // CHECK-NEXT: hal.interface @legacy_io { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @static_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, @@ -129,17 +129,17 @@ flow.executable @static_tiled_dispatch { // CHECK-NEXT: module { module { // CHECK-NEXT: func @entry() { - func @entry(%arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32>) { + func @entry(%arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor) { // CHECK-NEXT: %c0 = constant 0 : index - // CHECK-NEXT: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<8x4xf32> - // CHECK-NEXT: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<4x8xf32> + // CHECK-NEXT: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : !flow.dispatch.tensor + // CHECK-NEXT: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@wo1[%c0] : !flow.dispatch.tensor - // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG]] - %arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> + // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG]] + %arg_tile = flow.dispatch.tensor.load %arg : !flow.dispatch.tensor -> tensor<8x4xf32> // CHECK-NEXT: %[[RET_TILE:.+]] = "test.sink"(%[[ARG_TILE]]) %ret_tile = "test.sink"(%arg_tile) : (tensor<8x4xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET]] - flow.dispatch.output.store %ret_tile, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> + // CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET]] + flow.dispatch.tensor.store %ret_tile, %ret : tensor<4x8xf32> -> !flow.dispatch.tensor return } } @@ -149,15 +149,15 @@ flow.executable @static_tiled_dispatch { // CHECK-LABEL: hal.executable @dynamic_tiled_dispatch // CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 4 : i32} { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output, index, index, index, index) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<7x?x24x?xf32>) -> tensor, @@ -168,10 +168,10 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: func @entry() { func @entry( // CHECK-NEXT: %c0 = constant 0 : index - // CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32> - %arg: !flow.dispatch.input<7x?x24x?xf32>, - // CHECK-DAG: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output - %ret: !flow.dispatch.output, + // CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : !flow.dispatch.tensor + %arg: !flow.dispatch.tensor, + // CHECK-DAG: %[[RET:.+]] = hal.interface.binding.subspan @legacy_io::@wo1[%c0] : !flow.dispatch.tensor + %ret: !flow.dispatch.tensor, // CHECK-DAG: %[[ARG_DIM1:.+]] = hal.interface.load.constant offset = 0 : index %arg_dim1: index, // CHECK-DAG: %[[ARG_DIM3:.+]] = hal.interface.load.constant offset = 1 : index @@ -184,17 +184,17 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: %[[ARG_SHAPE:.+]] = shapex.make_ranked_shape %[[ARG_DIM1]], %[[ARG_DIM3]] %arg_shape = shapex.make_ranked_shape %arg_dim1, %arg_dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> // CHECK-NEXT: %[[ARG_SHAPED:.+]] = flow.dispatch.tie_shape %[[ARG]], %[[ARG_SHAPE]] - %arg_shaped = flow.dispatch.tie_shape %arg, %arg_shape : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32> + %arg_shaped = flow.dispatch.tie_shape %arg, %arg_shape : (!flow.dispatch.tensor, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.tensor // CHECK-NEXT: %[[RET_SHAPE:.+]] = shapex.make_ranked_shape %[[RET_DIM0]], %[[RET_DIM1]] %ret_shape = shapex.make_ranked_shape %ret_dim0, %ret_dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> // CHECK-NEXT: %[[RET_SHAPED:.+]] = flow.dispatch.tie_shape %[[RET]], %[[RET_SHAPE]] - %ret_shaped = flow.dispatch.tie_shape %ret, %ret_shape : (!flow.dispatch.output, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output - // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.input.load %[[ARG_SHAPED]] - %arg_tile = flow.dispatch.input.load %arg_shaped : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> + %ret_shaped = flow.dispatch.tie_shape %ret, %ret_shape : (!flow.dispatch.tensor, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.tensor + // CHECK-NEXT: %[[ARG_TILE:.+]] = flow.dispatch.tensor.load %[[ARG_SHAPED]] + %arg_tile = flow.dispatch.tensor.load %arg_shaped : !flow.dispatch.tensor -> tensor<7x?x24x?xf32> // CHECK-NEXT: %[[RET_TILE:.+]] = "test.tile_math"(%[[ARG_TILE]]) %ret_tile = "test.tile_math"(%arg_tile) : (tensor<7x?x24x?xf32>) -> tensor - // CHECK-NEXT: flow.dispatch.output.store %[[RET_TILE]], %[[RET_SHAPED]] - flow.dispatch.output.store %ret_tile, %ret_shaped : tensor -> !flow.dispatch.output + // CHECK-NEXT: flow.dispatch.tensor.store %[[RET_TILE]], %[[RET_SHAPED]] + flow.dispatch.tensor.store %ret_tile, %ret_shaped : tensor -> !flow.dispatch.tensor return } } @@ -204,15 +204,15 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-LABEL: hal.executable @workgroup_infos // CHECK-NEXT: hal.interface @legacy_io { -// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" -// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +// CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" +// CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } flow.executable @workgroup_infos { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, // CHECK-SAME: ordinal = 0 : i32, - // CHECK-SAME: signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> () + // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, @@ -221,7 +221,7 @@ flow.executable @workgroup_infos { // CHECK-NEXT: module { module { // CHECK-NEXT: func @entry() { - func @entry(%arg: !flow.dispatch.input<8x4xf32>, %ret: !flow.dispatch.output<4x8xf32>) { + func @entry(%arg: !flow.dispatch.tensor, %ret: !flow.dispatch.tensor) { // CHECK-DAG: %[[WORKGROUP_ID_X:.+]] = hal.interface.workgroup.id[0] : index %id_x = flow.dispatch.workgroup.id[0] : index // CHECK-DAG: %[[WORKGROUP_ID_Y:.+]] = hal.interface.workgroup.id[1] : index diff --git a/iree/compiler/Dialect/IREE/Conversion/test/BUILD b/iree/compiler/Dialect/IREE/Conversion/test/BUILD index b780b112a316..21bed501e8cf 100644 --- a/iree/compiler/Dialect/IREE/Conversion/test/BUILD +++ b/iree/compiler/Dialect/IREE/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["convert_flow_to_hal.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt index 203b15530bf5..5dc210ff6429 100644 --- a/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_flow_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/IREE/IR/BUILD b/iree/compiler/Dialect/IREE/IR/BUILD index e04be16fe782..455f4613bfae 100644 --- a/iree/compiler/Dialect/IREE/IR/BUILD +++ b/iree/compiler/Dialect/IREE/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,25 +26,35 @@ exports_files(["IREEBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "IREEBase.td", + "IREEInterfaces.td", + "IREEOps.td", + ], + include = ["*.td"], + ), ) cc_library( name = "IR", srcs = [ "IREEDialect.cpp", + "IREEOpInterfaces.cpp.inc", "IREEOps.cpp", "IREEOps.cpp.inc", "IREETypes.cpp", ], hdrs = [ "IREEDialect.h", + "IREEOpInterfaces.h.inc", "IREEOps.h", "IREEOps.h.inc", "IREETraits.h", "IREETypes.h", ], deps = [ + ":IREEInterfacesGen", ":IREEOpsGen", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -54,6 +65,20 @@ cc_library( ], ) +gentbl( + name = "IREEInterfacesGen", + tbl_outs = [ + ("-gen-op-interface-decls", "IREEOpInterfaces.h.inc"), + ("-gen-op-interface-defs", "IREEOpInterfaces.cpp.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "IREEInterfaces.td", + td_srcs = [ + ":td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + gentbl( name = "IREEOpsGen", tbl_outs = [ diff --git a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt index 14f21a4c7294..772034ba6054 100644 --- a/iree/compiler/Dialect/IREE/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/CMakeLists.txt @@ -10,18 +10,19 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR HDRS "IREEDialect.h" + "IREEOpInterfaces.h.inc" "IREEOps.h" "IREEOps.h.inc" "IREETraits.h" "IREETypes.h" SRCS "IREEDialect.cpp" + "IREEOpInterfaces.cpp.inc" "IREEOps.cpp" "IREEOps.cpp.inc" "IREETypes.cpp" @@ -35,6 +36,16 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + IREEInterfacesGen + TD_FILE + "IREEInterfaces.td" + OUTS + -gen-op-interface-decls IREEOpInterfaces.h.inc + -gen-op-interface-defs IREEOpInterfaces.cpp.inc +) + iree_tablegen_library( NAME IREEOpsGen diff --git a/iree/compiler/Dialect/IREE/IR/IREEBase.td b/iree/compiler/Dialect/IREE/IR/IREEBase.td index 0d2b1bbba123..c4906dab4fa4 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEBase.td +++ b/iree/compiler/Dialect/IREE/IR/IREEBase.td @@ -53,6 +53,11 @@ class IREE_IndexAttrBase : } def IREE_IndexAttr : IREE_IndexAttrBase<"size_t">; +def IREE_TiedOpStorageAttr : + TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; +} + //===----------------------------------------------------------------------===// // Status codes //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td b/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td new file mode 100644 index 000000000000..2e2a65e3df79 --- /dev/null +++ b/iree/compiler/Dialect/IREE/IR/IREEInterfaces.td @@ -0,0 +1,166 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_IREE_INTERFACES +#define IREE_DIALECT_IREE_INTERFACES + +include "iree/compiler/Dialect/IREE/IR/IREEBase.td" + +//===----------------------------------------------------------------------===// +// IREE::TiedOpInterface +//===----------------------------------------------------------------------===// + +def IREE_TiedOpInterface : OpInterface<"TiedOpInterface"> { + let description = [{ + An operation that "ties" one or more results to its operands indicating + that the result is directly related to the operand in an operation-defined + way. Results are still SSA values distinct from the operands and the tie is + strictly a relationship relevant to transformations and not something that + modifies IR definitions. + + Example: + An operation on tensors that wants to indicate that the storage for a + result should alias the storage for an operand, performing an "in-place" + operation. Since tensors are still used there is no hard requirement that + uses of the result SSA value alias the operand; a copy may still be + introduced. + + See: flow.dispatch.workgroups + + Example: + An operation on buffers that wants to encode activity on the buffer in IR + (such as a barrier, a transfer operation, etc) such that the SSA use-def + chain is representing the state of the buffer at various points in time + but that the underlying buffers are all tied together. + + See: hal.stream.barrier + + The default implementations use an attribute on the op to store the + relationship: + `OptionalAttr:$tied_operands` + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the set of operands that results may be tied to as an + (index, length) pair ala getODSOperandIndexAndLength. + + By default assumes all operands may be tied. If an op treats some + operands as special then the op can override this and specify only the + ones it will tie. For example, a cond_branch that has a condition + operand as well as the successor operands would return only the range + of successor operands. + }], + /*retTy=*/"std::pair", + /*methodName=*/"getTiedOperandsIndexAndLength", (ins), + /*args=*/[{}], + /*defaultImplementation=*/[{ + return {0, $_op.getNumOperands()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Walks up the SSA use-def chain to find the first defined value reachable + from the given value by traversing tied ops. The returned value may be + in another block if that block dominates the one the result is defined + in. + + Note that the returned value may be a block argument and have no + defining op, and the search will not continue past branches. + If the result is untied then the result itself is returned. + }], + /*retTy=*/"Value", + /*methodName=*/"getTiedResult", + /*args=*/(ins "unsigned":$resultIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::TiedOpInterface::findTiedBaseValue($_op.getResult(resultIndex)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand tied to the given result of the op or nullptr if + none. + }], + /*retTy=*/"Value", + /*methodName=*/"getTiedResultOperand", + /*args=*/(ins "Value":$result), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + auto resultIndex = result.cast().getResultNumber(); + auto operandIndex = IREE::detail::getTiedResultOperandIndex($_op, resultIndex); + return operandIndex.hasValue() ? + $_op.getOperand(operandIndex.getValue()) : + nullptr; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand index tied to the given result index, if any. + }], + /*retTy=*/"::llvm::Optional", + /*methodName=*/"getTiedResultOperandIndex", + /*args=*/(ins "unsigned":$resultIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::getTiedResultOperandIndex($_op, resultIndex); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Sets the operand index tied to the given result index, if any. + }], + /*retTy=*/"void", + /*methodName=*/"setTiedResultOperandIndex", + /*args=*/(ins "unsigned":$resultIndex, + "::llvm::Optional":$operandIndex), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::setTiedResultOperandIndex($_op, resultIndex, operandIndex); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns an array containing the tied result operand indices with -1 + indicating that a result is not tied. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getTiedResultOperandIndices", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return IREE::detail::getTiedResultOperandIndices($_op); + }] + >, + ]; + + let extraClassDeclaration = [{ + static StringRef getStorageAttrName() { return "tied_operands"; } + + // Indicates that a result is not tied to any operand. + static constexpr int64_t kUntiedIndex = -1; + + // Walks the SSA use-def chain to find the first defined value reachable + // from the given value by traversing tied ops. Note that the returned + // value may be a block argument and have no defining op. + static Value findTiedBaseValue(Value derivedValue); + }]; + + let verify = [{ + return IREE::detail::verifyTiedOp($_op); + }]; +} + +#endif // IREE_DIALECT_IREE_INTERFACES diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp index 745697cac353..c278c5370eb6 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/TypeSupport.h" +#include "mlir/Interfaces/CastInterfaces.h" namespace mlir { namespace iree_compiler { @@ -103,6 +104,124 @@ PtrType PtrType::getChecked(function_ref emitError, Type PtrType::getTargetType() { return getImpl()->targetType; } +//===----------------------------------------------------------------------===// +// TiedOpInterface +//===----------------------------------------------------------------------===// + +llvm::Optional detail::getTiedResultOperandIndex( + Operation *op, unsigned resultIndex) { + auto storageAttr = + op->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr) return llvm::None; + auto valueAttrs = storageAttr.getValue(); + if (valueAttrs.empty()) return llvm::None; + int64_t value = valueAttrs[resultIndex].cast().getInt(); + if (value == TiedOpInterface::kUntiedIndex) return llvm::None; + auto tiedOp = cast(op); + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + return tiedOperandsOffset + static_cast(value); +} + +void detail::setTiedResultOperandIndex(Operation *op, unsigned resultIndex, + llvm::Optional operandIndex) { + auto indices = getTiedResultOperandIndices(op); + if (indices.empty()) { + indices.resize(op->getNumResults(), TiedOpInterface::kUntiedIndex); + } + indices[resultIndex] = operandIndex.hasValue() + ? operandIndex.getValue() + : TiedOpInterface::kUntiedIndex; + auto indexType = IndexType::get(op->getContext()); + op->setAttr(TiedOpInterface::getStorageAttrName(), + ArrayAttr::get(op->getContext(), + llvm::to_vector<8>(llvm::map_range( + indices, [&](int64_t v) -> Attribute { + return IntegerAttr::get(indexType, v); + })))); +} + +SmallVector detail::getTiedResultOperandIndices(Operation *op) { + SmallVector indices; + auto storageAttr = + op->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr) return indices; + auto valueAttrs = storageAttr.getValue(); + if (valueAttrs.empty()) return indices; + auto tiedOp = cast(op); + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + indices.resize(op->getNumResults()); + for (unsigned i = 0; i < valueAttrs.size(); ++i) { + int64_t index = valueAttrs[i].cast().getInt(); + indices[i] = index != TiedOpInterface::kUntiedIndex + ? tiedOperandsOffset + index + : TiedOpInterface::kUntiedIndex; + } + return indices; +} + +Value TiedOpInterface::findTiedBaseValue(Value derivedValue) { + Value baseValue = derivedValue; + while (auto definingOp = + dyn_cast_or_null(baseValue.getDefiningOp())) { + auto tiedValue = definingOp.getTiedResultOperand(baseValue); + if (!tiedValue) break; + baseValue = tiedValue; + } + return baseValue; +} + +LogicalResult detail::verifyTiedOp(TiedOpInterface tiedOp) { + unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; + auto storageAttr = + tiedOp->getAttrOfType(TiedOpInterface::getStorageAttrName()); + if (!storageAttr || storageAttr.getValue().empty()) { + return success(); + } + auto tiedOperandIndices = storageAttr.getValue(); + if (tiedOperandIndices.size() != tiedOp->getNumResults()) { + return tiedOp.emitError("op results/tied operand indices mismatch"); + } + for (unsigned resultIndex = 0; resultIndex < tiedOp->getNumResults(); + ++resultIndex) { + int64_t tiedOperandIndex = + tiedOperandIndices[resultIndex].cast().getInt(); + if (tiedOperandIndex < 0) continue; + auto operandType = + tiedOp->getOperand(tiedOperandsOffset + tiedOperandIndex).getType(); + auto resultType = tiedOp->getResult(resultIndex).getType(); + if (operandType != resultType) { + return tiedOp.emitError( + "tied operand and result type mismatch; operand has ") + << operandType << " and result has " << resultType; + } + } + return success(); +} + +void excludeTiedOperandAndResultIndices( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices, + SmallVector &tiedOperandIndices) { + SmallVector oldTiedOperandIndices = tiedOperandIndices; + tiedOperandIndices.clear(); + for (auto it : llvm::enumerate(oldTiedOperandIndices)) { + unsigned resultIndex = it.index(); + if (!llvm::count(excludedResultIndices, resultIndex)) { + continue; // result removed + } + int64_t tiedOperandIndex = it.value(); + if (tiedOperandIndex != TiedOpInterface::kUntiedIndex) { + if (!llvm::count(excludedOperandIndices, tiedOperandIndex)) { + tiedOperandIndex = TiedOpInterface::kUntiedIndex; // operand removed + } + } + tiedOperandIndices.push_back(tiedOperandIndex); + } +} + +// At the end so it can use functions above: +#include "iree/compiler/Dialect/IREE/IR/IREEOpInterfaces.cpp.inc" + } // namespace IREE } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h index 3b55df45480b..73e81305bdec 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.h +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -25,10 +26,14 @@ namespace mlir { namespace iree_compiler { namespace IREE { +class TiedOpInterface; + namespace detail { + struct ListTypeStorage; struct PtrTypeStorage; struct RankedShapeTypeStorage; + } // namespace detail // Status code table mapping to iree::StatusCode in the runtime. @@ -111,6 +116,24 @@ class MutableByteBufferType using Base::Base; }; +namespace detail { +llvm::Optional getTiedResultOperandIndex(Operation *op, + unsigned resultIndex); +void setTiedResultOperandIndex(Operation *op, unsigned resultIndex, + llvm::Optional operandIndex); +SmallVector getTiedResultOperandIndices(Operation *op); +LogicalResult verifyTiedOp(TiedOpInterface tiedOp); +} // namespace detail + +// Resets or removes the indices in |tiedOperandIndices| based on the given +// exclusion lists. +void excludeTiedOperandAndResultIndices( + ArrayRef excludedOperandIndices, + ArrayRef excludedResultIndices, + SmallVector &tiedOperandIndices); + +#include "iree/compiler/Dialect/IREE/IR/IREEOpInterfaces.h.inc" + } // namespace IREE } // namespace iree_compiler diff --git a/iree/compiler/Dialect/IREE/IR/test/BUILD b/iree/compiler/Dialect/IREE/IR/test/BUILD index b780b112a316..aacb0e0f1def 100644 --- a/iree/compiler/Dialect/IREE/IR/test/BUILD +++ b/iree/compiler/Dialect/IREE/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "byte_buffer_ops.mlir", + "do_not_optimize.mlir", + "parse_print.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt index 4b02562ac5b7..34fbf727c1c8 100644 --- a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "byte_buffer_ops.mlir" + "do_not_optimize.mlir" + "parse_print.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/IREE/Transforms/test/BUILD b/iree/compiler/Dialect/IREE/Transforms/test/BUILD index b780b112a316..061dc7743d8f 100644 --- a/iree/compiler/Dialect/IREE/Transforms/test/BUILD +++ b/iree/compiler/Dialect/IREE/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["drop_compiler_hints.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt index 1de97a594fc5..3638b362f24d 100644 --- a/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/Transforms/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "drop_compiler_hints.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Check/IR/BUILD b/iree/compiler/Dialect/Modules/Check/IR/BUILD index fd35c89b731e..5ca666407258 100644 --- a/iree/compiler/Dialect/Modules/Check/IR/BUILD +++ b/iree/compiler/Dialect/Modules/Check/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["CheckOps.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt index 508b269a4400..6eb8f4b5bdd2 100644 --- a/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Check/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/Check/test/BUILD b/iree/compiler/Dialect/Modules/Check/test/BUILD index bb9c0e0808d7..d931dda42da7 100644 --- a/iree/compiler/Dialect/Modules/Check/test/BUILD +++ b/iree/compiler/Dialect/Modules/Check/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "canonicalize.mlir", + "ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt index 90487d98c828..263842a4761b 100644 --- a/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Check/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "canonicalize.mlir" + "ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD b/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD index bb9c0e0808d7..22b5c8ff4c51 100644 --- a/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["convert_to_hal.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt index 3d350faefc84..1d7774acf461 100644 --- a/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/Strings/IR/BUILD b/iree/compiler/Dialect/Modules/Strings/IR/BUILD index 6967ede925cc..e5a5fa414514 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["Ops.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt index d310fe234796..636318f6cdb3 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD b/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD index b780b112a316..bfb77fa33c7f 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD +++ b/iree/compiler/Dialect/Modules/Strings/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["strings_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt index bcb5d45b497a..47e93c248172 100644 --- a/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/Strings/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "strings_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/TensorList/BUILD b/iree/compiler/Dialect/Modules/TensorList/BUILD index 97fbca4dce3e..57e8c589a51a 100644 --- a/iree/compiler/Dialect/Modules/TensorList/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/BUILD @@ -20,11 +20,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -filegroup( - name = "td_files", - srcs = glob(["*.td"]), -) - cc_embed_data( name = "tensorlist_imports", srcs = ["tensorlist.imports.mlir"], diff --git a/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt index b100495ab132..795076a278bc 100644 --- a/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_embed_data( NAME tensorlist_imports diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD index b780b112a316..c7f1ae864218 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_hal_to_vm.mlir", + "convert_to_hal.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt index 42faab43624b..7b2121058027 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_hal_to_vm.mlir" + "convert_to_hal.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/BUILD b/iree/compiler/Dialect/Modules/TensorList/IR/BUILD index 02b533055129..ac7a903ad033 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["TensorListBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "TensorListBase.td", + "TensorListOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt index e50e0113ec43..b3f56e7c8658 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD b/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD index b780b112a316..50dd4d0d66b4 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD +++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt index c75d60e1c64a..f344660c0c92 100644 --- a/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Sequence/IR/BUILD b/iree/compiler/Dialect/Sequence/IR/BUILD index 430d094390d6..a670f1acc117 100644 --- a/iree/compiler/Dialect/Sequence/IR/BUILD +++ b/iree/compiler/Dialect/Sequence/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "SequenceBase.td", + "SequenceOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt b/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt index cad8c6b252a0..2ce24bd7b4f7 100644 --- a/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Sequence/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Sequence/IR/test/BUILD b/iree/compiler/Dialect/Sequence/IR/test/BUILD index bb9c0e0808d7..7eb3cff6c0e7 100644 --- a/iree/compiler/Dialect/Sequence/IR/test/BUILD +++ b/iree/compiler/Dialect/Sequence/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "sequence_map.mlir", + "sequence_of.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt index 80344af8af68..e4132bbb2f57 100644 --- a/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Sequence/IR/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "sequence_map.mlir" + "sequence_of.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Conversion/test/BUILD b/iree/compiler/Dialect/Shape/Conversion/test/BUILD index bb9c0e0808d7..c861d29b2047 100644 --- a/iree/compiler/Dialect/Shape/Conversion/test/BUILD +++ b/iree/compiler/Dialect/Shape/Conversion/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["shape_to_shapex.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt index b945e8f56445..d22775fc9683 100644 --- a/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Conversion/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "shape_to_shapex.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD index 96fe609303f5..59d237540531 100644 --- a/iree/compiler/Dialect/Shape/IR/BUILD +++ b/iree/compiler/Dialect/Shape/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,14 @@ exports_files(["ShapeBase.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "ShapeBase.td", + "ShapeInterfaces.td", + "ShapeOps.td", + ], + include = ["*.td"], + ), ) cc_library( @@ -35,6 +43,7 @@ cc_library( "Folders.cpp", "ShapeDialect.cpp", "ShapeInterface.cpp", + "ShapeInterfaces.cpp.inc", "ShapeOps.cpp", "ShapeOps.cpp.inc", "ShapeTypes.cpp", @@ -43,11 +52,13 @@ cc_library( "Builders.h", "ShapeDialect.h", "ShapeInterface.h", + "ShapeInterfaces.h.inc", "ShapeOps.h", "ShapeOps.h.inc", "ShapeTypes.h", ], deps = [ + ":ShapeInterfacesGen", ":ShapeOpsGen", "//iree/compiler/Dialect/IREE/IR", "//iree/compiler/Utils", @@ -64,6 +75,22 @@ cc_library( ], ) +gentbl( + name = "ShapeInterfacesGen", + tbl_outs = [ + ("-gen-op-interface-decls", "ShapeInterfaces.h.inc"), + ("-gen-op-interface-defs", "ShapeInterfaces.cpp.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ShapeInterfaces.td", + td_srcs = [ + ":td_files", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + ], +) + gentbl( name = "ShapeOpsGen", tbl_outs = [ diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp index 78ae3e333717..14a34862f99a 100644 --- a/iree/compiler/Dialect/Shape/IR/Builders.cpp +++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Shape/IR/Builders.h" +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -23,27 +24,68 @@ namespace mlir { namespace iree_compiler { namespace Shape { -namespace { +static Value getRankedShapeFromOpResult(Operation *op, Value resultValue, + OpBuilder &builder) { + if (!op) return nullptr; + if (auto carryingOp = dyn_cast(op)) { + return carryingOp.buildResultValueRankedShape(resultValue, builder); + } else { + return nullptr; + } +} -Value getRankedShapeFromOp(Operation *op) { - auto tieOp = llvm::dyn_cast_or_null(op); - if (!tieOp) return nullptr; - auto shape = tieOp.shape(); - if (!shape.getType().isa()) return nullptr; - return shape; +static Value getRankedShapeFromOpOperand(Operation *op, unsigned idx, + OpBuilder &builder) { + auto carryingOp = dyn_cast_or_null(op); + if (!carryingOp) { + auto value = op->getOperand(idx); + auto definingOp = value.getDefiningOp(); + if (!definingOp) return nullptr; + return getRankedShapeFromOpResult(definingOp, value, builder); + } + return carryingOp.buildOperandRankedShape(idx, builder); } -Value findRankedShapeFromUse(Value value) { - Value rs = getRankedShapeFromOp(value.getDefiningOp()); +static Value findRankedShapeFromUse(Value value, OpBuilder &builder) { + Value rs = getRankedShapeFromOpResult(value.getDefiningOp(), value, builder); if (rs) return rs; for (auto &use : value.getUses()) { - rs = getRankedShapeFromOp(use.getOwner()); + rs = getRankedShapeFromOpOperand(use.getOwner(), use.getOperandNumber(), + builder); if (rs) return rs; } return nullptr; } -} // namespace +Value buildRankedShapeForValue(Location loc, Value shapedValue, + ValueRange dynamicDims, OpBuilder &builder) { + auto shapedType = shapedValue.getType().dyn_cast(); + assert(shapedType && "only valid to call on shaped types"); + return builder.createOrFold( + loc, Shape::RankedShapeType::get(shapedType), dynamicDims); +} + +// Slices out a range of |dynamicDims| corresponding to the value at |index|. +static ValueRange sliceDynamicDims(unsigned index, ValueRange values, + ValueRange dynamicDims) { + auto valueType = values[index].getType().dyn_cast(); + assert(valueType && "must be a shaped type to get dims"); + unsigned dimsIndex = 0; + for (unsigned i = 0; i < index; ++i) { + if (auto shapedType = values[i].getType().dyn_cast()) { + dimsIndex += shapedType.getNumDynamicDims(); + } + } + return dynamicDims.slice(dimsIndex, valueType.getNumDynamicDims()); +} + +Value buildRankedShapeForValueInList(Location loc, unsigned index, + ValueRange flatValues, + ValueRange flatDynamicDims, + OpBuilder &builder) { + auto dynamicDims = sliceDynamicDims(index, flatValues, flatDynamicDims); + return buildRankedShapeForValue(loc, flatValues[index], dynamicDims, builder); +} Value buildCastInputsToResultShape(Location loc, RankedShapeType resultShapeType, @@ -121,9 +163,10 @@ Value buildDegenerateBroadcastRankedShape( } } -LogicalResult getRankedDimsFromRankedShape( - Location loc, Value rsValue, bool createIntermediateOps, - SmallVectorImpl &outDims, ConversionPatternRewriter &rewriter) { +LogicalResult getRankedDimsFromRankedShape(Location loc, Value rsValue, + bool createIntermediateOps, + SmallVectorImpl &outDims, + OpBuilder &builder) { Operation *op = rsValue.getDefiningOp(); if (op && (llvm::isa(op) || llvm::isa(op))) { @@ -134,28 +177,26 @@ LogicalResult getRankedDimsFromRankedShape( if (dynamicDimIndex >= op->getNumOperands()) { return emitError(loc, "mismatched dynamic dimensions"); } - Value remappedValue = - rewriter.getRemappedValue(op->getOperand(dynamicDimIndex++)); - if (!remappedValue) { + Value dimValue = op->getOperand(dynamicDimIndex++); + if (!dimValue) { return emitError( loc, "unable to find remapped value for ranked dim value"); } - outDims.push_back(remappedValue); + outDims.push_back(dimValue); } else { outDims.push_back( - rewriter.create(loc, rsType.getStaticDim(i))); + builder.create(loc, rsType.getStaticDim(i))); } } - return success(); } else if (createIntermediateOps) { - auto dimsOp = rewriter.create(loc, rsValue); + auto dimsOp = builder.create(loc, rsValue); outDims.resize(dimsOp.result().size()); std::copy(dimsOp.result().begin(), dimsOp.result().end(), outDims.begin()); - return success(); } else { return emitError(loc, "could not resolve ranked dimensions from metadata ops"); } + return success(); } Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, @@ -174,7 +215,7 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, // Dynamic - walk the uses to find a tie_shape op (either this op or an // immediate use). - Value rs = findRankedShapeFromUse(value); + Value rs = findRankedShapeFromUse(value, builder); if (!rs) { builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) << "dynamically shaped value is missing a shape association via " @@ -192,6 +233,49 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, return rs; } +SmallVector buildOrFindDynamicDimsForValue(Location loc, Value value, + OpBuilder &builder) { + auto valueSt = value.getType().dyn_cast(); + if (!valueSt) { + builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) + << "cannot construct shape for non shaped value: " << value.getType(); + return {}; + } + + // Bail if all dimensions are static. + if (valueSt.hasStaticShape()) { + return {}; + } + + // Dynamic - walk the uses to find a tie_shape op (either this op or an + // immediate use). + SmallVector result; + Value rs = findRankedShapeFromUse(value, builder); + if (rs) { + auto rsType = rs.getType().dyn_cast(); + if (!rsType) { + builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error) + << "dynamically shaped value is not ranked (which is not yet " + << "supported)"; + return {}; + } + for (unsigned i = 0; i < rsType.getRank(); ++i) { + if (rsType.isDimDynamic(i)) { + result.push_back(builder.createOrFold(loc, rs, i)); + } + } + } else { + // No tie information - insert std.dim ops that may later be used and + // hopefully converted to ranked shape types. + for (unsigned i = 0; i < valueSt.getRank(); ++i) { + if (valueSt.isDynamicDim(i)) { + result.push_back(builder.createOrFold(loc, value, i)); + } + } + } + return result; +} + } // namespace Shape } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Dialect/Shape/IR/Builders.h b/iree/compiler/Dialect/Shape/IR/Builders.h index 825c69ac0282..9951272a398a 100644 --- a/iree/compiler/Dialect/Shape/IR/Builders.h +++ b/iree/compiler/Dialect/Shape/IR/Builders.h @@ -25,6 +25,18 @@ namespace mlir { namespace iree_compiler { namespace Shape { +// Builds a ranked_shape for the given |shapedValue| with zero or more dynamic +// dims with the values taken from |dynamicDims|. +Value buildRankedShapeForValue(Location loc, Value shapedValue, + ValueRange dynamicDims, OpBuilder &builder); + +// As with buildRankedShapeForValue but by selecting out the appropriate dims +// from a flattened set of values and dynamic dims. +Value buildRankedShapeForValueInList(Location loc, unsigned index, + ValueRange flatValues, + ValueRange flatDynamicDims, + OpBuilder &builder); + // Given an arbitrary list of inputs, builds IR to obtain their shapes and // cast them to a given !shapex.ranked_shape. Statically verifiable invariants // will be checked within this call and runtime code will be emitted to verify @@ -60,6 +72,12 @@ Value buildDegenerateBroadcastRankedShape( Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, OpBuilder &builder); +// Returns dimension values for each dynamic dimension of the given |value|. +// |value| must be a ShapedType and may optionally have a ranked_shape tied. +// The returned value range will be empty if the shape is fully static. +SmallVector buildOrFindDynamicDimsForValue(Location loc, Value value, + OpBuilder &builder); + // Given a RankedShapeType'd value |rsValue|, populate values for all // dimensions. If |createIntermediateOps|, then if the dims cannot be resolved // by walking the IR, then a RankedDimsOp is created. If false, then the dims @@ -68,7 +86,7 @@ Value buildOrFindRankedShapeForValue(Location loc, Value value, Type dimType, LogicalResult getRankedDimsFromRankedShape(Location loc, Value rsValue, bool createIntermediateOps, SmallVectorImpl &outDims, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); } // namespace Shape } // namespace iree_compiler diff --git a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt index 84ff59ddaff2..b00454146a0f 100644 --- a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR @@ -18,6 +17,7 @@ iree_cc_library( "Builders.h" "ShapeDialect.h" "ShapeInterface.h" + "ShapeInterfaces.h.inc" "ShapeOps.h" "ShapeOps.h.inc" "ShapeTypes.h" @@ -26,6 +26,7 @@ iree_cc_library( "Folders.cpp" "ShapeDialect.cpp" "ShapeInterface.cpp" + "ShapeInterfaces.cpp.inc" "ShapeOps.cpp" "ShapeOps.cpp.inc" "ShapeTypes.cpp" @@ -45,6 +46,16 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + ShapeInterfacesGen + TD_FILE + "ShapeInterfaces.td" + OUTS + -gen-op-interface-decls ShapeInterfaces.h.inc + -gen-op-interface-defs ShapeInterfaces.cpp.inc +) + iree_tablegen_library( NAME ShapeOpsGen diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp index fc157937292b..a6349ee9b854 100644 --- a/iree/compiler/Dialect/Shape/IR/Folders.cpp +++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp @@ -52,17 +52,17 @@ LogicalResult safeCastCompatibleShapePattern( return failure(); } -LogicalResult elideTiedGetRankedShapePattern(GetRankedShapeOp op, - GetRankedShapeOp::Adaptor operands, - PatternRewriter &rewriter) { - // If the immediate predecessor is a TieShapeOp, then this op can be - // erased in favor of the input to the tie op. - auto tieOp = dyn_cast_or_null(operands.operand().getDefiningOp()); - if (!tieOp) { - return rewriter.notifyMatchFailure(op, "no associated tie_shape op"); +LogicalResult elideShapeCarryingGetRankedShapePattern( + GetRankedShapeOp op, GetRankedShapeOp::Adaptor operands, + PatternRewriter &rewriter) { + auto carryingOp = dyn_cast_or_null( + operands.operand().getDefiningOp()); + if (!carryingOp) { + return rewriter.notifyMatchFailure(op, + "no associated dynamic-shape aware op"); } - - rewriter.replaceOp(op, tieOp.shape()); + rewriter.replaceOp( + op, carryingOp.buildResultValueRankedShape(operands.operand(), rewriter)); return success(); } @@ -225,6 +225,42 @@ LogicalResult elideDuplicateTieShapePattern(TieShapeOp op, return success(); } +// Removes tie_shape ops when the operand is produced by a shape-aware op. +LogicalResult elideShapeCarryingOperandTieShapePattern( + TieShapeOp op, TieShapeOp::Adaptor operands, PatternRewriter &rewriter) { + auto definingOp = operands.operand().getDefiningOp(); + if (!definingOp) return failure(); + if (isa(definingOp)) { + return failure(); // ignore tie-shape handled above + } else if (isa(definingOp)) { + rewriter.replaceOp(op, operands.operand()); + return success(); + } else { + return failure(); + } +} + +// Reroutes uses of tie_shape ops by ops that are shape-aware or dim ops. +LogicalResult elideTieShapeUsagePattern(TieShapeOp op, + TieShapeOp::Adaptor operands, + PatternRewriter &rewriter) { + bool didAnything = false; + for (auto &use : llvm::make_early_inc_range(op.result().getUses())) { + if (auto carryingOp = dyn_cast(use.getOwner())) { + carryingOp->setOperand(use.getOperandNumber(), operands.operand()); + didAnything = true; + } else if (auto dimOp = dyn_cast(use.getOwner())) { + auto index = dimOp.getConstantIndex(); + if (index.hasValue()) { + rewriter.replaceOpWithNewOp(dimOp, op.shape(), + index.getValue()); + didAnything = true; + } + } + } + return didAnything ? success() : failure(); +} + //===----------------------------------------------------------------------===// // shapex.tie_shape //===----------------------------------------------------------------------===// @@ -232,6 +268,9 @@ LogicalResult elideDuplicateTieShapePattern(TieShapeOp op, void TieShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { insertGreedyPattern(patterns, context, elideDuplicateTieShapePattern); + insertGreedyPattern(patterns, context, + elideShapeCarryingOperandTieShapePattern); + insertGreedyPattern(patterns, context, elideTieShapeUsagePattern); } //===----------------------------------------------------------------------===// @@ -249,7 +288,8 @@ void CastCompatibleShapeOp::getCanonicalizationPatterns( void GetRankedShapeOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - insertGreedyPattern(patterns, context, elideTiedGetRankedShapePattern); + insertGreedyPattern(patterns, context, + elideShapeCarryingGetRankedShapePattern); insertGreedyPattern(patterns, context, elideDuplicateGetRankedShapePattern); insertGreedyPattern(patterns, context, elideStaticGetRankedShapePattern); } @@ -358,7 +398,11 @@ void populateFoldConversionPatterns(MLIRContext *context, insertConversionPattern(patterns, context, elideDuplicateGetRankedShapePattern); insertConversionPattern(patterns, context, elideDuplicateTieShapePattern); - insertConversionPattern(patterns, context, elideTiedGetRankedShapePattern); + insertConversionPattern(patterns, context, + elideShapeCarryingOperandTieShapePattern); + insertConversionPattern(patterns, context, elideTieShapeUsagePattern); + insertConversionPattern(patterns, context, + elideShapeCarryingGetRankedShapePattern); insertConversionPattern(patterns, context, expandRankedShapeDimsPattern); insertConversionPattern(patterns, context, identityMakeRankedShapePattern); insertConversionPattern(patterns, context, elideStaticGetRankedShapePattern); diff --git a/iree/compiler/Dialect/Shape/IR/ShapeBase.td b/iree/compiler/Dialect/Shape/IR/ShapeBase.td index 180c7aca47fb..d577ad3aff76 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeBase.td +++ b/iree/compiler/Dialect/Shape/IR/ShapeBase.td @@ -18,7 +18,7 @@ include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// -// Shape dialect. +// Shape dialect //===----------------------------------------------------------------------===// // TODO(b/143787186): rename when old dialects are removed. diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp index 7586c5fb3d03..71511ec5b369 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp +++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp @@ -32,6 +32,8 @@ namespace mlir { namespace iree_compiler { +#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.cpp.inc" + // Used to control inlining behavior. struct ShapeInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h b/iree/compiler/Dialect/Shape/IR/ShapeDialect.h index 6465384204c2..f865f84ce6ee 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.h @@ -21,6 +21,8 @@ namespace mlir { namespace iree_compiler { +#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.h.inc" + class ShapeDialect : public Dialect { public: explicit ShapeDialect(MLIRContext* context); diff --git a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td new file mode 100644 index 000000000000..c14a11961f6f --- /dev/null +++ b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td @@ -0,0 +1,58 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_SHAPE_INTERFACES +#define IREE_DIALECT_SHAPE_INTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Op interfaces +//===----------------------------------------------------------------------===// + +def Shape_ShapeCarryingOpInterface : OpInterface<"ShapeCarryingInterface"> { + let description = [{ + Interface for ops that interact with dynamically shaped inputs and outputs. + Such ops are able to materialize RankedShapes on demand for any operand or + result that derives from ShapedType. + }]; + + let methods = [ + StaticInterfaceMethod< + /*desc=*/[{Returns a RankedShape for the given shaped result value.}], + /*retTy=*/"Value", + /*methodName=*/"buildResultValueRankedShape", + /*args=*/(ins "Value":$result, "OpBuilder &":$builder), + /*methodBody=*/[{ + auto carryingOp = dyn_cast(result.getDefiningOp()); + return carryingOp.buildResultRankedShape( + result.cast().getResultNumber(), builder); + }] + >, + InterfaceMethod< + /*desc=*/[{Returns a RankedShape for the given shaped operand index.}], + /*retTy=*/"Value", + /*methodName=*/"buildOperandRankedShape", + /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder) + >, + InterfaceMethod< + /*desc=*/[{Returns a RankedShape for the given shaped result index.}], + /*retTy=*/"Value", + /*methodName=*/"buildResultRankedShape", + /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder) + >, + ]; +} + +#endif // IREE_DIALECT_SHAPE_INTERFACES diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.h b/iree/compiler/Dialect/Shape/IR/ShapeOps.h index ffd50296186e..ee4c30bf02a7 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeOps.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.h @@ -15,6 +15,7 @@ #ifndef IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_ #define IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_ +#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td index 617e526fd5e7..164f66744f81 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td +++ b/iree/compiler/Dialect/Shape/IR/ShapeOps.td @@ -16,6 +16,7 @@ #define IREE_DIALECT_SHAPE_OPS include "iree/compiler/Dialect/Shape/IR/ShapeBase.td" +include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -39,8 +40,9 @@ class Shape_PureOp traits = []> : //===----------------------------------------------------------------------===// def Shape_TieShapeOp : Shape_PureOp<"tie_shape", [ + Shape_ShapeCarryingOpInterface, AllTypesMatch<["operand", "result"]>, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, ]> { let summary = "Ties a tensor and a shape together."; let description = [{ @@ -57,6 +59,12 @@ def Shape_TieShapeOp : Shape_PureOp<"tie_shape", [ let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; + let extraClassDeclaration = [{ + // ShapeCarryingInterface: + Value buildOperandRankedShape(unsigned, OpBuilder) { return shape(); } + Value buildResultRankedShape(unsigned, OpBuilder) { return shape(); } + }]; + let verifier = [{ return verify$cppClass(*this); }]; let hasCanonicalizer = 1; diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp b/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp index 2dbec24a5024..ce549223c627 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp +++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp @@ -79,6 +79,10 @@ RankedShapeType RankedShapeType::getChecked( return Base::getChecked(emitError, context, dims); } +RankedShapeType RankedShapeType::get(ShapedType shapedType) { + return Base::get(shapedType.getContext(), shapedType.getShape()); +} + LogicalResult RankedShapeType::verify( function_ref emitError, ArrayRef dims) { for (auto dim : dims) { diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h index b48aafb6da01..e8d43309b1c6 100644 --- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h +++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h @@ -47,6 +47,9 @@ class RankedShapeType : public Type::TypeBase emitError, MLIRContext *context, ArrayRef dims); + // Derives a RankedShapeType from a ShapedType. + static RankedShapeType get(ShapedType shapedType); + // Verifies construction invariants and issues errors/warnings. static LogicalResult verify(function_ref emitError, ArrayRef dims); diff --git a/iree/compiler/Dialect/Shape/IR/test/BUILD b/iree/compiler/Dialect/Shape/IR/test/BUILD index b780b112a316..64deb152c406 100644 --- a/iree/compiler/Dialect/Shape/IR/test/BUILD +++ b/iree/compiler/Dialect/Shape/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "canonicalize.mlir", + "op_verification.mlir", + "parse_print.mlir", + "ranked_shape_type.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt index 43c4a40d08a2..fa8db2b4d34f 100644 --- a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "canonicalize.mlir" + "op_verification.mlir" + "parse_print.mlir" + "ranked_shape_type.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir index dfef6a2c6a72..73070967dc46 100644 --- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir +++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir @@ -123,8 +123,8 @@ func @elideDuplicateTieShapePattern_match(%arg0 : tensor, %arg1 : !shapex func @elideDuplicateTieShapePattern_different_shapes(%arg0 : tensor, %arg1 : !shapex.ranked_shape<[?]>, %arg2 : !shapex.ranked_shape<[?]>) -> (tensor) { %0 = shapex.tie_shape %arg0, %arg1 : tensor, !shapex.ranked_shape<[?]> %1 = shapex.tie_shape %0, %arg2 : tensor, !shapex.ranked_shape<[?]> - // CHECK: %[[T:.+]] = shapex.tie_shape %[[ARGT]], %[[ARGRS1]] - // CHECK: shapex.tie_shape %[[T]], %[[ARGRS2]] + // CHECK: %[[T:.+]] = shapex.tie_shape %[[ARGT]], %[[ARGRS2]] + // CHECK: return %[[T]] return %1 : tensor } diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD index bb9c0e0808d7..24fd2874ad5e 100644 --- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD +++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["custom_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt index e0eb941b7a31..3c6404e90d64 100644 --- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "custom_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD index bb9c0e0808d7..24fd2874ad5e 100644 --- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD +++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["custom_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt index 44744ecb6f31..2fd4f816907b 100644 --- a/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Plugins/XLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "custom_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Shape/Transforms/test/BUILD b/iree/compiler/Dialect/Shape/Transforms/test/BUILD index bb9c0e0808d7..eb8a1ecfe445 100644 --- a/iree/compiler/Dialect/Shape/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Shape/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "convert_hlo_to_shape_dialect.mlir", + "expand_function_dynamic_dims.mlir", + "expand_function_ranked_shape_dims.mlir", + "hoist_shape_calculations.mlir", + "materialize_shape_calculations.mlir", + "tie_dynamic_shapes.mlir", + "tie_dynamic_shapes_no_recurse.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt index 4acab9242282..6c0229e3efdf 100644 --- a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "convert_hlo_to_shape_dialect.mlir" + "expand_function_dynamic_dims.mlir" + "expand_function_ranked_shape_dims.mlir" + "hoist_shape_calculations.mlir" + "materialize_shape_calculations.mlir" + "tie_dynamic_shapes.mlir" + "tie_dynamic_shapes_no_recurse.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Analysis/test/BUILD b/iree/compiler/Dialect/VM/Analysis/test/BUILD index b780b112a316..a411dd6ffd53 100644 --- a/iree/compiler/Dialect/VM/Analysis/test/BUILD +++ b/iree/compiler/Dialect/VM/Analysis/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "register_allocation.mlir", + "value_liveness.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt index cb58a80932b4..2dd28bfa3b50 100644 --- a/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Analysis/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "register_allocation.mlir" + "value_liveness.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD index bb9c0e0808d7..26a05ad4bb48 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "byte_buffer_ops.mlir", + "hint_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt index c11e1256d462..81db6954bef8 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "byte_buffer_ops.mlir" + "hint_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD index b780b112a316..cd7c2f57edba 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_ops.mlir", + "assignment_ops.mlir", + "comparison_ops.mlir", + "const_ops.mlir", + "control_flow_ops.mlir", + "func_attrs.mlir", + "structural_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt index 17e1e27ff562..38aadb785895 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_ops.mlir" + "assignment_ops.mlir" + "comparison_ops.mlir" + "const_ops.mlir" + "control_flow_ops.mlir" + "func_attrs.mlir" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index a788be853c31..9664c6789c06 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -103,6 +103,21 @@ class ConstOpConversion : public OpRewritePattern { } }; +template +class ConstZeroOpConversion : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstZeroOpTy constZeroOp, + PatternRewriter &rewriter) const final { + auto type = constZeroOp.getType(); + IntegerAttr value = rewriter.getIntegerAttr(type, 0); + + rewriter.replaceOpWithNewOp(constZeroOp, type, value); + return success(); + } +}; + template class GlobalLoadOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -190,6 +205,7 @@ void populateVMToCPatterns(MLIRContext *context, // Constants patterns.insert>(context); + patterns.insert>(context); // Conditional assignment ops patterns.insert>(context, @@ -253,6 +269,7 @@ void populateVMToCPatterns(MLIRContext *context, // ExtI64: Constants patterns.insert>(context); + patterns.insert>(context); // ExtI64: Conditional assignment ops patterns.insert>(context, diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir index 451bf7d81904..b17486cc2839 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir @@ -1,7 +1,18 @@ // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s -// CHECK: vm.module @module { -vm.module @module { + +vm.module @my_module { + // CHECK-LABEL: vm.func @const_i32_zero + vm.func @const_i32_zero() -> i32 { + // CHECK: %zero = "emitc.const"() {value = 0 : i32} : () -> i32 + %zero = vm.const.i32.zero : i32 + vm.return %zero : i32 + } +} + +// ----- + +vm.module @my_module { // CHECK-LABEL: vm.func @const_i32 vm.func @const_i32() { // CHECK-NEXT: %0 = "emitc.const"() {value = 0 : i32} : () -> i32 @@ -10,7 +21,6 @@ vm.module @module { %1 = vm.const.i32 2 : i32 // CHECK-NEXT: %2 = "emitc.const"() {value = -2 : i32} : () -> i32 %2 = vm.const.i32 -2 : i32 - // CHECK-NEXT: vm.return vm.return } } diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir index 89b205ad8a1a..0161dc4bc29e 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir @@ -1,7 +1,18 @@ // RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s -// CHECK: vm.module @module { -vm.module @module { + +vm.module @my_module { + // CHECK-LABEL: vm.func @const_i64_zero + vm.func @const_i64_zero() -> i64 { + // CHECK: %zero = "emitc.const"() {value = 0 : i64} : () -> i64 + %zero = vm.const.i64.zero : i64 + vm.return %zero : i64 + } +} + +// ----- + +vm.module @my_module { // CHECK-LABEL: vm.func @const_i64 vm.func @const_i64() { // CHECK-NEXT: %0 = "emitc.const"() {value = 0 : i64} : () -> i64 @@ -10,7 +21,6 @@ vm.module @module { %1 = vm.const.i64 2 : i64 // CHECK-NEXT: %2 = "emitc.const"() {value = -2 : i64} : () -> i64 %2 = vm.const.i64 -2 : i64 - // CHECK-NEXT: vm.return vm.return } } diff --git a/iree/compiler/Dialect/VM/IR/BUILD b/iree/compiler/Dialect/VM/IR/BUILD index 5cc616239764..b07c0c5fb013 100644 --- a/iree/compiler/Dialect/VM/IR/BUILD +++ b/iree/compiler/Dialect/VM/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["VMOps.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VMBase.td", + "VMOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/VM/IR/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/CMakeLists.txt index 0b1a1784ae6f..e9395ad58104 100644 --- a/iree/compiler/Dialect/VM/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/VM/IR/test/BUILD b/iree/compiler/Dialect/VM/IR/test/BUILD index b780b112a316..76cf62ebf88f 100644 --- a/iree/compiler/Dialect/VM/IR/test/BUILD +++ b/iree/compiler/Dialect/VM/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,31 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "arithmetic_folding.mlir", + "arithmetic_ops.mlir", + "assignment_folding.mlir", + "assignment_ops.mlir", + "comparison_folding.mlir", + "comparison_ops.mlir", + "const_folding.mlir", + "const_ops.mlir", + "control_flow_folding.mlir", + "control_flow_ops.mlir", + "conversion_folding.mlir", + "conversion_ops.mlir", + "debug_folding.mlir", + "debug_ops.mlir", + "global_folding.mlir", + "global_ops.mlir", + "list_op_verification.mlir", + "list_ops.mlir", + "shift_ops.mlir", + "structural_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt index 66b2364af296..2604211f754e 100644 --- a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt @@ -10,12 +10,30 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "arithmetic_folding.mlir" + "arithmetic_ops.mlir" + "assignment_folding.mlir" + "assignment_ops.mlir" + "comparison_folding.mlir" + "comparison_ops.mlir" + "const_folding.mlir" + "const_ops.mlir" + "control_flow_folding.mlir" + "control_flow_ops.mlir" + "conversion_folding.mlir" + "conversion_ops.mlir" + "debug_folding.mlir" + "debug_ops.mlir" + "global_folding.mlir" + "global_ops.mlir" + "list_op_verification.mlir" + "list_ops.mlir" + "shift_ops.mlir" + "structural_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD b/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD index 71d362403828..3e005972478b 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD +++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_encoding.mlir", + "module_encoding_smoke.mlir", + "reflection_attrs.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-translate", diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt index 56dcc5835ca9..287b67d8bbfd 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_encoding.mlir" + "module_encoding_smoke.mlir" + "reflection_attrs.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-translate diff --git a/iree/compiler/Dialect/VM/Transforms/test/BUILD b/iree/compiler/Dialect/VM/Transforms/test/BUILD index b780b112a316..6ebc889abf13 100644 --- a/iree/compiler/Dialect/VM/Transforms/test/BUILD +++ b/iree/compiler/Dialect/VM/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,16 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "global_initialization.mlir", + "hoist_inlined_rodata.mlir", + "mark_public_symbols_exported.mlir", + "ordinal_allocation.mlir", + "sink_defining_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt index 8c85b404f1cf..dcae639e29bc 100644 --- a/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt @@ -10,12 +10,15 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "global_initialization.mlir" + "hoist_inlined_rodata.mlir" + "mark_public_symbols_exported.mlir" + "ordinal_allocation.mlir" + "sink_defining_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD index bb9c0e0808d7..40dc09d52b4e 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["interface_ops.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt index 096adae6ff75..42cc9b2bc342 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "interface_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp index 6c2f44266bb4..52b35432d0c3 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp @@ -598,7 +598,9 @@ struct DynamicSliceOpConversion }; struct CompareOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + CompareOpConversion(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + /*benefit=*/9999) {} LogicalResult matchAndRewrite( mhlo::CompareOp srcOp, ArrayRef rawOperands, @@ -874,11 +876,6 @@ struct ConvertOpConversion : public OpConversionPattern { void populateHLOToVMLAPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &typeConverter) { - // We rely on some additional HLO->std patterns and assume they - // have been run already. In case they haven't we provide them here (useful - // for standalone conversion testing). - mhlo::PopulateMhloToStdPatterns(&patterns, context); - // mhlo.convolution. populateHLOConvToVMLAPatterns(context, patterns, typeConverter); @@ -995,8 +992,11 @@ void populateHLOToVMLAPatterns(MLIRContext *context, // runtime. patterns.insert(context); - // TODO(benvanik): add missing ops: - // - ConvOp + // We rely on some additional HLO->std patterns and assume they + // have been run already. In case they haven't we provide them here (useful + // for standalone conversion testing). We run them last so that other patterns + // have a chance to handle the HLO conversions first. + mhlo::PopulateMhloToStdPatterns(&patterns, context); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD index bb9c0e0808d7..3cbab8b24a2d 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,25 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "broadcast_in_dim.mlir", + "concatenate.mlir", + "conv.mlir", + "convert.mlir", + "dynamic_slice.mlir", + "fft.mlir", + "math_ops.mlir", + "reduce.mlir", + "reduce_window.mlir", + "reshape.mlir", + "scatter.mlir", + "slice.mlir", + "sort.mlir", + "transpose.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt index 8e0086f90cf6..fbd6949426ce 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt @@ -10,12 +10,24 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "broadcast_in_dim.mlir" + "concatenate.mlir" + "conv.mlir" + "convert.mlir" + "dynamic_slice.mlir" + "fft.mlir" + "math_ops.mlir" + "reduce.mlir" + "reduce_window.mlir" + "reshape.mlir" + "scatter.mlir" + "slice.mlir" + "sort.mlir" + "transpose.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD index bb9c0e0808d7..80a7bd483590 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "comparison_ops.mlir", + "constant_ops.mlir", + "math_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt index 4b5dad785f34..ed4a341301fe 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "comparison_ops.mlir" + "constant_ops.mlir" + "math_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD index bb9c0e0808d7..368d1027565b 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "constant_ops.mlir", + "conversion.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt index 06fa3fd0ae00..e60fbfeca624 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "constant_ops.mlir" + "conversion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/IR/BUILD b/iree/compiler/Dialect/VMLA/IR/BUILD index 40e8973cc3e0..365d43e150e4 100644 --- a/iree/compiler/Dialect/VMLA/IR/BUILD +++ b/iree/compiler/Dialect/VMLA/IR/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -25,7 +26,13 @@ exports_files(["VMLAOps.td"]) filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VMLABase.td", + "VMLAOps.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt index c35d99cf7576..c33ab9f72501 100644 --- a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/VMLA/IR/test/BUILD b/iree/compiler/Dialect/VMLA/IR/test/BUILD index bb9c0e0808d7..228bc609241e 100644 --- a/iree/compiler/Dialect/VMLA/IR/test/BUILD +++ b/iree/compiler/Dialect/VMLA/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "buffer_ops.mlir", + "conv_reduction_ops.mlir", + "general_ops.mlir", + "shape_structure_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt index c4a637329d89..a392ee0f426c 100644 --- a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "buffer_ops.mlir" + "conv_reduction_ops.mlir" + "general_ops.mlir" + "shape_structure_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD b/iree/compiler/Dialect/VMLA/Transforms/test/BUILD index bb9c0e0808d7..8e124ba2e665 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD +++ b/iree/compiler/Dialect/VMLA/Transforms/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,14 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "pre_conversion_lowering.mlir", + "transformation.mlir", + "unroll_reductions.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt index 94ba2219ef08..4a84f8595d0c 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt @@ -10,12 +10,13 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "pre_conversion_lowering.mlir" + "transformation.mlir" + "unroll_reductions.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Vulkan/IR/BUILD b/iree/compiler/Dialect/Vulkan/IR/BUILD index 260521d56de0..9b9220b008be 100644 --- a/iree/compiler/Dialect/Vulkan/IR/BUILD +++ b/iree/compiler/Dialect/Vulkan/IR/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,13 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + [ + "VulkanAttributes.td", + "VulkanBase.td", + ], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt index 06bdf83818f6..0a03e5f98852 100644 --- a/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME IR diff --git a/iree/compiler/Dialect/Vulkan/IR/test/BUILD b/iree/compiler/Dialect/Vulkan/IR/test/BUILD index bb9c0e0808d7..37d4ac5967c0 100644 --- a/iree/compiler/Dialect/Vulkan/IR/test/BUILD +++ b/iree/compiler/Dialect/Vulkan/IR/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,7 +23,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["target_env.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt index ddec28513220..3d87e77b7e82 100644 --- a/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "target_env.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD index 29aa3a5b0bf8..0b3152403c8f 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD +++ b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD @@ -14,6 +14,7 @@ load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -31,7 +32,10 @@ endif() iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["target_env_conversion.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt index 2dd596b2b2cb..52386a414e2d 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt @@ -14,12 +14,11 @@ endif() iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "target_env_conversion.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Translation/test/BUILD b/iree/compiler/Translation/test/BUILD index cac3e25116fd..ed1d5ed68d44 100644 --- a/iree/compiler/Translation/test/BUILD +++ b/iree/compiler/Translation/test/BUILD @@ -15,6 +15,7 @@ # Tests for common transforms. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "do_not_optimize.mlir", + "smoketest.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-opt", diff --git a/iree/compiler/Translation/test/CMakeLists.txt b/iree/compiler/Translation/test/CMakeLists.txt index 1e21a2f0f37d..ef4fb17852e9 100644 --- a/iree/compiler/Translation/test/CMakeLists.txt +++ b/iree/compiler/Translation/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "do_not_optimize.mlir" + "smoketest.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD index 162109cf2862..a9205903ab4b 100644 --- a/iree/modules/check/test/BUILD +++ b/iree/modules/check/test/BUILD @@ -14,6 +14,7 @@ load("//build_tools/bazel:iree_check_test.bzl", "iree_check_test_suite") load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,13 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "failure.mlir", + "success.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-check-module", diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt index 48a8afe6cde8..943942155d25 100644 --- a/iree/modules/check/test/CMakeLists.txt +++ b/iree/modules/check/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "failure.mlir" + "success.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-check-module diff --git a/iree/samples/custom_modules/dialect/BUILD b/iree/samples/custom_modules/dialect/BUILD index fa0bc78cb04c..19eac520704c 100644 --- a/iree/samples/custom_modules/dialect/BUILD +++ b/iree/samples/custom_modules/dialect/BUILD @@ -14,6 +14,7 @@ load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data") load("//build_tools/bazel:tblgen.bzl", "gentbl") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -23,7 +24,10 @@ package( filegroup( name = "td_files", - srcs = glob(["*.td"]), + srcs = enforce_glob( + ["custom_ops.td"], + include = ["*.td"], + ), ) cc_library( diff --git a/iree/samples/custom_modules/dialect/CMakeLists.txt b/iree/samples/custom_modules/dialect/CMakeLists.txt index 01ea23d55de1..758d9fa3a1d5 100644 --- a/iree/samples/custom_modules/dialect/CMakeLists.txt +++ b/iree/samples/custom_modules/dialect/CMakeLists.txt @@ -10,7 +10,6 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_TD LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.td) iree_cc_library( NAME dialect diff --git a/iree/samples/custom_modules/dialect/test/BUILD b/iree/samples/custom_modules/dialect/test/BUILD index d0637066cfe4..98133e326e9a 100644 --- a/iree/samples/custom_modules/dialect/test/BUILD +++ b/iree/samples/custom_modules/dialect/test/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -22,18 +23,15 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "conversion.mlir", + "custom_ops.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/samples/custom_modules/dialect:custom-opt", "//iree/tools:IreeFileCheck", ], ) -# load("//iree:lit_test.bzl", "iree_lit_test_suite") -# iree_lit_test_suite( -# name = "lit", -# srcs = glob(["*.mlir"]), -# data = [ -# "//iree/tools:IreeFileCheck", -# "//iree/samples/custom_modules/dialect:custom-opt", -# ], -# ) diff --git a/iree/samples/custom_modules/dialect/test/CMakeLists.txt b/iree/samples/custom_modules/dialect/test/CMakeLists.txt index f3b9b6c25b06..e8aaaa50cbfe 100644 --- a/iree/samples/custom_modules/dialect/test/CMakeLists.txt +++ b/iree/samples/custom_modules/dialect/test/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "conversion.mlir" + "custom_ops.mlir" DATA iree::samples::custom_modules::dialect::custom-opt iree::tools::IreeFileCheck diff --git a/iree/test/e2e/hackability/BUILD b/iree/test/e2e/hackability/BUILD index 361d3dfd8f8e..ffaf264c6c51 100644 --- a/iree/test/e2e/hackability/BUILD +++ b/iree/test/e2e/hackability/BUILD @@ -17,6 +17,7 @@ # those tests in iree/tools/test/ load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -26,7 +27,10 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + ["flow_partitioned.mlir"], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-run-mlir", diff --git a/iree/test/e2e/hackability/CMakeLists.txt b/iree/test/e2e/hackability/CMakeLists.txt index e03613370fa8..928042ac1fbb 100644 --- a/iree/test/e2e/hackability/CMakeLists.txt +++ b/iree/test/e2e/hackability/CMakeLists.txt @@ -10,12 +10,11 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "flow_partitioned.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-run-mlir diff --git a/iree/test/e2e/hackability/flow_partitioned.mlir b/iree/test/e2e/hackability/flow_partitioned.mlir index d9b99f4a0871..7c5513e2f7ce 100644 --- a/iree/test/e2e/hackability/flow_partitioned.mlir +++ b/iree/test/e2e/hackability/flow_partitioned.mlir @@ -15,7 +15,7 @@ flow.executable @ex0 { func @staticShapedFn() -> tensor<4xf32> { %input = iree.unfoldable_constant dense<[-1.0, 2.0, -3.0, 4.0]> : tensor<4xf32> %workload = constant 4 : index - %0 = flow.dispatch @ex0::@dispatch0[%workload] (%input) : (tensor<4xf32>) -> tensor<4xf32> + %0 = flow.dispatch @ex0::@dispatch0[%workload](%input) : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } // CHECK: 4xf32=-2 4 -6 8 diff --git a/iree/test/e2e/linalg_tensor_ops/BUILD b/iree/test/e2e/linalg_tensor_ops/BUILD index c580c9785072..df5113f3b8d2 100644 --- a/iree/test/e2e/linalg_tensor_ops/BUILD +++ b/iree/test/e2e/linalg_tensor_ops/BUILD @@ -18,6 +18,7 @@ # written using the IREE Check framework and should always pass on the reference VMLA backend. # See https://google.github.io/iree/TestingGuide#iree-core-end-to-end-tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -26,9 +27,17 @@ package( licenses = ["notice"], # Apache 2.0 ) +ALL_SRCS = enforce_glob( + [ + "add.mlir", + "matmul.mlir", + ], + include = ["*.mlir"], +) + iree_check_single_backend_test_suite( name = "check_llvm-ir_llvm", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-llvm-experimental-linalg-on-tensors", @@ -39,7 +48,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-spirv-experimental-linalg-on-tensors", diff --git a/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt b/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt index 642e8e2aa45b..f60e17c17766 100644 --- a/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt +++ b/iree/test/e2e/linalg_tensor_ops/CMakeLists.txt @@ -10,12 +10,12 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_llvm-ir_llvm SRCS - "${_GLOB_X_MLIR}" + "add.mlir" + "matmul.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER @@ -25,12 +25,12 @@ iree_check_single_backend_test_suite( "-iree-codegen-llvm-experimental-linalg-on-tensors" ) -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "add.mlir" + "matmul.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER diff --git a/iree/test/e2e/llvm_specific/BUILD b/iree/test/e2e/llvm_specific/BUILD index dffd9477264c..9c7d633670ab 100644 --- a/iree/test/e2e/llvm_specific/BUILD +++ b/iree/test/e2e/llvm_specific/BUILD @@ -34,15 +34,3 @@ iree_check_single_backend_test_suite( driver = "dylib", target_backend = "dylib-llvm-aot", ) - -iree_check_single_backend_test_suite( - name = "check_llvm-aot-exponential_fast", - srcs = [ - "exponential.mlir", - ], - compiler_flags = [ - "-iree-codegen-linalg-to-llvm-fast-exp=true", - ], - driver = "dylib", - target_backend = "dylib-llvm-aot", -) diff --git a/iree/test/e2e/llvm_specific/CMakeLists.txt b/iree/test/e2e/llvm_specific/CMakeLists.txt index 66272613afab..b7f1add319bc 100644 --- a/iree/test/e2e/llvm_specific/CMakeLists.txt +++ b/iree/test/e2e/llvm_specific/CMakeLists.txt @@ -23,17 +23,4 @@ iree_check_single_backend_test_suite( "-iree-codegen-linalg-to-llvm-conv-img2col-conversion=true" ) -iree_check_single_backend_test_suite( - NAME - check_llvm-aot-exponential_fast - SRCS - "exponential.mlir" - TARGET_BACKEND - "dylib-llvm-aot" - DRIVER - "dylib" - COMPILER_FLAGS - "-iree-codegen-linalg-to-llvm-fast-exp=true" -) - ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/test/e2e/llvm_specific/exponential.mlir b/iree/test/e2e/llvm_specific/exponential.mlir deleted file mode 100644 index 6e9132698a41..000000000000 --- a/iree/test/e2e/llvm_specific/exponential.mlir +++ /dev/null @@ -1,27 +0,0 @@ -func @tensor() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32> - %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func @scalar() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func @double() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func @negative() attributes { iree.module.export } { - %input = iree.unfoldable_constant dense<-1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<0.367879> : tensor) : tensor - return -} diff --git a/iree/test/e2e/models/BUILD b/iree/test/e2e/models/BUILD index 1acc650f87d6..0d3b5fc0c9cd 100644 --- a/iree/test/e2e/models/BUILD +++ b/iree/test/e2e/models/BUILD @@ -14,6 +14,7 @@ # Tests for end-to-end IREE support of entire models or their close derivatives. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//iree:lit_test.bzl", "iree_lit_test_suite") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") @@ -28,8 +29,18 @@ CHECK_FRAMEWORK_TESTS = ["bert_encoder_unrolled_fake_weights.mlir"] iree_lit_test_suite( name = "lit", size = "medium", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "collatz.mlir", + "edge_detection.mlir", + "fragment_000.mlir", + "fullyconnected.mlir", + "mnist_fake_weights.mlir", + "resnet_fake_weights.mlir", + "unidirectional_lstm.mlir", + ], + include = + ["*.mlir"], exclude = CHECK_FRAMEWORK_TESTS, ), data = [ diff --git a/iree/test/e2e/models/CMakeLists.txt b/iree/test/e2e/models/CMakeLists.txt index 03ca925612c5..4fd5d28b364f 100644 --- a/iree/test/e2e/models/CMakeLists.txt +++ b/iree/test/e2e/models/CMakeLists.txt @@ -10,14 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_BERT_ENCODER_UNROLLED_FAKE_WEIGHTS_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS bert_encoder_unrolled_fake_weights.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_BERT_ENCODER_UNROLLED_FAKE_WEIGHTS_MLIR}) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "collatz.mlir" + "edge_detection.mlir" + "fragment_000.mlir" + "fullyconnected.mlir" + "mnist_fake_weights.mlir" + "resnet_fake_weights.mlir" + "unidirectional_lstm.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-run-mlir diff --git a/iree/test/e2e/models/unidirectional_lstm.mlir b/iree/test/e2e/models/unidirectional_lstm.mlir index 4ad0c7935b9e..e72556852eea 100644 --- a/iree/test/e2e/models/unidirectional_lstm.mlir +++ b/iree/test/e2e/models/unidirectional_lstm.mlir @@ -1,8 +1,8 @@ // An example LSTM exported from a python reference model with dummy weights. -// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]" -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0 1 0 3 4]" -function-input="1x5x2x2xf32=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") +// RUN: iree-run-mlir %s -iree-hal-target-backends=vmla -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]" +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | IreeFileCheck %s --implicit-check-not="[" --implicit-check-not="]") // Exported via the XLA HLO Importer // The resulting MLIR was modified by hand by changing all large constants to be diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD index 8e23925ae6fd..74f00d8755eb 100644 --- a/iree/test/e2e/regression/BUILD +++ b/iree/test/e2e/regression/BUILD @@ -16,6 +16,7 @@ # These should focus on support by IREE itself, not for issues with specific runner tools. Place # those tests in iree/tools/test/ +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//iree:lit_test.bzl", "iree_lit_test_suite") package( @@ -26,8 +27,25 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "dynamic_abs.mlir", + "dynamic_add.mlir", + "dynamic_compare_and_select.mlir", + "dynamic_dot.mlir", + "dynamic_dot_general.mlir", + "dynamic_torch_index_select_high_rank.mlir", + "dynamic_torch_index_select_negative.mlir", + "dynamic_torch_index_select_scalar.mlir", + "dynamic_torch_index_select_vector.mlir", + "executable_benchmark.mlir", + "globals.mlir", + "scalar.mlir", + "trace_dispatch_tensors.mlir", + "unused_args.mlir", + ], + include = + ["*.mlir"], # Disabled temporarily. See GH Issue #4733 exclude = [ "dynamic_linalg_matmul_on_tensors.mlir", diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt index dbc8f464baf7..bd998946879f 100644 --- a/iree/test/e2e/regression/CMakeLists.txt +++ b/iree/test/e2e/regression/CMakeLists.txt @@ -10,20 +10,24 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_0_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_0.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_0_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_1_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_1.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_1_MLIR}) -file(GLOB _GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_2_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS dynamic_linalg_matmul_on_tensors_fuse_2.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_DYNAMIC_LINALG_MATMUL_ON_TENSORS_FUSE_2_MLIR}) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "dynamic_abs.mlir" + "dynamic_add.mlir" + "dynamic_compare_and_select.mlir" + "dynamic_dot.mlir" + "dynamic_dot_general.mlir" + "dynamic_torch_index_select_high_rank.mlir" + "dynamic_torch_index_select_negative.mlir" + "dynamic_torch_index_select_scalar.mlir" + "dynamic_torch_index_select_vector.mlir" + "executable_benchmark.mlir" + "globals.mlir" + "scalar.mlir" + "trace_dispatch_tensors.mlir" + "unused_args.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-benchmark-module diff --git a/iree/test/e2e/structural/BUILD b/iree/test/e2e/structural/BUILD index 9ffecf240ada..09585a71f698 100644 --- a/iree/test/e2e/structural/BUILD +++ b/iree/test/e2e/structural/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -23,7 +24,15 @@ package( # TODO(#2395): Enable all the tests for both LLVM and SPIR-V. iree_check_single_backend_test_suite( name = "check_vmla_vmla", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "fused_dispatch_region.mlir", + "gather_add.mlir", + "gather_concat.mlir", + "matmul_add.mlir", + ], + include = ["*.mlir"], + ), driver = "vmla", target_backend = "vmla", ) diff --git a/iree/test/e2e/structural/CMakeLists.txt b/iree/test/e2e/structural/CMakeLists.txt index d8be71771d13..37a2f9042a52 100644 --- a/iree/test/e2e/structural/CMakeLists.txt +++ b/iree/test/e2e/structural/CMakeLists.txt @@ -10,12 +10,14 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vmla_vmla SRCS - "${_GLOB_X_MLIR}" + "fused_dispatch_region.mlir" + "gather_add.mlir" + "gather_concat.mlir" + "matmul_add.mlir" TARGET_BACKEND "vmla" DRIVER diff --git a/iree/test/e2e/tosa_ops/BUILD b/iree/test/e2e/tosa_ops/BUILD index 5eecc23ad5c0..cb3a05392252 100644 --- a/iree/test/e2e/tosa_ops/BUILD +++ b/iree/test/e2e/tosa_ops/BUILD @@ -18,6 +18,7 @@ # written using the IREE Check framework. # See https://google.github.io/iree/developing-iree/testing-guide#iree-core-end-to-end-tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -26,16 +27,50 @@ package( licenses = ["notice"], # Apache 2.0 ) +ALL_SRCS = enforce_glob( + [ + "abs.mlir", + "add.mlir", + "bitwise_and.mlir", + "bitwise_or.mlir", + "bitwise_xor.mlir", + "ceil.mlir", + "clamp.mlir", + "const.mlir", + "exp.mlir", + "floor.mlir", + "greater.mlir", + "greater_equal.mlir", + "if.mlir", + "log.mlir", + "logical_left_shift.mlir", + "logical_right_shift.mlir", + "maximum.mlir", + "minimum.mlir", + "mul.mlir", + "negate.mlir", + "reluN.mlir", + "reshape.mlir", + "rsqrt.mlir", + "select.mlir", + "sub.mlir", + "tanh.mlir", + "transpose.mlir", + "while.mlir", + ], + include = ["*.mlir"], +) + iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, driver = "vulkan", target_backend = "vulkan-spirv", ) iree_check_single_backend_test_suite( name = "check_dylib-llvm-aot_dylib", - srcs = glob(["*.mlir"]), + srcs = ALL_SRCS, driver = "dylib", target_backend = "dylib-llvm-aot", ) diff --git a/iree/test/e2e/tosa_ops/CMakeLists.txt b/iree/test/e2e/tosa_ops/CMakeLists.txt index 71543d37f47e..3948a95a51dd 100644 --- a/iree/test/e2e/tosa_ops/CMakeLists.txt +++ b/iree/test/e2e/tosa_ops/CMakeLists.txt @@ -10,24 +10,76 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "bitwise_and.mlir" + "bitwise_or.mlir" + "bitwise_xor.mlir" + "ceil.mlir" + "clamp.mlir" + "const.mlir" + "exp.mlir" + "floor.mlir" + "greater.mlir" + "greater_equal.mlir" + "if.mlir" + "log.mlir" + "logical_left_shift.mlir" + "logical_right_shift.mlir" + "maximum.mlir" + "minimum.mlir" + "mul.mlir" + "negate.mlir" + "reluN.mlir" + "reshape.mlir" + "rsqrt.mlir" + "select.mlir" + "sub.mlir" + "tanh.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER "vulkan" ) -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_dylib-llvm-aot_dylib SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "bitwise_and.mlir" + "bitwise_or.mlir" + "bitwise_xor.mlir" + "ceil.mlir" + "clamp.mlir" + "const.mlir" + "exp.mlir" + "floor.mlir" + "greater.mlir" + "greater_equal.mlir" + "if.mlir" + "log.mlir" + "logical_left_shift.mlir" + "logical_right_shift.mlir" + "maximum.mlir" + "minimum.mlir" + "mul.mlir" + "negate.mlir" + "reluN.mlir" + "reshape.mlir" + "rsqrt.mlir" + "select.mlir" + "sub.mlir" + "tanh.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD index b3408042ebf1..876cc74c4727 100644 --- a/iree/test/e2e/vulkan_specific/BUILD +++ b/iree/test/e2e/vulkan_specific/BUILD @@ -15,6 +15,7 @@ # Tests for end-to-end IREE support specific to the vulkan-spirv lowering. # TODO(ravishankarm): Reorganize these tests. +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") package( @@ -25,8 +26,17 @@ package( iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", - srcs = glob( - ["*.mlir"], + srcs = enforce_glob( + [ + "compare.mlir", + "conv.mlir", + "dot_general.mlir", + "log_plus_one.mlir", + "pw_add_multiwg.mlir", + "reduce.mlir", + "vectorized_conv.mlir", + ], + include = ["*.mlir"], exclude = [ "gemm.mlir", ], @@ -64,6 +74,8 @@ iree_check_single_backend_test_suite( "vectorized_conv.mlir", ], compiler_flags = [ + "-iree-flow-dispatch-linalg-on-tensors", + "-iree-codegen-spirv-experimental-linalg-on-tensors", "-iree-spirv-enable-vectorization", "-iree-vulkan-target-triple=valhall-g77-unknown-android10", ], diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt index 53d6be831941..c5ec28c9189e 100644 --- a/iree/test/e2e/vulkan_specific/CMakeLists.txt +++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt @@ -10,14 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) -file(GLOB _GLOB_GEMM_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS gemm.mlir) -list(REMOVE_ITEM _GLOB_X_MLIR ${_GLOB_GEMM_MLIR}) iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS - "${_GLOB_X_MLIR}" + "compare.mlir" + "conv.mlir" + "dot_general.mlir" + "log_plus_one.mlir" + "pw_add_multiwg.mlir" + "reduce.mlir" + "vectorized_conv.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER @@ -63,6 +66,8 @@ iree_check_single_backend_test_suite( DRIVER "vulkan" COMPILER_FLAGS + "-iree-flow-dispatch-linalg-on-tensors" + "-iree-codegen-spirv-experimental-linalg-on-tensors" "-iree-spirv-enable-vectorization" "-iree-vulkan-target-triple=valhall-g77-unknown-android10" ) diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD index 4967da1ef81b..377b40c5a9e0 100644 --- a/iree/test/e2e/xla_ops/BUILD +++ b/iree/test/e2e/xla_ops/BUILD @@ -19,6 +19,7 @@ # See https://google.github.io/iree/TestingGuide#iree-core-end-to-end-tests. load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -28,7 +29,57 @@ package( iree_check_single_backend_test_suite( name = "check_vmla_vmla", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "abs.mlir", + "add.mlir", + "batch_norm_inference.mlir", + "broadcast.mlir", + "broadcast_add.mlir", + "broadcast_in_dim.mlir", + "clamp.mlir", + "compare.mlir", + "concatenate.mlir", + "constant.mlir", + "convert.mlir", + "convolution.mlir", + "cosine.mlir", + "divide.mlir", + "dot.mlir", + "dot_general.mlir", + "exponential.mlir", + "exponential_minus_one.mlir", + "finite.mlir", + "floor.mlir", + "gather.mlir", + "iota.mlir", + "log.mlir", + "log_plus_one.mlir", + "maximum.mlir", + "minimum.mlir", + "multiply.mlir", + "negate.mlir", + "pad.mlir", + "reduce.mlir", + "reduce_window.mlir", + "remainder.mlir", + "reshape.mlir", + "reverse.mlir", + "round.mlir", + "rsqrt.mlir", + "select.mlir", + "sine.mlir", + "slice.mlir", + "sort.mlir", + "sqrt.mlir", + "subtract.mlir", + "tanh.mlir", + "torch_index_select.mlir", + "transpose.mlir", + "while.mlir", + ], + include = ["*.mlir"], + ), driver = "vmla", target_backend = "vmla", ) @@ -185,7 +236,7 @@ iree_check_single_backend_test_suite( # https://github.com/google/iree/issues/4079 # "torch_index_select.mlir", "transpose.mlir", - # "while.mlir", + "while.mlir", ], compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", @@ -246,7 +297,7 @@ iree_check_single_backend_test_suite( # https://github.com/google/iree/issues/4079 # "torch_index_select.mlir", "transpose.mlir", - # "while.mlir", + "while.mlir", ], compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt index 90ed6bd88079..b5bdab60e4cb 100644 --- a/iree/test/e2e/xla_ops/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/CMakeLists.txt @@ -10,12 +10,56 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_check_single_backend_test_suite( NAME check_vmla_vmla SRCS - "${_GLOB_X_MLIR}" + "abs.mlir" + "add.mlir" + "batch_norm_inference.mlir" + "broadcast.mlir" + "broadcast_add.mlir" + "broadcast_in_dim.mlir" + "clamp.mlir" + "compare.mlir" + "concatenate.mlir" + "constant.mlir" + "convert.mlir" + "convolution.mlir" + "cosine.mlir" + "divide.mlir" + "dot.mlir" + "dot_general.mlir" + "exponential.mlir" + "exponential_minus_one.mlir" + "finite.mlir" + "floor.mlir" + "gather.mlir" + "iota.mlir" + "log.mlir" + "log_plus_one.mlir" + "maximum.mlir" + "minimum.mlir" + "multiply.mlir" + "negate.mlir" + "pad.mlir" + "reduce.mlir" + "reduce_window.mlir" + "remainder.mlir" + "reshape.mlir" + "reverse.mlir" + "round.mlir" + "rsqrt.mlir" + "select.mlir" + "sine.mlir" + "slice.mlir" + "sort.mlir" + "sqrt.mlir" + "subtract.mlir" + "tanh.mlir" + "torch_index_select.mlir" + "transpose.mlir" + "while.mlir" TARGET_BACKEND "vmla" DRIVER @@ -166,6 +210,7 @@ iree_check_single_backend_test_suite( "subtract.mlir" "tanh.mlir" "transpose.mlir" + "while.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER @@ -213,6 +258,7 @@ iree_check_single_backend_test_suite( "subtract.mlir" "tanh.mlir" "transpose.mlir" + "while.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD index 0978d1c7b113..bcabe82a6d9b 100644 --- a/iree/tools/test/BUILD +++ b/iree/tools/test/BUILD @@ -15,6 +15,7 @@ # Smoke tests for the execution of tool binaries. load("//iree:lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") package( default_visibility = ["//visibility:public"], @@ -24,7 +25,18 @@ package( iree_lit_test_suite( name = "lit", - srcs = glob(["*.mlir"]), + srcs = enforce_glob( + [ + "iree-benchmark-module.mlir", + "iree-run-mlir.mlir", + "iree-run-module.mlir", + "multiple_args.mlir", + "multiple_exported_functions.mlir", + "repeated_return.mlir", + "scalars.mlir", + ], + include = ["*.mlir"], + ), data = [ "//iree/tools:IreeFileCheck", "//iree/tools:iree-benchmark-module", diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt index 1342a45f7b87..bf6581f0422c 100644 --- a/iree/tools/test/CMakeLists.txt +++ b/iree/tools/test/CMakeLists.txt @@ -10,12 +10,17 @@ iree_add_all_subdirs() -file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir) iree_lit_test_suite( NAME lit SRCS - "${_GLOB_X_MLIR}" + "iree-benchmark-module.mlir" + "iree-run-mlir.mlir" + "iree-run-module.mlir" + "multiple_args.mlir" + "multiple_exported_functions.mlir" + "repeated_return.mlir" + "scalars.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-benchmark-module diff --git a/iree/vm/bytecode_dispatch_test.cc b/iree/vm/bytecode_dispatch_test.cc index bbb6fdf14d8f..b371bb16e2b0 100644 --- a/iree/vm/bytecode_dispatch_test.cc +++ b/iree/vm/bytecode_dispatch_test.cc @@ -131,6 +131,7 @@ TEST_P(VMBytecodeDispatchTest, Check) { } } else { if (expect_failure) { + iree_status_ignore(status); GTEST_SUCCEED(); } else { GTEST_FAIL() << "Function expected success but failed with error: "