Skip to content

Commit

Permalink
[spirv] Add support to tile and distribute linalg_ext.scatter (iree-o…
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst authored Aug 26, 2021
1 parent 5aaeabd commit 3336357
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 15 deletions.
1 change: 1 addition & 0 deletions iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) {
pm.addPass(createCanonicalizerPass());

pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass());
pm.addNestedPass<FuncOp>(linalg_ext::createLinalgExtToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createCanonicalizerPass());
Expand Down
32 changes: 18 additions & 14 deletions iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
Expand Down Expand Up @@ -173,20 +175,22 @@ static void populateTilingToInvocationPatterns(MLIRContext *context,
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(invocationDistributionOptions);

patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::FillOp>,
linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
linalg::LinalgTilingPattern<linalg::GenericOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));
patterns
.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::FillOp>,
linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
linalg::LinalgTilingPattern<linalg::GenericOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>,
linalg_ext::TiledOpInterfaceTilingPattern<linalg_ext::ScatterOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));

patterns.insert<linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>>(
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Codegen/SPIRV/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"pipeline_matmul_cooperative_matrix.mlir",
"pipeline_matmul_vectorization.mlir",
"remove_one_trip_tiled_loop.mlir",
"tile_and_distribute_scatter.mlir",
"tile_and_vectorize.mlir",
"tile_and_vectorize_batch_matmul.mlir",
"tile_and_vectorize_conv.mlir",
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"pipeline_matmul_cooperative_matrix.mlir"
"pipeline_matmul_vectorization.mlir"
"remove_one_trip_tiled_loop.mlir"
"tile_and_distribute_scatter.mlir"
"tile_and_vectorize.mlir"
"tile_and_vectorize_batch_matmul.mlir"
"tile_and_vectorize_conv.mlir"
Expand Down
77 changes: 77 additions & 0 deletions iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute))))' %s | IreeFileCheck %s

hal.executable @static_scatter_update_slice attributes {sym_visibility = "private"} {
hal.interface @io {
hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @s0b2_rw_external, set=0, binding=2, type="StorageBuffer", access="Read|Write"
}

hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb"> {
hal.executable.entry_point @static_scatter_update_slice attributes {
interface = @io, ordinal = 0 : index,
translation.info = {passPipeline = 5 : i32, workloadPerWorkgroup = [16, 1]},
workgroup_size = [16 : index, 1 : index, 1 : index]
}

builtin.module {
builtin.func @static_scatter_update_slice() {
%c40 = constant 40 : index
%c500 = constant 500 : index
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<40x500xi32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<40x1xi32>
%2 = hal.interface.binding.subspan @io::@s0b2_rw_external[%c0] : memref<100x500xi32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
scf.for %arg0 = %workgroup_id_y to %c40 step %workgroup_count_y {
%3 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%4 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
scf.for %arg1 = %3 to %c500 step %4 {
%5 = affine.min affine_map<(d0) -> (16, -d0 + 500)>(%arg1)
%6 = memref.subview %0[%arg0, %arg1] [1, %5] [1, 1] : memref<40x500xi32> to memref<1x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>
%7 = memref.cast %6 : memref<1x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>> to memref<?x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>
%8 = memref.subview %1[%arg0, 0] [1, 1] [1, 1] : memref<40x1xi32> to memref<1x1xi32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>
%9 = memref.cast %8 : memref<1x1xi32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>> to memref<?x1xi32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>
%10 = memref.subview %2[0, %arg1] [100, %5] [1, 1] : memref<100x500xi32> to memref<100x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>
linalg_ext.scatter {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[1, 16], [], [1, 1]]}} ins(%7, %9 : memref<?x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>, memref<?x1xi32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>) outs(%10 : memref<100x?xi32, affine_map<(d0, d1)[s0] -> (d0 * 500 + s0 + d1)>>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
linalg_ext.yield %arg2 : i32
}
}
}
return
}
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @s0b2_rw_external, set=0, binding=2, type="StorageBuffer", access="Read|Write"
}
}
}
}

// CHECK-LABEL: func @static_scatter_update_slice()
// CHECK: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external
// CHECK: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external
// CHECK: %[[ARG2:.+]] = hal.interface.binding.subspan @io::@s0b2_rw_external
// CHECK: scf.for
// CHECK: scf.for
// CHECK: %[[WG_UPDATE:.+]] = memref.subview %[[ARG0]]
// CHECK: %[[WG_INDEX:.+]] = memref.subview %[[ARG1]]
// CHECK: %[[WG_TARGET:.+]] = memref.subview %[[ARG2]]
// CHECK: %[[TID_X:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index
// CHECK: %[[DIM_X:.+]] = "gpu.block_dim"() {dimension = "x"} : () -> index
// CHECK: %[[TID_Y:.+]] = "gpu.thread_id"() {dimension = "y"} : () -> index
// CHECK: scf.for %[[IV:.+]] = %[[TID_X]] to %{{.+}} step %[[DIM_X]]
// CHECK: %[[T_UPDATE:.+]] = memref.subview %[[WG_UPDATE]][%[[TID_Y]], %[[IV]]] [1, 1] [1, 1]
// CHECK: %[[T_UPDATE_CAST:.+]] = memref.cast %[[T_UPDATE]]
// CHECK: %[[T_INDEX:.+]] = memref.cast %[[WG_INDEX]]
// CHECK: %[[T_TARGET:.+]] = memref.subview %[[WG_TARGET]][0, %[[IV]]] [100, 1] [1, 1]
// CHECK: %[[T_TARGET_CAST:.+]] = memref.cast %[[T_TARGET]]
// CHECK: linalg_ext.scatter
// CHECK-SAME: __internal_linalg_transform__ = "vectorize"
// CHECK-SAME: ins(%[[T_UPDATE_CAST]], %[[T_INDEX]]
// CHECK-SAME: outs(%[[T_TARGET_CAST]]
2 changes: 1 addition & 1 deletion iree/test/e2e/xla_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ iree_check_single_backend_test_suite(
"reshape.mlir",
"reverse.mlir",
"rsqrt.mlir",
"scatter.mlir",
"select.mlir",
"sine.mlir",
"slice.mlir",
Expand All @@ -355,7 +356,6 @@ iree_check_single_backend_test_suite(
exclude = [
"bitcast_convert.mlir",
"round.mlir",
"scatter.mlir", # TODO(GH-6388): Enable the test.
"scatter_dynamic.mlir", # TODO(GH-6388): Enable the test.
"sort.mlir",
],
Expand Down
1 change: 1 addition & 0 deletions iree/test/e2e/xla_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ iree_check_single_backend_test_suite(
"reshape.mlir"
"reverse.mlir"
"rsqrt.mlir"
"scatter.mlir"
"select.mlir"
"sine.mlir"
"slice.mlir"
Expand Down

0 comments on commit 3336357

Please sign in to comment.