diff --git a/iree/compiler/Conversion/Common/LaunchConfig.cpp b/iree/compiler/Conversion/Common/LaunchConfig.cpp index 88b99f533d2f..246257494f15 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.cpp +++ b/iree/compiler/Conversion/Common/LaunchConfig.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -38,6 +39,7 @@ namespace iree_compiler { /// Name of the StrAttr that can be used to get the key to access the tile size /// information. static const char kLaunchInfoKey[] = "launch_info_key"; +static const char kRootOpKey[] = "is_root_op"; static Optional getKey(Operation *op) { StringAttr attr = op->getAttrOfType(kLaunchInfoKey); @@ -82,8 +84,7 @@ ArrayRef LaunchConfig::getTileSizes(Operation *op, Operation *LaunchConfig::getRootOperation(ArrayRef ops) { for (auto op : ops) { - auto key = getKey(op); - if (key && key.getValue() == rootOperationKey) return op; + if (op->getAttrOfType(kRootOpKey)) return op; } return nullptr; } @@ -114,8 +115,7 @@ void LaunchConfig::setNumSubgroups(ArrayRef vNumSubgroups) { } void LaunchConfig::setRootOperation(Operation *op) { - Optional key = getKey(op); - if (key) rootOperationKey = *key; + op->setAttr(kRootOpKey, UnitAttr::get(op->getContext())); } void LaunchConfig::setSameConfig(Operation *source, Operation *target) { diff --git a/iree/compiler/Conversion/Common/LaunchConfig.h b/iree/compiler/Conversion/Common/LaunchConfig.h index 7f6f3366c788..1367f6054f05 100644 --- a/iree/compiler/Conversion/Common/LaunchConfig.h +++ b/iree/compiler/Conversion/Common/LaunchConfig.h @@ -127,10 +127,6 @@ class LaunchConfig { /// these attributes. llvm::StringMap tileSizes; - /// Key used for tagging the root operation. The launch config does not track - /// the root operation itself, but rather the key used for the root operation. - StringRef rootOperationKey = ""; - /// Workgroup size to use. std::array workgroupSize = {1, 1, 1}; diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir index 4868d059698f..a430f6334824 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir @@ -84,4 +84,5 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { // CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]] // CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]> // CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]> -// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) +// CHECK: linalg.matmul +//CHECK-SAME: ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD index b224b665d763..2aab244cd670 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD @@ -34,11 +34,13 @@ cc_library( cc_library( name = "LinalgToSPIRV", srcs = [ + "ConcretizeTileAmongWorkgroupsPass.cpp", "ConvertToGPUPass.cpp", "ConvertToSPIRVPass.cpp", "CooperativeMatrixAnalysis.cpp", "FoldGPUProcessorIDUses.cpp", "KernelDispatchUtils.cpp", + "LinalgTileAndDistributePass.cpp", "LinalgTileAndFusePass.cpp", "MatMulVectorizationTest.cpp", "Passes.cpp", @@ -61,6 +63,7 @@ cc_library( "//iree/compiler/Conversion/HLOToHLO", "//iree/compiler/Conversion/HLOToLinalg", "//iree/compiler/Conversion/LinalgToVector", + "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/IR", "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/IREE/IR", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt index e51da222e2eb..7046ccca4ae4 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt @@ -32,11 +32,13 @@ iree_cc_library( "Passes.h" "Utils.h" SRCS + "ConcretizeTileAmongWorkgroupsPass.cpp" "ConvertToGPUPass.cpp" "ConvertToSPIRVPass.cpp" "CooperativeMatrixAnalysis.cpp" "FoldGPUProcessorIDUses.cpp" "KernelDispatchUtils.cpp" + "LinalgTileAndDistributePass.cpp" "LinalgTileAndFusePass.cpp" "MatMulVectorizationTest.cpp" "Passes.cpp" @@ -74,6 +76,7 @@ iree_cc_library( iree::compiler::Conversion::HLOToHLO iree::compiler::Conversion::HLOToLinalg iree::compiler::Conversion::LinalgToVector + iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::IREE::IR diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp new file mode 100644 index 000000000000..6973241ddae7 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -0,0 +1,565 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- ConcretizeTileAmongWorkgroupsPass.cpp ------------------------------===// +// +// This pass concretizes hal.interface.workgroup ops by replacing them with +// constant values from the chosen tiling and distribution scheme. +// +// During dispatch region formation in IREE Flow transformations, ops are tiled +// and distributed in an abstract way by using symbolic hal.interface.workgroup +// ops. That is because the same source region is compiled towards different +// target backends and each target backend could use different tiling and +// distribution schemes. However, after HAL interface materialization, the +// hal.executable.target is just meant for one target backend. We need to +// concretize the tiling and distribution in order to inject static information +// for further compilation. +// +// This pass performs the conretization in two modes: +// +// 1) Partically static: where have a concrete tiling and distirbution sheme +// *but not* a full static original problem size (e.g., due to dynamic +// shapes). Under such circumstances, we can only replace ops like +// hal.interface.workgroup.size ops and still need to compute the number +// of workgroups using symbolic values. +// 2) Fully static: where we have a concrete tiling and distribution scheme +// *and* the full static original problem size. Under such circumstances, +// we can fully deduce the number of workgroups to dispatch and replace +// hal.interface.workgroup.count ops with constant values too. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/Common/LaunchConfig.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-spirv-concretize-tile-among-workgroups" + +namespace mlir { +namespace iree_compiler { + +namespace { + +constexpr unsigned kWorkgroupDimCount = 3; + +int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } + +static size_t getNumOuterParallelDims(linalg::LinalgOp op) { + ArrayRef iterators = op.iterator_types().getValue(); + auto parallels = iterators.take_while( + [](Attribute attr) { return linalg::isParallelIteratorType(attr); }); + return parallels.size(); +} + +/// Returns the root Linalg op that dictates tiling and distribution policy. +linalg::LinalgOp getRootLinalgOp(FuncOp funcOp) { + SmallVector linalgOps; + SmallVector tiledLoops; + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) return {}; + + SPIRVCodegenOptions options; + options.enableVectorization = true; + options.usingLinalgOnTensors = true; + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = initGPULaunchConfig( + funcOp.getContext(), dependenceGraph, options, linalgOps); + if (!launchConfigOpt) return {}; + + LaunchConfig &launchConfig = *launchConfigOpt; + Operation *rootOp = + launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( + linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); }))); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig.finalize(funcOp); + + return rootOp; +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// original input/output types for all tiles. +/// +/// Note: After the abstract tiling and distribution in Flow dispatch region +/// creation, the anchoring root op is already in a loop nest and works on a +/// tile. The full type for all tiles in the IR is not explicit anymore after +/// HAL interface is materialized. So go through the IR use chain to figure it +/// out. Otherwise we need to make even more assumptions in the following. +// TODO(antiagainst): This is quite fragile. We need a better way to pass the +// information down from the upper layer, which readily has it. Probably via +// linalg.tile op. +LogicalResult getInputOutputTypesForAllTiles( + linalg::LinalgOp rootOp, SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes) { + for (Value inputBuffer : rootOp.getInputBuffers()) { + auto subviewOp = inputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + inputTypes.push_back(subviewOp.getViewSource().getType()); + } + + for (Value outputBuffer : rootOp.getOutputBuffers()) { + auto subviewOp = outputBuffer.getDefiningOp(); + if (!subviewOp) return failure(); + outputTypes.push_back(subviewOp.getViewSource().getType()); + } + + return success(); +} + +/// Assuming the given `rootOp` is the tiled root Linalg op, returns the +/// tile sizes for distributing to workgroups and the workgroups size for the +/// generated kernel. +/// +/// TODO(antiagainst): This pass can be shared between CPU and GPU. But the +/// following query scopes it to GPU for now. +llvm::Optional, ArrayRef>> +getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, + ArrayRef outputTypes) { + // Build necesary structures to query the tile sizes for distributing to + // workgroups. + linalg::Aliases aliases; + SmallVector linalgOps; + auto ops = rootOp->getBlock()->getOps(); + linalgOps.assign(ops.begin(), ops.end()); + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + SPIRVCodegenOptions options; + + // NOTE: Launch configuration expects the original input/output type to decide + // the configuration. But we have already tiled the Linalg ops here. Use an + // attribute to send it over for now. + const char inputTypeAttrName[] = "iree.codegen.original_input_types"; + const char outputTypeAttrName[] = "iree.codegen.original_output_types"; + if (!inputTypes.empty()) { + rootOp->setAttr(inputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(inputTypes)); + } + if (!outputTypes.empty()) { + rootOp->setAttr(outputTypeAttrName, + Builder(rootOp).getTypeArrayAttr(outputTypes)); + } + + Optional launchConfig = initGPULaunchConfig( + rootOp->getContext(), dependenceGraph, options, linalgOps); + if (!launchConfig) { + rootOp->emitError("unable to find launch configuration"); + return llvm::None; + } + + ArrayRef tileSize = launchConfig->getTileSizes(rootOp, 0); + ArrayRef workgroupSize = launchConfig->getWorkgroupSize(); + + // Clean up internal markers that are set during launch configuration + // preparation. + launchConfig->finalize(rootOp->getParentOfType()); + + return std::make_pair(tileSize, workgroupSize); +} + +/// Replaces hal.interface.workgroup.size op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupSizeOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupSizeOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(std::move(workloadSize)), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex < kWorkgroupDimCount && tileSize[dimIndex] != 0) { + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexAttr(tileSize[dimIndex])); + return success(); + } + + return failure(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Replaces hal.interface.workgroup.count op with the constant value chosen +/// from tiling scheme. +class ConcretizeWorkgroupCountOp final + : public OpRewritePattern { + public: + ConcretizeWorkgroupCountOp(MLIRContext *context, + SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op, + PatternRewriter &rewriter) const override { + unsigned dimIndex = op.dimension().getZExtValue(); + + if (dimIndex >= kWorkgroupDimCount) return failure(); + + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize || dimTile == 0) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexAttr(count)); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +// Canonicalizes away a trip-one scf.for loop by inlining its body and removing +// the loop. +// +// This pattern is needed because in Flow abstract tiling and distribution we +// will create scf.for loops that distribute workload cyclically. After +// concretizing hal.interface.workgroup.* ops, these scf.for loops still remain, +// and they will be of the form: +// +// %lb = mul %workgroup_id_{x|y|z}, %cst_tile_size_{x|y|z} +// scf.for %iv = %lb to %cst_wokload_size_{x|y|z} +// step %cst_workload_size_{x|y|z} { ... } +// +// Such scf.for loops can be inlined if %lb is smaller than upper bound. +class RemoveTripOneLoop final : public OpRewritePattern { + public: + RemoveTripOneLoop(MLIRContext *context, SmallVector workloadSize, + SmallVector tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + workloadSize(workloadSize), + tileSize(std::move(tileSize)) {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + // Get constant upper bound and step values. + IntegerAttr ub, step; + if (!matchPattern(op.upperBound(), m_Constant(&ub)) || + !matchPattern(op.step(), m_Constant(&step))) { + return failure(); + } + + // Require that they are the same. + if (ub != step) return failure(); + + // Now make sure the lower bound is smaller than upper bound. The lower + // bound should be multiplying the workgroup ID with some constant. + + auto mulOp = op.lowerBound().getDefiningOp(); + if (!mulOp || mulOp.mapOperands().size() != 2) return failure(); + + AffineExpr lhs, rhs; + bindSymbols(op.getContext(), lhs, rhs); + auto mulMap = AffineMap::get(0, 2, lhs * rhs); + if (mulOp.getAffineMap() != mulMap) return failure(); + + auto mulLhs = mulOp.mapOperands().front(); + auto mulRhs = mulOp.mapOperands().back(); + + auto idOp = mulLhs.getDefiningOp(); + IntegerAttr multipler; + if (!idOp || !matchPattern(mulRhs, m_Constant(&multipler))) + return failure(); + + // We just need to make sure the max value of the workgroup ID multipled by + // the multipler is smaller than the upper bound to guarantee one trip. + unsigned dimIndex = idOp.dimension().getZExtValue(); + int64_t dimSize = workloadSize[dimIndex]; + int64_t dimTile = tileSize[dimIndex]; + + if (dimSize == ShapedType::kDynamicSize) return failure(); + + int64_t count = ceilDiv(dimSize, dimTile); + assert(count > 0 && "expected at least one tile!"); + + // ID should be in range [0, count). + if ((count - 1) * multipler.getInt() >= ub.getInt()) { + // Dead loop. It can actually be removed entirely. But we aren't expecting + // it to happen here. Do not canonicalize for such case. + return failure(); + } + + SmallVector blockArgs; + blockArgs.reserve(op.getNumIterOperands() + 1); + blockArgs.push_back(op.lowerBound()); + llvm::append_range(blockArgs, op.getIterOperands()); + + Block *block = &op.getLoopBody().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + + return success(); + } + + private: + SmallVector workloadSize; + SmallVector tileSize; +}; + +/// Concretizes hal.interface.workgroup.* ops with constants from the chosen +/// tiling sheme when possible and perform loop canonicalization afterwards. +class ConcretizeTileAmongWorkgroupsPass + : public PassWrapper> { + public: + ConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) + : options(options) {} + ConcretizeTileAmongWorkgroupsPass( + const ConcretizeTileAmongWorkgroupsPass &that) + : options(that.options) { + inlineTripOneLoops = that.inlineTripOneLoops; + } + + void runOnOperation() override { + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!funcOp.isPublic()) continue; + if (failed(runOnFunction(funcOp))) return signalPassFailure(); + } + } + + private: + LogicalResult runOnFunction(FuncOp funcOp) { + MLIRContext &context = getContext(); + + // 1. Get the root op first. We need it to figure out the original problem + // size, which then affects the tiling and distribution policy. + + linalg::LinalgOp rootOp = getRootLinalgOp(funcOp); + if (!rootOp) { + LLVM_DEBUG(llvm::dbgs() << "unable to find root Linalg op\n"); + // It can happen for ops that are not abstractly tiled during dispatch + // region formation. So don't trigger pass failure. + return success(); + } + LLVM_DEBUG(llvm::dbgs() << "Root op: " << rootOp << "\n"); + + size_t numTilableDims = getNumOuterParallelDims(rootOp); + + // 2. Figure out the original problem size. + + SmallVector inputTypes, outputTypes; + SmallVector workloadSize; + if (succeeded( + getInputOutputTypesForAllTiles(rootOp, inputTypes, outputTypes))) { + if (outputTypes.size() != 1) { + return rootOp.emitError("only support ops with one result right now"); + } + + // Flow/HAL processor id/size/count ops' indices follow the reverse order + // of the shape dimensions. + workloadSize = llvm::to_vector<4>(llvm::reverse( + outputTypes.front().cast().getShape().take_front( + numTilableDims))); + } else { + // This can happen for dynamic shapes. + LLVM_DEBUG(llvm::dbgs() + << "unable to find input/output type for all tiles"); + + inputTypes.clear(); + outputTypes.clear(); + + workloadSize.assign(numTilableDims, ShapedType::kDynamicSize); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried workload size: "; + llvm::interleaveComma(workloadSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 3. Query the scheme for tiling among workgroups. + + SmallVector tileSize; + SmallVector workgroupSize; + + // Try to use configuration from the command-line first for testing. + tileSize.assign(options.tileSizes.begin(), options.tileSizes.end()); + tileSize.resize(numTilableDims, 0); + workgroupSize.assign(options.workgroupSize.begin(), + options.workgroupSize.end()); + if (tileSize.empty() || workgroupSize.empty()) { + auto sizes = getTileSizeAndWorkgroupSize(rootOp, inputTypes, outputTypes); + if (sizes) { + // The tile sizes are specified against the original dimension order of + // the workload shape. But Flow/HAL processor id/size/count ops' are + // created using the reverse order. + tileSize = llvm::to_vector<4>( + llvm::reverse(sizes->first.take_front(numTilableDims))); + workgroupSize = llvm::to_vector<4>(sizes->second); + } else { + return funcOp.emitError("failed to query tile size and workgroup size"); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "Queried tile size: "; + llvm::interleaveComma(tileSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // 4. Replace hal.interface.workgroup symbolic ops with constant values. + + { + OwningRewritePatternList patterns; + patterns.insert( + &context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + LLVM_DEBUG({ + llvm::dbgs() + << "--- After concretizing hal.interface.workgroup ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + + // 5. Set the entry point region for computing the number of workgroups + // to dispatch. The region has symbolic arguments representing the workload. + // So two modes here (see comments at the begining of this file). + + { + SmallVector numWorkgroups; + for (auto pair : llvm::zip(workloadSize, tileSize)) { + auto workload = std::get<0>(pair); + auto tile = std::get<1>(pair); + if (workload == ShapedType::kDynamicSize || tile == 0) { + numWorkgroups.push_back(ShapedType::kDynamicSize); + } else { + numWorkgroups.push_back(ceilDiv(workload, tile)); + } + } + + numWorkgroups.resize(kWorkgroupDimCount, 1); + + // If all dimensions are known constant, then we can set the number of + // workgroups directly. Otherwise, we need to generate the IR for + // computing it using symbolic values. + if (llvm::none_of(numWorkgroups, [](int64_t dim) { + return dim == ShapedType::kDynamicSize; + })) { + OpBuilder builder(&context); + WorkgroupCountRegionBuilder regionBuilder = + [&](OpBuilder &builder, Location loc, std::array) { + std::array returnValues; + for (unsigned i = 0; i < kWorkgroupDimCount; ++i) { + returnValues[i] = + builder.create(loc, numWorkgroups[i]); + } + return returnValues; + }; + if (failed( + defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) { + return funcOp.emitError( + "failed to set entry point region for number of workgroups"); + } + } else { + if (failed(materializeStaticLaunchInformation(funcOp, tileSize))) { + return funcOp.emitOpError( + "failed to materialize static launch information"); + } + } + } + + if (failed(updateWorkGroupSize(funcOp, workgroupSize))) { + return funcOp.emitOpError("failed to set workgroup size on function"); + } + + // 6. Canonicalization and clean up. + + if (inlineTripOneLoops) { + OwningRewritePatternList patterns; + patterns.insert(&context, workloadSize, tileSize); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + return success(); + } + + private: + SPIRVCodegenOptions options; + + // TODO(#5034): Investigate whether there is a better way to prove tileability + // and canonicalize affine.min ops, without matching against the specific + // pattern involving loops. + Option inlineTripOneLoops{ + *this, "inline-trip-one-loops", + llvm::cl::desc( + "Inline a loop's body if it can be proven to just have one trip"), + llvm::cl::init(true)}; +}; + +} // namespace + +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-spirv-concretize-tile-among-workgroups", + "Replace hal.interface.workgroup.* ops with constant values from chosen " + "tiling and distribution scheme", + [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp index d43044f52e40..323832d9e707 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp @@ -20,6 +20,7 @@ #include "iree/compiler/Conversion/Common/Attributes.h" #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -89,26 +90,34 @@ int dimensionToIndex(StringRef dimension) { return StringSwitch(dimension).Case("x", 0).Case("y", 1).Case("z", 2); } -/// Gets the block processor ID's upper bound. This queries the workgroup count -/// function. -Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { - auto funcOp = blockIDOp->getParentOfType(); +IREE::HAL::ReturnOp getEntryPointReturnOp(Operation *op) { + auto funcOp = op->getParentOfType(); auto targetOp = funcOp.getOperation()->getParentOfType(); - IREE::HAL::ExecutableEntryPointOp entryPointOp = nullptr; + + IREE::HAL::ExecutableEntryPointOp entryPointOp; for (auto op : targetOp.getOps()) { if (op.sym_name() == funcOp.getName()) { entryPointOp = op; break; } } - if (!entryPointOp) return llvm::None; + if (!entryPointOp) return {}; Operation *terminator = entryPointOp.getBlock()->getTerminator(); auto retOp = dyn_cast(terminator); - if (!retOp || retOp.getNumOperands() != 3) return llvm::None; + if (!retOp || retOp.getNumOperands() != 3) return {}; + LLVM_DEBUG(llvm::dbgs() << "workgroup count function return op: " << retOp << "\n"); + return retOp; +} + +/// Gets the block processor ID's upper bound. This queries the workgroup count +/// function. +Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; int index = dimensionToIndex(blockIDOp.dimension()); IntegerAttr attr; @@ -118,6 +127,19 @@ Optional getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) { return attr.getInt(); } +Optional getProcessorIDUpperBound( + IREE::HAL::InterfaceWorkgroupIDOp blockIDOp) { + auto retOp = getEntryPointReturnOp(blockIDOp); + if (!retOp) return llvm::None; + + int index = blockIDOp.dimensionAttr().getInt(); + IntegerAttr attr; + if (!matchPattern(retOp.getOperand(index), m_Constant(&attr))) + return llvm::None; + + return attr.getInt(); +} + /// Gets the thread processor ID's upper bound. This queries the SPIR-V entry /// point ABI. Optional getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) { @@ -165,6 +187,9 @@ struct FoldAffineMinOverProcessorID : OpRewritePattern { Optional ub; if (auto blockIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(blockIDOp); + } else if (auto blockIDOp = + dyn_cast(symbolOp)) { + ub = getProcessorIDUpperBound(blockIDOp); } else if (auto threadIDOp = dyn_cast(symbolOp)) { ub = getProcessorIDUpperBound(threadIDOp); } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index 0ef687798f0e..097a0f8f1983 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -121,8 +121,23 @@ static LogicalResult getMaliSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); // Get a vector of best tile size ordered from best to worst. @@ -292,8 +307,21 @@ static LogicalResult getTargetSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - auto lhsType = op.inputs()[0].getType().cast(); - auto rhsType = op.inputs()[1].getType().cast(); + ShapedType lhsType, rhsType; + if (auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types")) { + lhsType = inputTypeAttr.getValue()[0] + .cast() + .getValue() + .cast(); + rhsType = inputTypeAttr.getValue()[1] + .cast() + .getValue() + .cast(); + } else { + lhsType = op.inputs()[0].getType().cast(); + rhsType = op.inputs()[1].getType().cast(); + } assert(lhsType.getElementType() == rhsType.getElementType()); // If the shape size is unknonw fall back to none vectorized path. if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); @@ -375,14 +403,34 @@ template static LogicalResult getMaliSpecificConfig(ConvOpTy op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(1).getType().template cast(); - auto outputType = op.getOutputBufferTypes()[0].template cast(); + Operation *operation = op.getOperation(); + if (!isa(operation)) return failure(); + + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_output_types")) { + auto inputTypeAttr = operation->getAttrOfType( + "iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "conv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "conv output types: " << outputType << "\n"); + } else { + inputType = op.getInputs().front().getType().template cast(); + outputType = op.getOutputBufferTypes()[0].template cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); - // Only support NHWC conv. - if (!isa(op.getOperation())) { - return failure(); - } bool isInputTilable = inputType.getDimSize(3) % 4 == 0 || inputType.getDimSize(3) < 4; @@ -478,8 +526,29 @@ GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp) static LogicalResult getMaliSpecificConfig( linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - auto inputType = op.getInput(0).getType().cast(); - auto outputType = op.getOutputBufferTypes()[0].cast(); + ShapedType inputType, outputType; + + // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able + // to query launch configurations. + if (auto outputTypeAttr = + op->getAttrOfType("iree.codegen.original_output_types")) { + auto inputTypeAttr = + op->getAttrOfType("iree.codegen.original_input_types"); + inputType = inputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + outputType = outputTypeAttr.getValue()[0] + .template cast() + .getValue() + .template cast(); + LLVM_DEBUG(llvm::dbgs() << "dwconv input types: " << inputType << "\n"); + LLVM_DEBUG(llvm::dbgs() << "dwconv output types: " << outputType << "\n"); + } else { + inputType = op.getInput(0).getType().cast(); + outputType = op.getOutputBufferTypes()[0].cast(); + } + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) return failure(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp new file mode 100644 index 000000000000..3c6d5e35e2bc --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp @@ -0,0 +1,155 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- LinalgTileAndDistributePass.cpp ------------------------------------===// +// +// This pass tiles and distributes linalg operations among multiple workgroups. +// +// NOTE: Deprecated. This pass is used for the first-level tiling in the Linalg +// on buffers path, which is expected to go away soon. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" +#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" +#include "iree/compiler/Conversion/Common/Attributes.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h" +#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-linalg-to-spirv-tile-and-distribute" + +namespace mlir { +namespace iree_compiler { +namespace { + +/// Returns the distribution options for operations when targeting workgroups. +linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { + linalg::LinalgLoopDistributionOptions options; + + options.procInfo = [](OpBuilder &builder, Location loc, + ArrayRef parallelLoopRanges) { + return getGPUProcessorIdsAndCounts( + builder, loc, parallelLoopRanges.size()); + }; + options.distributionMethod.assign( + 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); + + return options; +} + +class LinalgTileAndDistributePass + : public PassWrapper> { + public: + LinalgTileAndDistributePass(const SPIRVCodegenOptions &options) + : options(options) {} + LinalgTileAndDistributePass(const LinalgTileAndDistributePass &that) + : options(that.options) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + IREE::HAL::ExecutableTargetOp targetOp = getOperation(); + ModuleOp module = targetOp.getInnerModule(); + + for (FuncOp funcOp : module.getOps()) { + if (!isEntryPoint(funcOp)) continue; + + SmallVector linalgOps; + SmallVector tiledLoops; + + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { + return signalPassFailure(); + } + + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfigOpt = + initGPULaunchConfig(context, dependenceGraph, options, linalgOps); + if (!launchConfigOpt) { + funcOp.emitError("unable to find launch configuration"); + return signalPassFailure(); + } + LaunchConfig &launchConfig = *launchConfigOpt; + + LLVM_DEBUG({ + llvm::dbgs() + << "\n--- IREE Linalg tile and distribute configuration ---\n"; + llvm::dbgs() << "@func " << funcOp.getName() + << ": # workgroup sizes: ["; + interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); + llvm::dbgs() << "]\n"; + for (auto op : linalgOps) { + llvm::dbgs() << "\t" << op.getOperation()->getName() << " : "; + TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op); + llvm::dbgs() << "{"; + std::string sep = ""; + for (auto &level : enumerate(tileSizes)) { + llvm::dbgs() << sep << level.index() << " : ["; + sep = ", "; + interleaveComma(level.value(), llvm::dbgs()); + llvm::dbgs() << "]"; + } + llvm::dbgs() << "}\n"; + } + }); + + TileAndFuseOptions tileAndFuseOptions = { + getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; + if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, + launchConfig, + tileAndFuseOptions)) || + failed( + updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { + return signalPassFailure(); + } + } + } + + private: + SPIRVCodegenOptions options; +}; + +} // namespace + +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options) { + return std::make_unique(options); +} + +static PassRegistration pass( + "iree-codegen-spirv-linalg-tile-and-distribute", + "Tile and distribute Linalg operations on buffers", [] { + SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions(); + return std::make_unique(options); + }); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp index 3aae4847324f..36d6ee2262c8 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp @@ -14,7 +14,8 @@ //===- LinalgTileAndFusePass.cpp - Tile and fuse Linalg on Buffers --------===// // -// Implements a pass to tile and fuse linalg operations on buffers. +// This pass tiles and vectorizes Linalg ops on buffers within in a single +// workgroup. // //===----------------------------------------------------------------------===// @@ -32,6 +33,8 @@ #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -72,33 +75,6 @@ static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker( markers, Identifier::get(replaceMarker, context)); } -/// Returns the distribution options for operations when targeting workgroups. -static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() { - linalg::LinalgLoopDistributionOptions options; - - options.procInfo = [](OpBuilder &builder, Location loc, - ArrayRef parallelLoopRanges) { - return getGPUProcessorIdsAndCounts( - builder, loc, parallelLoopRanges.size()); - }; - options.distributionMethod.assign( - 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters); - - return options; -} - -/// Applies canonicalization over index calculation inside the given `funcOp`. -static void applyIndexCalculationCanonicalization(FuncOp funcOp) { - MLIRContext *context = funcOp.getContext(); - OwningRewritePatternList canonicalizationPatterns; - DimOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - AddIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SubIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - SignedDivIOp::getCanonicalizationPatterns(canonicalizationPatterns, context); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns)); -} - //===----------------------------------------------------------------------===// // Main pass //===----------------------------------------------------------------------===// @@ -429,8 +405,6 @@ void LinalgTileAndFusePass::runOnOperation() { IREE::HAL::ExecutableTargetOp targetOp = getOperation(); ModuleOp module = targetOp.getInnerModule(); - LLVM_DEBUG( - llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";); for (FuncOp funcOp : module.getOps()) { if (!isEntryPoint(funcOp)) continue; @@ -452,6 +426,7 @@ void LinalgTileAndFusePass::runOnOperation() { LaunchConfig &launchConfig = *launchConfigOpt; LLVM_DEBUG({ + llvm::dbgs() << "\n--- IREE Linalg tile and fuse configuration ---\n"; llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: ["; interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs()); llvm::dbgs() << "]\n"; @@ -470,55 +445,6 @@ void LinalgTileAndFusePass::runOnOperation() { } }); - if (!options.usingLinalgOnTensors) { - TileAndFuseOptions tileAndFuseOptions = { - getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; - if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, - launchConfig, - tileAndFuseOptions)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - return signalPassFailure(); - } - } else { - // Find the root operation for the dispatch region and get the tile sizes. - Operation *rootOperation = - launchConfig.getRootOperation(llvm::to_vector<4>(llvm::map_range( - linalgOps, - [](linalg::LinalgOp op) { return op.getOperation(); }))); - if (!rootOperation) { - launchConfig.finalize(funcOp); - return; - } - - ArrayRef rootOperationTileSizes = - launchConfig.getTileSizes(rootOperation, 0); - if (rootOperationTileSizes.empty()) { - launchConfig.finalize(funcOp); - return; - } - - // Only use the tile sizes for parallel loops of the root operation. - rootOperationTileSizes = rootOperationTileSizes.take_front( - getNumOuterParallelLoops(rootOperation)); - - SmallVector workloadPerWorkgroup = - llvm::to_vector<4>(llvm::reverse(rootOperationTileSizes)); - if (failed(materializeStaticLaunchInformation(funcOp, - workloadPerWorkgroup)) || - failed( - updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) { - funcOp.emitOpError("failed to materialize static launch information"); - return signalPassFailure(); - } - } - - LLVM_DEBUG({ - llvm::dbgs() << "--- After first level of tiling and distribution ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - if (options.useWorkgroupMemory) { // The promotion patterns are put separate from the tiling patterns to // make sure that the allocated scratchspace memory is constant sizes diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp index 0ea3a74f1285..0dfe6cde7228 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp @@ -83,9 +83,15 @@ static void addLinalgToSPIRVPasses(OpPassManager &pm, // - The Linalg op is kept untouched. // //===--------------------------------------------------------------------===// - if (!options.usingLinalgOnTensors) { + if (options.usingLinalgOnTensors) { + // flow.dispatch.workgroups performed abstract tiling and distribution. Make + // them concrete now since we know the target and settings now. + pm.addPass(createConcretizeTileAmongWorkgroupsPass(options)); + } else { pm.addPass(createSplitDispatchFunctionPass()); + pm.addPass(createTileAndDistributeAmongWorkgroupsPass(options)); } + pm.addPass(createLinalgTileAndFusePass(options)); if (options.vectorizeMemref) { pm.nest().addNestedPass( diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h index 1c29576e9433..bce9f2efd7c0 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h +++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h @@ -81,6 +81,16 @@ createFoldProcessorIDUsesPass(); std::unique_ptr> createMaterializeEntryPointsPass(); +/// Creates a pass to concretize hal.interface.workgroup.* ops with concrete +/// tiling and distribution scheme. +std::unique_ptr> +createConcretizeTileAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + +/// Tiles and distributes Linalg operations on buffers among multiple +/// workgroups. +std::unique_ptr> +createTileAndDistributeAmongWorkgroupsPass(const SPIRVCodegenOptions &options); + //===----------------------------------------------------------------------===// // Pipelines //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp index 65febab28f3e..671512a8b5fa 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp @@ -126,6 +126,7 @@ class MemRefUsageAnalysis { void analyzeFunc(FuncOp funcOp); void analyzeAlloc(AllocOp allocOp); void analyzePlaceholder(IREE::PlaceholderOp placeholderOp); + void analyzeInterfaceBinding(IREE::HAL::InterfaceBindingSubspanOp bindingOp); llvm::DenseMap vectorization_size; llvm::DenseSet transferOps; }; @@ -136,6 +137,8 @@ MemRefUsageAnalysis::MemRefUsageAnalysis(mlir::Operation *op) { if (auto alloc = dyn_cast(op)) analyzeAlloc(alloc); if (auto placeholder = dyn_cast(op)) analyzePlaceholder(placeholder); + if (auto bindingOp = dyn_cast(op)) + analyzeInterfaceBinding(bindingOp); }); } @@ -159,6 +162,15 @@ void MemRefUsageAnalysis::analyzePlaceholder( } } +void MemRefUsageAnalysis::analyzeInterfaceBinding( + IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + SmallVector vectorUses; + if (unsigned vectorSize = isMemRefAndVectorizable(bindingOp, vectorUses)) { + vectorization_size.insert(std::make_pair(bindingOp, vectorSize)); + transferOps.insert(vectorUses.begin(), vectorUses.end()); + } +} + void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) { SmallVector vectorUses; if (unsigned vectorSize = isMemRefAndVectorizable(allocOp, vectorUses)) { @@ -363,6 +375,26 @@ class ProcessPlaceHolder final } }; +class ProcessInterfaceBinding final + : public MemRefConversionPattern { + public: + using MemRefConversionPattern< + IREE::HAL::InterfaceBindingSubspanOp>::MemRefConversionPattern; + + LogicalResult matchAndRewrite( + IREE::HAL::InterfaceBindingSubspanOp bindingOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memrefType = bindingOp.getType().dyn_cast(); + if (!memrefType) return failure(); + auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.getResult()); + if (!vecMemRef) return failure(); + rewriter.replaceOpWithNewOp( + bindingOp, *vecMemRef, bindingOp.binding(), bindingOp.byte_offset(), + bindingOp.byte_length()); + return success(); + } +}; + class VectorizeMemRefPass final : public PassWrapper> { void runOnOperation() override; @@ -410,8 +442,8 @@ void VectorizeMemRefPass::runOnOperation() { OwningRewritePatternList patterns; patterns.insert(context, - *memrefUsageAnalysis); + ProcessAlloc, ProcessPlaceHolder, ProcessInterfaceBinding>( + context, *memrefUsageAnalysis); ConversionTarget target(*context); target.addDynamicallyLegalOp([&](FuncOp op) { @@ -426,6 +458,10 @@ void VectorizeMemRefPass::runOnOperation() { [&](IREE::PlaceholderOp placeholder) { return !memrefUsageAnalysis->vectorizeMemRef(placeholder); }); + target.addDynamicallyLegalOp( + [&](IREE::HAL::InterfaceBindingSubspanOp bindingOp) { + return !memrefUsageAnalysis->vectorizeMemRef(bindingOp); + }); target.markUnknownOpDynamicallyLegal([&](Operation *op) { if (isa(op)) return !memrefUsageAnalysis->transferConvert(op); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir index 508f72e6ad12..c597271a70c2 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir new file mode 100644 index 000000000000..d863939ede90 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir @@ -0,0 +1,225 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-spirv-tile-size=16,4,4 -iree-spirv-workgroup-size=4,4,1 %s | IreeFileCheck %s + +hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @conv2d_static_shape attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv2d_static_shape() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully static shaped dispatch region, we can: +// 1) Generate static constant workgroup counts, +// 2) Replace hal.interface.workgroup.{size|count} ops with constants, +// 3) Canonicalize loops and subview ops. + +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (9, d0 * -2 + 225)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + 32)> +// CHECK: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (4, -d0 + 112)> + +// CHECK: hal.executable.entry_point @conv2d_static_shape +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[C28_0:.+]] = constant 28 : index +// CHECK: %[[C28_1:.+]] = constant 28 : index +// CHECK: hal.return %[[C2]], %[[C28_0]], %[[C28_1]] : index, index, index + +// CHECK: func @conv2d_static_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index + +// CHECK: %[[Z_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Z]], %c4] +// CHECK: %[[Y_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[X_MUL_16:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] + +// CHECK: %[[Z_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Z_MUL_4]]) +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP1]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Y_MUL_4]]) +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP1]](%[[Y_MUL_4]])[%c4] + +// CHECK: %[[INPUT:.+]] = subview %{{.+}}[0, %[[Z_OFFSET]], %[[Y_OFFSET]], 0] [1, %[[Z_SIZE]], %[[Y_SIZE]], 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}> + +// CHECK: %[[X_SIZE:.+]] = affine.min #[[MAP2]](%[[X_MUL_16]])[%c16] + +// CHECK: %[[FILTER:.+]] = subview %{{.+}}[0, 0, 0, %[[X_MUL_16]]] [3, 3, 16, %[[X_SIZE]]] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}> + +// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP3]](%[[Z_MUL_4]])[%c4] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP3]](%[[Y_MUL_4]])[%c4] +// CHECK: %[[OUTPUT:.+]] = subview %{{.+}}[0, %[[Z_MUL_4]], %[[Y_MUL_4]], %[[X_MUL_16]]] [1, %[[Z_SIZE]], %[[Y_SIZE]], %[[X_SIZE]]] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}> + +// CHECK: linalg.fill(%[[OUTPUT]], %{{.+}}) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, is_root_op, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[FILTER]] : memref<1x?x?x16xf32, {{.+}}>, memref<3x3x16x?xf32, {{.+}}>) outs(%[[OUTPUT]] : memref<1x?x?x?xf32, {{.+}}>) + +// ----- + +hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_dynamic_shape attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_dynamic_shape() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %0 = hal.interface.load.constant offset = 0 : index + %1 = hal.interface.load.constant offset = 1 : index + %2 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref + %3 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref + %4 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref + %5 = hal.interface.load.constant offset = 2 : index + %6 = hal.interface.load.constant offset = 3 : index + %7 = hal.interface.load.constant offset = 4 : index + %8 = hal.interface.load.constant offset = 5 : index + %9 = hal.interface.load.constant offset = 6 : index + %10 = hal.interface.load.constant offset = 7 : index + %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]> + %12 = shapex.tie_shape %2, %11 : memref, !shapex.ranked_shape<[?,?]> + %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]> + %14 = shapex.tie_shape %3, %13 : memref, !shapex.ranked_shape<[?,?]> + %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]> + %16 = shapex.tie_shape %4, %15 : memref, !shapex.ranked_shape<[?,?]> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %17 to %5 step %18 { + %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %19 to %8 step %20 { + %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y] + %22 = subview %12[%arg0, 0] [%21, %6] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x] + %24 = subview %14[0, %arg1] [%7, %23] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y] + %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x] + %27 = subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + linalg.fill(%27, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * s1 + s0 + d1)>>, f32 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref (d0 * s1 + s0 + d1)>>, memref (d0 * s1 + s0 + d1)>>) outs(%27 : memref (d0 * s1 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully dynamic shaped dispatch region, we can: +// 1) Generate symbolic workgroup counts, +// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants. + +// CHECK: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (4, -d0 + s0)> +// CHECK: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (16, -d0 + s0)> + +// CHECK: hal.executable.entry_point @matmul_dynamic_shape +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index): +// CHECK: %c1 = constant 1 : index +// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]] +// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]] +// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1 + +// CHECK: func @matmul_dynamic_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index +// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index +// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index +// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index +// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index +// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index + +// CHECK: %[[Y_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] +// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_Y]], %c4] +// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]] +// CHECK: %[[X_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] +// CHECK: %[[X_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_X]], %c16] +// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]], %c4] +// CHECK: %[[A_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref to memref +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]], %c16] +// CHECK: %[[B_TILE:.+]] = subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]], %c4] +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]], %c16] +// CHECK: %[[C_TILE:.+]] = subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: linalg.fill(%[[C_TILE]], %cst) {__internal_linalg_transform__ = "workgroup"} : memref, f32 +// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op} ins(%[[A_TILE]], %[[B_TILE]] : memref, memref) outs(%[[C_TILE]] : memref) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir index fe9bbabefc44..8adf600f5a96 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir @@ -34,6 +34,40 @@ hal.executable @fold_block_id attributes {sym_visibility = "private"} { // ----- +hal.executable @fold_interface_workgroup_id attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + } + hal.executable.target @vulkan, filter="vulkan*" { + hal.executable.entry_point @fold_interface_workgroup_id attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = () -> ()} { + ^bb0(%arg0 : index, %arg1 : index, %arg2 : index): + %x = constant 112: index + %y = constant 42: index + %z = constant 1: index + hal.return %x, %y, %z: index, index, index + } + module { + func @fold_interface_workgroup_id() -> (index, index, index) { + %0 = hal.interface.workgroup.id[0] : index + %1 = hal.interface.workgroup.id[1] : index + %2 = hal.interface.workgroup.id[2] : index + %3 = affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%0] + %4 = affine.min affine_map<()[s0] -> (8, s0 * -1 + s0 * -1 + s0 * -1 + 131)>()[%2] + %5 = affine.min affine_map<()[s0] -> (11, s0 + 15)>()[%3] + return %3, %4, %5: index, index, index + } + } + } +} +// CHECK-LABEL: func @fold_interface_workgroup_id() +// CHECK-DAG: %[[C3:.+]] = constant 3 +// CHECK-DAG: %[[C8:.+]] = constant 8 +// CHECK-DAG: %[[C11:.+]] = constant 11 +// CHECK-DAG: return %[[C3]], %[[C8]], %[[C11]] + +// ----- + hal.executable @fold_thread_id attributes {sym_visibility = "private"} { hal.interface @legacy_io { } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir index d8da2427ec4b..f5a2b68858ba 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @conv_no_padding attributes {sym_visibility = "private"} { @@ -439,148 +439,3 @@ hal.executable @three_op_fusion attributes {sym_visibility = "private"} { // CHECK-SAME: ) // CHECK-SAME: outs(%[[SV_RET0]] // CHECK-SAME: ) - -// ----- - -// TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. -hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @conv_tiled_and_vectorized attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, - !flow.dispatch.output<1x112x112x32xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x112x112x32xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x225x225x16xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x16x32xf32> - linalg.fill(%0, %cst) : memref<1x112x112x32xf32>, f32 - linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} - ins (%1, %2: memref<1x225x225x16xf32>, memref<3x3x16x32xf32>) - outs (%0: memref<1x112x112x32xf32>) - return - } - - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK-COUNT-4: vector.transfer_write - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 -// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) - -// CHECK-COUNT-16: vector.contract - -// CHECK-COUNT-3: scf.yield - -// For linalg.conv_2d_input_nhwc_filter_hwcf -// CHECK-COUNT-4: vector.transfer_write - -// ----- - -hal.executable @depthwise_conv2d_2452x2423_valid_stride_2 attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @depthwise_conv2d_2452x2423_valid_stride_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<2x4x5x2xf32>, tensor<2x4x2x3xf32>) -> tensor<2x2x1x6xf32>} - module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { - func @depthwise_conv2d_2452x2423_valid_stride_2() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x2x1x2x3xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4x5x2xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<2x4x2x3xf32> - linalg.fill(%0, %cst) : memref<2x2x1x2x3xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<2x4x5x2xf32>, memref<2x4x2x3xf32>) outs(%0 : memref<2x2x1x2x3xf32>) - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv2d_2452x2423_valid_stride_2() -// CHECK: linalg.fill -// CHECK: linalg.generic -// CHECK-NOT: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf - -// ----- - -hal.executable @conv_tiled_and_vectorized attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @depthwise_conv_tiled_and_vectorized attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input<1x113x113x96xf32>, !flow.dispatch.input<3x3x96xf32>, - !flow.dispatch.output<1x56x56x96xf32>) -> ()} - module attributes { - spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}> - } { - func @depthwise_conv_tiled_and_vectorized() { - %cst = constant 0.000000e+00 : f32 - %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x56x56x96xf32> - %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x113x113x96xf32> - %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x96xf32> - linalg.fill(%0, %cst) : memref<1x56x56x96xf32>, f32 - linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%1, %2 : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%0 : memref<1x56x56x96xf32>) - return - } - } - } -} - -// CHECK-LABEL: func @depthwise_conv_tiled_and_vectorized() - -// For linalg.fill -// CHECK: vector.transfer_write - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_read - -// check tiling loop along filter height/width and input channel -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) -// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 -// CHECK-SAME: -> (vector<4xf32>) - - -// CHECK: vector.fma - -// CHECK-COUNT-2: scf.yield - -// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc -// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 8151a358e43f..c55b112f32fa 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s +// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-codegen-spirv-experimental-linalg-on-tensors -cse -canonicalize -split-input-file %s | IreeFileCheck %s hal.executable @matmul_tensors attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir index 88251f44d56a..b07f50ad08ff 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -61,67 +61,3 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { // CHECK: scf.yield // CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]] // CHECK: return - -// ----- - -hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan, filter="dylib*" { - hal.executable.entry_point @matmul_static_shape_f16 attributes { - interface = @legacy_io, ordinal = 0 : i32, - signature = (!flow.dispatch.input, !flow.dispatch.input, - !flow.dispatch.output) -> ()} - module attributes { - spv.target_env = - #spv.target_env<#spv.vce, - ARM:IntegratedGPU, - {max_compute_shared_memory_size = 32768 : i32, - max_compute_workgroup_invocations = 512 : i32, - max_compute_workgroup_size = dense<512> : vector<3xi32>, - subgroup_size = 16 : i32}>} { - func @matmul_static_shape_f16() - attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} { - %arg0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16> - %arg1 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16> - %ret0 = iree.placeholder for "interface buffer" - {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16> - %cst = constant 0.000000e+00 : f16 - linalg.fill(%ret0, %cst) : memref<4096x4096xf16>, f16 - linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>) - outs(%ret0 : memref<4096x4096xf16>) - return - } - func private @matmul_static_shape__num_workgroups__ - (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>, - !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index) - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} -// CHECK-LABEL: func @matmul_static_shape_f16 -// CHECK-COUNT-16: vector.transfer_write -// CHECK-COUNT-16: vector.transfer_read -// CHECK: %[[FOR_RES:.+]]:16 = scf.for -// CHECK-COUNT-16: vector.transfer_read -// CHECK-COUNT-64: vector.contract -// CHECK: scf.yield -// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] -// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir index 625b2b5ee41c..ee9eef28624c 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir @@ -1,5 +1,5 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir index f5a6a5410268..0af82c633e63 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir @@ -132,3 +132,24 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } + +// ----- + +func @vectorize_binding_subspan() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + // CHECK: hal.interface.binding.subspan @legacy_io::@arg0[%c0] + // CHECK-SAME: memref<4096x1024xvector<4xf32>> + // CHECK: hal.interface.binding.subspan @legacy_io::@ret0[%c0] + // CHECK-SAME: memref<4096x1024xvector<4xf32>> + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf32> + %1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf32> + %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<32x8xf32> + vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf32>, memref<4096x4096xf32> + return +} + +hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" +} diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir new file mode 100644 index 000000000000..165cd52394b8 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir @@ -0,0 +1,176 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c32 = constant 32 : index + %c112 = constant 112 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x225x225x16xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x16x32xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x112x112x32xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %3 to %c112 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %5 to %c112 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %7 to %c32 step %8 { + %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z] + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y] + %13 = subview %0[0, %9, %11, 0] [1, %10, %12, 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>> + %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x] + %15 = subview %1[0, 0, 0, %arg2] [3, 3, 16, %14] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>> + %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z] + %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y] + %18 = subview %2[0, %arg0, %arg1, %arg2] [1, %16, %17, %14] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>> + linalg.fill(%18, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>, f32 + linalg.conv_2d_input_nhwc_filter_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : memref<1x?x?x16xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 810000 + s0 + d1 * 3600 + d2 * 16 + d3)>>, memref<3x3x16x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 512 + d2 * 32 + d3)>>) outs(%18 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 401408 + s0 + d1 * 3584 + d2 * 32 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @conv_static_shape_f32() + +// For linalg.fill +// CHECK-COUNT-4: vector.transfer_write + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4 +// CHECK-SAME: -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) + +// CHECK-COUNT-16: vector.contract + +// CHECK-COUNT-3: scf.yield + +// For linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-COUNT-4: vector.transfer_write + +// ----- + +hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @depthwise_conv_static_shape_f32() { + %cst = constant 0.000000e+00 : f32 + %c96 = constant 96 : index + %c56 = constant 56 : index + %c0 = constant 0 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<1x113x113x96xf32> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<3x3x1x96xf32> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<1x56x56x96xf32> + %3 = linalg.reshape %1 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] : memref<3x3x1x96xf32> into memref<864xf32> + %4 = linalg.reshape %3 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : memref<864xf32> into memref<3x3x96xf32> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z] + scf.for %arg0 = %5 to %c56 step %6 { + %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg1 = %7 to %c56 step %8 { + %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg2 = %9 to %c96 step %10 { + %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg0)[%workgroup_size_z] + %13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %14 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 113)>(%arg1)[%workgroup_size_y] + %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x] + %16 = subview %0[0, %11, %13, %arg2] [1, %12, %14, %15] [1, 1, 1, 1] : memref<1x113x113x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>> + %17 = subview %4[0, 0, %arg2] [3, 3, %15] [1, 1, 1] : memref<3x3x96xf32> to memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>> + %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg0)[%workgroup_size_z] + %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg1)[%workgroup_size_y] + %20 = subview %2[0, %arg0, %arg1, %arg2] [1, %18, %19, %15] [1, 1, 1, 1] : memref<1x56x56x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>> + linalg.fill(%20, %cst) {__internal_linalg_transform__ = "workgroup"} : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>, f32 + linalg.depthwise_conv_2d_input_nhwc_filter_hwc {__internal_linalg_transform__ = "workgroup", strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>) + } + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @depthwise_conv_static_shape_f32() + +// For linalg.fill +// CHECK: vector.transfer_write + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_read + +// check tiling loop along filter height/width and input channel +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) +// CHECK: scf.for %{{.+}} = %c0 to %c3 step %c1 +// CHECK-SAME: -> (vector<4xf32>) + + +// CHECK: vector.fma + +// CHECK-COUNT-2: scf.yield + +// For linalg.depthwise_conv_2d_input_nhwc_filter_hwc +// CHECK: vector.transfer_write diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir new file mode 100644 index 000000000000..74d43514e568 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir @@ -0,0 +1,61 @@ +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-codegen-linalg-tile-and-fuse))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} { + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_static_shape_f16 attributes { + interface = @legacy_io, ordinal = 0 : i32, + signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_static_shape_f16() { + %cst = constant 0.000000e+00 : f16 + %c0 = constant 0 : index + %c4096 = constant 4096 : index + %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<4096x4096xf16> + %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<4096x4096xf16> + %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<4096x4096xf16> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %3 to %c4096 step %4 { + %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %5 to %c4096 step %6 { + %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y] + %8 = subview %0[%arg0, 0] [%7, 4096] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x] + %10 = subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf16> to memref (d0 * 4096 + s0 + d1)>> + %11 = subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf16> to memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> + linalg.fill(%10, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * 4096 + s0 + d1)>>, f16 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %11 : memref (d0 * 4096 + s0 + d1)>>, memref<4096x?xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%10 : memref (d0 * 4096 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// CHECK-LABEL: func @matmul_static_shape_f16 +// CHECK-COUNT-16: vector.transfer_write +// CHECK-COUNT-16: vector.transfer_read +// CHECK: %[[FOR_RES:.+]]:16 = scf.for +// CHECK-COUNT-16: vector.transfer_read +// CHECK-COUNT-64: vector.contract +// CHECK: scf.yield +// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]] +// CHECK: return diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir index 5da02d9d228c..c96ce203934f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-codegen-linalg-tile-and-fuse,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s // TODO(GH-4901): Convert these tests back to use dynamic shapes when linalg on tensors becomes default. hal.executable @matmul_tile attributes {sym_visibility = "private"} { diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h index 9e0f52459371..cb1c9eb71028 100644 --- a/iree/compiler/Conversion/init_conversions.h +++ b/iree/compiler/Conversion/init_conversions.h @@ -65,6 +65,7 @@ inline void registerLinalgToSPIRVPasses() { // LinalgToSPIRV createConvertToGPUPass(SPIRVCodegenOptions()); createFoldProcessorIDUsesPass(); + createTileAndDistributeAmongWorkgroupsPass(SPIRVCodegenOptions()); createLinalgTileAndFusePass(SPIRVCodegenOptions()); createSplitDispatchFunctionPass(); createVectorToGPUPass(); diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 739a3c1617c4..f84b74350f17 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -20,8 +20,10 @@ #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -178,7 +180,9 @@ static SmallVector convertToWorkload(OpBuilder &b, Location loc, /// linalg.init_tensor operations. static bool isRootOp(Operation *op) { - return isa(op); + return isa(op); } static bool isAlwaysClonedIntoDispatchOp(Operation *op) { @@ -524,6 +528,14 @@ struct TileAndDistributeOnTensorsPattern SmallVector count = llvm::to_vector<4>( llvm::map_range(linalgOp.createLoopRanges(rewriter, loc), [](Range r) { return r.size; })); + // NOTE: Special treatment for convolution, which have more than 3 parallel + // dimensions. We want to ignore the batch dimension and tile along the + // next three. + // TODO(#5048): figure out a better way to avoid this special case. + if (isa(op)) { + count.erase(count.begin()); + } count.resize(getNumTilableLoops(op)); auto workload = convertToWorkload(rewriter, loc, count); @@ -770,10 +782,9 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { auto numParallelDims = parallelLoopRanges.size(); + SmallVector procInfo(numParallelDims); - for (size_t dim = 0; - dim < std::min(numParallelDims, kNumMaxParallelDims); - ++dim) { + for (size_t dim = 0; dim < numParallelDims; ++dim) { procInfo[numParallelDims - dim - 1] = { buildFlowWorkgroupInfoOp(builder, dim), @@ -787,21 +798,34 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { auto tileSizeFn = [&](OpBuilder &builder, Operation *op) -> SmallVector { + auto numParallelDims = getNumOuterParallelLoops(cast(op)); auto numTiledLoops = getNumTilableLoops(cast(op)); - SmallVector useTileSizes(numTiledLoops); + + // Default to zero to skip tiling. + auto zero = builder.create(op->getLoc(), 0); + SmallVector useTileSizes(numParallelDims, zero); + if (!clLinalgOnTensorsTileSizes.empty()) { SmallVector tileSizes(clLinalgOnTensorsTileSizes.begin(), clLinalgOnTensorsTileSizes.end()); - useTileSizes.resize(std::min(tileSizes.size(), numTiledLoops)); + useTileSizes.resize(std::min(tileSizes.size(), numParallelDims)); return llvm::to_vector<4>(llvm::map_range( ArrayRef(tileSizes).take_front( - std::min(tileSizes.size(), numTiledLoops)), + std::min(tileSizes.size(), numParallelDims)), [&](int64_t t) -> Value { return builder.create(op->getLoc(), t); })); } + + // NOTE: Special treatment for convolution, which have more than 3 + // parallel dimensions. We want to ignore the batch dimension and tile + // along the next three. That means setting the first position to zero. + // TODO(#5048): figure out a better way to avoid this special case. + bool isConvOp = isa(op); + for (size_t dim = 0; dim < numTiledLoops; ++dim) { - useTileSizes[numTiledLoops - dim - 1] = + useTileSizes[(isConvOp ? numParallelDims : numTiledLoops) - dim - 1] = buildFlowWorkgroupInfoOp(builder, dim); } return useTileSizes; @@ -858,6 +882,19 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { return signalPassFailure(); } + // Run necessary canonicalization patterns before destructive updates. + { + OwningRewritePatternList patterns; + // This is needed because tiling and distribution may create + // subtensor_insert ops whose source operands come from tensor.cast ops. + // Those tensor.cast ops cast tensors into a more dynamic shape, in order + // to guarantee type match during transformation. Later in destructive + // update subtensor_insert ops will be turned into flow dispatch output + // store ops. + SubTensorInsertOp::getCanonicalizationPatterns(patterns, context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + // Rewrite destructive updates and ensure no remaining store remains to the // full output. if (funcOp diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 70fe28f58566..ac5225c94c90 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -194,3 +194,33 @@ func @tensor5(%A: tensor, %B: tensor, %C: tensor) // CHECK: return %[[D]], %[[origCC]] return %D, %CC: tensor, tensor } + +func @conv2d(%input: tensor<1x225x225x16xf32>, %filter: tensor<3x3x16x32xf32>) -> tensor<1x112x112x32xf32> { + %0 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> + %cst = constant 0.000000e+00 : f32 + %1 = linalg.fill(%0, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32> + %2 = linalg.conv_2d_input_nhwc_filter_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : tensor<1x225x225x16xf32>, tensor<3x3x16x32xf32>) + outs(%1 : tensor<1x112x112x32xf32>) + -> tensor<1x112x112x32xf32> + return %2 : tensor<1x112x112x32xf32> +} + +// CHECK-LABEL: func @conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf + +func @depthwise_conv2d(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> { + %cst = constant 0.000000e+00 : f32 + %1 = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32> + %2 = linalg.fill(%1, %cst) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32> + %4 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%input, %filter : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) outs(%2 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> + return %4 : tensor<1x56x56x96xf32> +} + +// CHECK-LABEL: func @depthwise_conv2d +// CHECK: scf.for +// CHECK: scf.for +// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD index 689361b25a31..876cc74c4727 100644 --- a/iree/test/e2e/vulkan_specific/BUILD +++ b/iree/test/e2e/vulkan_specific/BUILD @@ -74,6 +74,8 @@ iree_check_single_backend_test_suite( "vectorized_conv.mlir", ], compiler_flags = [ + "-iree-flow-dispatch-linalg-on-tensors", + "-iree-codegen-spirv-experimental-linalg-on-tensors", "-iree-spirv-enable-vectorization", "-iree-vulkan-target-triple=valhall-g77-unknown-android10", ], diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt index f476cdaaf052..c5ec28c9189e 100644 --- a/iree/test/e2e/vulkan_specific/CMakeLists.txt +++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt @@ -66,6 +66,8 @@ iree_check_single_backend_test_suite( DRIVER "vulkan" COMPILER_FLAGS + "-iree-flow-dispatch-linalg-on-tensors" + "-iree-codegen-spirv-experimental-linalg-on-tensors" "-iree-spirv-enable-vectorization" "-iree-vulkan-target-triple=valhall-g77-unknown-android10" )