Skip to content

Commit

Permalink
Plumb GPU conv and matmul vectorization through flow.dispatch.workgro…
Browse files Browse the repository at this point in the history
…ups (#4999)

This commit adds all necessary plumbing to connect 2-D convolution
and matmul vectorization inside flow.dispatch.workgroups. This includes:

* Let 2-D convolution be recognized as a root op during dispatch
  region formation.
* Add a pass to concretize the abstract tiling and distribution
  during flow dispatch region formation. It substitutes symbolic
  ops with concrete values from CodeGen policy.
* Recognize `hal.interface.workgroup.id` ops when folding GPU
  processor ID uses.
* Recognize `hal.interface.binding.subspan` when vectorizing
  `memref`s for better memory access.

Along the way, the old path is refactored to have a better structure
for further cleaning up:

* `LinalgTileAndDistributePass` is split out of `LinalgTileAndFusePass`.
  It will be used for tiling and distribution among workgroups in the
  old path. `LinalgTileAndFusePass` will be for tiling and vectorization
  in a single workgroup (and it will be renamed later).
* A few tests are updated to use the new path.
  • Loading branch information
antiagainst authored Mar 11, 2021
1 parent 565c97c commit 7116f0c
Show file tree
Hide file tree
Showing 29 changed files with 1,509 additions and 334 deletions.
8 changes: 4 additions & 4 deletions iree/compiler/Conversion/Common/LaunchConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
Expand All @@ -38,6 +39,7 @@ namespace iree_compiler {
/// Name of the StrAttr that can be used to get the key to access the tile size
/// information.
static const char kLaunchInfoKey[] = "launch_info_key";
static const char kRootOpKey[] = "is_root_op";

static Optional<StringRef> getKey(Operation *op) {
StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
Expand Down Expand Up @@ -82,8 +84,7 @@ ArrayRef<int64_t> LaunchConfig::getTileSizes(Operation *op,

Operation *LaunchConfig::getRootOperation(ArrayRef<Operation *> ops) {
for (auto op : ops) {
auto key = getKey(op);
if (key && key.getValue() == rootOperationKey) return op;
if (op->getAttrOfType<UnitAttr>(kRootOpKey)) return op;
}
return nullptr;
}
Expand Down Expand Up @@ -114,8 +115,7 @@ void LaunchConfig::setNumSubgroups(ArrayRef<int64_t> vNumSubgroups) {
}

void LaunchConfig::setRootOperation(Operation *op) {
Optional<StringRef> key = getKey(op);
if (key) rootOperationKey = *key;
op->setAttr(kRootOpKey, UnitAttr::get(op->getContext()));
}

void LaunchConfig::setSameConfig(Operation *source, Operation *target) {
Expand Down
4 changes: 0 additions & 4 deletions iree/compiler/Conversion/Common/LaunchConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ class LaunchConfig {
/// these attributes.
llvm::StringMap<TileSizesListType> tileSizes;

/// Key used for tagging the root operation. The launch config does not track
/// the root operation itself, but rather the key used for the root operation.
StringRef rootOperationKey = "";

/// Workgroup size to use.
std::array<int64_t, 3> workgroupSize = {1, 1, 1};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} {
// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>)
// CHECK: linalg.matmul
//CHECK-SAME: ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%4 : memref<2x4xf32, #[[MAP3]]>)
3 changes: 3 additions & 0 deletions iree/compiler/Conversion/LinalgToSPIRV/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ cc_library(
cc_library(
name = "LinalgToSPIRV",
srcs = [
"ConcretizeTileAmongWorkgroupsPass.cpp",
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"CooperativeMatrixAnalysis.cpp",
"FoldGPUProcessorIDUses.cpp",
"KernelDispatchUtils.cpp",
"LinalgTileAndDistributePass.cpp",
"LinalgTileAndFusePass.cpp",
"MatMulVectorizationTest.cpp",
"Passes.cpp",
Expand All @@ -61,6 +63,7 @@ cc_library(
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToVector",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
Expand Down
3 changes: 3 additions & 0 deletions iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ iree_cc_library(
"Passes.h"
"Utils.h"
SRCS
"ConcretizeTileAmongWorkgroupsPass.cpp"
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"CooperativeMatrixAnalysis.cpp"
"FoldGPUProcessorIDUses.cpp"
"KernelDispatchUtils.cpp"
"LinalgTileAndDistributePass.cpp"
"LinalgTileAndFusePass.cpp"
"MatMulVectorizationTest.cpp"
"Passes.cpp"
Expand Down Expand Up @@ -74,6 +76,7 @@ iree_cc_library(
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToVector
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
Expand Down
Loading

0 comments on commit 7116f0c

Please sign in to comment.