Skip to content

Commit

Permalink
Adding TiedOpInterface and wiring it through the Flow dialect.
Browse files Browse the repository at this point in the history
This allows for results of operations to be tied back to their
operands in storage but not in time. This allows for in-place
operations to be defined on tensors that carry enough metadata
to be able to correctly form streams, materialize HAL
interfaces, and allocate buffers.

Example:
```mlir
%t = flow.dispatch @foo[...](%input) : (tensor<4xf32>) -> %input
```

This syntax also combines with the shape-carrying op interface
to make it possible to also indicate that an input and a result
share type and shape information:
```mlir
%t = flow.dispatch @foo[...](%input) : (tensor<?xf32>{%dim}) -> %input
```
which is effectively:
```mlir
%t = flow.dispatch @foo[...](%input) : (tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
```
but with the extra bit that result 0 is tied to operand 0.

Here the result %t of the dispatch aliases the storage for %input,
making %input a read-write/mutable binding in the resulting HAL
executable. %t is a distinct SSA value from %input, though, and
represents the value of the storage backing %input after the
dispatch has completed. By keeping the SSA use-def chains correct
with respect to time they are still meaningful for analysi2As and
nothing at this level (and the beginning of the HAL transformations)
needs to perform alias analysis, while still giving us all of the
information required to induce aliasing during later allocation
passes.
  • Loading branch information
benvanik committed Mar 11, 2021
1 parent 9467a90 commit 54749ce
Show file tree
Hide file tree
Showing 63 changed files with 1,700 additions and 813 deletions.
16 changes: 8 additions & 8 deletions iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}

LogicalResult convertInterfaceLoadTensorOp(
OpBuilder &b, IREE::Flow::DispatchInputLoadOp loadOp,
OpBuilder &b, IREE::Flow::DispatchTensorLoadOp loadOp,
BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(loadOp);
Expand Down Expand Up @@ -449,7 +449,7 @@ static Operation *getInsertionPointForReplacementStoreOp(
/// LinalgOp, create the subview operation that can be used by the op itself to
/// store the result into directly. This avoids an extra allocation + copies.
LogicalResult preProcessInterfaceStoreTensorOp(
OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp,
OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp,
BlockAndValueMapping &bvm) {
// Find the insertion point for the subview.
SmallVector<Value, 4> operandsOfSubviewOp;
Expand Down Expand Up @@ -491,7 +491,7 @@ LogicalResult preProcessLinalgOps(OpBuilder &b, linalg::LinalgOp op,
}

LogicalResult convertInterfaceStoreTensorOp(
OpBuilder &b, IREE::Flow::DispatchOutputStoreOp storeOp,
OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp,
BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(storeOp);
Expand Down Expand Up @@ -570,7 +570,7 @@ void LinalgBufferizePass::runOnFunction() {
transferShapeOpsToMemref(b, op.getResult(), baseBuffer.getResult(), bvm);
});
if (funcOp
.walk([&](IREE::Flow::DispatchOutputStoreOp op) -> WalkResult {
.walk([&](IREE::Flow::DispatchTensorStoreOp op) -> WalkResult {
return preProcessInterfaceStoreTensorOp(b, op, bvm);
})
.wasInterrupted()) {
Expand All @@ -596,12 +596,12 @@ void LinalgBufferizePass::runOnFunction() {

auto conversionDispatch = [&](Operation *op) -> WalkResult {
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<IREE::Flow::DispatchInputLoadOp>(
[&](IREE::Flow::DispatchInputLoadOp loadOp) {
.Case<IREE::Flow::DispatchTensorLoadOp>(
[&](IREE::Flow::DispatchTensorLoadOp loadOp) {
return convertInterfaceLoadTensorOp(b, loadOp, bvm);
})
.Case<IREE::Flow::DispatchOutputStoreOp>(
[&](IREE::Flow::DispatchOutputStoreOp storeOp) {
.Case<IREE::Flow::DispatchTensorStoreOp>(
[&](IREE::Flow::DispatchTensorStoreOp storeOp) {
return convertInterfaceStoreTensorOp(b, storeOp, bvm);
})
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
Expand Down
166 changes: 83 additions & 83 deletions iree/compiler/Conversion/Common/test/linalg_bufferize.mlir

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions iree/compiler/Conversion/LinalgToLLVM/test/linalg_vectorize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK: %[[I0:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32>
// CHECK: %[[I1:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32>
// CHECK: %[[I2:.*]] = flow.dispatch.input.load {{.*}} : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32>
// CHECK: %[[I0:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x3xf32>
// CHECK: %[[I1:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor<readonly:3x4xf32> -> tensor<3x1xf32>
// CHECK: %[[I2:.*]] = flow.dispatch.tensor.load {{.*}} : !flow.dispatch.tensor<readonly:2x4xf32> -> tensor<2x1xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C0]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C1]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[I0]][%[[C0]], %[[C2]]], {{.*}} : tensor<2x3xf32>, vector<1x1xf32>
Expand All @@ -26,21 +26,21 @@
// CHECK: %[[D5:.*]] = vector.contract {{.*}} %[[V5]], %[[V8]], %[[D4]] : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
// CHECK: %[[W0:.*]] = vector.transfer_write %[[D2]], %[[I2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32>
// CHECK: %[[W1:.*]] = vector.transfer_write %[[D5]], %[[W0]][%[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1xf32>, tensor<2x1xf32>
// CHECK: flow.dispatch.output.store %[[W1]]
// CHECK: flow.dispatch.tensor.store %[[W1]]

func @tensor_dispatch_0() {
%c0 = constant 0 : index
%c3 = constant 3 : index
%c1 = constant 1 : index
%c2 = constant 1 : index
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<2x3xf32>
%1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<3x4xf32>
%2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.input<2x4xf32>
%3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<2x4xf32>
%4 = flow.dispatch.input.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.input<2x3xf32> -> tensor<2x3xf32>
%5 = flow.dispatch.input.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.input<3x4xf32> -> tensor<3x1xf32>
%6 = flow.dispatch.input.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.input<2x4xf32> -> tensor<2x1xf32>
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
%1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor<readonly:3x4xf32>
%2 = hal.interface.binding.subspan @legacy_io::@arg2[%c0] : !flow.dispatch.tensor<readonly:2x4xf32>
%3 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:2x4xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c2, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x3xf32>
%5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:3x4xf32> -> tensor<3x1xf32>
%6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x4xf32> -> tensor<2x1xf32>
%7 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%4, %5 : tensor<2x3xf32>, tensor<3x1xf32>) outs(%6 : tensor<2x1xf32>) -> tensor<2x1xf32>
flow.dispatch.output.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.output<2x4xf32>
flow.dispatch.tensor.store %7, %3, offsets = [%c0, %c0], sizes = [%c2, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor<writeonly:2x4xf32>
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} {
hal.executable.target @llvm_aot, filter="dylib*" {
hal.executable.entry_point @matmul_tensors attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?x?xf32>,
!flow.dispatch.output<?x?xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
!flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
module {
func @matmul_tensors() {
%c0 = constant 0 : index
Expand Down Expand Up @@ -97,8 +97,8 @@ hal.executable @add attributes {sym_visibility = "private"} {
hal.executable.target @llvm_aot, filter="dylib*" {
hal.executable.entry_point @add attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?xf32>,
!flow.dispatch.output<?x?xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?xf32>,
!flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
module {
func @add() {
%c0 = constant 0 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} {
hal.executable.target @llvm_aot, filter="dylib*" {
hal.executable.entry_point @matmul_128x128x128 attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<128x128xf32>, !flow.dispatch.input<128x128xf32>,
!flow.dispatch.output<128x128xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:128x128xf32>, !flow.dispatch.tensor<readonly:128x128xf32>,
!flow.dispatch.tensor<writeonly:128x128xf32>) -> ()}
module {
func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
linalg.matmul ins(%arg0, %arg1 : memref<128x128xf32>, memref<128x128xf32>) outs(%arg2 : memref<128x128xf32>)
Expand Down Expand Up @@ -91,8 +91,8 @@ hal.executable @dynamic_matmul_i8_i8_i32 attributes {sym_visibility = "private"}
hal.executable.target @llvm_aot, filter="dylib*" {
hal.executable.entry_point @matmul_i8_i8_i32_128x128x128 attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<128x128xi8>, !flow.dispatch.input<128x128xi8>,
!flow.dispatch.output<128x128xi32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:128x128xi8>, !flow.dispatch.tensor<readonly:128x128xi8>,
!flow.dispatch.tensor<writeonly:128x128xi32>) -> ()}
module {
func @matmul_i8_i8_i32_128x128x128(%arg0 : memref<128x128xi8>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) {
linalg.matmul_i8_i8_i32 ins(%arg0, %arg1 : memref<128x128xi8>, memref<128x128xi8>) outs(%arg2 : memref<128x128xi32>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// hal.executable.target @llvm_aot, filter="dylib*" {
// hal.executable.entry_point @dynamic_matmul attributes {
// interface = @legacy_io, ordinal = 0 : i32,
// signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?x?xf32>,
// !flow.dispatch.output<?x?xf32>) -> ()}
// signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
// !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
// module {
// func @dynamic_matmul(%lhs: memref<?x?xf32>, %rhs: memref<?x?xf32>, %result: memref<?x?xf32>) {
// linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>) outs(%result : memref<?x?xf32>)
Expand Down Expand Up @@ -58,8 +58,8 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} {
hal.executable.target @llvm_aot, filter="dylib*" {
hal.executable.entry_point @static_matmul attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<16x4xf32>, !flow.dispatch.input<4x8xf32>,
!flow.dispatch.output<16x8xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:16x4xf32>, !flow.dispatch.tensor<readonly:4x8xf32>,
!flow.dispatch.tensor<writeonly:16x8xf32>) -> ()}
module {
func @static_matmul(%lhs: memref<16x4xf32>, %rhs: memref<4x8xf32>, %result: memref<16x8xf32>) {
linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>)
Expand Down
16 changes: 8 additions & 8 deletions iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-nvvm-pipeline))" %s | IreeFileCheck %s

// Verify that a simple element wise op gets lowered succefully all the way to
// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect.

hal.executable @simpleMath_ex_dispatch_0 {
Expand All @@ -9,22 +9,22 @@ hal.executable @simpleMath_ex_dispatch_0 {
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
hal.executable.target @cuda, filter="cuda" {
hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<16xf32>, !flow.dispatch.input<16xf32>, !flow.dispatch.output<16xf32>) -> ()}
hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor<readonly:16xf32>, !flow.dispatch.tensor<readonly:16xf32>, !flow.dispatch.tensor<writeonly:16xf32>) -> ()}
module {
func @add_dispatch_0() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<16xf32>
%1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<16xf32>
%2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<16xf32>
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:16xf32>
%1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.tensor<readonly:16xf32>
%2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:16xf32>
%3 = linalg.init_tensor [16] : tensor<16xf32>
%4 = flow.dispatch.input.load %0 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
%5 = flow.dispatch.input.load %1 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
%4 = flow.dispatch.tensor.load %0 : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32>
%5 = flow.dispatch.tensor.load %1 : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
%7 = addf %arg0, %arg1 : f32
linalg.yield %7 : f32
} -> tensor<16xf32>
flow.dispatch.output.store %6, %2 : tensor<16xf32> -> !flow.dispatch.output<16xf32>
flow.dispatch.tensor.store %6, %2 : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:16xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"
hal.executable.target @vulkan, filter="dylib*" {
hal.executable.entry_point @batch_matmul_static_shape attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?x?xf32>,
!flow.dispatch.output<?x?xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
!flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
Expand Down Expand Up @@ -299,8 +299,8 @@ hal.executable @batch_matmul_fused_fillop attributes {sym_visibility = "private"
hal.executable.target @vulkan, filter="dylib*" {
hal.executable.entry_point @batch_matmul_fused_fillop attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<?x?xf32>, !flow.dispatch.input<?x?xf32>,
!flow.dispatch.output<?x?xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
!flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} {
hal.executable.target @vulkan_spirv, filter="vulkan*" {
hal.executable.entry_point @conv2d_static_shape attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:1x225x225x16xf32>, !flow.dispatch.tensor<readonly:3x3x16x32xf32>, !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) -> ()}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @conv2d_static_shape() {
%cst = constant 0.000000e+00 : f32
Expand Down Expand Up @@ -120,7 +120,7 @@ hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} {
hal.executable.target @vulkan_spirv, filter="vulkan*" {
hal.executable.entry_point @matmul_dynamic_shape attributes {
interface = @legacy_io, ordinal = 0 : i32,
signature = (!flow.dispatch.input<1x225x225x16xf32>, !flow.dispatch.input<3x3x16x32xf32>, !flow.dispatch.output<1x112x112x32xf32>) -> ()}
signature = (!flow.dispatch.tensor<readonly:1x225x225x16xf32>, !flow.dispatch.tensor<readonly:3x3x16x32xf32>, !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) -> ()}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @matmul_dynamic_shape() {
%cst = constant 0.000000e+00 : f32
Expand Down
Loading

0 comments on commit 54749ce

Please sign in to comment.